diff --git a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py index d56fa5e3..6f13d1da 100644 --- a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py +++ b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py @@ -122,25 +122,22 @@ def __del__(self): self._file.close() def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0): - # Fancy indexing is non performant for h5.py with long frame lists - if frame_idxs is not None: - slice_start = np.min(frame_idxs) - slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames()) - else: - slice_start = 0 - slice_stop = self.get_num_frames() - - frames = self._video.lazy_slice[slice_start:slice_stop, :, :, channel].dsetread() + squeeze_data = False if isinstance(frame_idxs, int): + squeeze_data = True + frame_idxs = [frame_idxs] + elif isinstance(frame_idxs, np.ndarray): + frame_idxs = frame_idxs.tolist() + frames = self._video.lazy_slice[frame_idxs, :, :, channel].dsetread() + if squeeze_data: frames = frames.squeeze() - return frames def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0) -> np.ndarray: return self._video.lazy_slice[start_frame:end_frame, :, :, channel].dsetread() def get_image_size(self) -> Tuple[int, int]: - return (self._num_rows, self._num_cols) + return self._num_rows, self._num_cols def get_num_frames(self): return self._num_frames diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index daf271a5..8f73285f 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -181,18 +181,14 @@ def make_nwb_metadata( ) def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0): - # Fancy indexing is non performant for h5.py with long frame lists - if frame_idxs is not None: - slice_start = np.min(frame_idxs) - slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames()) - else: - slice_start = 0 - slice_stop = self.get_num_frames() - - data = self.photon_series.data - frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1]) - + squeeze_data = False if isinstance(frame_idxs, int): + squeeze_data = True + frame_idxs = [frame_idxs] + elif isinstance(frame_idxs, np.ndarray): + frame_idxs = frame_idxs.tolist() + frames = self.photon_series.data[frame_idxs].transpose([0, 2, 1]) + if squeeze_data: frames = frames.squeeze() return frames diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index 941881a4..d975e30d 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -406,6 +406,11 @@ def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor): assert_msg = "get_frames does not work correctly with frame_idxs=np.arrray([0, 1])" assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg + frame_idxs = [0, 2] + frames_with_array = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) + assert_msg = "get_frames does not work correctly with frame_idxs=[0, 2]" + assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg + def check_imaging_return_types(img_ex: ImagingExtractor): """Check that the return types of the imaging extractor are correct."""