Skip to content

Commit

Permalink
refactored to take plane and channel at __init__ rather than in get_v…
Browse files Browse the repository at this point in the history
…ideo
  • Loading branch information
pauladkisson committed Sep 18, 2023
1 parent 4819be2 commit e84b6ad
Showing 1 changed file with 52 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __init__(
self,
file_path: PathType,
sampling_frequency: float,
channel: Optional[int] = 0,
num_channels: Optional[int] = 1,
plane: Optional[int] = 0,
num_planes: Optional[int] = 1,
frames_per_slice: Optional[int] = 1,
channel_names: Optional[list] = None,
Expand All @@ -96,27 +98,36 @@ def __init__(
Path to the TIFF file.
sampling_frequency : float
Sampling frequency of each plane (scanVolumeRate) in Hz.
channel : int, optional
Index of the optical channel for this extractor (default=0).
num_channels : int, optional
Number of active channels that were acquired (default=1).
plane : int, optional
Index of the depth plane for this extractor (default=0).
num_planes : int, optional
Number of depth planes that were acquired (default=1).
frames_per_slice : int, optional
Number of frames per depth plane that were acquired (default=1).
channel_names : list, optional
Names of the channels that were acquired (default=None).
Names of the channels (default=None).
"""
super().__init__()
self.file_path = Path(file_path)
self._sampling_frequency = sampling_frequency
self.metadata = extract_extra_metadata(file_path)
self.channel = channel
self._num_channels = num_channels
self.plane = plane
self._num_planes = num_planes
if channel >= num_channels:
raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).")
if plane >= num_planes:
raise ValueError(f"Plane index ({plane}) exceeds number of planes ({num_planes}).")
if frames_per_slice != 1:
raise NotImplementedError(
"Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: "
"https://github.com/catalystneuro/roiextractors/issues "
)
self._channel_names = channel_names

valid_suffixes = [".tiff", ".tif", ".TIFF", ".TIF"]
if self.file_path.suffix not in valid_suffixes:
Expand All @@ -137,60 +148,50 @@ def __init__(
"https://github.com/catalystneuro/roiextractors/issues "
)

def get_frames(self, frame_idxs: ArrayType, channel: int = 0, plane: int = 0) -> np.ndarray:
def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
channel: int, optional
Channel index.
plane: int, optional
Plane index.
Returns
-------
frames: numpy.ndarray
The video frames.
"""
self.check_frame_inputs(frame_idxs[-1], channel, plane)
self.check_frame_inputs(frame_idxs[-1])
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]

if not all(np.diff(frame_idxs) == 1):
return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs])
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1, channel=channel)
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

# Data accessed through an open ScanImageTiffReader io gets scrambled if there are multiple calls.
# Thus, open fresh io in context each time something is needed.
def _get_single_frame(self, frame: int, channel: Optional[int] = 0, plane: Optional[int] = 0) -> np.ndarray:
def _get_single_frame(self, frame: int) -> np.ndarray:
"""Get a single frame of data from the TIFF file.
Parameters
----------
frame : int
The index of the frame to retrieve.
channel : int, optional
The index of the channel to retrieve.
plane : int, optional
The index of the plane to retrieve.
Returns
-------
frame: numpy.ndarray
The frame of data.
"""
self.check_frame_inputs(frame, channel, plane)
self.check_frame_inputs(frame)
ScanImageTiffReader = _get_scanimage_reader()
raw_index = (frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel
raw_index = self.frame_to_raw_index(frame)
with ScanImageTiffReader(str(self.file_path)) as io:
return io.data(beg=raw_index, end=raw_index + 1)

def get_video(
self, start_frame=None, end_frame=None, channel: Optional[int] = 0, plane: Optional[int] = 0
) -> np.ndarray:
def get_video(self, start_frame=None, end_frame=None) -> np.ndarray:
"""Get the video frames.
Parameters
Expand All @@ -199,10 +200,6 @@ def get_video(
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).
channel: int, optional
Channel index.
plane: int, optional
Plane index.
Returns
-------
Expand All @@ -213,15 +210,15 @@ def get_video(
start_frame = 0
if end_frame is None:
end_frame = self._num_frames
self.check_frame_inputs(end_frame - 1, channel, plane)
self.check_frame_inputs(end_frame - 1)
ScanImageTiffReader = _get_scanimage_reader()
raw_start = (start_frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel
raw_end = (end_frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel
raw_start = self.frame_to_raw_index(start_frame)
raw_end = self.frame_to_raw_index(end_frame)
raw_end = np.min([raw_end, self._total_num_frames])
with ScanImageTiffReader(filename=str(self.file_path)) as io:
raw_video = io.data(beg=raw_start, end=raw_end)
video = raw_video[channel :: self._num_channels]
video = video[plane :: self._num_planes]
video = raw_video[self.channel :: self._num_channels]
video = video[self.plane :: self._num_planes]
return video

def get_image_size(self) -> Tuple[int, int]:
Expand All @@ -242,10 +239,32 @@ def get_channel_names(self) -> list:
def get_num_planes(self) -> int:
return self._num_planes

def check_frame_inputs(self, frame, channel, plane) -> None:
def check_frame_inputs(self, frame) -> None:
if frame >= self._num_frames:
raise ValueError(f"Frame index ({frame}) exceeds number of frames ({self._num_frames}).")
if channel >= self._num_channels:
raise ValueError(f"Channel index ({channel}) exceeds number of channels ({self._num_channels}).")
if plane >= self._num_planes:
raise ValueError(f"Plane index ({plane}) exceeds number of planes ({self._num_planes}).")

def frame_to_raw_index(self, frame):
"""Convert a frame index to the raw index in the TIFF file.
Parameters
----------
frame : int
The index of the frame to retrieve.
Returns
-------
raw_index: int
The raw index of the frame in the TIFF file.
Notes
-----
The underlying data is stored in a round-robin format collapsed into 3 dimensions (frames, rows, columns).
I.e. the first frame of each channel and each plane is stored, and then the second frame of each channel and
each plane, etc.
Ex. for 2 channels and 2 planes:
[channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_2_frame_1, channel_2_plane_2_frame_1,
channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, ...
channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N]
"""
raw_index = (frame * self._num_planes * self._num_channels) + (self.plane * self._num_channels) + self.channel
return raw_index

0 comments on commit e84b6ad

Please sign in to comment.