Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VolumetricImagingExtractor #248

Merged
merged 13 commits into from
Oct 25, 2023
Merged
2 changes: 2 additions & 0 deletions src/roiextractors/extractorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor
from .multisegmentationextractor import MultiSegmentationExtractor
from .multiimagingextractor import MultiImagingExtractor
from .volumetricimagingextractor import VolumetricImagingExtractor

imaging_extractor_full_list = [
NumpyImagingExtractor,
Expand All @@ -39,6 +40,7 @@
SbxImagingExtractor,
NumpyMemmapImagingExtractor,
MemmapImagingExtractor,
VolumetricImagingExtractor,
]

segmentation_extractor_full_list = [
Expand Down
4 changes: 3 additions & 1 deletion src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def generate_dummy_imaging_extractor(
num_channels: int = 1,
sampling_frequency: float = 30,
dtype: DtypeType = "uint16",
channel_names: Optional[list] = None,
):
"""Generate a dummy imaging extractor for testing.

Expand All @@ -78,7 +79,8 @@ def generate_dummy_imaging_extractor(
ImagingExtractor
An imaging extractor with random data fed into `NumpyImagingExtractor`.
"""
channel_names = [f"channel_num_{num}" for num in range(num_channels)]
if channel_names is None:
channel_names = [f"channel_num_{num}" for num in range(num_channels)]

size = (num_frames, num_rows, num_columns, num_channels)
video = generate_dummy_video(size=size, dtype=dtype)
Expand Down
169 changes: 169 additions & 0 deletions src/roiextractors/volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Base class definition for volumetric imaging extractors."""

from typing import Tuple, List, Iterable, Optional
import numpy as np

from .extraction_tools import ArrayType, DtypeType
from .imagingextractor import ImagingExtractor


class VolumetricImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "VolumetricImaging"
installed = True
installatiuon_mesage = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors.

Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.

Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects

Raises
------
AssertionError
If any of the properties are not consistent between extractors.

Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
- num_frames
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type",
get_num_frames="The number of frames",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.

Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).

Returns
-------
video: numpy.ndarray
The 3D video frames (num_frames, num_rows, num_columns, num_planes).
"""
if start_frame is None:
start_frame = 0
elif start_frame < 0:
start_frame = self.get_num_frames() + start_frame
elif start_frame >= self.get_num_frames():
raise ValueError(
f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}"
)
if end_frame is None:
end_frame = self.get_num_frames()
elif end_frame < 0:
end_frame = self.get_num_frames() + end_frame
elif end_frame > self.get_num_frames():
raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}")
if end_frame <= start_frame:
raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}")

video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).

Parameters
----------
frame_idxs: array-like
Indices of frames to return.

Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]
for frame_idx in frame_idxs:
if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames():
raise ValueError(f"frame_idx {frame_idx} is out of bounds")

# Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int
if not all(np.diff(frame_idxs) == 1):
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
return frames
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def get_image_size(self) -> Tuple:
"""Get the size of a single frame.

Returns
-------
image_size: tuple
The size of a single frame (num_rows, num_columns, num_planes).
"""
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes())
return image_size

def get_num_planes(self) -> int:
"""Get the number of depth planes.

Returns
-------
_num_planes: int
The number of depth planes.
"""
return self._num_planes

def get_num_frames(self) -> int:
return self._imaging_extractors[0].get_num_frames()

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()
121 changes: 121 additions & 0 deletions tests/test_volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import numpy as np
from roiextractors.testing import generate_dummy_imaging_extractor
from roiextractors import VolumetricImagingExtractor

num_frames = 10


@pytest.fixture(scope="module", params=[1, 2])
def imaging_extractors(request):
num_channels = request.param
return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)]


@pytest.fixture(scope="module")
def volumetric_imaging_extractor(imaging_extractors):
return VolumetricImagingExtractor(imaging_extractors)


@pytest.mark.parametrize(
"params",
[
[dict(sampling_frequency=1), dict(sampling_frequency=2)],
[dict(num_rows=1), dict(num_rows=2)],
[dict(num_channels=1), dict(num_channels=2)],
[dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)],
[dict(dtype=np.int16), dict(dtype=np.float32)],
[dict(num_frames=1), dict(num_frames=2)],
],
)
def test_check_consistency_between_imaging_extractors(params):
imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params]
with pytest.raises(AssertionError):
VolumetricImagingExtractor(imaging_extractors=imaging_extractors)


@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)])
def test_get_video(volumetric_imaging_extractor, start_frame, end_frame):
video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame)
expected_video = []
for extractor in volumetric_imaging_extractor._imaging_extractors:
expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame))
expected_video = np.array(expected_video)
expected_video = np.moveaxis(expected_video, 0, -1)
assert np.all(video == expected_video)


@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)])
def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame):
with pytest.raises(ValueError):
volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame)


@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]])
def test_get_frames(volumetric_imaging_extractor, frame_idxs):
frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs)
expected_frames = []
for extractor in volumetric_imaging_extractor._imaging_extractors:
expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs))
expected_frames = np.array(expected_frames)
expected_frames = np.moveaxis(expected_frames, 0, -1)
assert np.all(frames == expected_frames)


@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]])
def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs):
with pytest.raises(ValueError):
volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs)


@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)])
def test_get_image_size(num_rows, num_columns, num_planes):
imaging_extractors = [
generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes)
]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes)


@pytest.mark.parametrize("num_planes", [1, 2, 3])
def test_get_num_planes(num_planes):
imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_planes() == num_planes


@pytest.mark.parametrize("num_frames", [1, 2, 3])
def test_get_num_frames(num_frames):
imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_frames() == num_frames


@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_get_sampling_frequency(sampling_frequency):
imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency


@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]])
def test_get_channel_names(channel_names):
imaging_extractors = [
generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names))
]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_channel_names() == channel_names


@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_get_num_channels(num_channels):
imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_channels() == num_channels


@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8])
def test_get_dtype(dtype):
imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_dtype() == dtype
Loading