From 8e407e8fb573e570ecb92f2d6329e308f3d2409c Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Mon, 18 Sep 2023 15:50:45 -0700 Subject: [PATCH] added multi-plane extractor --- .../scanimagetiffimagingextractor.py | 136 +++++++++++++++++- tests/temp_test_scanimage.py | 3 +- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index 134c9ba0..bf0014f2 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -6,7 +6,7 @@ Specialized extractor for reading TIFF files produced via ScanImage. """ from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, List from warnings import warn import numpy as np from pprint import pprint @@ -62,6 +62,137 @@ def parse_metadata(metadata): return metadata_parsed +class ScanImageTiffMultiPlaneImagingExtractor(ImagingExtractor): + """Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage.""" + + extractor_name = "ScanImageTiffMultiPlaneImaging" + is_writable = True + mode = "file" + + def __init__( + self, + file_path: PathType, + sampling_frequency: float, + channel: Optional[int] = 0, + num_channels: Optional[int] = 1, + num_planes: Optional[int] = 1, + frames_per_slice: Optional[int] = 1, + channel_names: Optional[list] = None, + ) -> None: + super().__init__() + self.file_path = Path(file_path) + self._sampling_frequency = sampling_frequency + self.metadata = extract_extra_metadata(file_path) + self.channel = channel + self._num_channels = num_channels + self._num_planes = num_planes + self._channel_names = channel_names + if channel >= num_channels: + raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).") + if frames_per_slice != 1: + raise NotImplementedError( + "Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: " + "https://github.com/catalystneuro/roiextractors/issues " + ) + imaging_extractors = [] + for plane in range(num_planes): + imaging_extractor = ScanImageTiffImagingExtractor( + file_path=file_path, + sampling_frequency=sampling_frequency, + channel=channel, + num_channels=num_channels, + plane=plane, + num_planes=num_planes, + channel_names=channel_names, + ) + imaging_extractors.append(imaging_extractor) + assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" + assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) + # self._check_consistency_between_imaging_extractors(imaging_extractors) + self._num_planes = len(imaging_extractors) + assert all( + imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors + ), "All imaging extractors must have the same number of planes." + self._imaging_extractors = imaging_extractors + + def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: + """Get the video frames. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + video: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + + video = np.zeros((end_frame - start_frame, *self.get_image_size(), self.get_num_planes()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + video[..., i] = imaging_extractor.get_video(start_frame, end_frame) + return video + + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: + """Get specific video frames from indices (not necessarily continuous). + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + frames: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + + if not all(np.diff(frame_idxs) == 1): + return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs]) + else: + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) + + def _get_single_frame(self, frame: int) -> np.ndarray: + """Get a single frame of data from the TIFF file. + + Parameters + ---------- + frame : int + The index of the frame to retrieve. + + Returns + ------- + frame: numpy.ndarray + The 3D frame of data (num_rows, num_columns, num_planes). + """ + return self.get_video(start_frame=frame, end_frame=frame + 1)[0] + + def get_image_size(self) -> Tuple: + return self._imaging_extractors[0].get_image_size() + + def get_num_planes(self) -> int: + return self._num_planes + + def get_num_frames(self) -> int: + return self._num_frames + + def get_sampling_frequency(self) -> float: + return self._imaging_extractors[0].get_sampling_frequency() + + def get_channel_names(self) -> list: + return self._imaging_extractors[0].get_channel_names() + + def get_num_channels(self) -> int: + return self._imaging_extractors[0].get_num_channels() + + class ScanImageTiffImagingExtractor(ImagingExtractor): """Specialized extractor for reading TIFF files produced via ScanImage.""" @@ -89,8 +220,7 @@ def __init__( [channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_2_frame_1, channel_2_plane_2_frame_1, channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, ... channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N] - This file structure is sliced lazily using ScanImageTiffReader with the appropriate logic for specified - channels/frames. + This file structured is accessed by ScanImageTiffImagingExtractor for a single channel and plane. Parameters ---------- diff --git a/tests/temp_test_scanimage.py b/tests/temp_test_scanimage.py index 81e2e4f3..1bc9580c 100644 --- a/tests/temp_test_scanimage.py +++ b/tests/temp_test_scanimage.py @@ -3,6 +3,7 @@ from roiextractors.extractors.tiffimagingextractors.scanimagetiffimagingextractor import ( extract_extra_metadata, parse_metadata, + ScanImageTiffMultiPlaneImagingExtractor, ) @@ -18,7 +19,7 @@ def main(): metadata = extract_extra_metadata(example_holo) metadata_parsed = parse_metadata(metadata) - extractor = ScanImageTiffImagingExtractor(file_path=example_holo, **metadata_parsed) + extractor = ScanImageTiffMultiPlaneImagingExtractor(file_path=example_holo, **metadata_parsed) print("Example holographic file loads!") metadata = extract_extra_metadata(example_single)