diff --git a/main.py b/main.py
index 1a0453beb5f69f6a3f94676c2e83675636fac565..54d8a70c5d8647bdd996b82999250b2d9ba87135 100644
--- a/main.py
+++ b/main.py
@@ -21,7 +21,7 @@ from resources.simplex import Simplex
 logger = init_logging.get_logger(__name__)
 
 
-def create_network_graph():
+def create_network_graph(small_flow=True):
     """
     Creates new instance of network problem, using custom "Graph Library" Network Flow class.
     :return: A "Network Flow Graph" class instance of the problem to solve.
@@ -29,9 +29,14 @@ def create_network_graph():
     # Create new graph.
     graph = NetworkFlowGraph()
 
+    if small_flow:
+        file_name = './resources/json_files/small_flow.json'
+    else:
+        file_name = './resources/json_files/network_flow_values.json'
+
     # Read in JSON data. We assume data is valid and correct, so we don't validate.
     # Open file.
-    with open('./resources/json_files/network_flow_values.json') as json_file:
+    with open(file_name) as json_file:
         # Parse JSON data into Python format.
         json_data = json.load(json_file)
 
@@ -46,9 +51,15 @@ def create_network_graph():
     graph.nodes.create('t')
 
     # Manually set y coordinates so graph is consistent.
-    top_row = ['A', 'D']
-    mid_row = ['s', 'B', 'E', 't']
-    bot_row = ['C', 'F']
+    if small_flow:
+        top_row = ['A', 'C']
+        mid_row = ['s', 't']
+        bot_row = ['B', 'D']
+    else:
+        top_row = ['A', 'D']
+        mid_row = ['s', 'B', 'E', 't']
+        bot_row = ['C', 'F']
+
     for node in graph.nodes.all().values():
         # Set coord based on row.
         if node.get_name() in top_row:
@@ -59,8 +70,12 @@ def create_network_graph():
             node.set_y_coord(10)
 
     # Manually set x coordinates to give nodes more breathing room.
-    mid_left = ['A', 'B', 'C']
-    mid_right = ['D', 'E', 'F']
+    if small_flow:
+        mid_left = ['A', 'B']
+        mid_right = ['C', 'D']
+    else:
+        mid_left = ['A', 'B', 'C']
+        mid_right = ['D', 'E', 'F']
     for node in graph.nodes.all().values():
         # Set coord based on col.
         if node.get_name() in mid_left: