Skip to content

Commit

Permalink
Replace query_in_roi with query_{nodes,edges}_in_roi
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Oct 3, 2024
1 parent d25696d commit c98dbbd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
20 changes: 4 additions & 16 deletions spatial_graph/spatial_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,11 @@ def add_edges(self, edges, **kwargs):
self._edge_rtree.insert_lines(edges, starts, ends)
super().add_edges(edges, **kwargs)

def query_in_roi(self, roi, edge_inclusion=None):
nodes = self._node_rtree.search(roi[0], roi[1])
def query_nodes_in_roi(self, roi):
return self._node_rtree.search(roi[0], roi[1])

if not edge_inclusion:
return nodes

if edge_inclusion not in SpatialGraph.edge_inclusion_values:
raise ValueError("edge_inclusion has to be in {edge_inclusion_values}")

edges = self.edges_by_nodes(nodes)

if edge_inclusion == "incident":
return nodes, edges
elif edge_inclusion == "leaving":
return nodes, [] # TODO
elif edge_inclusion == "entering":
return nodes, [] # TODO
def query_edges_in_roi(self, roi):
return self._edge_rtree.search(roi[0], roi[1])

def query_nearest_nodes(self, point, k, return_distances=False):
return self._node_rtree.nearest(point, k, return_distances)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_spatial_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ def test_roi_query():
score=np.array([0.2, 0.3, 0.4], dtype="float32"),
)

nodes, edges = graph.query_in_roi(
np.array([[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]), edge_inclusion="incident"
)
nodes = graph.query_nodes_in_roi(np.array([[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]))
edges = graph.query_edges_in_roi(np.array([[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]))

assert list(sorted(nodes)) == [1, 2]
np.testing.assert_array_equal(edges, [[1, 2], [1, 5], [2, 1]])
np.testing.assert_array_equal(edges, [[1, 2], [5, 1]])

def test_delete():
graph = sg.SpatialGraph(
Expand Down

0 comments on commit c98dbbd

Please sign in to comment.