Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VolumetricImagingExtractor #248

Merged
merged 13 commits into from
Oct 25, 2023
148 changes: 148 additions & 0 deletions src/roiextractors/volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Tuple, List, Iterable, Optional
import numpy as np

from .extraction_tools import ArrayType, DtypeType
from .imagingextractor import ImagingExtractor


class VolumetricImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "VolumetricImaging"
installed = True
installatiuon_mesage = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors.

Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.

Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects

Raises
------
AssertionError
If any of the properties are not consistent between extractors.

Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
- num_frames
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type",
get_num_frames="The number of frames",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.

Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).

Returns
-------
video: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
pauladkisson marked this conversation as resolved.
Show resolved Hide resolved
"""
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.get_num_frames()

video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).

Parameters
----------
frame_idxs: array-like
Indices of frames to return.

Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def get_image_size(self) -> Tuple:
"""Get the size of a single frame.

Returns
-------
image_size: tuple
The size of a single frame (num_rows, num_columns, num_planes).
"""
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes())
return image_size

def get_num_planes(self) -> int:
"""Get the number of depth planes.

Returns
-------
_num_planes: int
The number of depth planes.
"""
return self._num_planes

def get_num_frames(self) -> int:
return self._imaging_extractors[0].get_num_frames()

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()
Loading