Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Minian segmentation extractor #368

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Features
* Added a seed to dummy generators [#361](https://github.com/catalystneuro/roiextractors/pull/361)
* Added depth_slice for VolumetricImagingExtractors [PR #363](https://github.com/catalystneuro/roiextractors/pull/363)
* Added MinianSegmentationExtractor: [PR #368](https://github.com/catalystneuro/roiextractors/pull/368)

### Fixes
* Added specific error message for single-frame scanimage data [PR #360](https://github.com/catalystneuro/roiextractors/pull/360)
Expand Down
2 changes: 2 additions & 0 deletions src/roiextractors/extractorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .extractors.inscopixextractors import InscopixImagingExtractor
from .extractors.memmapextractors import NumpyMemmapImagingExtractor
from .extractors.memmapextractors import MemmapImagingExtractor
from .extractors.minian import MinianSegmentationExtractor
from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor
from .multisegmentationextractor import MultiSegmentationExtractor
from .multiimagingextractor import MultiImagingExtractor
Expand Down Expand Up @@ -62,6 +63,7 @@
ExtractSegmentationExtractor,
SimaSegmentationExtractor,
CaimanSegmentationExtractor,
MinianSegmentationExtractor,
]

imaging_extractor_dict = {imaging_class.extractor_name: imaging_class for imaging_class in imaging_extractor_full_list}
Expand Down
14 changes: 14 additions & 0 deletions src/roiextractors/extractors/minian/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""A Segmentation Extractor for Minian.

Modules
-------
miniansegmentationextractor
A Segmentation Extractor for Minian.

Classes
-------
MinianSegmentationExtractor
A class for extracting segmentation from Minian output.
"""

from .miniansegmentationextractor import MinianSegmentationExtractor
204 changes: 204 additions & 0 deletions src/roiextractors/extractors/minian/miniansegmentationextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""A SegmentationExtractor for Minian.

Classes
-------
MinianSegmentationExtractor
A class for extracting segmentation from Minian output.
"""

from pathlib import Path

import zarr
import warnings
import numpy as np
import pandas as pd

from ...extraction_tools import PathType
from ...segmentationextractor import SegmentationExtractor


class MinianSegmentationExtractor(SegmentationExtractor):
"""A SegmentationExtractor for Minian.

This class inherits from the SegmentationExtractor class, having all
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this docstring is very helpful, I think it should be oriented to explain things to final users and not implementation details.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sure! I just wanted to be consistent with all the other extractors, but I guess a more detailed docstring won't hurt.

its functionality specifically applied to the dataset output from
the 'Minian' ROI segmentation method.

Users can extract key information such as ROI traces, image masks,
and timestamps from the output of the Minian pipeline.

Key features:
- Extracts fluorescence traces (denoised, baseline, neuropil, deconvolved) for each ROI.
- Retrieves ROI masks and background components.
- Provides access to timestamps corresponding to calcium traces.
- Retrieves maximum projection image.

Parameters
----------
folder_path: str
Path to the folder containing Minian .zarr output files.

"""

extractor_name = "MinianSegmentation"
is_writable = True
mode = "file"

def __init__(self, folder_path: PathType):
"""Initialize a MinianSegmentationExtractor instance.

Parameters
----------
folder_path: str
The location of the folder containing minian .zarr output.
"""
SegmentationExtractor.__init__(self)
self.folder_path = folder_path
self._roi_response_denoised = self._read_trace_from_zarr_filed(field="C")
self._roi_response_baseline = self._read_trace_from_zarr_filed(field="b0")
self._roi_response_neuropil = self._read_trace_from_zarr_filed(field="f")
self._roi_response_deconvolved = self._read_trace_from_zarr_filed(field="S")
self._image_maximum_projection = np.array(self._read_zarr_group("/max_proj.zarr/max_proj"))
self._image_masks = self._read_roi_image_mask_from_zarr_filed()
self._background_image_masks = self._read_background_image_mask_from_zarr_filed()
self._times = self._read_timestamps_from_csv()

def _read_zarr_group(self, zarr_group=""):
"""Read the zarr.

Returns
-------
zarr.open
The zarr object specified by self.folder_path.
"""
if zarr_group not in zarr.open(self.folder_path, mode="r"):
warnings.warn(f"Group '{zarr_group}' not found in the Zarr store.", UserWarning)
return None
else:
return zarr.open(str(self.folder_path) + f"/{zarr_group}", "r")

def _read_roi_image_mask_from_zarr_filed(self):
"""Read the image masks from the zarr output.

Returns
-------
image_masks: numpy.ndarray
The image masks for each ROI.
"""
dataset = self._read_zarr_group("/A.zarr")
if dataset is None or "A" not in dataset:
return None
else:
return np.transpose(dataset["A"], (1, 2, 0))

def _read_background_image_mask_from_zarr_filed(self):
"""Read the image masks from the zarr output.

Returns
-------
image_masks: numpy.ndarray
The image masks for each background components.
"""
dataset = self._read_zarr_group("/b.zarr")
if dataset is None or "b" not in dataset:
return None
else:
return np.expand_dims(dataset["b"], axis=2)

def _read_trace_from_zarr_filed(self, field):
"""Read the traces specified by the field from the zarr object.

Parameters
----------
field: str
The field to read from the zarr object.

Returns
-------
trace: numpy.ndarray
The traces specified by the field.
"""
dataset = self._read_zarr_group(f"/{field}.zarr")

if dataset is None or field not in dataset:
return None
elif dataset[field].ndim == 2:
return np.transpose(dataset[field])
elif dataset[field].ndim == 1:
return np.expand_dims(dataset[field], axis=1)

def _read_timestamps_from_csv(self):
"""Extract timestamps corresponding to frame numbers of the stored denoised trace

Returns
-------
np.ndarray
The timestamps of the denoised trace.
"""
csv_file = self.folder_path / "timeStamps.csv"
df = pd.read_csv(csv_file)
frame_numbers = self._read_zarr_group("/C.zarr/frame")
filtered_df = df[df["Frame Number"].isin(frame_numbers)] * 1e-3

return filtered_df["Time Stamp (ms)"].to_numpy()

def get_image_size(self):
dataset = self._read_zarr_group("/A.zarr")
height = dataset["height"].shape[0]
width = dataset["width"].shape[0]
return (height, width)

def get_accepted_list(self) -> list:
"""Get a list of accepted ROI ids.

Returns
-------
accepted_list: list
List of accepted ROI ids.
"""
return list(range(self.get_num_rois()))

def get_rejected_list(self) -> list:
"""Get a list of rejected ROI ids.

Returns
-------
rejected_list: list
List of rejected ROI ids.
"""
return list()

def get_roi_ids(self) -> list:
dataset = self._read_zarr_group("/A.zarr")
return list(dataset["unit_id"])

def get_traces_dict(self) -> dict:
"""Get traces as a dictionary with key as the name of the ROiResponseSeries.

Returns
-------
_roi_response_dict: dict
dictionary with key, values representing different types of RoiResponseSeries:
Raw Fluorescence, DeltaFOverF, Denoised, Neuropil, Deconvolved, Background, etc.
"""
return dict(
denoised=self._roi_response_denoised,
baseline=self._roi_response_baseline,
neuropil=self._roi_response_neuropil,
deconvolved=self._roi_response_deconvolved,
)

def get_images_dict(self) -> dict:
"""Get images as a dictionary with key as the name of the ROIResponseSeries.

Returns
-------
_roi_image_dict: dict
dictionary with key, values representing different types of Images used in segmentation:
Mean, Correlation image
"""
return dict(
mean=self._image_mean,
correlation=self._image_correlation,
maximum_projection=self._image_maximum_projection,
)
114 changes: 114 additions & 0 deletions tests/test_miniansegmentationextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import shutil
import tempfile
from pathlib import Path

import numpy as np
import zarr
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal

from roiextractors import MinianSegmentationExtractor
from tests.setup_paths import OPHYS_DATA_PATH


class TestMinianSegmentationExtractor(TestCase):
@classmethod
def setUpClass(cls):
folder_path = str(OPHYS_DATA_PATH / "segmentation_datasets" / "minian")

cls.folder_path = Path(folder_path)
extractor = MinianSegmentationExtractor(folder_path=cls.folder_path)
cls.extractor = extractor

cls.test_dir = Path(tempfile.mkdtemp())

# denoised traces
dataset = zarr.open(folder_path + "/C.zarr")
cls.denoised_traces = np.transpose(dataset["C"])
cls.num_frames = len(dataset["frame"][:])
# deconvolved traces
dataset = zarr.open(folder_path + "/S.zarr")
cls.deconvolved_traces = np.transpose(dataset["S"])
# baseline traces
dataset = zarr.open(folder_path + "/b0.zarr")
cls.baseline_traces = np.transpose(dataset["b0"])
# neuropil trace
dataset = zarr.open(folder_path + "/f.zarr")
cls.neuropil_trace = np.expand_dims(dataset["f"], axis=1)

# ROIs masks
dataset = zarr.open(folder_path + "/A.zarr")
cls.image_masks = np.transpose(dataset["A"], (1, 2, 0))
cls.image_size = (dataset["height"].shape[0], dataset["width"].shape[0])
cls.num_rois = dataset["unit_id"].shape[0]
# background mask
dataset = zarr.open(folder_path + "/b.zarr")
cls.background_image_mask = np.expand_dims(dataset["b"], axis=2)
# summary image: maximum projection
cls.maximum_projection_image = np.array(zarr.open(folder_path + "/max_proj.zarr/max_proj"))

@classmethod
def tearDownClass(cls):
# remove the temporary directory and its contents
shutil.rmtree(cls.test_dir)

def test_incomplete_extractor_load(self):
"""Check extractor can be initialized when not all traces are available."""
# temporary directory for testing assertion when some of the files are missing
folders_to_copy = [
"A.zarr",
"C.zarr",
"b0.zarr",
"b.zarr",
"f.zarr",
"max_proj.zarr",
".zgroup",
"timeStamps.csv",
]
self.test_dir.mkdir(exist_ok=True)

for folder in folders_to_copy:
src = Path(self.folder_path) / folder
dst = self.test_dir / folder
if src.is_dir():
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy(src, dst)

extractor = MinianSegmentationExtractor(folder_path=self.test_dir)
traces_dict = extractor.get_traces_dict()
self.assertEqual(traces_dict["deconvolved"], None)

def test_image_size(self):
self.assertEqual(self.extractor.get_image_size(), self.image_size)

def test_num_frames(self):
self.assertEqual(self.extractor.get_num_frames(), self.num_frames)

def test_frame_to_time(self):
self.assertEqual(self.extractor.frame_to_time(frames=[0]), 0.328)

def test_num_channels(self):
self.assertEqual(self.extractor.get_num_channels(), 1)

def test_num_rois(self):
self.assertEqual(self.extractor.get_num_rois(), self.num_rois)

def test_extractor_denoised_traces(self):
assert_array_equal(self.extractor.get_traces(name="denoised"), self.denoised_traces)

def test_extractor_neuropil_trace(self):
assert_array_equal(self.extractor.get_traces(name="neuropil"), self.neuropil_trace)

def test_extractor_image_masks(self):
"""Test that the image masks are correctly extracted."""
assert_array_equal(self.extractor.get_roi_image_masks(), self.image_masks)

def test_extractor_background_image_masks(self):
"""Test that the image masks are correctly extracted."""
assert_array_equal(self.extractor.get_background_image_masks(), self.background_image_mask)

def test_maximum_projection_image(self):
"""Test that the mean image is correctly loaded from the extractor."""
images_dict = self.extractor.get_images_dict()
assert_array_equal(images_dict["maximum_projection"], self.maximum_projection_image)
Loading