diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index bf0014f2..2a7f6a1b 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -6,12 +6,14 @@ Specialized extractor for reading TIFF files produced via ScanImage. """ from pathlib import Path -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Iterable from warnings import warn import numpy as np from pprint import pprint -from ...extraction_tools import PathType, FloatType, ArrayType, get_package +from roiextractors.extraction_tools import DtypeType + +from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package from ...imagingextractor import ImagingExtractor @@ -62,58 +64,63 @@ def parse_metadata(metadata): return metadata_parsed -class ScanImageTiffMultiPlaneImagingExtractor(ImagingExtractor): - """Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage.""" +class MultiPlaneImagingExtractor(ImagingExtractor): + """Class to combine multiple ImagingExtractor objects by depth plane.""" - extractor_name = "ScanImageTiffMultiPlaneImaging" - is_writable = True - mode = "file" + extractor_name = "MultiPlaneImaging" + installed = True + installatiuon_mesage = "" - def __init__( - self, - file_path: PathType, - sampling_frequency: float, - channel: Optional[int] = 0, - num_channels: Optional[int] = 1, - num_planes: Optional[int] = 1, - frames_per_slice: Optional[int] = 1, - channel_names: Optional[list] = None, - ) -> None: + def __init__(self, imaging_extractors: List[ImagingExtractor]): + """Initialize a MultiPlaneImagingExtractor object from a list of ImagingExtractors. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + """ super().__init__() - self.file_path = Path(file_path) - self._sampling_frequency = sampling_frequency - self.metadata = extract_extra_metadata(file_path) - self.channel = channel - self._num_channels = num_channels - self._num_planes = num_planes - self._channel_names = channel_names - if channel >= num_channels: - raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).") - if frames_per_slice != 1: - raise NotImplementedError( - "Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: " - "https://github.com/catalystneuro/roiextractors/issues " - ) - imaging_extractors = [] - for plane in range(num_planes): - imaging_extractor = ScanImageTiffImagingExtractor( - file_path=file_path, - sampling_frequency=sampling_frequency, - channel=channel, - num_channels=num_channels, - plane=plane, - num_planes=num_planes, - channel_names=channel_names, - ) - imaging_extractors.append(imaging_extractor) assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) - # self._check_consistency_between_imaging_extractors(imaging_extractors) - self._num_planes = len(imaging_extractors) - assert all( - imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors - ), "All imaging extractors must have the same number of planes." + self._check_consistency_between_imaging_extractors(imaging_extractors) self._imaging_extractors = imaging_extractors + self._num_planes = len(imaging_extractors) + + def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): + """Check that essential properties are consistent between extractors so that they can be combined appropriately. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + + Raises + ------ + AssertionError + If any of the properties are not consistent between extractors. + + Notes + ----- + This method checks the following properties: + - sampling frequency + - image size + - number of channels + - channel names + - data type + """ + properties_to_check = dict( + get_sampling_frequency="The sampling frequency", + get_image_size="The size of a frame", + get_num_channels="The number of channels", + get_channel_names="The name of the channels", + get_dtype="The data type.", + ) + for method, property_message in properties_to_check.items(): + values = [getattr(extractor, method)() for extractor in imaging_extractors] + unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) + assert ( + len(unique_values) == 1 + ), f"{property_message} is not consistent over the files (found {unique_values})." def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: """Get the video frames. @@ -155,29 +162,23 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: frame_idxs = [frame_idxs] if not all(np.diff(frame_idxs) == 1): - return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs]) + frames = np.zeros((len(frame_idxs), *self.get_image_size(), self.get_num_planes()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + frames[..., i] = imaging_extractor.get_frames(frame_idxs) else: return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) - def _get_single_frame(self, frame: int) -> np.ndarray: - """Get a single frame of data from the TIFF file. + def get_image_size(self) -> Tuple: + return self._imaging_extractors[0].get_image_size() - Parameters - ---------- - frame : int - The index of the frame to retrieve. + def get_num_planes(self) -> int: + """Get the number of depth planes. Returns ------- - frame: numpy.ndarray - The 3D frame of data (num_rows, num_columns, num_planes). + _num_planes: int + The number of depth planes. """ - return self.get_video(start_frame=frame, end_frame=frame + 1)[0] - - def get_image_size(self) -> Tuple: - return self._imaging_extractors[0].get_image_size() - - def get_num_planes(self) -> int: return self._num_planes def get_num_frames(self) -> int: @@ -192,6 +193,54 @@ def get_channel_names(self) -> list: def get_num_channels(self) -> int: return self._imaging_extractors[0].get_num_channels() + def get_dtype(self) -> DtypeType: + return self._imaging_extractors[0].get_dtype() + + +class ScanImageTiffMultiPlaneImagingExtractor(MultiPlaneImagingExtractor): + """Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage.""" + + extractor_name = "ScanImageTiffMultiPlaneImaging" + is_writable = True + mode = "file" + + def __init__( + self, + file_path: PathType, + sampling_frequency: float, + channel: Optional[int] = 0, + num_channels: Optional[int] = 1, + num_planes: Optional[int] = 1, + frames_per_slice: Optional[int] = 1, + channel_names: Optional[list] = None, + ) -> None: + self.file_path = Path(file_path) + self.metadata = extract_extra_metadata(file_path) + self.channel = channel + if channel >= num_channels: + raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).") + if frames_per_slice != 1: + raise NotImplementedError( + "Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: " + "https://github.com/catalystneuro/roiextractors/issues " + ) + imaging_extractors = [] + for plane in range(num_planes): + imaging_extractor = ScanImageTiffImagingExtractor( + file_path=file_path, + sampling_frequency=sampling_frequency, + channel=channel, + num_channels=num_channels, + plane=plane, + num_planes=num_planes, + channel_names=channel_names, + ) + imaging_extractors.append(imaging_extractor) + super().__init__(imaging_extractors=imaging_extractors) + assert all( + imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors + ), "All imaging extractors must have the same number of planes." + class ScanImageTiffImagingExtractor(ImagingExtractor): """Specialized extractor for reading TIFF files produced via ScanImage.""" @@ -292,9 +341,9 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: frames: numpy.ndarray The video frames. """ - self.check_frame_inputs(frame_idxs[-1]) if isinstance(frame_idxs, int): frame_idxs = [frame_idxs] + self.check_frame_inputs(frame_idxs[-1]) if not all(np.diff(frame_idxs) == 1): return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs]) @@ -370,6 +419,9 @@ def get_channel_names(self) -> list: def get_num_planes(self) -> int: return self._num_planes + def get_dtype(self) -> DtypeType: + return self.get_frames(0).dtype + def check_frame_inputs(self, frame) -> None: if frame >= self._num_frames: raise ValueError(f"Frame index ({frame}) exceeds number of frames ({self._num_frames}).")