Skip to content

Commit

Permalink
voxel reassign docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aelefebv committed Oct 3, 2024
1 parent 5766d60 commit 310bc06
Showing 1 changed file with 201 additions and 0 deletions.
201 changes: 201 additions & 0 deletions nellie/tracking/voxel_reassignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,73 @@


class VoxelReassigner:
"""
A class for voxel reassignment across time points using forward and backward flow interpolation.
Attributes
----------
im_info : ImInfo
An object containing image metadata and memory-mapped image data.
num_t : int
Number of timepoints in the image.
flow_interpolator_fw : FlowInterpolator
Flow interpolator for forward timepoint matching.
flow_interpolator_bw : FlowInterpolator
Flow interpolator for backward timepoint matching.
running_matches : list
List of running matches for voxel reassignment between timepoints.
voxel_matches_path : str or None
Path to save the voxel matches array.
branch_label_memmap : np.ndarray or None
Memory-mapped data for relabeled branches.
obj_label_memmap : np.ndarray or None
Memory-mapped data for object labels.
reassigned_branch_memmap : np.ndarray or None
Memory-mapped data for reassigned branches.
reassigned_obj_memmap : np.ndarray or None
Memory-mapped data for reassigned object labels.
viewer : Any
Optional viewer (e.g., for visualization purposes).
Methods
-------
_match_forward(flow_interpolator, vox_prev, vox_next, t)
Matches voxels forward using flow interpolation.
_match_backward(flow_interpolator, vox_next, vox_prev, t)
Matches voxels backward using flow interpolation.
_match_voxels_to_centroids(coords_real, coords_interpx)
Matches voxels to centroids using nearest neighbor search.
_assign_unique_matches(vox_prev_matches, vox_next_matches, distances)
Assigns unique matches between timepoint voxels based on minimum distances.
_distance_threshold(vox_prev_matched, vox_next_matched)
Filters voxel matches by applying a distance threshold.
match_voxels(vox_prev, vox_next, t)
Matches voxels between two consecutive timepoints using forward and backward interpolation.
_get_t()
Gets the number of timepoints in the dataset.
_allocate_memory()
Allocates memory for voxel reassignment data, including memory-mapped arrays.
_run_frame(t, all_mask_coords, reassigned_memmap)
Runs the voxel reassignment process for a single timepoint.
_run_reassignment(label_type)
Runs the voxel reassignment process for all frames, for either branch or object labels.
run()
Main method to execute voxel reassignment for both branch and object labels.
"""
def __init__(self, im_info: ImInfo, num_t=None,
viewer=None):
"""
Initializes the VoxelReassigner class with image metadata and timepoints.
Parameters
----------
im_info : ImInfo
Image metadata and memory-mapped data.
num_t : int, optional
Number of timepoints in the dataset. If None, it is inferred from the image metadata (default is None).
viewer : Any, optional
Optional viewer for visualization purposes (default is None).
"""
self.im_info = im_info

if self.im_info.no_t:
Expand All @@ -35,6 +100,25 @@ def __init__(self, im_info: ImInfo, num_t=None,
self.viewer = viewer

def _match_forward(self, flow_interpolator, vox_prev, vox_next, t):
"""
Matches voxels forward in time using flow interpolation.
Parameters
----------
flow_interpolator : FlowInterpolator
Flow interpolator for forward voxel matching.
vox_prev : np.ndarray
Voxels from the previous timepoint.
vox_next : np.ndarray
Voxels from the next timepoint.
t : int
Current timepoint index.
Returns
-------
tuple
Arrays of matched voxels from the previous and next timepoints and valid distances between them.
"""
vectors_interpx_prev = flow_interpolator.interpolate_coord(vox_prev, t)
if vectors_interpx_prev is None:
return [], [], []
Expand All @@ -60,6 +144,25 @@ def _match_forward(self, flow_interpolator, vox_prev, vox_next, t):
return vox_prev_matched_valid, vox_next_matched_valid, distances_valid

def _match_backward(self, flow_interpolator, vox_next, vox_prev, t):
"""
Matches voxels backward in time using flow interpolation.
Parameters
----------
flow_interpolator : FlowInterpolator
Flow interpolator for backward voxel matching.
vox_next : np.ndarray
Voxels from the next timepoint.
vox_prev : np.ndarray
Voxels from the previous timepoint.
t : int
Current timepoint index.
Returns
-------
tuple
Arrays of matched voxels from the previous and next timepoints and valid distances between them.
"""
# interpolate flow vectors to all voxels in t1 from centroids derived from t0 centroids + t0 flow vectors
vectors_interpx_prev = flow_interpolator.interpolate_coord(vox_next, t)
if vectors_interpx_prev is None:
Expand All @@ -85,13 +188,45 @@ def _match_backward(self, flow_interpolator, vox_next, vox_prev, t):
return vox_prev_matched_valid, vox_next_matched_valid, distances_valid

def _match_voxels_to_centroids(self, coords_real, coords_interpx):
"""
Matches real voxel coordinates to interpolated centroids using nearest neighbor search.
Parameters
----------
coords_real : np.ndarray
Real voxel coordinates.
coords_interpx : np.ndarray
Interpolated centroid coordinates.
Returns
-------
tuple
Arrays of distances and indices of matched centroids.
"""
coords_interpx = np.array(coords_interpx) * self.flow_interpolator_fw.scaling
coords_real = np.array(coords_real) * self.flow_interpolator_fw.scaling
tree = cKDTree(coords_real)
dist, idx = tree.query(coords_interpx, k=1, workers=-1)
return dist, idx

def _assign_unique_matches(self, vox_prev_matches, vox_next_matches, distances):
"""
Assigns unique voxel matches based on the minimum distance criteria.
Parameters
----------
vox_prev_matches : np.ndarray
Array of matched voxels from the previous timepoint.
vox_next_matches : np.ndarray
Array of matched voxels from the next timepoint.
distances : np.ndarray
Array of distances between matched voxels.
Returns
-------
tuple
Arrays of uniquely matched voxels for the previous and next timepoints.
"""
# create a dict where the key is a voxel in t1, and the value is a list of distances and t0 voxels matched to it
vox_next_dict = {}
for match_idx, match_next in enumerate(vox_next_matches):
Expand Down Expand Up @@ -137,6 +272,21 @@ def _assign_unique_matches(self, vox_prev_matches, vox_next_matches, distances):
return vox_prev_matches_final, vox_next_matches_final

def _distance_threshold(self, vox_prev_matched, vox_next_matched):
"""
Filters voxel matches by applying a distance threshold.
Parameters
----------
vox_prev_matched : np.ndarray
Array of matched voxels from the previous timepoint.
vox_next_matched : np.ndarray
Array of matched voxels from the next timepoint.
Returns
-------
tuple
Arrays of valid voxel matches and corresponding distances.
"""
distances = np.linalg.norm((vox_prev_matched - vox_next_matched) * self.flow_interpolator_fw.scaling, axis=1)
distance_mask = distances < self.flow_interpolator_fw.max_distance_um
vox_prev_matched_valid = vox_prev_matched[distance_mask]
Expand All @@ -145,6 +295,23 @@ def _distance_threshold(self, vox_prev_matched, vox_next_matched):
return vox_prev_matched_valid, vox_next_matched_valid, distances_valid

def match_voxels(self, vox_prev, vox_next, t):
"""
Matches voxels between two consecutive timepoints using both forward and backward interpolation.
Parameters
----------
vox_prev : np.ndarray
Voxels from the previous timepoint.
vox_next : np.ndarray
Voxels from the next timepoint.
t : int
Current timepoint index.
Returns
-------
tuple
Arrays of matched voxels from the previous and next timepoints.
"""
# forward interpolation:
# from t0 voxels and interpolated flow, get t1 centroids.
# match nearby t1 voxels to t1 centroids, which are linked to t0 voxels.
Expand Down Expand Up @@ -201,6 +368,9 @@ def match_voxels(self, vox_prev, vox_next, t):
return np.array(vox_prev_matches_unique), np.array(vox_next_matches_unique)

def _get_t(self):
"""
Gets the number of timepoints from the image metadata or sets it if not provided.
"""
if self.num_t is None:
if self.im_info.no_t:
self.num_t = 1
Expand All @@ -210,6 +380,9 @@ def _get_t(self):
return

def _allocate_memory(self):
"""
Allocates memory for voxel reassignment, including initializing memory-mapped arrays for branch and object labels.
"""
logger.debug('Allocating memory for voxel reassignment.')
self.voxel_matches_path = self.im_info.pipeline_paths['voxel_matches']

Expand All @@ -230,6 +403,23 @@ def _allocate_memory(self):
return_memmap=True)

def _run_frame(self, t, all_mask_coords, reassigned_memmap):
"""
Reassigns voxels in a single timepoint based on voxel matches with the previous timepoint.
Parameters
----------
t : int
Current timepoint index.
all_mask_coords : list
List of voxel coordinates for each timepoint.
reassigned_memmap : np.ndarray
Memory-mapped array for the reassigned labels.
Returns
-------
bool
Returns True if no matches are found, otherwise False.
"""
logger.info(f'Reassigning pixels in frame {t + 1} of {self.num_t - 1}')

vox_prev = all_mask_coords[t]
Expand All @@ -250,6 +440,14 @@ def _run_frame(self, t, all_mask_coords, reassigned_memmap):
return False

def _run_reassignment(self, label_type):
"""
Runs voxel reassignment for all frames based on the specified label type (either 'branch' or 'obj').
Parameters
----------
label_type : str
The label type, either 'branch' or 'obj'.
"""
# todo, be able to specify which frame to start at.
if label_type == 'branch':
label_memmap = self.branch_label_memmap
Expand All @@ -272,6 +470,9 @@ def _run_reassignment(self, label_type):
break

def run(self):
"""
Main method to execute voxel reassignment for both branch and object labels.
"""
if self.im_info.no_t:
return
self._get_t()
Expand Down

0 comments on commit 310bc06

Please sign in to comment.