Skip to content

Commit

Permalink
Merge pull request #5 from harisankar95/update-heap
Browse files Browse the repository at this point in the history
Update heap
  • Loading branch information
harisankar95 authored Jan 20, 2024
2 parents 868ef4d + 9faee0c commit 8588618
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 31 deletions.
8 changes: 4 additions & 4 deletions examples/03_view_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
elif isinstance(node, tuple):
path.append([node[0], node[1], node[2]])
print(f"path: {path}")


# visualize path in open3d
if USE_OPEN3D:
Expand All @@ -51,16 +51,16 @@
xyz_pt = np.stack(obstacle_indices, axis=-1).astype(float)
colors = np.zeros((xyz_pt.shape[0], 3))
colors[:, 2] = obstacle_indices[2] / np.max(obstacle_indices[2])

# Prepare start and end colors
start_color = np.array([[1.0, 0, 0]]) # Red
end_color = np.array([[0, 1.0, 0]]) # Green
end_color = np.array([[0, 1.0, 0]]) # Green
path_colors = np.full((len(path) - 2, 3), [0.7, 0.7, 0.7]) # Grey for the path

# Combine points and colors
xyz_pt = np.concatenate((xyz_pt, [start_pt], [end_pt], path[1:-1]))
colors = np.concatenate((colors, start_color, end_color, path_colors))

# Create and visualize the point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz_pt)
Expand Down
143 changes: 143 additions & 0 deletions pathfinding3d/core/heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Simple heap with ordering and removal.
Inspired from https://github.com/brean/python-pathfinding/pull/54
Original author: https://github.com/peterchenadded
"""
import heapq
from typing import Callable, Union

from .grid import Grid
from .node import GridNode
from .world import World


class SimpleHeap:
"""
A simple implementation of a heap data structure optimized for pathfinding.
It maintains an open list of nodes, a status for each node, and a function to retrieve nodes.
"""

def __init__(self, node: GridNode, grid: Union[Grid, World]):
"""
Initializes the SimpleHeap with a given node and grid.
Parameters
----------
node : GridNode
The initial node to be added to the heap. This node should have an 'f' attribute representing its cost.
grid : Union[Grid, World]
The grid in which the nodes are located.
"""

self.grid = grid
self._get_node_tuple = self._determine_node_retrieval_function()
self._get_node = self._determine_node_function()
self.open_list = [self._get_node_tuple(node, 0)]
self.removed_node_tuples = set()
self.heap_order = {}
self.number_pushed = 0

def _determine_node_retrieval_function(self) -> Callable:
"""
Determines the node retrieval function based on the type of grid.
Returns
-------
function
A function that takes a node tuple and returns the corresponding node.
Raises
------
ValueError
If the grid is not of type Grid or World.
"""
if isinstance(self.grid, Grid):
return lambda node, heap_order: (node.f, heap_order, *node.identifier)

if isinstance(self.grid, World):
return lambda node, heap_order: (node.f, heap_order, *node.identifier)

raise ValueError("Unsupported grid type")

def _determine_node_function(self) -> Callable:
"""
Determines the node function based on the type of grid.
Returns
-------
function
A function that takes a node tuple and returns the corresponding node.
Raises
------
ValueError
If the grid is not of type Grid or World.
"""

if isinstance(self.grid, Grid):
return lambda node_tuple: self.grid.node(*node_tuple[2:])

if isinstance(self.grid, World):
return lambda node_tuple: self.grid.grids[node_tuple[5]].node(*node_tuple[2:5])

raise ValueError("Unsupported grid type")

def pop_node(self) -> GridNode:
"""
Pops the node with the lowest cost from the heap.
Returns
-------
GridNode
The node with the lowest cost.
"""
node_tuple = heapq.heappop(self.open_list)
while node_tuple in self.removed_node_tuples:
node_tuple = heapq.heappop(self.open_list)

return self._get_node(node_tuple)

def push_node(self, node: GridNode):
"""
Pushes a node to the heap.
Parameters
----------
node : GridNode
The node to be pushed to the heap.
"""
self.number_pushed = self.number_pushed + 1
node_tuple = self._get_node_tuple(node, self.number_pushed)

self.heap_order[node.identifier] = self.number_pushed

heapq.heappush(self.open_list, node_tuple)

def remove_node(self, node: GridNode, old_f: float):
"""
Remove the node from the heap.
This just stores it in a set and we just ignore the node if it does
get popped from the heap.
Parameters
----------
node : GridNode
The node to be removed from the heap.
old_f: float
The old cost of the node.
"""
heap_order = self.heap_order[node.identifier]
node_tuple = self._get_node_tuple(node, heap_order)
self.removed_node_tuples.add(node_tuple)

def __len__(self) -> int:
"""
Returns the length of the heap.
Returns
-------
int
The length of the heap.
"""
return len(self.open_list)
16 changes: 13 additions & 3 deletions pathfinding3d/core/node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
from typing import List, Optional
from typing import List, Optional, Tuple


@dataclasses.dataclass
class Node:
def __post_init__(self):
# values used in the finder
__slots__ = ["h", "g", "f", "opened", "closed", "parent", "retain_count", "tested"]

def __init__(self):
self.cleanup()

def __lt__(self, other: "Node") -> bool:
Expand Down Expand Up @@ -63,6 +64,15 @@ class GridNode(Node):

connections: Optional[List] = None

identifier: Optional[Tuple] = None

def __post_init__(self):
super().__init__()
# for heap
self.identifier: Tuple = (
(self.x, self.y, self.z) if self.grid_id is None else (self.x, self.y, self.z, self.grid_id)
)

def __iter__(self):
yield self.x
yield self.y
Expand Down
3 changes: 1 addition & 2 deletions pathfinding3d/finder/a_star.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import heapq # used for the so colled "open list" that stores known nodes
from typing import Callable, List, Optional, Tuple, Union

from ..core.diagonal_movement import DiagonalMovement
Expand Down Expand Up @@ -86,7 +85,7 @@ def check_neighbors(
"""

# pop node with minimum 'f' value
node = heapq.heappop(open_list)
node = open_list.pop_node()
node.closed = True

# if reached the end position, construct the path and return it
Expand Down
5 changes: 3 additions & 2 deletions pathfinding3d/finder/bi_a_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ..core.diagonal_movement import DiagonalMovement
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode
from .a_star import AStarFinder
from .finder import BY_END, BY_START, MAX_RUNS, TIME_LIMIT
Expand Down Expand Up @@ -71,12 +72,12 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i
self.start_time = time.time() # execution time limitation
self.runs = 0 # count number of iterations

start_open_list = [start]
start_open_list = SimpleHeap(start, grid)
start.g = 0
start.f = 0
start.opened = BY_START

end_open_list = [end]
end_open_list = SimpleHeap(end, grid)
end.g = 0
end.f = 0
end.opened = BY_END
Expand Down
4 changes: 2 additions & 2 deletions pathfinding3d/finder/breadth_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def check_neighbors(
List[GridNode]
path
"""
node = open_list.pop(0)
node = open_list.pop_node()
node.closed = True

if node == end:
Expand All @@ -85,6 +85,6 @@ def check_neighbors(
if neighbor.closed or neighbor.opened:
continue

open_list.append(neighbor)
open_list.push_node(neighbor)
neighbor.opened = True
neighbor.parent = node
11 changes: 6 additions & 5 deletions pathfinding3d/finder/finder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import heapq # used for the so colled "open list" that stores known nodes
import time # for time limitation
from typing import Callable, List, Optional, Tuple, Union

from ..core.diagonal_movement import DiagonalMovement
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode

# max. amount of tries we iterate until we abort the search
Expand Down Expand Up @@ -180,21 +180,22 @@ def process_node(
ng = parent.g + grid.calc_cost(parent, node, self.weighted)

if not node.opened or ng < node.g:
old_f = node.f
node.g = ng
node.h = node.h or self.apply_heuristic(node, end)
# f is the estimated total cost from start to goal
node.f = node.g + node.h
node.parent = parent

if not node.opened:
heapq.heappush(open_list, node)
open_list.push_node(node)
node.opened = open_value
else:
# the node can be reached with smaller cost.
# Since its f value has been updated, we have to
# update its position in the open list
open_list.remove(node)
heapq.heappush(open_list, node)
open_list.remove_node(node, old_f)
open_list.push_node(node)

def check_neighbors(
self,
Expand Down Expand Up @@ -251,7 +252,7 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i
self.runs = 0 # count number of iterations
start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
Expand Down
11 changes: 4 additions & 7 deletions pathfinding3d/finder/msp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import heapq
import time
from collections import deque, namedtuple
from typing import List, Tuple

from ..core import heuristic
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode
from ..finder.finder import Finder

Expand Down Expand Up @@ -62,23 +62,20 @@ def itertree(self, grid: Grid, start: GridNode):

start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
self.keep_running()

node = heapq.nsmallest(1, open_list)[0]
open_list.remove(node)
node = open_list.pop_node()
node.closed = True
yield node

neighbors = self.find_neighbors(grid, node)
for neighbor in neighbors:
if not neighbor.closed:
self.process_node(
grid, neighbor, node, end, open_list, open_value=True
)
self.process_node(grid, neighbor, node, end, open_list, open_value=True)

def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, int]:
"""
Expand Down
6 changes: 3 additions & 3 deletions test/test_connect_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

PATH = [
(2, 0, 0, 0),
(2, 0, 1, 0),
(2, 0, 2, 0),
(2, 1, 2, 0),
(2, 1, 0, 0),
(2, 2, 0, 0),
(2, 2, 1, 0),
(2, 2, 2, 0),
# move to grid 1
(2, 2, 2, 1),
Expand Down
26 changes: 26 additions & 0 deletions test/test_heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathfinding3d.core.grid import Grid
from pathfinding3d.core.heap import SimpleHeap


def test_heap():
grid = Grid(width=10, height=10, depth=10)
start = grid.node(0, 0, 0)
open_list = SimpleHeap(start, grid)

# Test pop
assert open_list.pop_node() == start
assert len(open_list) == 0

# Test push
open_list.push_node(grid.node(1, 1, 1))
open_list.push_node(grid.node(1, 1, 2))
open_list.push_node(grid.node(1, 1, 3))

# Test removal and pop
assert len(open_list) == 3
open_list.remove_node(grid.node(1, 1, 2), 0)
assert len(open_list) == 3

assert open_list.pop_node() == grid.node(1, 1, 1)
assert open_list.pop_node() == grid.node(1, 1, 3)
assert len(open_list) == 0
Loading

0 comments on commit 8588618

Please sign in to comment.