Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update heap #5

Merged
merged 6 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/03_view_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
elif isinstance(node, tuple):
path.append([node[0], node[1], node[2]])
print(f"path: {path}")


# visualize path in open3d
if USE_OPEN3D:
Expand All @@ -51,16 +51,16 @@
xyz_pt = np.stack(obstacle_indices, axis=-1).astype(float)
colors = np.zeros((xyz_pt.shape[0], 3))
colors[:, 2] = obstacle_indices[2] / np.max(obstacle_indices[2])

# Prepare start and end colors
start_color = np.array([[1.0, 0, 0]]) # Red
end_color = np.array([[0, 1.0, 0]]) # Green
end_color = np.array([[0, 1.0, 0]]) # Green
path_colors = np.full((len(path) - 2, 3), [0.7, 0.7, 0.7]) # Grey for the path

# Combine points and colors
xyz_pt = np.concatenate((xyz_pt, [start_pt], [end_pt], path[1:-1]))
colors = np.concatenate((colors, start_color, end_color, path_colors))

# Create and visualize the point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz_pt)
Expand Down
143 changes: 143 additions & 0 deletions pathfinding3d/core/heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Simple heap with ordering and removal.
Inspired from https://github.com/brean/python-pathfinding/pull/54
Original author: https://github.com/peterchenadded
"""
import heapq
from typing import Callable, Union

from .grid import Grid
from .node import GridNode
from .world import World


class SimpleHeap:
"""
A simple implementation of a heap data structure optimized for pathfinding.
It maintains an open list of nodes, a status for each node, and a function to retrieve nodes.
"""

def __init__(self, node: GridNode, grid: Union[Grid, World]):
"""
Initializes the SimpleHeap with a given node and grid.

Parameters
----------
node : GridNode
The initial node to be added to the heap. This node should have an 'f' attribute representing its cost.
grid : Union[Grid, World]
The grid in which the nodes are located.
"""

self.grid = grid
self._get_node_tuple = self._determine_node_retrieval_function()
self._get_node = self._determine_node_function()
self.open_list = [self._get_node_tuple(node, 0)]
self.removed_node_tuples = set()
self.heap_order = {}
self.number_pushed = 0

def _determine_node_retrieval_function(self) -> Callable:
"""
Determines the node retrieval function based on the type of grid.

Returns
-------
function
A function that takes a node tuple and returns the corresponding node.

Raises
------
ValueError
If the grid is not of type Grid or World.
"""
if isinstance(self.grid, Grid):
return lambda node, heap_order: (node.f, heap_order, *node.identifier)

if isinstance(self.grid, World):
return lambda node, heap_order: (node.f, heap_order, *node.identifier)

raise ValueError("Unsupported grid type")

Check warning on line 60 in pathfinding3d/core/heap.py

View check run for this annotation

Codecov / codecov/patch

pathfinding3d/core/heap.py#L60

Added line #L60 was not covered by tests

def _determine_node_function(self) -> Callable:
"""
Determines the node function based on the type of grid.

Returns
-------
function
A function that takes a node tuple and returns the corresponding node.

Raises
------
ValueError
If the grid is not of type Grid or World.
"""

if isinstance(self.grid, Grid):
return lambda node_tuple: self.grid.node(*node_tuple[2:])

if isinstance(self.grid, World):
return lambda node_tuple: self.grid.grids[node_tuple[5]].node(*node_tuple[2:5])

raise ValueError("Unsupported grid type")

Check warning on line 83 in pathfinding3d/core/heap.py

View check run for this annotation

Codecov / codecov/patch

pathfinding3d/core/heap.py#L83

Added line #L83 was not covered by tests

def pop_node(self) -> GridNode:
"""
Pops the node with the lowest cost from the heap.

Returns
-------
GridNode
The node with the lowest cost.
"""
node_tuple = heapq.heappop(self.open_list)
while node_tuple in self.removed_node_tuples:
node_tuple = heapq.heappop(self.open_list)

return self._get_node(node_tuple)

def push_node(self, node: GridNode):
"""
Pushes a node to the heap.

Parameters
----------
node : GridNode
The node to be pushed to the heap.
"""
self.number_pushed = self.number_pushed + 1
node_tuple = self._get_node_tuple(node, self.number_pushed)

self.heap_order[node.identifier] = self.number_pushed

heapq.heappush(self.open_list, node_tuple)

def remove_node(self, node: GridNode, old_f: float):
"""
Remove the node from the heap.

This just stores it in a set and we just ignore the node if it does
get popped from the heap.

Parameters
----------
node : GridNode
The node to be removed from the heap.
old_f: float
The old cost of the node.
"""
heap_order = self.heap_order[node.identifier]
node_tuple = self._get_node_tuple(node, heap_order)
self.removed_node_tuples.add(node_tuple)

def __len__(self) -> int:
"""
Returns the length of the heap.

Returns
-------
int
The length of the heap.
"""
return len(self.open_list)
16 changes: 13 additions & 3 deletions pathfinding3d/core/node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
from typing import List, Optional
from typing import List, Optional, Tuple


@dataclasses.dataclass
class Node:
def __post_init__(self):
# values used in the finder
__slots__ = ["h", "g", "f", "opened", "closed", "parent", "retain_count", "tested"]

def __init__(self):
self.cleanup()

def __lt__(self, other: "Node") -> bool:
Expand Down Expand Up @@ -63,6 +64,15 @@ class GridNode(Node):

connections: Optional[List] = None

identifier: Optional[Tuple] = None

def __post_init__(self):
super().__init__()
# for heap
self.identifier: Tuple = (
(self.x, self.y, self.z) if self.grid_id is None else (self.x, self.y, self.z, self.grid_id)
)

def __iter__(self):
yield self.x
yield self.y
Expand Down
3 changes: 1 addition & 2 deletions pathfinding3d/finder/a_star.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import heapq # used for the so colled "open list" that stores known nodes
from typing import Callable, List, Optional, Tuple, Union

from ..core.diagonal_movement import DiagonalMovement
Expand Down Expand Up @@ -86,7 +85,7 @@ def check_neighbors(
"""

# pop node with minimum 'f' value
node = heapq.heappop(open_list)
node = open_list.pop_node()
node.closed = True

# if reached the end position, construct the path and return it
Expand Down
5 changes: 3 additions & 2 deletions pathfinding3d/finder/bi_a_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ..core.diagonal_movement import DiagonalMovement
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode
from .a_star import AStarFinder
from .finder import BY_END, BY_START, MAX_RUNS, TIME_LIMIT
Expand Down Expand Up @@ -71,12 +72,12 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i
self.start_time = time.time() # execution time limitation
self.runs = 0 # count number of iterations

start_open_list = [start]
start_open_list = SimpleHeap(start, grid)
start.g = 0
start.f = 0
start.opened = BY_START

end_open_list = [end]
end_open_list = SimpleHeap(end, grid)
end.g = 0
end.f = 0
end.opened = BY_END
Expand Down
4 changes: 2 additions & 2 deletions pathfinding3d/finder/breadth_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def check_neighbors(
List[GridNode]
path
"""
node = open_list.pop(0)
node = open_list.pop_node()
node.closed = True

if node == end:
Expand All @@ -85,6 +85,6 @@ def check_neighbors(
if neighbor.closed or neighbor.opened:
continue

open_list.append(neighbor)
open_list.push_node(neighbor)
neighbor.opened = True
neighbor.parent = node
11 changes: 6 additions & 5 deletions pathfinding3d/finder/finder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import heapq # used for the so colled "open list" that stores known nodes
import time # for time limitation
from typing import Callable, List, Optional, Tuple, Union

from ..core.diagonal_movement import DiagonalMovement
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode

# max. amount of tries we iterate until we abort the search
Expand Down Expand Up @@ -180,21 +180,22 @@ def process_node(
ng = parent.g + grid.calc_cost(parent, node, self.weighted)

if not node.opened or ng < node.g:
old_f = node.f
node.g = ng
node.h = node.h or self.apply_heuristic(node, end)
# f is the estimated total cost from start to goal
node.f = node.g + node.h
node.parent = parent

if not node.opened:
heapq.heappush(open_list, node)
open_list.push_node(node)
node.opened = open_value
else:
# the node can be reached with smaller cost.
# Since its f value has been updated, we have to
# update its position in the open list
open_list.remove(node)
heapq.heappush(open_list, node)
open_list.remove_node(node, old_f)
open_list.push_node(node)

def check_neighbors(
self,
Expand Down Expand Up @@ -251,7 +252,7 @@ def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, i
self.runs = 0 # count number of iterations
start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
Expand Down
11 changes: 4 additions & 7 deletions pathfinding3d/finder/msp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import heapq
import time
from collections import deque, namedtuple
from typing import List, Tuple

from ..core import heuristic
from ..core.grid import Grid
from ..core.heap import SimpleHeap
from ..core.node import GridNode
from ..finder.finder import Finder

Expand Down Expand Up @@ -62,23 +62,20 @@ def itertree(self, grid: Grid, start: GridNode):

start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
self.keep_running()

node = heapq.nsmallest(1, open_list)[0]
open_list.remove(node)
node = open_list.pop_node()
node.closed = True
yield node

neighbors = self.find_neighbors(grid, node)
for neighbor in neighbors:
if not neighbor.closed:
self.process_node(
grid, neighbor, node, end, open_list, open_value=True
)
self.process_node(grid, neighbor, node, end, open_list, open_value=True)

def find_path(self, start: GridNode, end: GridNode, grid: Grid) -> Tuple[List, int]:
"""
Expand Down
6 changes: 3 additions & 3 deletions test/test_connect_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

PATH = [
(2, 0, 0, 0),
(2, 0, 1, 0),
(2, 0, 2, 0),
(2, 1, 2, 0),
(2, 1, 0, 0),
(2, 2, 0, 0),
(2, 2, 1, 0),
(2, 2, 2, 0),
# move to grid 1
(2, 2, 2, 1),
Expand Down
26 changes: 26 additions & 0 deletions test/test_heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathfinding3d.core.grid import Grid
from pathfinding3d.core.heap import SimpleHeap


def test_heap():
grid = Grid(width=10, height=10, depth=10)
start = grid.node(0, 0, 0)
open_list = SimpleHeap(start, grid)

# Test pop
assert open_list.pop_node() == start
assert len(open_list) == 0

# Test push
open_list.push_node(grid.node(1, 1, 1))
open_list.push_node(grid.node(1, 1, 2))
open_list.push_node(grid.node(1, 1, 3))

# Test removal and pop
assert len(open_list) == 3
open_list.remove_node(grid.node(1, 1, 2), 0)
assert len(open_list) == 3

assert open_list.pop_node() == grid.node(1, 1, 1)
assert open_list.pop_node() == grid.node(1, 1, 3)
assert len(open_list) == 0
Loading