From 30279d2a75b3b880969f3caff5b7b7db91c0b27e Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Tue, 24 Sep 2024 14:50:09 -0700 Subject: [PATCH] refactored input validation for get_video and get_frames into their own private methods --- .../numpyextractors/numpyextractors.py | 12 +--- src/roiextractors/imagingextractor.py | 71 +++++++++++++++---- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/roiextractors/extractors/numpyextractors/numpyextractors.py b/src/roiextractors/extractors/numpyextractors/numpyextractors.py index be20a0b0..2b738995 100644 --- a/src/roiextractors/extractors/numpyextractors/numpyextractors.py +++ b/src/roiextractors/extractors/numpyextractors/numpyextractors.py @@ -65,17 +65,7 @@ def __init__(self, timeseries: Union[PathType, np.ndarray], sampling_frequency: self._dtype = self._video.dtype def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: - num_frames = self.get_num_frames() - start_frame = start_frame if start_frame is not None else 0 - end_frame = end_frame if end_frame is not None else num_frames - assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}" - assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}" - assert ( - start_frame <= end_frame - ), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})" - assert isinstance(start_frame, IntType), "'start_frame' must be an integer" - assert isinstance(end_frame, IntType), "'end_frame' must be an integer" - + start_frame, end_frame = self._validate_get_video_arguments(start_frame=start_frame, end_frame=end_frame) return self._video[start_frame:end_frame, ...] def get_image_size(self) -> Tuple[int, int]: diff --git a/src/roiextractors/imagingextractor.py b/src/roiextractors/imagingextractor.py index 5aadf571..52207923 100644 --- a/src/roiextractors/imagingextractor.py +++ b/src/roiextractors/imagingextractor.py @@ -77,9 +77,9 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] Parameters ---------- start_frame: int, optional - Start frame index (inclusive). + Start frame index (inclusive). By default, it is set to 0. end_frame: int, optional - End frame index (exclusive). + End frame index (exclusive). By default, it is set to the number of frames. Returns ------- @@ -102,6 +102,38 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] """ pass + def _validate_get_video_arguments( + self, start_frame: Optional[int] = None, end_frame: Optional[int] = None + ) -> Tuple[int, int]: + """Validate the start_frame and end_frame arguments for the get_video method. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). By default, it is set to 0. + end_frame: int, optional + End frame index (exclusive). By default, it is set to the number of frames. + + Returns + ------- + start_frame: int + Start frame index (inclusive). + end_frame: int + End frame index (exclusive). + """ + num_frames = self.get_num_frames() + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else num_frames + assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}" + assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}" + assert ( + start_frame <= end_frame + ), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})" + # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + assert isinstance(start_frame, get_args(IntType)), "'start_frame' must be an integer" + assert isinstance(end_frame, get_args(IntType)), "'end_frame' must be an integer" + return start_frame, end_frame + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: """Get specific video frames from indices (not necessarily continuous). @@ -115,14 +147,33 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: frames: numpy.ndarray The video frames. """ + start_frame, end_frame = self._validate_get_frames_arguments(frame_idxs=frame_idxs) + relative_indices = np.array(frame_idxs) - start_frame + return self.get_video(start_frame=start_frame, end_frame=end_frame)[relative_indices, ...] + + def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, int]: + """Validate the frame_idxs argument for the get_frames method. + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + start_frame: int + Start frame index (inclusive). + end_frame: int + End frame index (exclusive). + """ start_frame = min(frame_idxs) end_frame = max(frame_idxs) + 1 assert start_frame >= 0, f"All 'frame_idxs' must be greater than or equal to zero but received {start_frame}." assert ( end_frame <= self.get_num_frames() ), f"All 'frame_idxs' must be less than the number of frames ({self.get_num_frames()}) but received {end_frame}." - relative_indices = np.array(frame_idxs) - start_frame - return self.get_video(start_frame=start_frame, end_frame=end_frame)[relative_indices, ...] + + return start_frame, end_frame def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]: """Convert user-inputted frame indices to times with units of seconds. @@ -298,17 +349,7 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: return self._parent_imaging.get_frames(frame_idxs=mapped_frame_idxs) def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: - num_frames = self.get_num_frames() - start_frame = start_frame if start_frame is not None else 0 - end_frame = end_frame if end_frame is not None else num_frames - assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}" - assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}" - assert ( - start_frame <= end_frame - ), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})" - # python 3.9 doesn't support get_instance on a Union of types, so we use get_args - assert isinstance(start_frame, get_args(IntType)), "'start_frame' must be an integer" - assert isinstance(end_frame, get_args(IntType)), "'end_frame' must be an integer" + start_frame, end_frame = self._validate_get_video_arguments(start_frame=start_frame, end_frame=end_frame) start_frame_shifted = start_frame + self._start_frame end_frame_shifted = end_frame + self._start_frame