From e84b6ad974a60621365d979f070d59ea82d4e242 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Mon, 18 Sep 2023 13:15:30 -0700 Subject: [PATCH] refactored to take plane and channel at __init__ rather than in get_video --- .../scanimagetiffimagingextractor.py | 85 ++++++++++++------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index 963867e4..1954a9ba 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -73,7 +73,9 @@ def __init__( self, file_path: PathType, sampling_frequency: float, + channel: Optional[int] = 0, num_channels: Optional[int] = 1, + plane: Optional[int] = 0, num_planes: Optional[int] = 1, frames_per_slice: Optional[int] = 1, channel_names: Optional[list] = None, @@ -96,27 +98,36 @@ def __init__( Path to the TIFF file. sampling_frequency : float Sampling frequency of each plane (scanVolumeRate) in Hz. + channel : int, optional + Index of the optical channel for this extractor (default=0). num_channels : int, optional Number of active channels that were acquired (default=1). + plane : int, optional + Index of the depth plane for this extractor (default=0). num_planes : int, optional Number of depth planes that were acquired (default=1). frames_per_slice : int, optional Number of frames per depth plane that were acquired (default=1). channel_names : list, optional - Names of the channels that were acquired (default=None). + Names of the channels (default=None). """ super().__init__() self.file_path = Path(file_path) self._sampling_frequency = sampling_frequency self.metadata = extract_extra_metadata(file_path) + self.channel = channel self._num_channels = num_channels + self.plane = plane self._num_planes = num_planes + if channel >= num_channels: + raise ValueError(f"Channel index ({channel}) exceeds number of channels ({num_channels}).") + if plane >= num_planes: + raise ValueError(f"Plane index ({plane}) exceeds number of planes ({num_planes}).") if frames_per_slice != 1: raise NotImplementedError( "Extractor cannot handle multiple frames per slice. Please raise an issue to request this feature: " "https://github.com/catalystneuro/roiextractors/issues " ) - self._channel_names = channel_names valid_suffixes = [".tiff", ".tif", ".TIFF", ".TIF"] if self.file_path.suffix not in valid_suffixes: @@ -137,60 +148,50 @@ def __init__( "https://github.com/catalystneuro/roiextractors/issues " ) - def get_frames(self, frame_idxs: ArrayType, channel: int = 0, plane: int = 0) -> np.ndarray: + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: """Get specific video frames from indices (not necessarily continuous). Parameters ---------- frame_idxs: array-like Indices of frames to return. - channel: int, optional - Channel index. - plane: int, optional - Plane index. Returns ------- frames: numpy.ndarray The video frames. """ - self.check_frame_inputs(frame_idxs[-1], channel, plane) + self.check_frame_inputs(frame_idxs[-1]) if isinstance(frame_idxs, int): frame_idxs = [frame_idxs] if not all(np.diff(frame_idxs) == 1): return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs]) else: - return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1, channel=channel) + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) # Data accessed through an open ScanImageTiffReader io gets scrambled if there are multiple calls. # Thus, open fresh io in context each time something is needed. - def _get_single_frame(self, frame: int, channel: Optional[int] = 0, plane: Optional[int] = 0) -> np.ndarray: + def _get_single_frame(self, frame: int) -> np.ndarray: """Get a single frame of data from the TIFF file. Parameters ---------- frame : int The index of the frame to retrieve. - channel : int, optional - The index of the channel to retrieve. - plane : int, optional - The index of the plane to retrieve. Returns ------- frame: numpy.ndarray The frame of data. """ - self.check_frame_inputs(frame, channel, plane) + self.check_frame_inputs(frame) ScanImageTiffReader = _get_scanimage_reader() - raw_index = (frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel + raw_index = self.frame_to_raw_index(frame) with ScanImageTiffReader(str(self.file_path)) as io: return io.data(beg=raw_index, end=raw_index + 1) - def get_video( - self, start_frame=None, end_frame=None, channel: Optional[int] = 0, plane: Optional[int] = 0 - ) -> np.ndarray: + def get_video(self, start_frame=None, end_frame=None) -> np.ndarray: """Get the video frames. Parameters @@ -199,10 +200,6 @@ def get_video( Start frame index (inclusive). end_frame: int, optional End frame index (exclusive). - channel: int, optional - Channel index. - plane: int, optional - Plane index. Returns ------- @@ -213,15 +210,15 @@ def get_video( start_frame = 0 if end_frame is None: end_frame = self._num_frames - self.check_frame_inputs(end_frame - 1, channel, plane) + self.check_frame_inputs(end_frame - 1) ScanImageTiffReader = _get_scanimage_reader() - raw_start = (start_frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel - raw_end = (end_frame * self._num_planes * self._num_channels) + (plane * self._num_channels) + channel + raw_start = self.frame_to_raw_index(start_frame) + raw_end = self.frame_to_raw_index(end_frame) raw_end = np.min([raw_end, self._total_num_frames]) with ScanImageTiffReader(filename=str(self.file_path)) as io: raw_video = io.data(beg=raw_start, end=raw_end) - video = raw_video[channel :: self._num_channels] - video = video[plane :: self._num_planes] + video = raw_video[self.channel :: self._num_channels] + video = video[self.plane :: self._num_planes] return video def get_image_size(self) -> Tuple[int, int]: @@ -242,10 +239,32 @@ def get_channel_names(self) -> list: def get_num_planes(self) -> int: return self._num_planes - def check_frame_inputs(self, frame, channel, plane) -> None: + def check_frame_inputs(self, frame) -> None: if frame >= self._num_frames: raise ValueError(f"Frame index ({frame}) exceeds number of frames ({self._num_frames}).") - if channel >= self._num_channels: - raise ValueError(f"Channel index ({channel}) exceeds number of channels ({self._num_channels}).") - if plane >= self._num_planes: - raise ValueError(f"Plane index ({plane}) exceeds number of planes ({self._num_planes}).") + + def frame_to_raw_index(self, frame): + """Convert a frame index to the raw index in the TIFF file. + + Parameters + ---------- + frame : int + The index of the frame to retrieve. + + Returns + ------- + raw_index: int + The raw index of the frame in the TIFF file. + + Notes + ----- + The underlying data is stored in a round-robin format collapsed into 3 dimensions (frames, rows, columns). + I.e. the first frame of each channel and each plane is stored, and then the second frame of each channel and + each plane, etc. + Ex. for 2 channels and 2 planes: + [channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_2_frame_1, channel_2_plane_2_frame_1, + channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, ... + channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N] + """ + raw_index = (frame * self._num_planes * self._num_channels) + (self.plane * self._num_channels) + self.channel + return raw_index