From 2db01340bf892a946f9a7d6d88e3b276fc2cd058 Mon Sep 17 00:00:00 2001
From: Brandon Rodriguez <brodriguez8774@gmail.com>
Date: Sat, 16 Nov 2019 16:49:04 -0500
Subject: [PATCH] Update nodes to track which graph they belong to

---
 resources/graphs/basic_graph/components.py  |  1 +
 resources/graphs/basic_graph/graph.py       |  2 ++
 tests/resources/graphs/basic_graph/graph.py | 12 ++++++++++++
 tests/resources/graphs/basic_graph/node.py  |  1 +
 4 files changed, 16 insertions(+)

diff --git a/resources/graphs/basic_graph/components.py b/resources/graphs/basic_graph/components.py
index 9094ed0..3fff6e5 100644
--- a/resources/graphs/basic_graph/components.py
+++ b/resources/graphs/basic_graph/components.py
@@ -228,6 +228,7 @@ class BasicNode():
         self._name = str(name)
         self._edges = {}
         self._connected_nodes = {}
+        self._graph = None
 
         # Define expected class types (should all be of "Basic" type).
         # This is necessary for inheritance, or else child classes will only have access to parent functions.
diff --git a/resources/graphs/basic_graph/graph.py b/resources/graphs/basic_graph/graph.py
index 346d55e..70dd6fc 100644
--- a/resources/graphs/basic_graph/graph.py
+++ b/resources/graphs/basic_graph/graph.py
@@ -760,6 +760,7 @@ class BasicGraph():
 
         new_node = self._node_type(node_name)
         self._nodes[new_node._name] = new_node
+        new_node._graph = self
 
         return new_node
 
@@ -780,6 +781,7 @@ class BasicGraph():
 
         # Add new_node to graph.
         self._nodes[new_node.get_name()] = new_node
+        new_node._graph = self
 
         # Check all node connections. For any that don't exist in graph, add those as well.
         for new_connected_node in new_node.get_connected_nodes().values():
diff --git a/tests/resources/graphs/basic_graph/graph.py b/tests/resources/graphs/basic_graph/graph.py
index 42654fa..4764e68 100644
--- a/tests/resources/graphs/basic_graph/graph.py
+++ b/tests/resources/graphs/basic_graph/graph.py
@@ -619,6 +619,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with no nodes.
             node_1 = self.test_graph.create_node()
             self.assertTrue(isinstance(node_1, BasicNode))
+            self.assertEqual(node_1._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 1)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -628,6 +629,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with one node.
             node_2 = self.test_graph.create_node()
             self.assertTrue(isinstance(node_2, BasicNode))
+            self.assertEqual(node_2._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 2)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -638,6 +640,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with two nodes.
             node_3 = self.test_graph.create_node()
             self.assertTrue(isinstance(node_3, BasicNode))
+            self.assertEqual(node_3._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 3)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -664,6 +667,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with no nodes.
             node_1 = self.test_graph.create_node(node_name='Node 1')
             self.assertTrue(isinstance(node_1, BasicNode))
+            self.assertEqual(node_1._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 1)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -673,6 +677,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with one node.
             node_2 = self.test_graph.create_node(node_name='Node 2')
             self.assertTrue(isinstance(node_2, BasicNode))
+            self.assertEqual(node_2._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 2)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -683,6 +688,7 @@ class TestBasicGraph(unittest.TestCase):
             # Test with two nodes.
             node_3 = self.test_graph.create_node(node_name='Node 3')
             self.assertTrue(isinstance(node_3, BasicNode))
+            self.assertEqual(node_3._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 3)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -703,9 +709,13 @@ class TestBasicGraph(unittest.TestCase):
             # Verify start state.
             self.assertEqual(self.test_graph.get_node_count(), 0)
             self.assertEqual(self.test_graph.get_all_nodes(), {})
+            self.assertIsNone(node_1._graph)
+            self.assertIsNone(node_2._graph)
+            self.assertIsNone(node_3._graph)
 
             # Test with no nodes.
             self.test_graph.add_node(node_1)
+            self.assertEqual(node_1._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 1)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -713,6 +723,7 @@ class TestBasicGraph(unittest.TestCase):
 
             # Test with one node.
             self.test_graph.add_node(node_2)
+            self.assertEqual(node_2._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 2)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
@@ -721,6 +732,7 @@ class TestBasicGraph(unittest.TestCase):
 
             # Test with two nodes.
             self.test_graph.add_node(node_3)
+            self.assertEqual(node_3._graph, self.test_graph)
             self.assertEqual(self.test_graph.get_node_count(), 3)
             self.assertEqual(self.test_graph.get_all_nodes(), {
                 node_1.get_name(): node_1,
diff --git a/tests/resources/graphs/basic_graph/node.py b/tests/resources/graphs/basic_graph/node.py
index 578dfd4..614a8eb 100644
--- a/tests/resources/graphs/basic_graph/node.py
+++ b/tests/resources/graphs/basic_graph/node.py
@@ -26,6 +26,7 @@ class TestBasicNode(unittest.TestCase):
         self.assertEqual(self.test_node.get_edge_count(), 0)
         self.assertEqual(self.test_node.get_edges(), {})
         self.assertEqual(self.test_node.get_connected_nodes(), {})
+        self.assertIsNone(self.test_node._graph)
 
     #region Information Function Tests
 
-- 
GitLab