Skip to content

Commit

Permalink
fix the graph world creation for more than 1 component.
Browse files Browse the repository at this point in the history
  • Loading branch information
srajan-kiyotaka committed Jun 20, 2024
1 parent 3af52d6 commit b98aca2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/traverseCraft/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def __init__(self, world, agentName:str, agentColor:str="blue", startNodeId = No
if startNodeId is not None:
self._currentNode = self._worldObj.getNode(startNodeId)
else:
self._currentNode = self._worldObj.root
self._currentNode = self._worldObj.nodeMap[list(self._worldObj.nodeMap.keys())[0]]
self._graphRoot = self._currentNode
self._worldObj.changeNodeColor(self._graphRoot.id, self._agentColor)
# ~~~~~~~~~~ Base Heat Map Color ~~~~~~~~~~ #
Expand Down
16 changes: 12 additions & 4 deletions src/traverseCraft/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,6 @@ class CreateGraphWorld:
- _canvas (Canvas): The canvas object for drawing the world.
- nodeMap (dict): Dictionary mapping node IDs to canvas objects.
- _visited (dict): Dictionary tracking visited nodes.
- root: The root of the Graph data structure(used to populate the graph).
- _agent: The agent in the world.
- _nodeObj (dict): Dictionary mapping node IDs to node objects.
- _nodeTextObj (dict): Dictionary mapping node IDs to node label objects.
Expand Down Expand Up @@ -1021,7 +1020,7 @@ def __init__(self, worldName: str, worldInfo: dict, radius: int = 20, fontSize:i
if("vals" not in self._worldInfo):
self._worldInfo['vals'] = None
self._visited = {nodeId: False for nodeId in self._worldInfo["position"].keys()}
self.root = self._generateGraphDS(self._worldInfo["adj"], self._graphRootId, None, self._worldInfo["edges"], self._worldInfo['vals'], visited=self._visited)
self.__generateGraphDS(adj = self._worldInfo["adj"], edges = self._worldInfo["edges"], values = self._worldInfo['vals'], visited = self._visited)
# ~~~~~ Agent information ~~~~~ #
self._agent = None
self._nodeObj = {}
Expand Down Expand Up @@ -1171,6 +1170,11 @@ def summary(self):
summary.add_row([nodeId, node._heatMapValue])
return str(summary)

def __generateGraphDS(self, adj, edges=None, values=None, visited={}):
for nodeId in adj:
if nodeId not in self.nodeMap:
self.nodeMap[nodeId] = self._generateGraphDS(adj, nodeId, None, edges, values, visited)

def _generateGraphDS(self, adj, rootId, parentId=None, edges=None, values=None, visited={}):
if rootId not in adj:
raise ValueError(f"Root ID {rootId} not found in adjacency list")
Expand Down Expand Up @@ -1216,8 +1220,12 @@ def constructWorld(self):
self._constructWorld()

def _constructWorld(self):
self._drawEdges(self.root)
self._drawNodes(self.root)
for nodeId in self.nodeMap:
if self._visited[nodeId]:
self._drawEdges(self.nodeMap[nodeId])
for nodeId in self.nodeMap:
if not self._visited[nodeId]:
self._drawNodes(self.nodeMap[nodeId])
self._addStartButton()

def _drawEdges(self, node):
Expand Down

0 comments on commit b98aca2

Please sign in to comment.