diff --git a/resources/graphs/network_flow_graph/components.py b/resources/graphs/network_flow_graph/components.py index deda0f7328c8407dff88e141a63a4b6205f78743..0a37bc466c50ea44356a831cc93f6561181089ca 100644 --- a/resources/graphs/network_flow_graph/components.py +++ b/resources/graphs/network_flow_graph/components.py @@ -105,3 +105,38 @@ class NetworkFlowNode(DirectedWeightedNode): # This is necessary for inheritance, or else child classes will only have access to parent functions. self._edge_type = NetworkFlowEdge self._node_type = NetworkFlowNode + + def get_highest_capacity(self, incoming_only=False, outgoing_only=False): + """ + Finds node connection with highest capacity out of all neighbors. + :param incoming_only: Bool indicating if only incoming edges should be examined. + :param outgoing_only: Bool indicating if only outgoing edges should be examined. + :return: Node with highest capacity connection | None if no connections. + """ + # Validate args. + if incoming_only and outgoing_only: + raise ValueError('Only one of "incoming_only" or "outgoing_only" should be True. Found both True.') + + capacity_node = None + capacity_edge = None + + # Get edge list based on args. + if incoming_only: + # Only check incoming edges. + edge_list = self.get_incoming_edges() + elif outgoing_only: + # Only check outgoing edges. + edge_list = self.get_outgoing_edges() + else: + # Get all edges. + edge_list = self.get_edges().values() + + # Loop through all edge connections. + for edge_connection in edge_list: + # Check if connection has higher capacity than current best. + if capacity_node is None or edge_connection.get_capacity() > capacity_edge.get_capacity(): + capacity_edge = edge_connection + capacity_node = edge_connection.get_partner_node(self) + + # Return highest found node. + return capacity_node diff --git a/tests/resources/graphs/network_flow_graph/node.py b/tests/resources/graphs/network_flow_graph/node.py index 4c743d6cd49b8e6e8384b57ce17e02e39f5e1933..0353130eea388aaad7ed1c543c40843419802276 100644 --- a/tests/resources/graphs/network_flow_graph/node.py +++ b/tests/resources/graphs/network_flow_graph/node.py @@ -22,3 +22,89 @@ class TestNetworkFlowNode(unittest.TestCase): def test__node_initialization(self): self.assertEqual(self.test_node._edge_type, NetworkFlowEdge) self.assertEqual(self.test_node._node_type, NetworkFlowNode) + + def test__get_highest_capacity(self): + # Create additional nodes. + node_1 = NetworkFlowNode(1) + node_2 = NetworkFlowNode(2) + node_3 = NetworkFlowNode(3) + node_4 = NetworkFlowNode(4) + + with self.subTest('With all edge directions.'): + # Test with no connections. + self.assertIsNone(self.test_node.get_highest_capacity()) + + # Test with one connection. + edge_1 = self.test_node.connect_node(node_1) + edge_1.set_capacity(15) + self.assertEqual(self.test_node.get_highest_capacity(), node_1) + + # Test with two connections. + edge_2 = self.test_node.connect_node(node_2) + edge_2.set_capacity(12) + self.assertEqual(self.test_node.get_highest_capacity(), node_1) + + # Test with three connections. + edge_3 = self.test_node.connect_node(node_3) + edge_3.set_capacity(13) + self.assertEqual(self.test_node.get_highest_capacity(), node_1) + + # Test with a new higher capacity, to ensure it wasn't just taking the first valid node. + edge_4 = self.test_node.connect_node(node_4) + edge_4.set_capacity(20) + self.assertEqual(self.test_node.get_highest_capacity(), node_4) + + # Reset edge connections for following subtests. + self.test_node.disconnect_node(node_identifier=node_1) + self.test_node.disconnect_node(node_identifier=node_2) + self.test_node.disconnect_node(node_identifier=node_3) + self.test_node.disconnect_node(node_identifier=node_4) + + with self.subTest('With only incoming edges.'): + # Test with no connections. + self.assertIsNone(self.test_node.get_highest_capacity(incoming_only=True)) + + # Test with incoming connection. + edge_1 = node_1.connect_node(self.test_node) + edge_1.set_capacity(10) + self.assertEqual(self.test_node.get_highest_capacity(incoming_only=True), node_1) + + # Test with two incoming connections. + edge_2 = node_2.connect_node(self.test_node) + edge_2.set_capacity(12) + self.assertEqual(self.test_node.get_highest_capacity(incoming_only=True), node_2) + + # Test with outgoing connection. + edge_3 = self.test_node.connect_node(node_3) + edge_3.set_capacity(15) + self.assertEqual(self.test_node.get_highest_capacity(incoming_only=True), node_2) + + # Reset edge connections for following subtests. + self.test_node.disconnect_node(node_identifier=node_1) + self.test_node.disconnect_node(node_identifier=node_2) + self.test_node.disconnect_node(node_identifier=node_3) + + with self.subTest('With only outgoing edges.'): + # Test with no connections. + self.assertIsNone(self.test_node.get_highest_capacity(outgoing_only=True)) + + # Test with outgoing connection. + edge_1 = self.test_node.connect_node(node_1) + edge_1.set_capacity(10) + self.assertEqual(self.test_node.get_highest_capacity(outgoing_only=True), node_1) + + # Test with two outgoing connections. + edge_2 = self.test_node.connect_node(node_2) + edge_2.set_capacity(12) + self.assertEqual(self.test_node.get_highest_capacity(outgoing_only=True), node_2) + + # Test with incoming connection. + edge_3 = node_3.connect_node(self.test_node) + edge_3.set_capacity(15) + self.assertEqual(self.test_node.get_highest_capacity(outgoing_only=True), node_2) + + # Reset edge connections for following subtests. + self.test_node.disconnect_node(node_identifier=node_1) + self.test_node.disconnect_node(node_identifier=node_2) + self.test_node.disconnect_node(node_identifier=node_3) +