Skip to content

Commit

Permalink
refactored input validation for get_video and get_frames into their o…
Browse files Browse the repository at this point in the history
…wn private methods
  • Loading branch information
pauladkisson committed Sep 24, 2024
1 parent ae8d6e6 commit 30279d2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
12 changes: 1 addition & 11 deletions src/roiextractors/extractors/numpyextractors/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,7 @@ def __init__(self, timeseries: Union[PathType, np.ndarray], sampling_frequency:
self._dtype = self._video.dtype

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
num_frames = self.get_num_frames()
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else num_frames
assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}"
assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}"
assert (
start_frame <= end_frame
), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})"
assert isinstance(start_frame, IntType), "'start_frame' must be an integer"
assert isinstance(end_frame, IntType), "'end_frame' must be an integer"

start_frame, end_frame = self._validate_get_video_arguments(start_frame=start_frame, end_frame=end_frame)
return self._video[start_frame:end_frame, ...]

def get_image_size(self) -> Tuple[int, int]:
Expand Down
71 changes: 56 additions & 15 deletions src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int]
Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
Start frame index (inclusive). By default, it is set to 0.
end_frame: int, optional
End frame index (exclusive).
End frame index (exclusive). By default, it is set to the number of frames.
Returns
-------
Expand All @@ -102,6 +102,38 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int]
"""
pass

def _validate_get_video_arguments(
self, start_frame: Optional[int] = None, end_frame: Optional[int] = None
) -> Tuple[int, int]:
"""Validate the start_frame and end_frame arguments for the get_video method.
Parameters
----------
start_frame: int, optional
Start frame index (inclusive). By default, it is set to 0.
end_frame: int, optional
End frame index (exclusive). By default, it is set to the number of frames.
Returns
-------
start_frame: int
Start frame index (inclusive).
end_frame: int
End frame index (exclusive).
"""
num_frames = self.get_num_frames()
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else num_frames
assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}"
assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}"
assert (
start_frame <= end_frame
), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})"
# python 3.9 doesn't support get_instance on a Union of types, so we use get_args
assert isinstance(start_frame, get_args(IntType)), "'start_frame' must be an integer"
assert isinstance(end_frame, get_args(IntType)), "'end_frame' must be an integer"
return start_frame, end_frame

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Expand All @@ -115,14 +147,33 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
frames: numpy.ndarray
The video frames.
"""
start_frame, end_frame = self._validate_get_frames_arguments(frame_idxs=frame_idxs)
relative_indices = np.array(frame_idxs) - start_frame
return self.get_video(start_frame=start_frame, end_frame=end_frame)[relative_indices, ...]

def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, int]:
"""Validate the frame_idxs argument for the get_frames method.
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
Returns
-------
start_frame: int
Start frame index (inclusive).
end_frame: int
End frame index (exclusive).
"""
start_frame = min(frame_idxs)
end_frame = max(frame_idxs) + 1
assert start_frame >= 0, f"All 'frame_idxs' must be greater than or equal to zero but received {start_frame}."
assert (
end_frame <= self.get_num_frames()
), f"All 'frame_idxs' must be less than the number of frames ({self.get_num_frames()}) but received {end_frame}."
relative_indices = np.array(frame_idxs) - start_frame
return self.get_video(start_frame=start_frame, end_frame=end_frame)[relative_indices, ...]

return start_frame, end_frame

def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.
Expand Down Expand Up @@ -298,17 +349,7 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
return self._parent_imaging.get_frames(frame_idxs=mapped_frame_idxs)

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
num_frames = self.get_num_frames()
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else num_frames
assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}"
assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}"
assert (
start_frame <= end_frame
), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})"
# python 3.9 doesn't support get_instance on a Union of types, so we use get_args
assert isinstance(start_frame, get_args(IntType)), "'start_frame' must be an integer"
assert isinstance(end_frame, get_args(IntType)), "'end_frame' must be an integer"
start_frame, end_frame = self._validate_get_video_arguments(start_frame=start_frame, end_frame=end_frame)

start_frame_shifted = start_frame + self._start_frame
end_frame_shifted = end_frame + self._start_frame
Expand Down

0 comments on commit 30279d2

Please sign in to comment.