diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e0a9ae3..f4c37649 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat * Add `frame_to_time` to `SegmentationExtractor`, `get_roi_ids` is now a class method. [PR #187](https://github.com/catalystneuro/roiextractors/pull/187) * Add `set_times` to `SegmentationExtractor`. [PR #188](https://github.com/catalystneuro/roiextractors/pull/188) * Updated the test for segmentation images to check all images for the given segmentation extractors. [PR #190](https://github.com/catalystneuro/roiextractors/pull/190) +* Refactored the `NwbSegmentationExtractor` to be more flexible with segmentation images and keep up + with the change in [catalystneuro/neuoroconv#41](https://github.com/catalystneuro/neuroconv/pull/41) + of trace names. [PR #191](https://github.com/catalystneuro/roiextractors/pull/191) ### Fixes diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index 4cc5a2e2..e6861090 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -248,7 +248,7 @@ def __init__(self, file_path: PathType): .nwb file location """ check_nwb_install() - SegmentationExtractor.__init__(self) + super().__init__() file_path = Path(file_path) if not file_path.is_file(): raise Exception("file does not exist") @@ -260,61 +260,53 @@ def __init__(self, file_path: PathType): self._io = NWBHDF5IO(str(file_path), mode="r") self.nwbfile = self._io.read() + assert "ophys" in self.nwbfile.processing, "Ophys processing module is not in nwbfile." ophys = self.nwbfile.processing.get("ophys") - if ophys is None: - raise Exception("could not find ophys processing module in nwbfile") - else: - # Extract roi_response: - fluorescence = None - dfof = None - any_roi_response_series_found = False - if "Fluorescence" in ophys.data_interfaces: - fluorescence = ophys.data_interfaces["Fluorescence"] - if "DfOverF" in ophys.data_interfaces: - dfof = ophys.data_interfaces["DfOverF"] - if fluorescence is None and dfof is None: - raise Exception("could not find Fluorescence/DfOverF module in nwbfile") - for trace_name in ["RoiResponseSeries", "Dff", "Neuropil", "Deconvolved"]: - trace_name_segext = "raw" if trace_name == "RoiResponseSeries" else trace_name.lower() - container = dfof if trace_name == "Dff" else fluorescence - if container is not None and trace_name in container.roi_response_series: - any_roi_response_series_found = True - setattr( - self, - f"_roi_response_{trace_name_segext}", - DatasetView(container.roi_response_series[trace_name].data).lazy_transpose(), - ) - if self._sampling_frequency is None: - self._sampling_frequency = container.roi_response_series[trace_name].rate - if not any_roi_response_series_found: - raise Exception( - "could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/'Deconvolved'" - "named RoiResponseSeries in nwbfile" + + # Extract roi_responses: + fluorescence = None + df_over_f = None + any_roi_response_series_found = False + if "Fluorescence" in ophys.data_interfaces: + fluorescence = ophys.data_interfaces["Fluorescence"] + if "DfOverF" in ophys.data_interfaces: + df_over_f = ophys.data_interfaces["DfOverF"] + if fluorescence is None and df_over_f is None: + raise Exception("Could not find Fluorescence/DfOverF module in nwbfile.") + for trace_name in self.get_traces_dict().keys(): + trace_name_segext = "RoiResponseSeries" if trace_name in ["raw", "dff"] else trace_name.capitalize() + container = df_over_f if trace_name == "dff" else fluorescence + if container is not None and trace_name_segext in container.roi_response_series: + any_roi_response_series_found = True + setattr( + self, + f"_roi_response_{trace_name}", + DatasetView(container.roi_response_series[trace_name_segext].data).lazy_transpose(), ) - # Extract image_mask/background: - if "ImageSegmentation" in ophys.data_interfaces: - image_seg = ophys.data_interfaces["ImageSegmentation"] - if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced - ps = image_seg.plane_segmentations["PlaneSegmentation"] - if "image_mask" in ps.colnames: - self._image_masks = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0]) - else: - raise Exception("could not find any image_masks in nwbfile") - if "RoiCentroid" in ps.colnames: - self._roi_locs = ps["RoiCentroid"] - if "Accepted" in ps.colnames: - self._accepted_list = ps["Accepted"].data[:] - if "Rejected" in ps.colnames: - self._rejected_list = ps["Rejected"].data[:] - else: - raise Exception("could not find any PlaneSegmentation in nwbfile") - # Extracting stores images as GrayscaleImages: - if "SegmentationImages" in ophys.data_interfaces: - images_container = ophys.data_interfaces["SegmentationImages"] - if "correlation" in images_container.images: - self._image_correlation = images_container.images["correlation"].data[()].T - if "mean" in images_container.images: - self._image_mean = images_container.images["mean"].data[()].T + if self._sampling_frequency is None: + self._sampling_frequency = container.roi_response_series[trace_name_segext].rate + if not any_roi_response_series_found: + raise Exception( + "could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/'Deconvolved'" + "named RoiResponseSeries in nwbfile" + ) + # Extract image_mask/background: + if "ImageSegmentation" in ophys.data_interfaces: + image_seg = ophys.data_interfaces["ImageSegmentation"] + assert len(image_seg.plane_segmentations), "Could not find any PlaneSegmentation in nwbfile." + if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced + ps = image_seg.plane_segmentations["PlaneSegmentation"] + assert "image_mask" in ps.colnames, "Could not find any image_masks in nwbfile." + self._image_masks = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0]) + self._roi_locs = ps["RoiCentroid"] if "RoiCentroid" in ps.colnames else None + self._accepted_list = ps["Accepted"].data[:] if "Accepted" in ps.colnames else None + self._rejected_list = ps["Rejected"].data[:] if "Rejected" in ps.colnames else None + + # Extracting stored images as GrayscaleImages: + self._segmentation_images = None + if "SegmentationImages" in ophys.data_interfaces: + images_container = ophys.data_interfaces["SegmentationImages"] + self._segmentation_images = images_container.images # Imaging plane: if "ImagingPlane" in self.nwbfile.imaging_planes: imaging_plane = self.nwbfile.imaging_planes["ImagingPlane"] @@ -335,6 +327,15 @@ def get_rejected_list(self): if len(rej_list) > 0: return rej_list + def get_images_dict(self): + images_dict = super().get_images_dict() + if self._segmentation_images is not None: + images_dict.update( + (image_name, image_data[:].T) for image_name, image_data in self._segmentation_images.items() + ) + + return images_dict + @property def roi_locations(self): if self._roi_locs is not None: diff --git a/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py b/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py index b4bfd6a0..8376f76e 100644 --- a/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py +++ b/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py @@ -240,7 +240,8 @@ def get_images_dict(self): images_dict: dict dictionary with key, values representing different types of Images """ - images_dict = dict( + images_dict = super().get_images_dict() + images_dict.update( summary_image=self._info_struct["summary_image"][:], f_per_pixel=self._info_struct["F_per_pixel"][:], max_image=self._info_struct["max_image"][:], diff --git a/tests/test_extractsegmentationextractor.py b/tests/test_extractsegmentationextractor.py index 0a053460..06497ddb 100644 --- a/tests/test_extractsegmentationextractor.py +++ b/tests/test_extractsegmentationextractor.py @@ -199,7 +199,10 @@ def test_extractor_get_images_dict(self): )[:] images_dict = self.extractor.get_images_dict() - self.assertEqual(len(images_dict), 3) + self.assertEqual(len(images_dict), 5) + + self.assertEqual(images_dict["correlation"], None) + self.assertEqual(images_dict["mean"], None) self.assertEqual(images_dict["summary_image"].shape, summary_image.shape) self.assertEqual(images_dict["max_image"].shape, max_image.shape)