diff --git a/examples/03_view_map.py b/examples/03_view_map.py index 7f91782..d054edd 100644 --- a/examples/03_view_map.py +++ b/examples/03_view_map.py @@ -42,7 +42,7 @@ elif isinstance(node, tuple): path.append([node[0], node[1], node[2]]) print(f"path: {path}") - + # visualize path in open3d if USE_OPEN3D: @@ -51,16 +51,16 @@ xyz_pt = np.stack(obstacle_indices, axis=-1).astype(float) colors = np.zeros((xyz_pt.shape[0], 3)) colors[:, 2] = obstacle_indices[2] / np.max(obstacle_indices[2]) - + # Prepare start and end colors start_color = np.array([[1.0, 0, 0]]) # Red - end_color = np.array([[0, 1.0, 0]]) # Green + end_color = np.array([[0, 1.0, 0]]) # Green path_colors = np.full((len(path) - 2, 3), [0.7, 0.7, 0.7]) # Grey for the path # Combine points and colors xyz_pt = np.concatenate((xyz_pt, [start_pt], [end_pt], path[1:-1])) colors = np.concatenate((colors, start_color, end_color, path_colors)) - + # Create and visualize the point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(xyz_pt) diff --git a/pathfinding3d/core/heap.py b/pathfinding3d/core/heap.py new file mode 100644 index 0000000..b5d819a --- /dev/null +++ b/pathfinding3d/core/heap.py @@ -0,0 +1,143 @@ +""" +Simple heap with ordering and removal. +Inspired from https://github.com/brean/python-pathfinding/pull/54 +Original author: https://github.com/peterchenadded +""" +import heapq +from typing import Callable, Union + +from .grid import Grid +from .node import GridNode +from .world import World + + +class SimpleHeap: + """ + A simple implementation of a heap data structure optimized for pathfinding. + It maintains an open list of nodes, a status for each node, and a function to retrieve nodes. + """ + + def __init__(self, node: GridNode, grid: Union[Grid, World]): + """ + Initializes the SimpleHeap with a given node and grid. + + Parameters + ---------- + node : GridNode + The initial node to be added to the heap. This node should have an 'f' attribute representing its cost. + grid : Union[Grid, World] + The grid in which the nodes are located. + """ + + self.grid = grid + self._get_node_tuple = self._determine_node_retrieval_function() + self._get_node = self._determine_node_function() + self.open_list = [self._get_node_tuple(node, 0)] + self.removed_node_tuples = set() + self.heap_order = {} + self.number_pushed = 0 + + def _determine_node_retrieval_function(self) -> Callable: + """ + Determines the node retrieval function based on the type of grid. + + Returns + ------- + function + A function that takes a node tuple and returns the corresponding node. + + Raises + ------ + ValueError + If the grid is not of type Grid or World. + """ + if isinstance(self.grid, Grid): + return lambda node, heap_order: (node.f, heap_order, *node.identifier) + + if isinstance(self.grid, World): + return lambda node, heap_order: (node.f, heap_order, *node.identifier) + + raise ValueError("Unsupported grid type") + + def _determine_node_function(self) -> Callable: + """ + Determines the node function based on the type of grid. + + Returns + ------- + function + A function that takes a node tuple and returns the corresponding node. + + Raises + ------ + ValueError + If the grid is not of type Grid or World. + """ + + if isinstance(self.grid, Grid): + return lambda node_tuple: self.grid.node(*node_tuple[2:]) + + if isinstance(self.grid, World): + return lambda node_tuple: self.grid.grids[node_tuple[5]].node(*node_tuple[2:5]) + + raise ValueError("Unsupported grid type") + + def pop_node(self) -> GridNode: + """ + Pops the node with the lowest cost from the heap. + + Returns + ------- + GridNode + The node with the lowest cost. + """ + node_tuple = heapq.heappop(self.open_list) + while node_tuple in self.removed_node_tuples: + node_tuple = heapq.heappop(self.open_list) + + return self._get_node(node_tuple) + + def push_node(self, node: GridNode): + """ + Pushes a node to the heap. + + Parameters + ---------- + node : GridNode + The node to be pushed to the heap. + """ + self.number_pushed = self.number_pushed + 1 + node_tuple = self._get_node_tuple(node, self.number_pushed) + + self.heap_order[node.identifier] = self.number_pushed + + heapq.heappush(self.open_list, node_tuple) + + def remove_node(self, node: GridNode, old_f: float): + """ + Remove the node from the heap. + + This just stores it in a set and we just ignore the node if it does + get popped from the heap. + + Parameters + ---------- + node : GridNode + The node to be removed from the heap. + old_f: float + The old cost of the node. + """ + heap_order = self.heap_order[node.identifier] + node_tuple = self._get_node_tuple(node, heap_order) + self.removed_node_tuples.add(node_tuple) + + def __len__(self) -> int: + """ + Returns the length of the heap. + + Returns + ------- + int + The length of the heap. + """ + return len(self.open_list) diff --git a/pathfinding3d/core/node.py b/pathfinding3d/core/node.py index f347b95..bf66f24 100644 --- a/pathfinding3d/core/node.py +++ b/pathfinding3d/core/node.py @@ -1,11 +1,12 @@ import dataclasses -from typing import List, Optional +from typing import List, Optional, Tuple @dataclasses.dataclass class Node: - def __post_init__(self): - # values used in the finder + __slots__ = ["h", "g", "f", "opened", "closed", "parent", "retain_count", "tested"] + + def __init__(self): self.cleanup() def __lt__(self, other: "Node") -> bool: @@ -63,6 +64,15 @@ class GridNode(Node): connections: Optional[List] = None + identifier: Optional[Tuple] = None + + def __post_init__(self): + super().__init__() + # for heap + self.identifier: Tuple = ( + (self.x, self.y, self.z) if self.grid_id is None else (self.x, self.y, self.z, self.grid_id) + ) + def __iter__(self): yield self.x yield self.y diff --git a/pathfinding3d/finder/a_star.py b/pathfinding3d/finder/a_star.py index a054799..9008220 100644 --- a/pathfinding3d/finder/a_star.py +++ b/pathfinding3d/finder/a_star.py @@ -1,4 +1,3 @@ -import heapq # used for the so colled "open list" that stores known nodes from typing import Callable, List, Optional, Tuple, Union from ..core.diagonal_movement import DiagonalMovement @@ -86,7 +85,7 @@ def check_neighbors( """ # pop node with minimum 'f' value - node = heapq.heappop(open_list) + node = open_list.pop_node() node.closed = True # if reached the end position, construct the path and return it diff --git a/pathfinding3d/finder/bi_a_star.py b/pathfinding3d/finder/bi_a_star.py index b1931e3..7798939 100644 --- a/pathfinding3d/finder/bi_a_star.py +++ b/pathfinding3d/finder/bi_a_star.py @@ -3,6 +3,7 @@ from ..core.diagonal_movement import DiagonalMovement from ..core.grid import Grid +from ..core.heap import SimpleHeap from ..core.node import GridNode from .a_star import AStarFinder from .finder import BY_END, BY_START, MAX_RUNS, TIME_LIMIT @@ -71,12 +72,12 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i self.start_time = time.time() # execution time limitation self.runs = 0 # count number of iterations - start_open_list = [start] + start_open_list = SimpleHeap(start, grid) start.g = 0 start.f = 0 start.opened = BY_START - end_open_list = [end] + end_open_list = SimpleHeap(end, grid) end.g = 0 end.f = 0 end.opened = BY_END diff --git a/pathfinding3d/finder/breadth_first.py b/pathfinding3d/finder/breadth_first.py index 7485665..7209786 100644 --- a/pathfinding3d/finder/breadth_first.py +++ b/pathfinding3d/finder/breadth_first.py @@ -74,7 +74,7 @@ def check_neighbors( List[GridNode] path """ - node = open_list.pop(0) + node = open_list.pop_node() node.closed = True if node == end: @@ -85,6 +85,6 @@ def check_neighbors( if neighbor.closed or neighbor.opened: continue - open_list.append(neighbor) + open_list.push_node(neighbor) neighbor.opened = True neighbor.parent = node diff --git a/pathfinding3d/finder/finder.py b/pathfinding3d/finder/finder.py index 70aac2e..bffa077 100644 --- a/pathfinding3d/finder/finder.py +++ b/pathfinding3d/finder/finder.py @@ -1,9 +1,9 @@ -import heapq # used for the so colled "open list" that stores known nodes import time # for time limitation from typing import Callable, List, Optional, Tuple, Union from ..core.diagonal_movement import DiagonalMovement from ..core.grid import Grid +from ..core.heap import SimpleHeap from ..core.node import GridNode # max. amount of tries we iterate until we abort the search @@ -180,6 +180,7 @@ def process_node( ng = parent.g + grid.calc_cost(parent, node, self.weighted) if not node.opened or ng < node.g: + old_f = node.f node.g = ng node.h = node.h or self.apply_heuristic(node, end) # f is the estimated total cost from start to goal @@ -187,14 +188,14 @@ def process_node( node.parent = parent if not node.opened: - heapq.heappush(open_list, node) + open_list.push_node(node) node.opened = open_value else: # the node can be reached with smaller cost. # Since its f value has been updated, we have to # update its position in the open list - open_list.remove(node) - heapq.heappush(open_list, node) + open_list.remove_node(node, old_f) + open_list.push_node(node) def check_neighbors( self, @@ -251,7 +252,7 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i self.runs = 0 # count number of iterations start.opened = True - open_list = [start] + open_list = SimpleHeap(start, grid) while len(open_list) > 0: self.runs += 1 diff --git a/pathfinding3d/finder/msp.py b/pathfinding3d/finder/msp.py index 51e6e11..d565a6a 100644 --- a/pathfinding3d/finder/msp.py +++ b/pathfinding3d/finder/msp.py @@ -1,10 +1,10 @@ -import heapq import time from collections import deque, namedtuple from typing import List, Tuple from ..core import heuristic from ..core.grid import Grid +from ..core.heap import SimpleHeap from ..core.node import GridNode from ..finder.finder import Finder @@ -62,23 +62,20 @@ def itertree(self, grid: Grid, start: GridNode): start.opened = True - open_list = [start] + open_list = SimpleHeap(start, grid) while len(open_list) > 0: self.runs += 1 self.keep_running() - node = heapq.nsmallest(1, open_list)[0] - open_list.remove(node) + node = open_list.pop_node() node.closed = True yield node neighbors = self.find_neighbors(grid, node) for neighbor in neighbors: if not neighbor.closed: - self.process_node( - grid, neighbor, node, end, open_list, open_value=True - ) + self.process_node(grid, neighbor, node, end, open_list, open_value=True) def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, int]: """ diff --git a/test/test_connect_grids.py b/test/test_connect_grids.py index 2c4454d..2f91e26 100644 --- a/test/test_connect_grids.py +++ b/test/test_connect_grids.py @@ -4,9 +4,9 @@ PATH = [ (2, 0, 0, 0), - (2, 0, 1, 0), - (2, 0, 2, 0), - (2, 1, 2, 0), + (2, 1, 0, 0), + (2, 2, 0, 0), + (2, 2, 1, 0), (2, 2, 2, 0), # move to grid 1 (2, 2, 2, 1), diff --git a/test/test_heap.py b/test/test_heap.py new file mode 100644 index 0000000..6d821c3 --- /dev/null +++ b/test/test_heap.py @@ -0,0 +1,26 @@ +from pathfinding3d.core.grid import Grid +from pathfinding3d.core.heap import SimpleHeap + + +def test_heap(): + grid = Grid(width=10, height=10, depth=10) + start = grid.node(0, 0, 0) + open_list = SimpleHeap(start, grid) + + # Test pop + assert open_list.pop_node() == start + assert len(open_list) == 0 + + # Test push + open_list.push_node(grid.node(1, 1, 1)) + open_list.push_node(grid.node(1, 1, 2)) + open_list.push_node(grid.node(1, 1, 3)) + + # Test removal and pop + assert len(open_list) == 3 + open_list.remove_node(grid.node(1, 1, 2), 0) + assert len(open_list) == 3 + + assert open_list.pop_node() == grid.node(1, 1, 1) + assert open_list.pop_node() == grid.node(1, 1, 3) + assert len(open_list) == 0 diff --git a/test/test_path.py b/test/test_path.py index fc838c6..f429883 100644 --- a/test/test_path.py +++ b/test/test_path.py @@ -13,7 +13,6 @@ from pathfinding3d.finder.ida_star import IDAStarFinder from pathfinding3d.finder.msp import MinimumSpanningTree -# test scenarios from Pathfinding.JS finders = [ AStarFinder, BestFirst, @@ -31,7 +30,6 @@ DijkstraFinder, MinimumSpanningTree, ] -TIME_LIMIT = 10 # give it a 10 second limit SIMPLE_MATRIX = np.zeros((5, 5, 5)) SIMPLE_MATRIX[0, 0, 0] = 1 @@ -60,7 +58,7 @@ def test_path(): """ - test scenarios defined in json file + test if we can find a path """ grid = Grid(matrix=SIMPLE_MATRIX) start = grid.node(0, 0, 0)