Skip to content

Commit

Permalink
switched from local MultiPlaneIE to rebased VolumetricIE
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 29, 2023
1 parent 53acb1d commit aac88dd
Showing 1 changed file with 2 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package
from ...imagingextractor import ImagingExtractor
from ...volumetricimagingextractor import VolumetricImagingExtractor
from .scanimagetiff_utils import (
extract_extra_metadata,
parse_metadata,
Expand All @@ -20,150 +21,7 @@
)


class MultiPlaneImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "MultiPlaneImaging"
installed = True
installatiuon_mesage = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
"""Initialize a MultiPlaneImagingExtractor object from a list of ImagingExtractors.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

# TODO: Add consistency check for channel_names when API is standardized
def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
Raises
------
AssertionError
If any of the properties are not consistent between extractors.
Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
- num_frames
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_dtype="The data type",
get_num_frames="The number of frames",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.
Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).
Returns
-------
video: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.get_num_frames()

video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def get_image_size(self) -> Tuple:
"""Get the size of a single frame.
Returns
-------
image_size: tuple
The size of a single frame (num_rows, num_columns, num_planes).
"""
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes())
return image_size

def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
_num_planes: int
The number of depth planes.
"""
return self._num_planes

def get_num_frames(self) -> int:
return self._imaging_extractors[0].get_num_frames()

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()


class ScanImageTiffMultiPlaneImagingExtractor(MultiPlaneImagingExtractor):
class ScanImageTiffMultiPlaneImagingExtractor(VolumetricImagingExtractor):
"""Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage."""

extractor_name = "ScanImageTiffMultiPlaneImaging"
Expand Down

0 comments on commit aac88dd

Please sign in to comment.