diff --git a/resources/graphs/basic_graph/components.py b/resources/graphs/basic_graph/components.py index 9094ed0f12231d83fcc632fbdb4cb64b77c8d2c9..3fff6e5e01d43eb155d7229470cb6dbb70ebf79b 100644 --- a/resources/graphs/basic_graph/components.py +++ b/resources/graphs/basic_graph/components.py @@ -228,6 +228,7 @@ class BasicNode(): self._name = str(name) self._edges = {} self._connected_nodes = {} + self._graph = None # Define expected class types (should all be of "Basic" type). # This is necessary for inheritance, or else child classes will only have access to parent functions. diff --git a/resources/graphs/basic_graph/graph.py b/resources/graphs/basic_graph/graph.py index 346d55e2f52c2de421dd5151e3363ceb5d7728c0..70dd6fc3b5e6735e9ddb78a64af7e7acf81835c5 100644 --- a/resources/graphs/basic_graph/graph.py +++ b/resources/graphs/basic_graph/graph.py @@ -760,6 +760,7 @@ class BasicGraph(): new_node = self._node_type(node_name) self._nodes[new_node._name] = new_node + new_node._graph = self return new_node @@ -780,6 +781,7 @@ class BasicGraph(): # Add new_node to graph. self._nodes[new_node.get_name()] = new_node + new_node._graph = self # Check all node connections. For any that don't exist in graph, add those as well. for new_connected_node in new_node.get_connected_nodes().values(): diff --git a/tests/resources/graphs/basic_graph/graph.py b/tests/resources/graphs/basic_graph/graph.py index 42654fa02f97b59386ad153355f80308fec92677..4764e68609fa1c42791cdbec352b0ca6993f07de 100644 --- a/tests/resources/graphs/basic_graph/graph.py +++ b/tests/resources/graphs/basic_graph/graph.py @@ -619,6 +619,7 @@ class TestBasicGraph(unittest.TestCase): # Test with no nodes. node_1 = self.test_graph.create_node() self.assertTrue(isinstance(node_1, BasicNode)) + self.assertEqual(node_1._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 1) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -628,6 +629,7 @@ class TestBasicGraph(unittest.TestCase): # Test with one node. node_2 = self.test_graph.create_node() self.assertTrue(isinstance(node_2, BasicNode)) + self.assertEqual(node_2._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 2) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -638,6 +640,7 @@ class TestBasicGraph(unittest.TestCase): # Test with two nodes. node_3 = self.test_graph.create_node() self.assertTrue(isinstance(node_3, BasicNode)) + self.assertEqual(node_3._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 3) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -664,6 +667,7 @@ class TestBasicGraph(unittest.TestCase): # Test with no nodes. node_1 = self.test_graph.create_node(node_name='Node 1') self.assertTrue(isinstance(node_1, BasicNode)) + self.assertEqual(node_1._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 1) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -673,6 +677,7 @@ class TestBasicGraph(unittest.TestCase): # Test with one node. node_2 = self.test_graph.create_node(node_name='Node 2') self.assertTrue(isinstance(node_2, BasicNode)) + self.assertEqual(node_2._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 2) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -683,6 +688,7 @@ class TestBasicGraph(unittest.TestCase): # Test with two nodes. node_3 = self.test_graph.create_node(node_name='Node 3') self.assertTrue(isinstance(node_3, BasicNode)) + self.assertEqual(node_3._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 3) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -703,9 +709,13 @@ class TestBasicGraph(unittest.TestCase): # Verify start state. self.assertEqual(self.test_graph.get_node_count(), 0) self.assertEqual(self.test_graph.get_all_nodes(), {}) + self.assertIsNone(node_1._graph) + self.assertIsNone(node_2._graph) + self.assertIsNone(node_3._graph) # Test with no nodes. self.test_graph.add_node(node_1) + self.assertEqual(node_1._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 1) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -713,6 +723,7 @@ class TestBasicGraph(unittest.TestCase): # Test with one node. self.test_graph.add_node(node_2) + self.assertEqual(node_2._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 2) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, @@ -721,6 +732,7 @@ class TestBasicGraph(unittest.TestCase): # Test with two nodes. self.test_graph.add_node(node_3) + self.assertEqual(node_3._graph, self.test_graph) self.assertEqual(self.test_graph.get_node_count(), 3) self.assertEqual(self.test_graph.get_all_nodes(), { node_1.get_name(): node_1, diff --git a/tests/resources/graphs/basic_graph/node.py b/tests/resources/graphs/basic_graph/node.py index 578dfd4774422092f2bf928c32b82f39c921d1da..614a8eba2a14e0d785d15dca315dbb6b0f79f5ed 100644 --- a/tests/resources/graphs/basic_graph/node.py +++ b/tests/resources/graphs/basic_graph/node.py @@ -26,6 +26,7 @@ class TestBasicNode(unittest.TestCase): self.assertEqual(self.test_node.get_edge_count(), 0) self.assertEqual(self.test_node.get_edges(), {}) self.assertEqual(self.test_node.get_connected_nodes(), {}) + self.assertIsNone(self.test_node._graph) #region Information Function Tests