From 006b11a511859484382806aea801ee80eefaa604 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Mon, 30 Sep 2024 10:00:28 -0700 Subject: [PATCH] refactored commonalites into basextractor --- src/roiextractors/baseextractor.py | 116 +++++++++++++++++++++ src/roiextractors/imagingextractor.py | 110 +------------------ src/roiextractors/segmentationextractor.py | 38 +------ 3 files changed, 122 insertions(+), 142 deletions(-) create mode 100644 src/roiextractors/baseextractor.py diff --git a/src/roiextractors/baseextractor.py b/src/roiextractors/baseextractor.py new file mode 100644 index 00000000..6a66f70b --- /dev/null +++ b/src/roiextractors/baseextractor.py @@ -0,0 +1,116 @@ +from abc import ABC, abstractmethod +from typing import Union, Tuple +from copy import deepcopy +import numpy as np +from .extraction_tools import ArrayType, FloatType + + +class BaseExtractor(ABC): + + def __init__(self): + self._times = None + + @abstractmethod + def get_image_size(self) -> Tuple[int, int]: + """Get the size of each image in the recording (num_rows, num_columns). + + Returns + ------- + image_size: tuple + Size of each image (num_rows, num_columns). + """ + pass + + @abstractmethod + def get_num_frames(self) -> int: + """Get the number of frames in the recording. + + Returns + ------- + num_frames: int + Number of frames in the recording. + """ + pass + + @abstractmethod + def get_sampling_frequency(self) -> float: + """Get the sampling frequency of the recording in Hz. + + Returns + ------- + sampling_frequency: float + Sampling frequency of the recording in Hz. + """ + pass + + def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]: + """Convert user-inputted frame indices to times with units of seconds. + + Parameters + ---------- + frames: array-like + The frame or frames to be converted to times. + + Returns + ------- + times: float or array-like + The corresponding times in seconds. + """ + # Default implementation + frames = np.asarray(frames) + if self._times is None: + return frames / self.get_sampling_frequency() + else: + return self._times[frames] + + def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]: + """Convert a user-inputted times (in seconds) to a frame indices. + + Parameters + ---------- + times: array-like + The times (in seconds) to be converted to frame indices. + + Returns + ------- + frames: float or array-like + The corresponding frame indices. + """ + # Default implementation + times = np.asarray(times) + if self._times is None: + return np.round(times * self.get_sampling_frequency()).astype("int64") + else: + return np.searchsorted(self._times, times).astype("int64") + + def set_times(self, times: ArrayType) -> None: + """Set the recording times (in seconds) for each frame. + + Parameters + ---------- + times: array-like + The times in seconds for each frame + """ + assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!" + self._times = np.array(times).astype("float64") + + def has_time_vector(self) -> bool: + """Detect if the ImagingExtractor has a time vector set or not. + + Returns + ------- + has_times: bool + True if the ImagingExtractor has a time vector set, otherwise False. + """ + return self._times is not None + + def copy_times(self, extractor) -> None: + """Copy times from another extractor. + + Parameters + ---------- + extractor + The extractor from which the times will be copied. + """ + if extractor._times is not None: + self.set_times(deepcopy(extractor._times)) diff --git a/src/roiextractors/imagingextractor.py b/src/roiextractors/imagingextractor.py index 52207923..2bb2438c 100644 --- a/src/roiextractors/imagingextractor.py +++ b/src/roiextractors/imagingextractor.py @@ -8,16 +8,17 @@ Class to get a lazy frame slice. """ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Union, Optional, Tuple, get_args from copy import deepcopy import numpy as np +from .baseextractor import BaseExtractor from .extraction_tools import ArrayType, PathType, DtypeType, FloatType, IntType -class ImagingExtractor(ABC): +class ImagingExtractor(BaseExtractor): """Abstract class that contains all the meta-data and input data from the imaging data.""" def __init__(self, *args, **kwargs) -> None: @@ -26,39 +27,6 @@ def __init__(self, *args, **kwargs) -> None: self._kwargs = kwargs self._times = None - @abstractmethod - def get_image_size(self) -> Tuple[int, int]: - """Get the size of the video (num_rows, num_columns). - - Returns - ------- - image_size: tuple - Size of the video (num_rows, num_columns). - """ - pass - - @abstractmethod - def get_num_frames(self) -> int: - """Get the number of frames in the video. - - Returns - ------- - num_frames: int - Number of frames in the video. - """ - pass - - @abstractmethod - def get_sampling_frequency(self) -> float: - """Get the sampling frequency in Hz. - - Returns - ------- - sampling_frequency: float - Sampling frequency in Hz. - """ - pass - @abstractmethod def get_dtype(self) -> DtypeType: """Get the data type of the video. @@ -175,78 +143,6 @@ def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, in return start_frame, end_frame - def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]: - """Convert user-inputted frame indices to times with units of seconds. - - Parameters - ---------- - frames: array-like - The frame or frames to be converted to times. - - Returns - ------- - times: float or array-like - The corresponding times in seconds. - """ - # Default implementation - frames = np.asarray(frames) - if self._times is None: - return frames / self.get_sampling_frequency() - else: - return self._times[frames] - - def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]: - """Convert a user-inputted times (in seconds) to a frame indices. - - Parameters - ---------- - times: array-like - The times (in seconds) to be converted to frame indices. - - Returns - ------- - frames: float or array-like - The corresponding frame indices. - """ - # Default implementation - times = np.asarray(times) - if self._times is None: - return np.round(times * self.get_sampling_frequency()).astype("int64") - else: - return np.searchsorted(self._times, times).astype("int64") - - def set_times(self, times: ArrayType) -> None: - """Set the recording times (in seconds) for each frame. - - Parameters - ---------- - times: array-like - The times in seconds for each frame - """ - assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!" - self._times = np.array(times).astype("float64") - - def has_time_vector(self) -> bool: - """Detect if the ImagingExtractor has a time vector set or not. - - Returns - ------- - has_times: bool - True if the ImagingExtractor has a time vector set, otherwise False. - """ - return self._times is not None - - def copy_times(self, extractor) -> None: - """Copy times from another extractor. - - Parameters - ---------- - extractor - The extractor from which the times will be copied. - """ - if extractor._times is not None: - self.set_times(deepcopy(extractor._times)) - def __eq__(self, imaging_extractor2): image_size_equal = self.get_image_size() == imaging_extractor2.get_image_size() num_frames_equal = self.get_num_frames() == imaging_extractor2.get_num_frames() diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 8e567c5e..1bd694ed 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -10,17 +10,18 @@ Class to get a lazy frame slice. """ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Union, Optional, Tuple, Iterable, List, get_args import numpy as np from numpy.typing import ArrayLike +from .baseextractor import BaseExtractor from .extraction_tools import ArrayType, IntType, FloatType from .extraction_tools import _pixel_mask_extractor -class SegmentationExtractor(ABC): +class SegmentationExtractor(BaseExtractor): """Abstract segmentation extractor class. An abstract class that contains all the meta-data and output data from @@ -35,39 +36,6 @@ def __init__(self): """Create a new SegmentationExtractor for a specific data format (unique to each child SegmentationExtractor).""" self._times = None - @abstractmethod - def get_image_size(self) -> ArrayType: - """Get frame size of movie (height, width). - - Returns - ------- - no_rois: array_like - 2-D array: image height x image width - """ - pass - - @abstractmethod - def get_num_frames(self) -> int: - """Get the number of frames in the recording (duration of recording). - - Returns - ------- - num_frames: int - Number of frames in the recording. - """ - pass - - @abstractmethod - def get_sampling_frequency(self) -> float: - """Get the sampling frequency in Hz. - - Returns - ------- - sampling_frequency: float - Sampling frequency of the recording in Hz. - """ - pass - @abstractmethod def get_roi_ids(self) -> list: """Get the list of ROI ids.