From b383d7be8412f79af32d71c3b8fc3288e1f26626 Mon Sep 17 00:00:00 2001 From: Paul Adkisson Date: Wed, 25 Sep 2024 02:38:42 +1000 Subject: [PATCH] Added depth_slice for volumetricImagingExtractors (#363) * added DepthSliceImagingExtractor from szonjas add_depth_slice branch * updated get_image_size and get_video to match volumetricimagingextractor * simplified extractor * moved input handling to depth_slice method instead of init * expanded tests * added a bit of clarity to the docstrings * removed unnecessary Union * updated changelog * added double depth slice test * added error for frame slice --- CHANGELOG.md | 1 + .../volumetricimagingextractor.py | 56 +++++++++++++++++++ tests/test_volumetricimagingextractor.py | 44 +++++++++++++++ 3 files changed, 101 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21f9cb39..2f4bfe33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features * Added a seed to dummy generators [#361](https://github.com/catalystneuro/roiextractors/pull/361) +* Added depth_slice for VolumetricImagingExtractors [PR #363](https://github.com/catalystneuro/roiextractors/pull/363) ### Fixes * Added specific error message for single-frame scanimage data [PR #360](https://github.com/catalystneuro/roiextractors/pull/360) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index 44f84968..893ae039 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -168,3 +168,59 @@ def get_num_channels(self) -> int: def get_dtype(self) -> DtypeType: return self._imaging_extractors[0].get_dtype() + + def depth_slice(self, start_plane: Optional[int] = None, end_plane: Optional[int] = None): + """Return a new VolumetricImagingExtractor ranging from the start_plane to the end_plane.""" + start_plane = start_plane if start_plane is not None else 0 + end_plane = end_plane if end_plane is not None else self._num_planes + assert ( + 0 <= start_plane < self._num_planes + ), f"'start_plane' ({start_plane}) must be greater than 0 and smaller than the number of planes ({self._num_planes})." + assert ( + start_plane < end_plane <= self._num_planes + ), f"'end_plane' ({end_plane}) must be greater than 'start_plane' ({start_plane}) and smaller than or equal to the number of planes ({self._num_planes})." + + return DepthSliceVolumetricImagingExtractor(parent_extractor=self, start_plane=start_plane, end_plane=end_plane) + + def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): + """Return a new VolumetricImagingExtractor with a subset of frames.""" + raise NotImplementedError( + "frame_slice is not implemented for VolumetricImagingExtractor due to conflicts with get_video()." + ) + + +class DepthSliceVolumetricImagingExtractor(VolumetricImagingExtractor): + """Class to get a lazy depth slice. + + This class can only be used for volumetric imaging data. + Do not use this class directly but use `.depth_slice(...)` on a VolumetricImagingExtractor object. + """ + + extractor_name = "DepthSliceVolumetricImagingExtractor" + installed = True + is_writable = True + installation_mesg = "" + + def __init__( + self, + parent_extractor: VolumetricImagingExtractor, + start_plane: Optional[int] = None, + end_plane: Optional[int] = None, + ): + """Initialize a VolumetricImagingExtractor whose plane(s) subset the parent. + + Subset is exclusive on the right bound, that is, the plane indices of this VolumetricImagingExtractor range over + [0, ..., end_plane-start_plane-1]. + + Parameters + ---------- + parent_extractor : VolumetricImagingExtractor + The VolumetricImagingExtractor object to subset the planes of. + start_plane : int, optional + The left bound of the depth to subset. + The default is the first plane of the parent. + end_plane : int, optional + The right bound of the depth, exclusively, to subset. + The default is the last plane of the parent. + """ + super().__init__(imaging_extractors=parent_extractor._imaging_extractors[start_plane:end_plane]) diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py index 62424dae..9119f0f8 100644 --- a/tests/test_volumetricimagingextractor.py +++ b/tests/test_volumetricimagingextractor.py @@ -119,3 +119,47 @@ def test_get_dtype(dtype): imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)] volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) assert volumetric_imaging_extractor.get_dtype() == dtype + + +@pytest.mark.parametrize("start_plane, end_plane", [(None, None), (0, 1), (1, 2)]) +def test_depth_slice(volumetric_imaging_extractor, start_plane, end_plane): + start_plane = start_plane or 0 + end_plane = end_plane or volumetric_imaging_extractor.get_num_planes() + sliced_extractor = volumetric_imaging_extractor.depth_slice(start_plane=start_plane, end_plane=end_plane) + + assert sliced_extractor.get_num_planes() == end_plane - start_plane + assert sliced_extractor.get_image_size() == ( + *volumetric_imaging_extractor.get_image_size()[:2], + end_plane - start_plane, + ) + video = volumetric_imaging_extractor.get_video() + sliced_video = sliced_extractor.get_video() + assert np.all(video[..., start_plane:end_plane] == sliced_video) + frames = volumetric_imaging_extractor.get_frames(frame_idxs=[0, 1, 2]) + sliced_frames = sliced_extractor.get_frames(frame_idxs=[0, 1, 2]) + assert np.all(frames[..., start_plane:end_plane] == sliced_frames) + + +@pytest.mark.parametrize("start_plane, end_plane", [(0, -1), (1, 0), (0, 4)]) +def test_depth_slice_invalid(volumetric_imaging_extractor, start_plane, end_plane): + with pytest.raises(AssertionError): + volumetric_imaging_extractor.depth_slice(start_plane=start_plane, end_plane=end_plane) + + +def test_depth_slice_twice(volumetric_imaging_extractor): + sliced_extractor = volumetric_imaging_extractor.depth_slice(start_plane=0, end_plane=2) + twice_sliced_extractor = sliced_extractor.depth_slice(start_plane=0, end_plane=1) + + assert twice_sliced_extractor.get_num_planes() == 1 + assert twice_sliced_extractor.get_image_size() == (*volumetric_imaging_extractor.get_image_size()[:2], 1) + video = volumetric_imaging_extractor.get_video() + sliced_video = twice_sliced_extractor.get_video() + assert np.all(video[..., :1] == sliced_video) + frames = volumetric_imaging_extractor.get_frames(frame_idxs=[0, 1, 2]) + sliced_frames = twice_sliced_extractor.get_frames(frame_idxs=[0, 1, 2]) + assert np.all(frames[..., :1] == sliced_frames) + + +def test_frame_slice(volumetric_imaging_extractor): + with pytest.raises(NotImplementedError): + volumetric_imaging_extractor.frame_slice(start_frame=0, end_frame=1)