diff --git a/src/traverseCraft/agent.py b/src/traverseCraft/agent.py index 063515f..13cbf0c 100644 --- a/src/traverseCraft/agent.py +++ b/src/traverseCraft/agent.py @@ -589,7 +589,7 @@ def __init__(self, world, agentName:str, agentColor:str="blue", startNodeId = No if startNodeId is not None: self._currentNode = self._worldObj.getNode(startNodeId) else: - self._currentNode = self._worldObj.root + self._currentNode = self._worldObj.nodeMap[list(self._worldObj.nodeMap.keys())[0]] self._graphRoot = self._currentNode self._worldObj.changeNodeColor(self._graphRoot.id, self._agentColor) # ~~~~~~~~~~ Base Heat Map Color ~~~~~~~~~~ # diff --git a/src/traverseCraft/world.py b/src/traverseCraft/world.py index aa912a6..8b9ae13 100644 --- a/src/traverseCraft/world.py +++ b/src/traverseCraft/world.py @@ -956,7 +956,6 @@ class CreateGraphWorld: - _canvas (Canvas): The canvas object for drawing the world. - nodeMap (dict): Dictionary mapping node IDs to canvas objects. - _visited (dict): Dictionary tracking visited nodes. - - root: The root of the Graph data structure(used to populate the graph). - _agent: The agent in the world. - _nodeObj (dict): Dictionary mapping node IDs to node objects. - _nodeTextObj (dict): Dictionary mapping node IDs to node label objects. @@ -1021,7 +1020,7 @@ def __init__(self, worldName: str, worldInfo: dict, radius: int = 20, fontSize:i if("vals" not in self._worldInfo): self._worldInfo['vals'] = None self._visited = {nodeId: False for nodeId in self._worldInfo["position"].keys()} - self.root = self._generateGraphDS(self._worldInfo["adj"], self._graphRootId, None, self._worldInfo["edges"], self._worldInfo['vals'], visited=self._visited) + self.__generateGraphDS(adj = self._worldInfo["adj"], edges = self._worldInfo["edges"], values = self._worldInfo['vals'], visited = self._visited) # ~~~~~ Agent information ~~~~~ # self._agent = None self._nodeObj = {} @@ -1171,6 +1170,11 @@ def summary(self): summary.add_row([nodeId, node._heatMapValue]) return str(summary) + def __generateGraphDS(self, adj, edges=None, values=None, visited={}): + for nodeId in adj: + if nodeId not in self.nodeMap: + self.nodeMap[nodeId] = self._generateGraphDS(adj, nodeId, None, edges, values, visited) + def _generateGraphDS(self, adj, rootId, parentId=None, edges=None, values=None, visited={}): if rootId not in adj: raise ValueError(f"Root ID {rootId} not found in adjacency list") @@ -1216,8 +1220,12 @@ def constructWorld(self): self._constructWorld() def _constructWorld(self): - self._drawEdges(self.root) - self._drawNodes(self.root) + for nodeId in self.nodeMap: + if self._visited[nodeId]: + self._drawEdges(self.nodeMap[nodeId]) + for nodeId in self.nodeMap: + if not self._visited[nodeId]: + self._drawNodes(self.nodeMap[nodeId]) self._addStartButton() def _drawEdges(self, node):