Skip to content

Commit

Permalink
Merge pull request #191 from catalystneuro/refactor_nwb_segmentation_…
Browse files Browse the repository at this point in the history
…extractor

[refactor] `NwbSegmentationExtractor`
  • Loading branch information
CodyCBakerPhD authored Aug 15, 2022
2 parents 8a7729e + 5f5dfb3 commit 3f08c50
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 56 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat
* Add `frame_to_time` to `SegmentationExtractor`, `get_roi_ids` is now a class method. [PR #187](https://github.com/catalystneuro/roiextractors/pull/187)
* Add `set_times` to `SegmentationExtractor`. [PR #188](https://github.com/catalystneuro/roiextractors/pull/188)
* Updated the test for segmentation images to check all images for the given segmentation extractors. [PR #190](https://github.com/catalystneuro/roiextractors/pull/190)
* Refactored the `NwbSegmentationExtractor` to be more flexible with segmentation images and keep up
with the change in [catalystneuro/neuoroconv#41](https://github.com/catalystneuro/neuroconv/pull/41)
of trace names. [PR #191](https://github.com/catalystneuro/roiextractors/pull/191)

### Fixes

Expand Down
109 changes: 55 additions & 54 deletions src/roiextractors/extractors/nwbextractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def __init__(self, file_path: PathType):
.nwb file location
"""
check_nwb_install()
SegmentationExtractor.__init__(self)
super().__init__()
file_path = Path(file_path)
if not file_path.is_file():
raise Exception("file does not exist")
Expand All @@ -260,61 +260,53 @@ def __init__(self, file_path: PathType):
self._io = NWBHDF5IO(str(file_path), mode="r")
self.nwbfile = self._io.read()

assert "ophys" in self.nwbfile.processing, "Ophys processing module is not in nwbfile."
ophys = self.nwbfile.processing.get("ophys")
if ophys is None:
raise Exception("could not find ophys processing module in nwbfile")
else:
# Extract roi_response:
fluorescence = None
dfof = None
any_roi_response_series_found = False
if "Fluorescence" in ophys.data_interfaces:
fluorescence = ophys.data_interfaces["Fluorescence"]
if "DfOverF" in ophys.data_interfaces:
dfof = ophys.data_interfaces["DfOverF"]
if fluorescence is None and dfof is None:
raise Exception("could not find Fluorescence/DfOverF module in nwbfile")
for trace_name in ["RoiResponseSeries", "Dff", "Neuropil", "Deconvolved"]:
trace_name_segext = "raw" if trace_name == "RoiResponseSeries" else trace_name.lower()
container = dfof if trace_name == "Dff" else fluorescence
if container is not None and trace_name in container.roi_response_series:
any_roi_response_series_found = True
setattr(
self,
f"_roi_response_{trace_name_segext}",
DatasetView(container.roi_response_series[trace_name].data).lazy_transpose(),
)
if self._sampling_frequency is None:
self._sampling_frequency = container.roi_response_series[trace_name].rate
if not any_roi_response_series_found:
raise Exception(
"could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/'Deconvolved'"
"named RoiResponseSeries in nwbfile"

# Extract roi_responses:
fluorescence = None
df_over_f = None
any_roi_response_series_found = False
if "Fluorescence" in ophys.data_interfaces:
fluorescence = ophys.data_interfaces["Fluorescence"]
if "DfOverF" in ophys.data_interfaces:
df_over_f = ophys.data_interfaces["DfOverF"]
if fluorescence is None and df_over_f is None:
raise Exception("Could not find Fluorescence/DfOverF module in nwbfile.")
for trace_name in self.get_traces_dict().keys():
trace_name_segext = "RoiResponseSeries" if trace_name in ["raw", "dff"] else trace_name.capitalize()
container = df_over_f if trace_name == "dff" else fluorescence
if container is not None and trace_name_segext in container.roi_response_series:
any_roi_response_series_found = True
setattr(
self,
f"_roi_response_{trace_name}",
DatasetView(container.roi_response_series[trace_name_segext].data).lazy_transpose(),
)
# Extract image_mask/background:
if "ImageSegmentation" in ophys.data_interfaces:
image_seg = ophys.data_interfaces["ImageSegmentation"]
if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced
ps = image_seg.plane_segmentations["PlaneSegmentation"]
if "image_mask" in ps.colnames:
self._image_masks = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0])
else:
raise Exception("could not find any image_masks in nwbfile")
if "RoiCentroid" in ps.colnames:
self._roi_locs = ps["RoiCentroid"]
if "Accepted" in ps.colnames:
self._accepted_list = ps["Accepted"].data[:]
if "Rejected" in ps.colnames:
self._rejected_list = ps["Rejected"].data[:]
else:
raise Exception("could not find any PlaneSegmentation in nwbfile")
# Extracting stores images as GrayscaleImages:
if "SegmentationImages" in ophys.data_interfaces:
images_container = ophys.data_interfaces["SegmentationImages"]
if "correlation" in images_container.images:
self._image_correlation = images_container.images["correlation"].data[()].T
if "mean" in images_container.images:
self._image_mean = images_container.images["mean"].data[()].T
if self._sampling_frequency is None:
self._sampling_frequency = container.roi_response_series[trace_name_segext].rate
if not any_roi_response_series_found:
raise Exception(
"could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/'Deconvolved'"
"named RoiResponseSeries in nwbfile"
)
# Extract image_mask/background:
if "ImageSegmentation" in ophys.data_interfaces:
image_seg = ophys.data_interfaces["ImageSegmentation"]
assert len(image_seg.plane_segmentations), "Could not find any PlaneSegmentation in nwbfile."
if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced
ps = image_seg.plane_segmentations["PlaneSegmentation"]
assert "image_mask" in ps.colnames, "Could not find any image_masks in nwbfile."
self._image_masks = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0])
self._roi_locs = ps["RoiCentroid"] if "RoiCentroid" in ps.colnames else None
self._accepted_list = ps["Accepted"].data[:] if "Accepted" in ps.colnames else None
self._rejected_list = ps["Rejected"].data[:] if "Rejected" in ps.colnames else None

# Extracting stored images as GrayscaleImages:
self._segmentation_images = None
if "SegmentationImages" in ophys.data_interfaces:
images_container = ophys.data_interfaces["SegmentationImages"]
self._segmentation_images = images_container.images
# Imaging plane:
if "ImagingPlane" in self.nwbfile.imaging_planes:
imaging_plane = self.nwbfile.imaging_planes["ImagingPlane"]
Expand All @@ -335,6 +327,15 @@ def get_rejected_list(self):
if len(rej_list) > 0:
return rej_list

def get_images_dict(self):
images_dict = super().get_images_dict()
if self._segmentation_images is not None:
images_dict.update(
(image_name, image_data[:].T) for image_name, image_data in self._segmentation_images.items()
)

return images_dict

@property
def roi_locations(self):
if self._roi_locs is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def get_images_dict(self):
images_dict: dict
dictionary with key, values representing different types of Images
"""
images_dict = dict(
images_dict = super().get_images_dict()
images_dict.update(
summary_image=self._info_struct["summary_image"][:],
f_per_pixel=self._info_struct["F_per_pixel"][:],
max_image=self._info_struct["max_image"][:],
Expand Down
5 changes: 4 additions & 1 deletion tests/test_extractsegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def test_extractor_get_images_dict(self):
)[:]

images_dict = self.extractor.get_images_dict()
self.assertEqual(len(images_dict), 3)
self.assertEqual(len(images_dict), 5)

self.assertEqual(images_dict["correlation"], None)
self.assertEqual(images_dict["mean"], None)

self.assertEqual(images_dict["summary_image"].shape, summary_image.shape)
self.assertEqual(images_dict["max_image"].shape, max_image.shape)
Expand Down

0 comments on commit 3f08c50

Please sign in to comment.