Skip to content

Commit

Permalink
Merge pull request #10 from harisankar95/theta-star
Browse files Browse the repository at this point in the history
Theta star
  • Loading branch information
harisankar95 authored Jan 27, 2024
2 parents f67478b + 4d45495 commit 282a94b
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 17 deletions.
1 change: 1 addition & 0 deletions pathfinding3d/core/diagonal_movement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class DiagonalMovement:
"""Enum for diagonal movement"""

always = 1
never = 2
if_at_most_one_obstacle = 3
Expand Down
1 change: 1 addition & 0 deletions pathfinding3d/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Grid:
"""
A grid represents the map (as 3d-list of nodes).
"""

def __init__(
self,
width: int = 0,
Expand Down
1 change: 1 addition & 0 deletions pathfinding3d/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Node:
tested : bool
Used for IDA* and Jump-Point-Search.
"""

__slots__ = ["h", "g", "f", "opened", "closed", "parent", "retain_count", "tested"]

def __init__(self):
Expand Down
82 changes: 82 additions & 0 deletions pathfinding3d/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,85 @@ def smoothen_path(grid: Grid, path: List[Coords], use_raytrace: bool = False) ->

new_path.append(path[-1])
return new_path


def line_of_sight(grid: Grid, node1: GridNode, node2: GridNode) -> bool:
"""
Check if there is a line of sight between two nodes using Bresenham's algorithm
Parameters
----------
grid : Grid
The grid on which the nodes exist
node1 : GridNode
The first node
node2 : GridNode
The second node
Returns
-------
bool
True if there is a line of sight between the two nodes, False otherwise
"""
x0, y0, z0 = node1.x, node1.y, node1.z
x1, y1, z1 = node2.x, node2.y, node2.z

dx = abs(x1 - x0)
dy = abs(y1 - y0)
dz = abs(z1 - z0)
sx = 1 if x0 < x1 else -1
sy = 1 if y0 < y1 else -1
sz = 1 if z0 < z1 else -1

# Driving axis is X-axis
if dx >= dy and dx >= dz:
err_1 = 2 * dy - dx
err_2 = 2 * dz - dx
while x0 != x1:
x0 += sx
if err_1 > 0:
y0 += sy
err_1 -= 2 * dx
if err_2 > 0:
z0 += sz
err_2 -= 2 * dx
err_1 += 2 * dy
err_2 += 2 * dz
if not grid.walkable(x0, y0, z0):
return False

# Driving axis is Y-axis
elif dy >= dx and dy >= dz:
err_1 = 2 * dx - dy
err_2 = 2 * dz - dy
while y0 != y1:
y0 += sy
if err_1 > 0:
x0 += sx
err_1 -= 2 * dy
if err_2 > 0:
z0 += sz
err_2 -= 2 * dy
err_1 += 2 * dx
err_2 += 2 * dz
if not grid.walkable(x0, y0, z0):
return False

# Driving axis is Z-axis
else:
err_1 = 2 * dy - dz
err_2 = 2 * dx - dz
while z0 != z1:
z0 += sz
if err_1 > 0:
y0 += sy
err_1 -= 2 * dz
if err_2 > 0:
x0 += sx
err_2 -= 2 * dz
err_1 += 2 * dy
err_2 += 2 * dx
if not grid.walkable(x0, y0, z0):
return False

return True
1 change: 1 addition & 0 deletions pathfinding3d/core/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class World:
"""
A world connects grids but can have multiple grids.
"""

def __init__(self, grids: Dict[int, Grid]):
"""
Initialize a new world.
Expand Down
1 change: 1 addition & 0 deletions pathfinding3d/finder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
"finder",
"ida_star",
"msp",
"theta_star",
]
100 changes: 100 additions & 0 deletions pathfinding3d/finder/theta_star.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
from typing import Callable, List, Union

from pathfinding3d.core.diagonal_movement import DiagonalMovement
from pathfinding3d.core.grid import Grid
from pathfinding3d.core.node import GridNode
from pathfinding3d.core.util import line_of_sight
from pathfinding3d.finder.a_star import AStarFinder
from pathfinding3d.finder.finder import MAX_RUNS, TIME_LIMIT


class ThetaStarFinder(AStarFinder):
def __init__(
self,
heuristic: Callable = None,
weight: int = 1,
diagonal_movement: int = DiagonalMovement.always,
time_limit: float = TIME_LIMIT,
max_runs: Union[int, float] = MAX_RUNS,
):
"""
Find shortest path using Theta* algorithm
Diagonal movement is forced to always.
Parameters
----------
heuristic : Callable
heuristic used to calculate distance of 2 points
weight : int
weight for the edges
diagonal_movement : int
if diagonal movement is allowed
(see enum in diagonal_movement)
time_limit : float
max. runtime in seconds
max_runs : int
max. amount of tries until we abort the search
(optional, only if we enter huge grids and have time constrains)
<=0 means there are no constrains and the code might run on any
large map.
"""

if diagonal_movement != DiagonalMovement.always:
logging.warning("Diagonal movement is forced to always for Theta*")
diagonal_movement = DiagonalMovement.always

super().__init__(
heuristic=heuristic,
weight=weight,
diagonal_movement=diagonal_movement,
time_limit=time_limit,
max_runs=max_runs,
)

def process_node(
self,
grid: Grid,
node: GridNode,
parent: GridNode,
end: GridNode,
open_list: List,
open_value: int = 1,
):
"""
Check if we can reach the grandparent node directly from the current node
and if so, skip the parent.
Parameters
----------
grid : Grid
grid that stores all possible steps/tiles as 3D-list
node : GridNode
the node we like to test
parent : GridNode
the parent node (of the current node we like to test)
end : GridNode
the end point to calculate the cost of the path
open_list : List
the list that keeps track of our current path
open_value : bool
needed if we like to set the open list to something
else than True (used for bi-directional algorithms)
"""
# Check for line of sight to the grandparent
if parent and parent.parent and line_of_sight(grid, node, parent.parent):
ng = parent.parent.g + grid.calc_cost(parent.parent, node, self.weighted)
if not node.opened or ng < node.g:
old_f = node.f
node.g = ng
node.h = node.h or self.apply_heuristic(node, end)
node.f = node.g + node.h
node.parent = parent.parent
if not node.opened:
open_list.push_node(node)
node.opened = open_value
else:
open_list.remove_node(node, old_f)
open_list.push_node(node)
else:
super().process_node(grid, node, parent, end, open_list)
13 changes: 0 additions & 13 deletions test/test_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,3 @@ def test_check_neighbors_raises_exception():

with pytest.raises(NotImplementedError):
finder.check_neighbors(start, end, grid, open_list)


def test_msp():
"""
Test that the minimum spanning tree finder returns all nodes.
"""
matrix = np.array(np.ones((3, 3, 3)))
grid = Grid(matrix=matrix)

start = grid.node(0, 0, 0)

finder = MinimumSpanningTree()
assert finder.tree(grid, start).sort() == [node for row in grid.nodes for col in row for node in col].sort()
24 changes: 23 additions & 1 deletion test/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathfinding3d.core.diagonal_movement import DiagonalMovement
from pathfinding3d.core.grid import Grid
from pathfinding3d.core.node import GridNode
from pathfinding3d.core.util import expand_path
from pathfinding3d.finder.a_star import AStarFinder
from pathfinding3d.finder.best_first import BestFirst
from pathfinding3d.finder.bi_a_star import BiAStarFinder
Expand All @@ -12,6 +13,7 @@
from pathfinding3d.finder.finder import ExecutionRunsException, ExecutionTimeException
from pathfinding3d.finder.ida_star import IDAStarFinder
from pathfinding3d.finder.msp import MinimumSpanningTree
from pathfinding3d.finder.theta_star import ThetaStarFinder

finders = [
AStarFinder,
Expand All @@ -21,6 +23,7 @@
IDAStarFinder,
BreadthFirstFinder,
MinimumSpanningTree,
ThetaStarFinder,
]
TIME_LIMIT = 10 # give it a 10 second limit.

Expand All @@ -29,6 +32,7 @@
BiAStarFinder,
DijkstraFinder,
MinimumSpanningTree,
ThetaStarFinder,
]

SIMPLE_MATRIX = np.zeros((5, 5, 5))
Expand Down Expand Up @@ -64,6 +68,8 @@ def test_path():
start = grid.node(0, 0, 0)
end = grid.node(4, 4, 0)
for find in finders:
if find == ThetaStarFinder:
continue
grid.cleanup()
finder = find(time_limit=TIME_LIMIT)
path_, runs = finder.find_path(start, end, grid)
Expand All @@ -84,6 +90,8 @@ def test_weighted_path():
start = grid.node(0, 0, 0)
end = grid.node(4, 4, 0)
for find in weighted_finders:
if find == ThetaStarFinder:
continue
grid.cleanup()
finder = find(time_limit=TIME_LIMIT)
path_, runs = finder.find_path(start, end, grid)
Expand Down Expand Up @@ -114,10 +122,11 @@ def test_path_diagonal():
path.append((node.x, node.y, node.z))
elif isinstance(node, tuple):
path.append((node[0], node[1], node[2]))

print(find.__name__)
print(f"path: {path}")
print(f"length: {len(path)}, runs: {runs}")
if find == ThetaStarFinder:
path = expand_path(path)
assert len(path) == 5


Expand Down Expand Up @@ -149,3 +158,16 @@ def test_time():
print(f"path: {path}")
msg = f"{finder.__class__.__name__} took too long"
assert finder.runs == 1, msg


def test_msp():
"""
Test that the minimum spanning tree finder returns all nodes.
"""
matrix = np.array(np.ones((3, 3, 3)))
grid = Grid(matrix=matrix)

start = grid.node(0, 0, 0)

finder = MinimumSpanningTree()
assert finder.tree(grid, start).sort() == [node for row in grid.nodes for col in row for node in col].sort()
56 changes: 53 additions & 3 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import pytest

from pathfinding3d.core.grid import Grid
from pathfinding3d.core.util import bresenham, expand_path, raytrace, smoothen_path
from pathfinding3d.core.node import GridNode
from pathfinding3d.core.util import (
bresenham,
expand_path,
line_of_sight,
raytrace,
smoothen_path,
)


def test_bresenham():
Expand Down Expand Up @@ -80,10 +89,10 @@ def test_expand_path():
test expand_path function
"""
# Test with empty path
assert expand_path([]) == []
assert not expand_path([])

# Test with one point path
assert expand_path([[0, 0, 0]]) == []
assert not expand_path([[0, 0, 0]])

# Test with two points path
assert expand_path([[0, 0, 0], [1, 1, 1]]) == [[0, 0, 0], [1, 1, 1]]
Expand All @@ -96,3 +105,44 @@ def test_expand_path():
[3, 2, 2],
[4, 2, 2],
]


@pytest.fixture
def grid():
# Create a 5x5x5 grid with all nodes walkable
grid = Grid(5, 5, 5)
for x in range(5):
for y in range(5):
for z in range(5):
grid.nodes[x][y][z].walkable = True
return grid


def test_line_of_sight_self(grid):
"""
test line_of_sight function with self
"""
# Test with self
node = GridNode(0, 0, 0)
assert line_of_sight(grid, node, node)


def test_line_of_sight_clear(grid):
"""
test line_of_sight function with clear line of sight
"""
# Test with clear line of sight
node1 = GridNode(0, 0, 0)
node2 = GridNode(4, 4, 4)
assert line_of_sight(grid, node1, node2)


def test_line_of_sight_obstacle(grid):
"""
test line_of_sight function with obstacle
"""
# Test with obstacle
node1 = GridNode(0, 0, 0)
node2 = GridNode(4, 4, 4)
grid.nodes[2][2][2].walkable = False
assert not line_of_sight(grid, node1, node2)

0 comments on commit 282a94b

Please sign in to comment.