Skip to content

Commit

Permalink
Fix HDF5 and NWB get_frames behavior (#174)
Browse files Browse the repository at this point in the history
* for HDF5:
* fix get_frames. It was previously assuming contiguous frames
* improve get_video. It was previously reading frame-by-frame, which is not optimal for HDF5

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added integer handling to default get_frames method

* added squeeze logic to get_frames

* specify axis for squeeze

* inverted if so that squeeze occurs in the int section

* removed np.newaxis so that bruker tests expect squeezing behavior on single frame

* removed np.newaxis so that micromanager tests expect squeezing behavior on single frame

* added squeeze logic to multiimaging
 extractor

* reverted imagingextractor changes

* reverted multiimagingextractor and testing.py changes

* reverted test_brukertiffimagingextractor changes

* reverted test_micromanagertiffimagingextractor changes

* reverted test_brukertiffimagingextractor changes

* added get_frames to hdf5imagingextractor

* added test for non-cts frames

* fixed nwb get_frames

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Heberto Mayorquin <[email protected]>
Co-authored-by: Paul Adkisson <[email protected]>
  • Loading branch information
4 people authored Feb 1, 2024
1 parent f49355c commit 4c75116
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,22 @@ def __del__(self):
self._file.close()

def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0):
# Fancy indexing is non performant for h5.py with long frame lists
if frame_idxs is not None:
slice_start = np.min(frame_idxs)
slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames())
else:
slice_start = 0
slice_stop = self.get_num_frames()

frames = self._video.lazy_slice[slice_start:slice_stop, :, :, channel].dsetread()
squeeze_data = False
if isinstance(frame_idxs, int):
squeeze_data = True
frame_idxs = [frame_idxs]
elif isinstance(frame_idxs, np.ndarray):
frame_idxs = frame_idxs.tolist()
frames = self._video.lazy_slice[frame_idxs, :, :, channel].dsetread()
if squeeze_data:
frames = frames.squeeze()

return frames

def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0) -> np.ndarray:
return self._video.lazy_slice[start_frame:end_frame, :, :, channel].dsetread()

def get_image_size(self) -> Tuple[int, int]:
return (self._num_rows, self._num_cols)
return self._num_rows, self._num_cols

def get_num_frames(self):
return self._num_frames
Expand Down
18 changes: 7 additions & 11 deletions src/roiextractors/extractors/nwbextractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,14 @@ def make_nwb_metadata(
)

def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0):
# Fancy indexing is non performant for h5.py with long frame lists
if frame_idxs is not None:
slice_start = np.min(frame_idxs)
slice_stop = min(np.max(frame_idxs) + 1, self.get_num_frames())
else:
slice_start = 0
slice_stop = self.get_num_frames()

data = self.photon_series.data
frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1])

squeeze_data = False
if isinstance(frame_idxs, int):
squeeze_data = True
frame_idxs = [frame_idxs]
elif isinstance(frame_idxs, np.ndarray):
frame_idxs = frame_idxs.tolist()
frames = self.photon_series.data[frame_idxs].transpose([0, 2, 1])
if squeeze_data:
frames = frames.squeeze()
return frames

Expand Down
5 changes: 5 additions & 0 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor):
assert_msg = "get_frames does not work correctly with frame_idxs=np.arrray([0, 1])"
assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg

frame_idxs = [0, 2]
frames_with_array = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0)
assert_msg = "get_frames does not work correctly with frame_idxs=[0, 2]"
assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg


def check_imaging_return_types(img_ex: ImagingExtractor):
"""Check that the return types of the imaging extractor are correct."""
Expand Down

0 comments on commit 4c75116

Please sign in to comment.