diff --git a/tests/test_graph_agent.py b/tests/test_graph_agent.py new file mode 100644 index 0000000..a739844 --- /dev/null +++ b/tests/test_graph_agent.py @@ -0,0 +1,121 @@ +import unittest +from tkinter import Tk +from unittest.mock import patch +from traverseCraft.world import CreateGraphWorld +from traverseCraft.agent import GraphAgent +from traverseCraft.dataStructures import GraphNode + +class TestGraphAgent(unittest.TestCase): + def setUp(self): + # Sample Graph world information + graphWorldInfo = { + 'adj': { + 'A': ['B', 'C'], + 'B': ['D', 'E'], + 'C': ['F'], + 'D': [], + 'E': ['H', 'A'], + 'F': ['G'], + 'G': ['H', 'C'], + 'H': ['D', 'E'] + }, + 'position': { + 'A': (300, 100), + 'B': (150, 200), + 'C': (450, 200), + 'D': (100, 300), + 'E': (200, 300), + 'F': (300, 300), + 'G': (400, 400), + 'H': (150, 400) + + }, + 'goals': ['G'] + } + # Create the Graph world + self.graphWorld = CreateGraphWorld("Graph World Test", graphWorldInfo) + + # Construct the world + self.graphWorld.constructWorld() + + # Initialize the Graph agent + self.agent = GraphAgent(agentName="Test Graph Agent", world=self.graphWorld, heatMapColor="#EF4040") + + # Link the agent with the world + self.graphWorld.setAgent(self.agent) + + def tearDown(self): + # Destroy the root window after each test + try: + self.graphWorld._root.destroy() + self.agent = None + except Exception as e: + print(f"Exception in tearDown: {e}") + + def test_initialization(self): + # Test initialization of the Graph agent + self.assertEqual(self.agent._worldObj, self.graphWorld) + self.assertEqual(self.agent._worldID, "GRAPHWORLD") + self.assertEqual(self.agent._agentName, "Test Graph Agent") + self.assertEqual(self.agent._agentColor, "blue") + self.assertTrue(self.agent._heatMapView) + self.assertNotEqual(self.agent._heatMapColor, "#FFA732") + self.assertEqual(self.agent._heatGradient, 0.05) + + def test_set_algorithm_callback(self): + # Define a sample callback function + def sample_callback(): + pass + + # Set the callback function + self.agent.setAlgorithmCallBack(sample_callback) + + # Verify that the callback function is set + self.assertEqual(self.agent.algorithmCallBack, sample_callback) + + def test_run_algorithm(self): + # Define a sample algorithm callback function + def sample_algorithm(): + pass + + # Set the algorithm callback function + self.agent.setAlgorithmCallBack(sample_algorithm) + + # Run the algorithm + self.agent.runAlgorithm() + + # Verify that the algorithm ran successfully (no errors raised) + self.assertTrue(True) + + def test_check_goal_state(self): + goal_node = self.graphWorld.getNode("G") + non_goal_node = self.graphWorld.getNode("A") + self.assertTrue(self.agent.checkGoalState(goal_node)) + self.assertFalse(self.agent.checkGoalState(non_goal_node)) + + def test_set_start_state(self): + self.agent.setStartState("B") + self.assertEqual(self.agent._currentNode.id, "B") + self.assertEqual(self.agent._graphRoot.id, "B") + self.assertEqual(self.graphWorld.root.id, "B") + + def test_set_start_state_invalid(self): + with self.assertRaises(ValueError): + self.agent.setStartState("Z") # Assuming "Z" is not a valid node ID + + def test_move_agent(self): + # Get pointers to nodes 'B', 'D', and 'E' using self.graphWorld.getNode() + node_B = self.graphWorld.getNode('B') + node_C = self.graphWorld.getNode('C') + node_D = self.graphWorld.getNode('D') + node_E = self.graphWorld.getNode('E') + + # Move the agent + self.assertFalse(self.agent.moveAgent(None)) # Move to None should fail + self.assertTrue(self.agent.moveAgent(node_B)) # Move to node 'B' should succeed + self.assertTrue(self.agent.moveAgent(node_C)) # Move to node 'C' should succeed + self.assertTrue(self.agent.moveAgent(node_D)) # Move to node 'D' should succeed + self.assertTrue(self.agent.moveAgent(node_E)) # Move to node 'E' should succeed + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_graph_world.py b/tests/test_graph_world.py index e69de29..568ed4b 100644 --- a/tests/test_graph_world.py +++ b/tests/test_graph_world.py @@ -0,0 +1,248 @@ +import unittest +from tkinter import Tk +from unittest.mock import patch +from traverseCraft.world import CreateGraphWorld +from traverseCraft.dataStructures import GraphNode + +class TestGraphWorld(unittest.TestCase): + + def setUp(self): + # Sample valid Graph information + self.graphWorldInfo = { + 'adj': { + 'A': ['B', 'C'], + 'B': ['D', 'E'], + 'C': ['F'], + 'D': [], + 'E': ['H', 'A'], + 'F': ['G'], + 'G': ['H', 'C'], + 'H': ['D', 'E'] + }, + 'position': { + 'A': (300, 100), + 'B': (150, 200), + 'C': (450, 200), + 'D': (100, 300), + 'E': (200, 300), + 'F': (300, 300), + 'G': (400, 400), + 'H': (150, 400) + + }, + 'goals': ['G'] +} + + # def tearDown(self): + # try: + # self.grid_world._root.destroy() + # except Exception as e: + # print(f"Exception in tearDown: {e}") + + def test_successful_initialization(self): + self.graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + self.assertEqual(self.graphWorld._worldName, "Graph World Test") + self.assertEqual(self.graphWorld._goalIds, ['G']) + self.assertEqual(self.graphWorld._position['A'], (300, 100)) + self.assertEqual(self.graphWorld._nodeColor, "gray") + self.assertEqual(self.graphWorld._goalColor, "green") + self.assertEqual(self.graphWorld._fontSize, 12) + self.assertTrue(self.graphWorld._fontBold) + self.assertTrue(self.graphWorld._fontItalic) + self.assertEqual(self.graphWorld._lineThickness, 2) + self.assertEqual(self.graphWorld._arrowShape, (10, 12, 5)) + self.assertEqual(self.graphWorld._buttonBgColor, "#7FC7D9") + self.assertEqual(self.graphWorld._buttonFgColor, "#332941") + self.assertEqual(self.graphWorld._textFont, "Helvetica") + self.assertEqual(self.graphWorld._textSize, 24) + self.assertEqual(self.graphWorld._textWeight, "bold") + self.assertEqual(self.graphWorld._buttonText, "Start Agent") + try: + self.graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + + def test_missing_goals_key(self): + invalidInfo = self.graphWorldInfo.copy() + del invalidInfo['goals'] + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", invalidInfo) + + def test_missing_adj_key(self): + invalidInfo = self.graphWorldInfo.copy() + del invalidInfo['adj'] + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", invalidInfo) + + def test_missing_position_key(self): + invalidInfo = self.graphWorldInfo.copy() + del invalidInfo['position'] + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", invalidInfo) + + def test_default_parameters(self): + graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + self.assertEqual(graphWorld._radius, 20) + self.assertEqual(graphWorld._fontSize, 12) + self.assertEqual(graphWorld._fontBold, True) + self.assertEqual(graphWorld._fontItalic, True) + self.assertEqual(graphWorld._nodeColor, "gray") + self.assertEqual(graphWorld._goalColor, "green") + self.assertEqual(graphWorld._lineThickness, 2) + self.assertEqual(graphWorld._arrowShape, (10, 12, 5)) + self.assertEqual(graphWorld._buttonBgColor, "#7FC7D9") + self.assertEqual(graphWorld._buttonFgColor, "#332941") + self.assertEqual(graphWorld._textFont, "Helvetica") + self.assertEqual(graphWorld._textSize, 24) + self.assertEqual(graphWorld._textWeight, "bold") + self.assertEqual(graphWorld._buttonText, "Start Agent") + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + def test_custom_parameters(self): + graphWorld = CreateGraphWorld( + "Custom Graph World", self.graphWorldInfo, radius=25, fontSize=14, fontBold=False, + fontItalic=False, nodeColor="blue", goalColor="orange", + width=800, height=600, lineThickness=3, arrowShape=(5, 10, 3), buttonBgColor="black", + buttonFgColor="white", textFont="Arial", textSize=18, textWeight="normal", buttonText="Run" + ) + self.assertEqual(graphWorld._radius, 25) + self.assertEqual(graphWorld._fontSize, 14) + self.assertEqual(graphWorld._fontBold, False) + self.assertEqual(graphWorld._fontItalic, False) + self.assertEqual(graphWorld._nodeColor, "blue") + self.assertEqual(graphWorld._goalColor, "orange") + self.assertEqual(graphWorld._width, 800) + self.assertEqual(graphWorld._height, 600) + self.assertEqual(graphWorld._lineThickness, 3) + self.assertEqual(graphWorld._arrowShape, (5, 10, 3)) + self.assertEqual(graphWorld._buttonBgColor, "black") + self.assertEqual(graphWorld._buttonFgColor, "white") + self.assertEqual(graphWorld._textFont, "Arial") + self.assertEqual(graphWorld._textSize, 18) + self.assertEqual(graphWorld._textWeight, "normal") + self.assertEqual(graphWorld._buttonText, "Run") + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + def test_valid_graph_format(self): + graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + self.assertTrue(graphWorld._check_graph_format(self.graphWorldInfo)[0]) + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + def test_missing_keys(self): + missing_keys = { + 'adj': { + 'A': ['B', 'C'], + 'B': ['D', 'E'], + 'C': ['F'], + 'D': [], + 'E': ['H'], + 'F': ['G'], + 'G': [], + 'H': [] + }, + 'position': { + 'A': (300, 100), + 'B': (150, 200), + 'C': (450, 200), + 'D': (100, 300), + 'E': (200, 300), + 'F': (300, 300), + 'G': (400, 400) + }, + 'goals': ['G'] + } + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", missing_keys) + + + def test_invalid_node_coordinates(self): + invalid_coords = [ + {'adj': {'A': ['B'], 'B': ['C'], 'C': []}, 'position': {'A': (0, 0), 'B': (100, '100'), 'C': (200, 200)}, 'root': 'A', 'goals': ['C']}, + {'adj': {'A': ['B'], 'B': ['C'], 'C': []}, 'position': {'A': (0, 0), 'B': (100,), 'C': (200, 200)}, 'root': 'A', 'goals': ['C']} + ] + for info in invalid_coords: + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", info) + + def test_invalid_goal_nodes(self): + invalid_goals = [ + {'adj': {'A': ['B'], 'B': ['C'], 'C': []}, 'position': {'A': (0, 0), 'B': (100, 100), 'C': (200, 200)}, 'root': 'A','goals': ['D']}, + {'adj': {'A': ['B'], 'B': ['C'], 'C': []}, 'position': {'A': (0, 0), 'B': (100, 100), 'C': (200, 200)}, 'root': 'A','goals': [1]} + ] + for info in invalid_goals: + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", info) + + def test_missing_nodes(self): + missing_nodes = [ + {'adj': {'A': ['B'], 'B': ['C'], 'D': []}, 'position': {'A': (0, 0), 'B': (100, 100), 'C': (200, 200)}, 'root': 'A', 'goals': ['C']}, + {'adj': {'A': ['B'], 'B': ['C'], 'C': []}, 'position': {'A': (0, 0), 'B': (100, 100)}, 'root': 'A', 'goals': ['C']} + ] + for info in missing_nodes: + with self.assertRaises(ValueError): + CreateGraphWorld("Graph World Test", info) + + def test_node_map_pointers(self): + # Create the Graph world + graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + + # Assert that all values in nodeMap are instances of GraphNode class + for node_id, node_obj in graphWorld.nodeMap.items(): + self.assertIsInstance(node_obj, GraphNode) + + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + + def test_get_node(self): + # Create the Graph world + graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + + # Test getNode method + node_A = graphWorld.getNode('A') + node_B = graphWorld.getNode('B') + node_C = graphWorld.getNode('C') + + # Assert nodes are retrieved correctly + self.assertEqual(node_A.id, 'A') + self.assertEqual(node_B.id, 'B') + self.assertEqual(node_C.id, 'C') + + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + @patch.object(Tk, 'mainloop', lambda self: None) + def test_show_world(self): + graphWorld = CreateGraphWorld("Graph World Test", self.graphWorldInfo) + + # Test showWorld method + graphWorld.showWorld() + + # Verify that the _root window is initialized after calling showWorld + self.assertIsInstance(graphWorld._root, Tk) + + # Verify that the _root window is visible + self.assertTrue(graphWorld._root.winfo_exists()) + + try: + graphWorld._root.destroy() + except Exception as e: + print(f"Exception in tearDown: {e}") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_tree_agent.py b/tests/test_tree_agent.py index 287c80f..b1a3243 100644 --- a/tests/test_tree_agent.py +++ b/tests/test_tree_agent.py @@ -5,11 +5,6 @@ from traverseCraft.agent import TreeAgent from traverseCraft.dataStructures import TreeNode -import unittest -import time -from traverseCraft.world import CreateTreeWorld -from traverseCraft.agent import TreeAgent - class TestTreeAgent(unittest.TestCase): def setUp(self): # Sample tree world information diff --git a/tests/test_tree_world.py b/tests/test_tree_world.py index 26e1bc7..b3e556a 100644 --- a/tests/test_tree_world.py +++ b/tests/test_tree_world.py @@ -160,9 +160,8 @@ def test_missing_keys(self): 'root': 'A', 'goals': ['G'] } - for info in missing_keys: - with self.assertRaises(ValueError): - CreateTreeWorld("Tree World Test", info) + with self.assertRaises(ValueError): + CreateTreeWorld("Tree World Test", missing_keys) def test_invalid_node_coordinates(self): diff --git a/traverseCraft/world.py b/traverseCraft/world.py index e76acbb..fa67c1c 100644 --- a/traverseCraft/world.py +++ b/traverseCraft/world.py @@ -918,7 +918,6 @@ class CreateGraphWorld: - fontBold (bool): Whether to use bold font for the node labels. Default is True. - fontItalic (bool): Whether to use italic font for the node labels. Default is True. - nodeColor (str): The color of the nodes. Default is "gray". - - rootColor (str): The color of the root node. Default is "red". - goalColor (str): The color of the goal nodes. Default is "green". - width (int): The width of the world visualization canvas. Default is SCREEN_WIDTH. - height (int): The height of the world visualization canvas. Default is SCREEN_HEIGHT. @@ -954,7 +953,7 @@ class CreateGraphWorld: - _canvas (Canvas): The canvas object for drawing the world. - nodeMap (dict): Dictionary mapping node IDs to canvas objects. - _visited (dict): Dictionary tracking visited nodes. - - root: The root of the tree data structure. + - root: The root of the Graph data structure(used to populate the graph). - _agent: The agent in the world. - _nodeObj (dict): Dictionary mapping node IDs to node objects. - _nodeTextObj (dict): Dictionary mapping node IDs to node label objects. @@ -966,7 +965,7 @@ class CreateGraphWorld: - _textWeight (str): The font weight of the button text. """ worldID = "GRAPHWORLD" - def __init__(self, worldName: str, worldInfo: dict, radius: int = 20, fontSize:int=12, fontBold:bool = True, fontItalic:bool = True, nodeColor: str = "gray", rootColor: str="red", goalColor: str="green", width: int = SCREEN_WIDTH, height: int = SCREEN_HEIGHT, lineThickness: int =2, arrowShape: tuple = (10, 12, 5), buttonBgColor:str="#7FC7D9", buttonFgColor:str="#332941", textFont:str="Helvetica", textSize:int=24, textWeight:str="bold", buttonText:str="Start Agent", logoPath:str=None): + def __init__(self, worldName: str, worldInfo: dict, radius: int = 20, fontSize:int=12, fontBold:bool = True, fontItalic:bool = True, nodeColor: str = "gray", goalColor: str="green", width: int = SCREEN_WIDTH, height: int = SCREEN_HEIGHT, lineThickness: int =2, arrowShape: tuple = (10, 12, 5), buttonBgColor:str="#7FC7D9", buttonFgColor:str="#332941", textFont:str="Helvetica", textSize:int=24, textWeight:str="bold", buttonText:str="Start Agent", logoPath:str=None): """ Initializes the Graph World.