Skip to content

Commit

Permalink
refactored out generic MultiPlaneImagingExtractor into its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 19, 2023
1 parent 8e407e8 commit b04baab
Showing 1 changed file with 116 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
Specialized extractor for reading TIFF files produced via ScanImage.
"""
from pathlib import Path
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Iterable
from warnings import warn
import numpy as np
from pprint import pprint

from ...extraction_tools import PathType, FloatType, ArrayType, get_package
from roiextractors.extraction_tools import DtypeType

from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package
from ...imagingextractor import ImagingExtractor


Expand Down Expand Up @@ -62,58 +64,63 @@ def parse_metadata(metadata):
return metadata_parsed


class ScanImageTiffMultiPlaneImagingExtractor(ImagingExtractor):
"""Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage."""
class MultiPlaneImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "ScanImageTiffMultiPlaneImaging"
is_writable = True
mode = "file"
extractor_name = "MultiPlaneImaging"
installed = True
installatiuon_mesage = ""

def __init__(
self,
file_path: PathType,
sampling_frequency: float,
channel: Optional[int] = 0,
num_channels: Optional[int] = 1,
num_planes: Optional[int] = 1,
frames_per_slice: Optional[int] = 1,
channel_names: Optional[list] = None,
) -> None:
def __init__(self, imaging_extractors: List[ImagingExtractor]):
"""Initialize a MultiPlaneImagingExtractor object from a list of ImagingExtractors.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
self.file_path = Path(file_path)
self._sampling_frequency = sampling_frequency
self.metadata = extract_extra_metadata(file_path)
self.channel = channel
self._num_channels = num_channels
self._num_planes = num_planes
self._channel_names = channel_names
if channel >= num_channels:
raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).")
if frames_per_slice != 1:
raise NotImplementedError(
"Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: "
"https://github.com/catalystneuro/roiextractors/issues "
)
imaging_extractors = []
for plane in range(num_planes):
imaging_extractor = ScanImageTiffImagingExtractor(
file_path=file_path,
sampling_frequency=sampling_frequency,
channel=channel,
num_channels=num_channels,
plane=plane,
num_planes=num_planes,
channel_names=channel_names,
)
imaging_extractors.append(imaging_extractor)
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
# self._check_consistency_between_imaging_extractors(imaging_extractors)
self._num_planes = len(imaging_extractors)
assert all(
imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors
), "All imaging extractors must have the same number of planes."
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
Raises
------
AssertionError
If any of the properties are not consistent between extractors.
Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type.",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.
Expand Down Expand Up @@ -155,29 +162,23 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs])
frames = np.zeros((len(frame_idxs), *self.get_image_size(), self.get_num_planes()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def _get_single_frame(self, frame: int) -> np.ndarray:
"""Get a single frame of data from the TIFF file.
def get_image_size(self) -> Tuple:
return self._imaging_extractors[0].get_image_size()

Parameters
----------
frame : int
The index of the frame to retrieve.
def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
frame: numpy.ndarray
The 3D frame of data (num_rows, num_columns, num_planes).
_num_planes: int
The number of depth planes.
"""
return self.get_video(start_frame=frame, end_frame=frame + 1)[0]

def get_image_size(self) -> Tuple:
return self._imaging_extractors[0].get_image_size()

def get_num_planes(self) -> int:
return self._num_planes

def get_num_frames(self) -> int:
Expand All @@ -192,6 +193,54 @@ def get_channel_names(self) -> list:
def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()


class ScanImageTiffMultiPlaneImagingExtractor(MultiPlaneImagingExtractor):
"""Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage."""

extractor_name = "ScanImageTiffMultiPlaneImaging"
is_writable = True
mode = "file"

def __init__(
self,
file_path: PathType,
sampling_frequency: float,
channel: Optional[int] = 0,
num_channels: Optional[int] = 1,
num_planes: Optional[int] = 1,
frames_per_slice: Optional[int] = 1,
channel_names: Optional[list] = None,
) -> None:
self.file_path = Path(file_path)
self.metadata = extract_extra_metadata(file_path)
self.channel = channel
if channel >= num_channels:
raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).")
if frames_per_slice != 1:
raise NotImplementedError(
"Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: "
"https://github.com/catalystneuro/roiextractors/issues "
)
imaging_extractors = []
for plane in range(num_planes):
imaging_extractor = ScanImageTiffImagingExtractor(
file_path=file_path,
sampling_frequency=sampling_frequency,
channel=channel,
num_channels=num_channels,
plane=plane,
num_planes=num_planes,
channel_names=channel_names,
)
imaging_extractors.append(imaging_extractor)
super().__init__(imaging_extractors=imaging_extractors)
assert all(
imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors
), "All imaging extractors must have the same number of planes."


class ScanImageTiffImagingExtractor(ImagingExtractor):
"""Specialized extractor for reading TIFF files produced via ScanImage."""
Expand Down Expand Up @@ -292,9 +341,9 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
frames: numpy.ndarray
The video frames.
"""
self.check_frame_inputs(frame_idxs[-1])
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]
self.check_frame_inputs(frame_idxs[-1])

if not all(np.diff(frame_idxs) == 1):
return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs])
Expand Down Expand Up @@ -370,6 +419,9 @@ def get_channel_names(self) -> list:
def get_num_planes(self) -> int:
return self._num_planes

def get_dtype(self) -> DtypeType:
return self.get_frames(0).dtype

def check_frame_inputs(self, frame) -> None:
if frame >= self._num_frames:
raise ValueError(f"Frame index ({frame}) exceeds number of frames ({self._num_frames}).")
Expand Down

0 comments on commit b04baab

Please sign in to comment.