From 4c75116f3bff79a08f7419fb42a61b2ceb49856f Mon Sep 17 00:00:00 2001 From: Ben Dichter Date: Thu, 1 Feb 2024 15:52:54 -0500 Subject: [PATCH] Fix HDF5 and NWB get_frames behavior (#174) * for HDF5: * fix get_frames. It was previously assuming contiguous frames * improve get_video. It was previously reading frame-by-frame, which is not optimal for HDF5 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added integer handling to default get_frames method * added squeeze logic to get_frames * specify axis for squeeze * inverted if so that squeeze occurs in the int section * removed np.newaxis so that bruker tests expect squeezing behavior on single frame * removed np.newaxis so that micromanager tests expect squeezing behavior on single frame * added squeeze logic to multiimaging extractor * reverted imagingextractor changes * reverted multiimagingextractor and testing.py changes * reverted test_brukertiffimagingextractor changes * reverted test_micromanagertiffimagingextractor changes * reverted test_brukertiffimagingextractor changes * added get_frames to hdf5imagingextractor * added test for non-cts frames * fixed nwb get_frames --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Heberto Mayorquin Co-authored-by: Paul Adkisson --- .../hdf5imagingextractor.py | 19 ++++++++----------- .../extractors/nwbextractors/nwbextractors.py | 18 +++++++----------- src/roiextractors/testing.py | 5 +++++ 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py index d56fa5e3..6f13d1da 100644 --- a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py +++ b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py @@ -122,25 +122,22 @@ def __del__(self): self._file.close() def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0): - # Fancy indexing is non performant for h5.py with long frame lists - if frame_idxs is not None: - slice_start = np.min(frame_idxs) - slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames()) - else: - slice_start = 0 - slice_stop = self.get_num_frames() - - frames = self._video.lazy_slice[slice_start:slice_stop, :, :, channel].dsetread() + squeeze_data = False if isinstance(frame_idxs, int): + squeeze_data = True + frame_idxs = [frame_idxs] + elif isinstance(frame_idxs, np.ndarray): + frame_idxs = frame_idxs.tolist() + frames = self._video.lazy_slice[frame_idxs, :, :, channel].dsetread() + if squeeze_data: frames = frames.squeeze() - return frames def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0) -> np.ndarray: return self._video.lazy_slice[start_frame:end_frame, :, :, channel].dsetread() def get_image_size(self) -> Tuple[int, int]: - return (self._num_rows, self._num_cols) + return self._num_rows, self._num_cols def get_num_frames(self): return self._num_frames diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index daf271a5..8f73285f 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -181,18 +181,14 @@ def make_nwb_metadata( ) def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0): - # Fancy indexing is non performant for h5.py with long frame lists - if frame_idxs is not None: - slice_start = np.min(frame_idxs) - slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames()) - else: - slice_start = 0 - slice_stop = self.get_num_frames() - - data = self.photon_series.data - frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1]) - + squeeze_data = False if isinstance(frame_idxs, int): + squeeze_data = True + frame_idxs = [frame_idxs] + elif isinstance(frame_idxs, np.ndarray): + frame_idxs = frame_idxs.tolist() + frames = self.photon_series.data[frame_idxs].transpose([0, 2, 1]) + if squeeze_data: frames = frames.squeeze() return frames diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index 941881a4..d975e30d 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -406,6 +406,11 @@ def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor): assert_msg = "get_frames does not work correctly with frame_idxs=np.arrray([0, 1])" assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg + frame_idxs = [0, 2] + frames_with_array = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) + assert_msg = "get_frames does not work correctly with frame_idxs=[0, 2]" + assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg + def check_imaging_return_types(img_ex: ImagingExtractor): """Check that the return types of the imaging extractor are correct."""