diff --git a/pathfinding3d/core/grid.py b/pathfinding3d/core/grid.py index 915ff2d..334c099 100644 --- a/pathfinding3d/core/grid.py +++ b/pathfinding3d/core/grid.py @@ -1,22 +1,19 @@ import math -from typing import List, Optional +from typing import List, Optional, Union + +import numpy as np from .diagonal_movement import DiagonalMovement from .node import GridNode -try: - import numpy as np - - USE_NUMPY = True -except ImportError: - USE_NUMPY = False +MatrixType = Optional[Union[List[List[List[int]]], np.ndarray]] def build_nodes( width: int, height: int, depth: int, - matrix: Optional[List] = None, + matrix: MatrixType = None, inverse: bool = False, grid_id: Optional[int] = None, ) -> List[List[List[GridNode]]]: @@ -32,7 +29,7 @@ def build_nodes( The height of the grid. depth : int The depth of the grid. - matrix : list, optional + matrix : MatrixType A 3D array of values (numbers or objects specifying weight) that determine how nodes are connected and if they are walkable. If no matrix is given, all nodes will be walkable. @@ -44,13 +41,11 @@ def build_nodes( Returns ------- - list + List A list of list of lists containing the nodes in the grid. """ nodes: List = [] - use_matrix = (isinstance(matrix, (tuple, list))) or ( - USE_NUMPY and isinstance(matrix, np.ndarray) and matrix.size > 0 - ) + use_matrix = matrix is not None for x in range(width): nodes.append([]) @@ -66,11 +61,7 @@ def build_nodes( weight = int(matrix[x][y][z]) if use_matrix else 1 walkable = weight <= 0 if inverse else weight >= 1 - nodes[x][y].append( - GridNode( - x=x, y=y, z=z, walkable=walkable, weight=weight, grid_id=grid_id - ) - ) + nodes[x][y].append(GridNode(x=x, y=y, z=z, walkable=walkable, weight=weight, grid_id=grid_id)) return nodes @@ -80,7 +71,7 @@ def __init__( width: int = 0, height: int = 0, depth: int = 0, - matrix: Optional[List] = None, + matrix: MatrixType = None, grid_id: Optional[int] = None, inverse: bool = False, ): @@ -95,7 +86,7 @@ def __init__( The height of the grid. depth : int, optional The depth of the grid. - matrix : list, optional + matrix : MatrixType A 3D array of values (numbers or objects specifying weight) that determine how nodes are connected and if they are walkable. If no matrix is given, all nodes will be walkable. @@ -103,21 +94,27 @@ def __init__( If true, all values in the matrix that are not 0 will be considered walkable. Otherwise all values that are 0 will be considered walkable. """ - self.width = width - self.height = height - self.depth = depth - if isinstance(matrix, (tuple, list)) or ( - USE_NUMPY and isinstance(matrix, np.ndarray) and (matrix.size > 0) - ): - self.width = len(matrix) - self.height = len(matrix[0]) if self.width > 0 else 0 - self.depth = len(matrix[0][0]) if self.height > 0 else 0 - if self.width > 0 and self.height > 0 and self.depth > 0: - self.nodes = build_nodes( - self.width, self.height, self.depth, matrix, inverse, grid_id - ) - else: - self.nodes = [[[]]] + self.width, self.height, self.depth = self._validate_dimensions(width, height, depth, matrix) + self.nodes = ( + build_nodes(self.width, self.height, self.depth, matrix, inverse, grid_id) + if self.is_valid_grid() + else [[[]]] + ) + + def _validate_dimensions(self, width: int, height: int, depth: int, matrix: MatrixType) -> tuple: + if matrix is not None: + if not ( + isinstance(matrix, (list, np.ndarray)) + and len(matrix) > 0 + and len(matrix[0]) > 0 + and len(matrix[0][0]) > 0 + ): + raise ValueError("Provided matrix is not a 3D structure or is empty.") + return len(matrix), len(matrix[0]), len(matrix[0][0]) + return width, height, depth + + def is_valid_grid(self) -> bool: + return self.width > 0 and self.height > 0 and self.depth > 0 def node(self, x: int, y: int, z: int) -> Optional[GridNode]: """ @@ -179,9 +176,7 @@ def walkable(self, x: int, y: int, z: int) -> bool: """ return self.inside(x, y, z) and self.nodes[x][y][z].walkable - def calc_cost( - self, node_a: GridNode, node_b: GridNode, weighted: bool = False - ) -> float: + def calc_cost(self, node_a: GridNode, node_b: GridNode, weighted: bool = False) -> float: """ Get the distance between current node and the neighbor (cost) diff --git a/pathfinding3d/finder/a_star.py b/pathfinding3d/finder/a_star.py index ca25dbe..a054799 100644 --- a/pathfinding3d/finder/a_star.py +++ b/pathfinding3d/finder/a_star.py @@ -86,8 +86,7 @@ def check_neighbors( """ # pop node with minimum 'f' value - node = heapq.nsmallest(1, open_list)[0] - open_list.remove(node) + node = heapq.heappop(open_list) node.closed = True # if reached the end position, construct the path and return it diff --git a/test/test_connect_grids.py b/test/test_connect_grids.py index 2f91e26..2c4454d 100644 --- a/test/test_connect_grids.py +++ b/test/test_connect_grids.py @@ -4,9 +4,9 @@ PATH = [ (2, 0, 0, 0), - (2, 1, 0, 0), - (2, 2, 0, 0), - (2, 2, 1, 0), + (2, 0, 1, 0), + (2, 0, 2, 0), + (2, 1, 2, 0), (2, 2, 2, 0), # move to grid 1 (2, 2, 2, 1),