Skip to content

Commit

Permalink
generlizing spatial bins with rtree
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Feb 18, 2024
1 parent e6398e8 commit afbc23d
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions python/neuroglancer/write_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import struct
from collections import defaultdict
from collections.abc import Sequence
from itertools import product
from typing import Literal, NamedTuple, Optional, SupportsInt, Union, cast
import math
import numpy as np
import rtree

try:
import tensorstore as ts
Expand Down Expand Up @@ -57,7 +59,6 @@ class Annotation(NamedTuple):
id: int
encoded: bytes
relationships: Sequence[Sequence[int]]
geometry: Sequence[float]


_PROPERTY_DTYPES: dict[
Expand Down Expand Up @@ -289,6 +290,9 @@ def __init__(
shape=(self.rank,), fill_value=float("-inf"), dtype=np.float32
)
self.related_annotations = [{} for _ in self.relationships]
p = rtree.index.Property()
p.dimension = self.rank
self.rtree = rtree.index.Index(properties=p)

def get_chunk_index(self, coords):
return tuple((coords // self.chunk_size).astype(np.int32))
Expand Down Expand Up @@ -420,9 +424,15 @@ def _add_obj(
id = len(self.annotations)

annotation = Annotation(
id=id, encoded=encoded.tobytes(), relationships=related_ids, geometry=coords
id=id, encoded=encoded.tobytes(), relationships=related_ids
)

spatial_points = coords[: n_spatial_coords * self.rank]
spatial_points = np.reshape(spatial_points, (self.rank, n_spatial_coords))
lower_bound = np.min(spatial_points, axis=1)
upper_bound = np.max(spatial_points, axis=1)
self.rtree.insert(id, tuple(lower_bound) + tuple(upper_bound), obj=annotation)

# for i in range(int(n_spatial_coords)):
# chunk_index = self.get_chunk_index(
# np.array(coords[i * self.rank : (i + 1) * self.rank])
Expand Down Expand Up @@ -550,11 +560,24 @@ def write(self, path: Union[str, pathlib.Path]):
os.makedirs(os.path.join(path, f"rel_{relationship}"), exist_ok=True)
os.makedirs(os.path.join(path, "by_id"), exist_ok=True)
os.makedirs(os.path.join(path, "spatial0"), exist_ok=True)
for ann in self.annotations:
if self.annotation_type == "point":
# get the first self.rank elements of the geometry array
chunk_index = self.get_chunk_index(ann.geometry - self.lower_bound)
self.annotations_by_chunk[chunk_index].append(ann)

# Generate all combinations of coordinates
coordinates = product(*(range(n) for n in num_chunks))

# Iterate over the grid
for cell in coordinates:
# Query the rtree index for annotations in the current chunk
lower_bound = self.lower_bound + np.array(cell) * self.chunk_size
upper_bound = lower_bound + self.chunk_size
coords = np.concatenate((lower_bound, upper_bound))
chunk_annotations = self.rtree.intersection(tuple(coords), objects="raw")
self.annotations_by_chunk[cell] = list(chunk_annotations)

# for ann in self.annotations:
# if self.annotation_type == "point":
# # get the first self.rank elements of the geometry array
# chunk_index = self.get_chunk_index(ann.geometry - self.lower_bound)
# self.annotations_by_chunk[chunk_index].append(ann)

total_chunks = len(self.annotations_by_chunk)
spatial_sharding_spec = choose_output_spec(
Expand Down

0 comments on commit afbc23d

Please sign in to comment.