diff --git a/resources/graphs/state_machine/components.py b/resources/graphs/state_machine/components.py
index 8dc3753b352e2d8accd196bff346c6df03c11851..02154e9da71f07b78b308539a7082af72986755c 100644
--- a/resources/graphs/state_machine/components.py
+++ b/resources/graphs/state_machine/components.py
@@ -138,8 +138,12 @@ class StateMachineNode(DirectedNode):
         """
         if not self.is_initial:
             self.is_initial = True
+            if self._graph is not None:
+                self._graph._initial_states[self.get_name()] = self
         else:
             self.is_initial = False
+            if self._graph is not None:
+                self._graph._initial_states.pop(self.get_name())
 
     def toggle_node_state_final(self):
         """
@@ -147,5 +151,9 @@ class StateMachineNode(DirectedNode):
         """
         if not self.is_final:
             self.is_final = True
+            if self._graph is not None:
+                self._graph._final_states[self.get_name()] = self
         else:
             self.is_final = False
+            if self._graph is not None:
+                self._graph._final_states.pop(self.get_name())
diff --git a/resources/graphs/state_machine/graph.py b/resources/graphs/state_machine/graph.py
index b8cf5f0a83a2ee0501f150568400f8ed928bf767..dc2d2d06f28c11402bd409c75f3b29c7747252e5 100644
--- a/resources/graphs/state_machine/graph.py
+++ b/resources/graphs/state_machine/graph.py
@@ -40,6 +40,10 @@ class StateMachineGraph(DirectedGraph):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
+        # Lists of start and final states.
+        self._initial_states = {}
+        self._final_states = {}
+
         # Define expected class types (should all be of "State Machine" type).
         # This is necessary for inheritance, or else child classes will only have access to parent functions.
         self._edge_type = StateMachineEdge
diff --git a/tests/resources/graphs/state_machine/graph.py b/tests/resources/graphs/state_machine/graph.py
index 5ab6de17208bfc96ddb6ab7ecf12939f6c0a2304..fee94ca5708890ce9e6b19a398b21e3259d6f3f7 100644
--- a/tests/resources/graphs/state_machine/graph.py
+++ b/tests/resources/graphs/state_machine/graph.py
@@ -76,3 +76,151 @@ class TestDirectedGraph(unittest.TestCase):
         self.assertEqual(edge_3.state_change_conditions, ['c'])
 
         # Works with 0, 1, and 2 connections. Assuming works with all further n connections.
+
+    def test__toggle_node_state_initial(self):
+        # Create node to toggle.
+        node_1 = self.test_graph.create_node()
+        node_2 = self.test_graph.create_node()
+        node_3 = self.test_graph.create_node()
+
+        with self.subTest('Adding initial/start states.'):
+            # Test initial values.
+            self.assertFalse(node_1.is_initial)
+            self.assertFalse(node_2.is_initial)
+            self.assertFalse(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {})
+
+            # Test add with 0 initial states.
+            node_1.toggle_node_state_initial()
+            self.assertTrue(node_1.is_initial)
+            self.assertFalse(node_2.is_initial)
+            self.assertFalse(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {node_1.get_name(): node_1})
+
+            # Test add with 1 initial states.
+            node_2.toggle_node_state_initial()
+            self.assertTrue(node_1.is_initial)
+            self.assertTrue(node_2.is_initial)
+            self.assertFalse(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {node_1.get_name(): node_1, node_2.get_name(): node_2})
+
+            # Test add with 2 initial states.
+            node_3.toggle_node_state_initial()
+            self.assertTrue(node_1.is_initial)
+            self.assertTrue(node_2.is_initial)
+            self.assertTrue(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {
+                node_1.get_name(): node_1,
+                node_2.get_name(): node_2,
+                node_3.get_name(): node_3,
+            })
+
+            # Works with 0, 1, and 2 initial/start states. Assuming works with all further n initial/start states.
+
+        with self.subTest('Removing initial/start states.'):
+            # Test initial values.
+            self.assertTrue(node_1.is_initial)
+            self.assertTrue(node_2.is_initial)
+            self.assertTrue(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {
+                node_1.get_name(): node_1,
+                node_2.get_name(): node_2,
+                node_3.get_name(): node_3,
+            })
+
+            # Test remove with 3 initial states.
+            node_1.toggle_node_state_initial()
+            self.assertFalse(node_1.is_initial)
+            self.assertTrue(node_2.is_initial)
+            self.assertTrue(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {node_2.get_name(): node_2, node_3.get_name(): node_3})
+
+            # Test remove with 2 initial states.
+            node_3.toggle_node_state_initial()
+            self.assertFalse(node_1.is_initial)
+            self.assertTrue(node_2.is_initial)
+            self.assertFalse(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {node_2.get_name(): node_2})
+
+            # Test remove with 1 initial states.
+            node_2.toggle_node_state_initial()
+            self.assertFalse(node_1.is_initial)
+            self.assertFalse(node_2.is_initial)
+            self.assertFalse(node_3.is_initial)
+            self.assertEqual(self.test_graph._initial_states, {})
+
+            # Works with 1, 2, and 3 initial/start states. Assuming works with all further n initial/start states.
+
+    def test__toggle_node_state_final(self):
+        # Create node to toggle.
+        node_1 = self.test_graph.create_node()
+        node_2 = self.test_graph.create_node()
+        node_3 = self.test_graph.create_node()
+
+        with self.subTest('Adding final states.'):
+            # Test initial values.
+            self.assertFalse(node_1.is_final)
+            self.assertFalse(node_2.is_final)
+            self.assertFalse(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {})
+
+            # Test add with 0 initial states.
+            node_1.toggle_node_state_final()
+            self.assertTrue(node_1.is_final)
+            self.assertFalse(node_2.is_final)
+            self.assertFalse(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {node_1.get_name(): node_1})
+
+            # Test add with 1 initial states.
+            node_2.toggle_node_state_final()
+            self.assertTrue(node_1.is_final)
+            self.assertTrue(node_2.is_final)
+            self.assertFalse(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {node_1.get_name(): node_1, node_2.get_name(): node_2})
+
+            # Test add with 2 initial states.
+            node_3.toggle_node_state_final()
+            self.assertTrue(node_1.is_final)
+            self.assertTrue(node_2.is_final)
+            self.assertTrue(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {
+                node_1.get_name(): node_1,
+                node_2.get_name(): node_2,
+                node_3.get_name(): node_3,
+            })
+
+            # Works with 0, 1, and 2 final states. Assuming works with all further n final states.
+
+        with self.subTest('Removing final states.'):
+            # Test initial values.
+            self.assertTrue(node_1.is_final)
+            self.assertTrue(node_2.is_final)
+            self.assertTrue(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {
+                node_1.get_name(): node_1,
+                node_2.get_name(): node_2,
+                node_3.get_name(): node_3,
+            })
+
+            # Test remove with 3 initial states.
+            node_1.toggle_node_state_final()
+            self.assertFalse(node_1.is_final)
+            self.assertTrue(node_2.is_final)
+            self.assertTrue(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {node_2.get_name(): node_2, node_3.get_name(): node_3})
+
+            # Test remove with 2 initial states.
+            node_3.toggle_node_state_final()
+            self.assertFalse(node_1.is_final)
+            self.assertTrue(node_2.is_final)
+            self.assertFalse(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {node_2.get_name(): node_2})
+
+            # Test remove with 1 initial states.
+            node_2.toggle_node_state_final()
+            self.assertFalse(node_1.is_final)
+            self.assertFalse(node_2.is_final)
+            self.assertFalse(node_3.is_final)
+            self.assertEqual(self.test_graph._final_states, {})
+
+            # Works with 1, 2, and 3 final states. Assuming works with all further n final states.