Skip to content

Commit

Permalink
refactored commonalites into basextractor
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 30, 2024
1 parent 2dd85af commit 006b11a
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 142 deletions.
116 changes: 116 additions & 0 deletions src/roiextractors/baseextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from abc import ABC, abstractmethod
from typing import Union, Tuple
from copy import deepcopy
import numpy as np
from .extraction_tools import ArrayType, FloatType


class BaseExtractor(ABC):

def __init__(self):
self._times = None

@abstractmethod
def get_image_size(self) -> Tuple[int, int]:
"""Get the size of each image in the recording (num_rows, num_columns).
Returns
-------
image_size: tuple
Size of each image (num_rows, num_columns).
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the recording.
Returns
-------
num_frames: int
Number of frames in the recording.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency of the recording in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency of the recording in Hz.
"""
pass

def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.
Parameters
----------
frames: array-like
The frame or frames to be converted to times.
Returns
-------
times: float or array-like
The corresponding times in seconds.
"""
# Default implementation
frames = np.asarray(frames)
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert a user-inputted times (in seconds) to a frame indices.
Parameters
----------
times: array-like
The times (in seconds) to be converted to frame indices.
Returns
-------
frames: float or array-like
The corresponding frame indices.
"""
# Default implementation
times = np.asarray(times)
if self._times is None:
return np.round(times * self.get_sampling_frequency()).astype("int64")
else:
return np.searchsorted(self._times, times).astype("int64")

def set_times(self, times: ArrayType) -> None:
"""Set the recording times (in seconds) for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times).astype("float64")

def has_time_vector(self) -> bool:
"""Detect if the ImagingExtractor has a time vector set or not.
Returns
-------
has_times: bool
True if the ImagingExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def copy_times(self, extractor) -> None:
"""Copy times from another extractor.
Parameters
----------
extractor
The extractor from which the times will be copied.
"""
if extractor._times is not None:
self.set_times(deepcopy(extractor._times))
110 changes: 3 additions & 107 deletions src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
Class to get a lazy frame slice.
"""

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Union, Optional, Tuple, get_args
from copy import deepcopy

import numpy as np

from .baseextractor import BaseExtractor
from .extraction_tools import ArrayType, PathType, DtypeType, FloatType, IntType


class ImagingExtractor(ABC):
class ImagingExtractor(BaseExtractor):
"""Abstract class that contains all the meta-data and input data from the imaging data."""

def __init__(self, *args, **kwargs) -> None:
Expand All @@ -26,39 +27,6 @@ def __init__(self, *args, **kwargs) -> None:
self._kwargs = kwargs
self._times = None

@abstractmethod
def get_image_size(self) -> Tuple[int, int]:
"""Get the size of the video (num_rows, num_columns).
Returns
-------
image_size: tuple
Size of the video (num_rows, num_columns).
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the video.
Returns
-------
num_frames: int
Number of frames in the video.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency in Hz.
"""
pass

@abstractmethod
def get_dtype(self) -> DtypeType:
"""Get the data type of the video.
Expand Down Expand Up @@ -175,78 +143,6 @@ def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, in

return start_frame, end_frame

def frame_to_time(self, frames: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.
Parameters
----------
frames: array-like
The frame or frames to be converted to times.
Returns
-------
times: float or array-like
The corresponding times in seconds.
"""
# Default implementation
frames = np.asarray(frames)
if self._times is None:
return frames / self.get_sampling_frequency()
else:
return self._times[frames]

def time_to_frame(self, times: ArrayType) -> Union[FloatType, np.ndarray]:
"""Convert a user-inputted times (in seconds) to a frame indices.
Parameters
----------
times: array-like
The times (in seconds) to be converted to frame indices.
Returns
-------
frames: float or array-like
The corresponding frame indices.
"""
# Default implementation
times = np.asarray(times)
if self._times is None:
return np.round(times * self.get_sampling_frequency()).astype("int64")
else:
return np.searchsorted(self._times, times).astype("int64")

def set_times(self, times: ArrayType) -> None:
"""Set the recording times (in seconds) for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times).astype("float64")

def has_time_vector(self) -> bool:
"""Detect if the ImagingExtractor has a time vector set or not.
Returns
-------
has_times: bool
True if the ImagingExtractor has a time vector set, otherwise False.
"""
return self._times is not None

def copy_times(self, extractor) -> None:
"""Copy times from another extractor.
Parameters
----------
extractor
The extractor from which the times will be copied.
"""
if extractor._times is not None:
self.set_times(deepcopy(extractor._times))

def __eq__(self, imaging_extractor2):
image_size_equal = self.get_image_size() == imaging_extractor2.get_image_size()
num_frames_equal = self.get_num_frames() == imaging_extractor2.get_num_frames()
Expand Down
38 changes: 3 additions & 35 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
Class to get a lazy frame slice.
"""

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Union, Optional, Tuple, Iterable, List, get_args

import numpy as np
from numpy.typing import ArrayLike

from .baseextractor import BaseExtractor
from .extraction_tools import ArrayType, IntType, FloatType
from .extraction_tools import _pixel_mask_extractor


class SegmentationExtractor(ABC):
class SegmentationExtractor(BaseExtractor):
"""Abstract segmentation extractor class.
An abstract class that contains all the meta-data and output data from
Expand All @@ -35,39 +36,6 @@ def __init__(self):
"""Create a new SegmentationExtractor for a specific data format (unique to each child SegmentationExtractor)."""
self._times = None

@abstractmethod
def get_image_size(self) -> ArrayType:
"""Get frame size of movie (height, width).
Returns
-------
no_rois: array_like
2-D array: image height x image width
"""
pass

@abstractmethod
def get_num_frames(self) -> int:
"""Get the number of frames in the recording (duration of recording).
Returns
-------
num_frames: int
Number of frames in the recording.
"""
pass

@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency of the recording in Hz.
"""
pass

@abstractmethod
def get_roi_ids(self) -> list:
"""Get the list of ROI ids.
Expand Down

0 comments on commit 006b11a

Please sign in to comment.