diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index 0ad12b43..db0e3209 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -12,6 +12,7 @@ from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package from ...imagingextractor import ImagingExtractor +from ...volumetricimagingextractor import VolumetricImagingExtractor from .scanimagetiff_utils import ( extract_extra_metadata, parse_metadata, @@ -20,150 +21,7 @@ ) -class MultiPlaneImagingExtractor(ImagingExtractor): - """Class to combine multiple ImagingExtractor objects by depth plane.""" - - extractor_name = "MultiPlaneImaging" - installed = True - installatiuon_mesage = "" - - def __init__(self, imaging_extractors: List[ImagingExtractor]): - """Initialize a MultiPlaneImagingExtractor object from a list of ImagingExtractors. - - Parameters - ---------- - imaging_extractors: list of ImagingExtractor - list of imaging extractor objects - """ - super().__init__() - assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" - assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) - self._check_consistency_between_imaging_extractors(imaging_extractors) - self._imaging_extractors = imaging_extractors - self._num_planes = len(imaging_extractors) - - # TODO: Add consistency check for channel_names when API is standardized - def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): - """Check that essential properties are consistent between extractors so that they can be combined appropriately. - - Parameters - ---------- - imaging_extractors: list of ImagingExtractor - list of imaging extractor objects - - Raises - ------ - AssertionError - If any of the properties are not consistent between extractors. - - Notes - ----- - This method checks the following properties: - - sampling frequency - - image size - - number of channels - - channel names - - data type - - num_frames - """ - properties_to_check = dict( - get_sampling_frequency="The sampling frequency", - get_image_size="The size of a frame", - get_num_channels="The number of channels", - get_dtype="The data type", - get_num_frames="The number of frames", - ) - for method, property_message in properties_to_check.items(): - values = [getattr(extractor, method)() for extractor in imaging_extractors] - unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) - assert ( - len(unique_values) == 1 - ), f"{property_message} is not consistent over the files (found {unique_values})." - - def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: - """Get the video frames. - - Parameters - ---------- - start_frame: int, optional - Start frame index (inclusive). - end_frame: int, optional - End frame index (exclusive). - - Returns - ------- - video: numpy.ndarray - The 3D video frames (num_rows, num_columns, num_planes). - """ - start_frame = start_frame if start_frame is not None else 0 - end_frame = end_frame if end_frame is not None else self.get_num_frames() - - video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) - for i, imaging_extractor in enumerate(self._imaging_extractors): - video[..., i] = imaging_extractor.get_video(start_frame, end_frame) - return video - - def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: - """Get specific video frames from indices (not necessarily continuous). - - Parameters - ---------- - frame_idxs: array-like - Indices of frames to return. - - Returns - ------- - frames: numpy.ndarray - The 3D video frames (num_rows, num_columns, num_planes). - """ - if isinstance(frame_idxs, int): - frame_idxs = [frame_idxs] - - if not all(np.diff(frame_idxs) == 1): - frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) - for i, imaging_extractor in enumerate(self._imaging_extractors): - frames[..., i] = imaging_extractor.get_frames(frame_idxs) - else: - return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) - - def get_image_size(self) -> Tuple: - """Get the size of a single frame. - - Returns - ------- - image_size: tuple - The size of a single frame (num_rows, num_columns, num_planes). - """ - image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) - return image_size - - def get_num_planes(self) -> int: - """Get the number of depth planes. - - Returns - ------- - _num_planes: int - The number of depth planes. - """ - return self._num_planes - - def get_num_frames(self) -> int: - return self._imaging_extractors[0].get_num_frames() - - def get_sampling_frequency(self) -> float: - return self._imaging_extractors[0].get_sampling_frequency() - - def get_channel_names(self) -> list: - return self._imaging_extractors[0].get_channel_names() - - def get_num_channels(self) -> int: - return self._imaging_extractors[0].get_num_channels() - - def get_dtype(self) -> DtypeType: - return self._imaging_extractors[0].get_dtype() - - -class ScanImageTiffMultiPlaneImagingExtractor(MultiPlaneImagingExtractor): +class ScanImageTiffMultiPlaneImagingExtractor(VolumetricImagingExtractor): """Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage.""" extractor_name = "ScanImageTiffMultiPlaneImaging"