Skip to content

Commit

Permalink
added multi-plane extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 18, 2023
1 parent 243c168 commit 8e407e8
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Specialized extractor for reading TIFF files produced via ScanImage.
"""
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, List
from warnings import warn
import numpy as np
from pprint import pprint
Expand Down Expand Up @@ -62,6 +62,137 @@ def parse_metadata(metadata):
return metadata_parsed


class ScanImageTiffMultiPlaneImagingExtractor(ImagingExtractor):
"""Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage."""

extractor_name = "ScanImageTiffMultiPlaneImaging"
is_writable = True
mode = "file"

def __init__(
self,
file_path: PathType,
sampling_frequency: float,
channel: Optional[int] = 0,
num_channels: Optional[int] = 1,
num_planes: Optional[int] = 1,
frames_per_slice: Optional[int] = 1,
channel_names: Optional[list] = None,
) -> None:
super().__init__()
self.file_path = Path(file_path)
self._sampling_frequency = sampling_frequency
self.metadata = extract_extra_metadata(file_path)
self.channel = channel
self._num_channels = num_channels
self._num_planes = num_planes
self._channel_names = channel_names
if channel >= num_channels:
raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).")
if frames_per_slice != 1:
raise NotImplementedError(
"Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: "
"https://github.com/catalystneuro/roiextractors/issues "
)
imaging_extractors = []
for plane in range(num_planes):
imaging_extractor = ScanImageTiffImagingExtractor(
file_path=file_path,
sampling_frequency=sampling_frequency,
channel=channel,
num_channels=num_channels,
plane=plane,
num_planes=num_planes,
channel_names=channel_names,
)
imaging_extractors.append(imaging_extractor)
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
# self._check_consistency_between_imaging_extractors(imaging_extractors)
self._num_planes = len(imaging_extractors)
assert all(
imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors
), "All imaging extractors must have the same number of planes."
self._imaging_extractors = imaging_extractors

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.
Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).
Returns
-------
video: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.get_num_frames()

video = np.zeros((end_frame - start_frame, *self.get_image_size(), self.get_num_planes()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs])
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def _get_single_frame(self, frame: int) -> np.ndarray:
"""Get a single frame of data from the TIFF file.
Parameters
----------
frame : int
The index of the frame to retrieve.
Returns
-------
frame: numpy.ndarray
The 3D frame of data (num_rows, num_columns, num_planes).
"""
return self.get_video(start_frame=frame, end_frame=frame + 1)[0]

def get_image_size(self) -> Tuple:
return self._imaging_extractors[0].get_image_size()

def get_num_planes(self) -> int:
return self._num_planes

def get_num_frames(self) -> int:
return self._num_frames

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()


class ScanImageTiffImagingExtractor(ImagingExtractor):
"""Specialized extractor for reading TIFF files produced via ScanImage."""

Expand Down Expand Up @@ -89,8 +220,7 @@ def __init__(
[channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_2_frame_1, channel_2_plane_2_frame_1,
channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, ...
channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N]
This file structure is sliced lazily using ScanImageTiffReader with the appropriate logic for specified
channels/frames.
This file structured is accessed by ScanImageTiffImagingExtractor for a single channel and plane.
Parameters
----------
Expand Down
3 changes: 2 additions & 1 deletion tests/temp_test_scanimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from roiextractors.extractors.tiffimagingextractors.scanimagetiffimagingextractor import (
extract_extra_metadata,
parse_metadata,
ScanImageTiffMultiPlaneImagingExtractor,
)


Expand All @@ -18,7 +19,7 @@ def main():

metadata = extract_extra_metadata(example_holo)
metadata_parsed = parse_metadata(metadata)
extractor = ScanImageTiffImagingExtractor(file_path=example_holo, **metadata_parsed)
extractor = ScanImageTiffMultiPlaneImagingExtractor(file_path=example_holo, **metadata_parsed)
print("Example holographic file loads!")

metadata = extract_extra_metadata(example_single)
Expand Down

0 comments on commit 8e407e8

Please sign in to comment.