Skip to content

Commit

Permalink
labeling docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aelefebv committed Oct 3, 2024
1 parent 79197b5 commit ebbbd1a
Showing 1 changed file with 151 additions and 0 deletions.
151 changes: 151 additions & 0 deletions nellie/segmentation/labelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,82 @@


class Label:
"""
A class for semantic and instance segmentation of microscopy images using thresholding and signal-to-noise ratio (SNR) techniques.
Attributes
----------
im_info : ImInfo
An object containing image metadata and memory-mapped image data.
num_t : int
Number of timepoints in the image.
threshold : float or None
Intensity threshold for segmenting objects.
snr_cleaning : bool
Flag to enable or disable signal-to-noise ratio (SNR) based cleaning of segmented objects.
otsu_thresh_intensity : bool
Whether to apply Otsu's thresholding method to segment objects based on intensity.
im_memmap : np.ndarray or None
Memory-mapped original image data.
frangi_memmap : np.ndarray or None
Memory-mapped Frangi-filtered image data.
max_label_num : int
Maximum label number used for segmented objects.
min_z_radius_um : float
Minimum radius for Z-axis objects based on Z resolution, used for filtering objects in the Z dimension.
semantic_mask_memmap : np.ndarray or None
Memory-mapped mask for semantic segmentation.
instance_label_memmap : np.ndarray or None
Memory-mapped mask for instance segmentation.
shape : tuple
Shape of the segmented image.
debug : dict
Debugging information for tracking segmentation steps.
viewer : object or None
Viewer object for displaying status during processing.
Methods
-------
_get_t()
Determines the number of timepoints to process.
_allocate_memory()
Allocates memory for the original image, Frangi-filtered image, and instance segmentation masks.
_get_labels(frame)
Generates binary labels for segmented objects in a single frame based on thresholding.
_get_subtraction_mask(original_frame, labels_frame)
Creates a mask by subtracting labeled regions from the original frame.
_get_object_snrs(original_frame, labels_frame)
Calculates the signal-to-noise ratios (SNR) of segmented objects and removes objects with low SNR.
_run_frame(t)
Runs segmentation for a single timepoint in the image.
_run_segmentation()
Runs the full segmentation process for all timepoints in the image.
run()
Main method to execute the full segmentation process over the image data.
"""
def __init__(self, im_info: ImInfo,
num_t=None,
threshold=None,
snr_cleaning=False, otsu_thresh_intensity=False,
viewer=None):
"""
Initializes the Label object with image metadata and segmentation parameters.
Parameters
----------
im_info : ImInfo
An instance of the ImInfo class, containing metadata and paths for the image file.
num_t : int, optional
Number of timepoints to process. If None, defaults to the number of timepoints in the image.
threshold : float or None, optional
Intensity threshold for segmenting objects (default is None).
snr_cleaning : bool, optional
Whether to apply SNR-based cleaning to the segmented objects (default is False).
otsu_thresh_intensity : bool, optional
Whether to apply Otsu's method for intensity thresholding (default is False).
viewer : object or None, optional
Viewer object for displaying status during processing (default is None).
"""
self.im_info = im_info
self.num_t = num_t
if num_t is None and not self.im_info.no_t:
Expand All @@ -35,6 +106,11 @@ def __init__(self, im_info: ImInfo,
self.viewer = viewer

def _get_t(self):
"""
Determines the number of timepoints to process.
If `num_t` is not set and the image contains a temporal dimension, it sets `num_t` to the number of timepoints.
"""
if self.num_t is None:
if self.im_info.no_t:
self.num_t = 1
Expand All @@ -44,6 +120,11 @@ def _get_t(self):
return

def _allocate_memory(self):
"""
Allocates memory for the original image, Frangi-filtered image, and instance segmentation masks.
This method creates memory-mapped arrays for the original image data, Frangi-filtered data, and instance label data.
"""
logger.debug('Allocating memory for semantic segmentation.')
self.im_memmap = self.im_info.get_memmap(self.im_info.im_path)
self.frangi_memmap = self.im_info.get_memmap(self.im_info.pipeline_paths['im_preprocessed'])
Expand All @@ -56,6 +137,21 @@ def _allocate_memory(self):
return_memmap=True)

def _get_labels(self, frame):
"""
Generates binary labels for segmented objects in a single frame based on thresholding.
Uses triangle and Otsu thresholding to generate a mask, then labels the mask using connected component labeling.
Parameters
----------
frame : np.ndarray
The input frame to be segmented.
Returns
-------
tuple
A tuple containing the mask and the labeled objects.
"""
ndim = 2 if self.im_info.no_z else 3
footprint = ndi.generate_binary_structure(ndim, 1)

Expand All @@ -82,11 +178,41 @@ def _get_labels(self, frame):
return mask, labels

def _get_subtraction_mask(self, original_frame, labels_frame):
"""
Creates a mask by subtracting labeled regions from the original frame.
Parameters
----------
original_frame : np.ndarray
The original image data.
labels_frame : np.ndarray
The labeled objects in the frame.
Returns
-------
np.ndarray
The subtraction mask where labeled objects are removed.
"""
subtraction_mask = original_frame.copy()
subtraction_mask[labels_frame > 0] = 0
return subtraction_mask

def _get_object_snrs(self, original_frame, labels_frame):
"""
Calculates the signal-to-noise ratios (SNR) of segmented objects and removes objects with low SNR.
Parameters
----------
original_frame : np.ndarray
The original image data.
labels_frame : np.ndarray
Labeled objects in the frame.
Returns
-------
np.ndarray
The labels of objects that meet the SNR threshold.
"""
logger.debug('Calculating object SNRs.')
subtraction_mask = self._get_subtraction_mask(original_frame, labels_frame)
unique_labels = xp.unique(labels_frame)
Expand Down Expand Up @@ -123,6 +249,21 @@ def _get_object_snrs(self, original_frame, labels_frame):
return labels_frame

def _run_frame(self, t):
"""
Runs segmentation for a single timepoint in the image.
Applies thresholding and optional SNR-based cleaning to segment objects in a single timepoint.
Parameters
----------
t : int
Timepoint index.
Returns
-------
np.ndarray
The labeled objects for the given timepoint.
"""
logger.info(f'Running semantic segmentation, volume {t}/{self.num_t - 1}')
original_in_mem = xp.asarray(self.im_memmap[t, ...])
frangi_in_mem = xp.asarray(self.frangi_memmap[t, ...])
Expand All @@ -142,6 +283,11 @@ def _run_frame(self, t):
return labels

def _run_segmentation(self):
"""
Runs the full segmentation process for all timepoints in the image.
Segments each timepoint sequentially, applying thresholding, labeling, and optional SNR cleaning.
"""
for t in range(self.num_t):
if self.viewer is not None:
self.viewer.status = f'Extracting organelles. Frame: {t + 1} of {self.num_t}.'
Expand All @@ -156,6 +302,11 @@ def _run_segmentation(self):
self.instance_label_memmap.flush()

def run(self):
"""
Main method to execute the full segmentation process over the image data.
This method allocates necessary memory, segments each timepoint, and applies labeling.
"""
logger.info('Running semantic segmentation.')
self._get_t()
self._allocate_memory()
Expand Down

0 comments on commit ebbbd1a

Please sign in to comment.