From 0a223a5c306af72c4087cfab4b5cba4b45786791 Mon Sep 17 00:00:00 2001 From: Ben Dichter Date: Tue, 20 Feb 2024 15:34:25 -0500 Subject: [PATCH] clean: (#273) * clean: typos fix docstrings rmv unused import * rmv unused imports * add typehints * moved get_video_shape from extraction_tools to numpyextractors where it is used * reordered documentation and names to match usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make get_video_shape a staticmethod --------- Co-authored-by: pauladkisson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/roiextractors/__init__.py | 2 +- src/roiextractors/extraction_tools.py | 28 ++----------------- .../hdf5imagingextractor.py | 5 +--- .../memmapextractors/memmapextractors.py | 8 ++---- .../memmapextractors/numpymemampextractor.py | 7 +---- .../numpyextractors/numpyextractors.py | 25 +++++++++++++++-- .../extractors/nwbextractors/nwbextractors.py | 2 -- .../sbximagingextractor.py | 1 - .../suite2p/suite2psegmentationextractor.py | 3 +- src/roiextractors/multiimagingextractor.py | 6 ++-- src/roiextractors/segmentationextractor.py | 20 ++++++++----- src/roiextractors/testing.py | 2 ++ tests/test_internals/test_extraction_tools.py | 2 +- tests/test_scanimage_utils.py | 1 - tests/test_suite2psegmentationextractor.py | 1 - 15 files changed, 51 insertions(+), 62 deletions(-) diff --git a/src/roiextractors/__init__.py b/src/roiextractors/__init__.py index 9b517e50..a708f1d2 100644 --- a/src/roiextractors/__init__.py +++ b/src/roiextractors/__init__.py @@ -1,7 +1,7 @@ """Python-based module for extracting from, converting between, and handling recorded and optical imaging data from several file formats.""" # Keeping __version__ accessible only to maintain backcompatability. -# Modern appraoch (Python >= 3.8) is to use importlib +# Modern approach (Python >= 3.8) is to use importlib try: from importlib.metadata import version diff --git a/src/roiextractors/extraction_tools.py b/src/roiextractors/extraction_tools.py index 704be5db..6f8a5874 100644 --- a/src/roiextractors/extraction_tools.py +++ b/src/roiextractors/extraction_tools.py @@ -149,7 +149,7 @@ def _validate_video_structure(self) -> None: "each property axis should be unique value between 0 and 3 (inclusive)" ) - axis_values = set((self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis)) + axis_values = {self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis} axis_values_are_not_unique = len(axis_values) != 4 if axis_values_are_not_unique: raise ValueError(exception_message) @@ -273,7 +273,7 @@ def read_numpy_memmap_video( return video_memap -def _pixel_mask_extractor(image_mask_, _roi_ids): +def _pixel_mask_extractor(image_mask_, _roi_ids) -> list: """Convert image mask to pixel mask. Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images. @@ -302,7 +302,7 @@ def _pixel_mask_extractor(image_mask_, _roi_ids): return pixel_mask_list -def _image_mask_extractor(pixel_mask, _roi_ids, image_shape): +def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray: """Convert a pixel mask to image mask. Parameters @@ -326,28 +326,6 @@ def _image_mask_extractor(pixel_mask, _roi_ids, image_shape): return image_mask -def get_video_shape(video): - """Get the shape of a video (num_channels, num_frames, size_x, size_y). - - Parameters - ---------- - video: numpy.ndarray - The video to get the shape of. - - Returns - ------- - video_shape: tuple - The shape of the video (num_channels, num_frames, size_x, size_y). - """ - if len(video.shape) == 3: - # 1 channel - num_channels = 1 - num_frames, size_x, size_y = video.shape - else: - num_channels, num_frames, size_x, size_y = video.shape - return num_channels, num_frames, size_x, size_y - - def check_get_frames_args(func): """Check the arguments of the get_frames function. diff --git a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py index 6f13d1da..5ce5dbf7 100644 --- a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py +++ b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py @@ -13,10 +13,7 @@ import numpy as np from ...extraction_tools import PathType, FloatType, ArrayType -from ...extraction_tools import ( - get_video_shape, - write_to_h5_dataset_format, -) +from ...extraction_tools import write_to_h5_dataset_format from ...imagingextractor import ImagingExtractor from lazy_ops import DatasetView diff --git a/src/roiextractors/extractors/memmapextractors/memmapextractors.py b/src/roiextractors/extractors/memmapextractors/memmapextractors.py index 62e44e69..88b7e52d 100644 --- a/src/roiextractors/extractors/memmapextractors/memmapextractors.py +++ b/src/roiextractors/extractors/memmapextractors/memmapextractors.py @@ -13,13 +13,9 @@ from tqdm import tqdm from ...imagingextractor import ImagingExtractor -from typing import Tuple, Dict, Optional +from typing import Tuple, Optional -from ...extraction_tools import ( - PathType, - DtypeType, - NumpyArray, -) +from ...extraction_tools import PathType, DtypeType class MemmapImagingExtractor(ImagingExtractor): diff --git a/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py b/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py index c3636132..bd25b33a 100644 --- a/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py +++ b/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py @@ -8,14 +8,9 @@ import os from pathlib import Path -from typing import Tuple, Dict -import numpy as np -from tqdm import tqdm - -from ...imagingextractor import ImagingExtractor -from typing import Tuple, Dict from roiextractors.extraction_tools import read_numpy_memmap_video, VideoStructure, DtypeType, PathType + from .memmapextractors import MemmapImagingExtractor diff --git a/src/roiextractors/extractors/numpyextractors/numpyextractors.py b/src/roiextractors/extractors/numpyextractors/numpyextractors.py index 0ed3c0df..e5ea0c67 100644 --- a/src/roiextractors/extractors/numpyextractors/numpyextractors.py +++ b/src/roiextractors/extractors/numpyextractors/numpyextractors.py @@ -14,7 +14,6 @@ import numpy as np from ...extraction_tools import PathType, FloatType, ArrayType -from ...extraction_tools import get_video_shape from ...imagingextractor import ImagingExtractor from ...segmentationextractor import SegmentationExtractor @@ -78,7 +77,7 @@ def __init__( self._num_rows, self._num_columns, self._num_channels, - ) = get_video_shape(self._video) + ) = self.get_video_shape(self._video) if len(self._video.shape) == 3: # check if this converts to np.ndarray @@ -91,6 +90,28 @@ def __init__( else: self._channel_names = [f"channel_{ch}" for ch in range(self._num_channels)] + @staticmethod + def get_video_shape(video) -> Tuple[int, int, int, int]: + """Get the shape of a video (num_frames, num_rows, num_columns, num_channels). + + Parameters + ---------- + video: numpy.ndarray + The video to get the shape of. + + Returns + ------- + video_shape: tuple + The shape of the video (num_frames, num_rows, num_columns, num_channels). + """ + if len(video.shape) == 3: + # 1 channel + num_channels = 1 + num_frames, num_rows, num_columns = video.shape + else: + num_frames, num_rows, num_columns, num_channels = video.shape + return num_frames, num_rows, num_columns, num_channels + def get_frames(self, frame_idxs=None, channel: Optional[int] = 0) -> np.ndarray: if frame_idxs is None: frame_idxs = [frame for frame in range(self.get_num_frames())] diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index 8f73285f..61ff9f44 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -26,8 +26,6 @@ FloatType, IntType, ArrayType, - check_get_frames_args, - check_get_videos_args, raise_multi_channel_or_depth_not_implemented, ) from ...imagingextractor import ImagingExtractor diff --git a/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py b/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py index 008b35e4..e013fb75 100644 --- a/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py +++ b/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py @@ -9,7 +9,6 @@ from multiprocessing.sharedctypes import Value import os from pathlib import Path -from warnings import warn from typing import Tuple, Optional import numpy as np diff --git a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py index b7ade533..cfbd6570 100644 --- a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py +++ b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Optional from warnings import warn -import os import numpy as np from ...extraction_tools import PathType @@ -34,7 +33,7 @@ def get_available_channels(cls, folder_path: PathType): Parameters ---------- - file_path : PathType + folder_path : PathType Path to Suite2p output path. Returns diff --git a/src/roiextractors/multiimagingextractor.py b/src/roiextractors/multiimagingextractor.py index a0a09bd4..294daa24 100644 --- a/src/roiextractors/multiimagingextractor.py +++ b/src/roiextractors/multiimagingextractor.py @@ -11,7 +11,7 @@ import numpy as np -from .extraction_tools import ArrayType, NumpyArray, check_get_frames_args +from .extraction_tools import ArrayType, NumpyArray from .imagingextractor import ImagingExtractor @@ -81,7 +81,7 @@ def _check_consistency_between_imaging_extractors(self): len(unique_values) == 1 ), f"{property_message} is not consistent over the files (found {unique_values})." - def _get_times(self): + def _get_times(self) -> np.ndarray: """Get all the times from the imaging extractors and combine them into a single array. Returns @@ -200,7 +200,7 @@ def get_video( return video - def get_image_size(self) -> Tuple: + def get_image_size(self) -> Tuple[int, int]: return self._imaging_extractors[0].get_image_size() def get_num_frames(self) -> int: diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index a1f310d3..2eb8d010 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -181,7 +181,13 @@ def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int """ return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame) - def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"): + def get_traces( + self, + roi_ids: ArrayType = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + name: str = "raw", + ): """Get the traces of each ROI specified by roi_ids. Parameters @@ -237,7 +243,7 @@ def get_images_dict(self): """ return dict(mean=self._image_mean, correlation=self._image_correlation) - def get_image(self, name="correlation"): + def get_image(self, name: str = "correlation") -> ArrayType: """Get specific images: mean or correlation. Parameters @@ -253,7 +259,7 @@ def get_image(self, name="correlation"): raise ValueError(f"could not find {name} image, enter one of {list(self.get_images_dict().keys())}") return self.get_images_dict().get(name) - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: """Get the sampling frequency in Hz. Returns @@ -266,7 +272,7 @@ def get_sampling_frequency(self): return self._sampling_frequency - def get_num_rois(self): + def get_num_rois(self) -> int: """Get total number of Regions of Interest (ROIs) in the acquired images. Returns @@ -278,7 +284,7 @@ def get_num_rois(self): if trace is not None and len(trace.shape) > 0: return trace.shape[1] - def get_channel_names(self): + def get_channel_names(self) -> List[str]: """Get names of channels in the pipeline. Returns @@ -288,7 +294,7 @@ def get_channel_names(self): """ return self._channel_names - def get_num_channels(self): + def get_num_channels(self) -> int: """Get number of channels in the pipeline. Returns @@ -342,7 +348,7 @@ def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, A Parameters ---------- - frame_indices: int or array-like + frames: int or array-like The frame or frames to be converted to times Returns diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index d975e30d..c82bc0f8 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -74,6 +74,8 @@ def generate_dummy_imaging_extractor( sampling frequency of the video, by default 30. dtype : DtypeType, optional dtype of the video, by default "uint16". + channel_names : list, optional + list of channel names. Returns ------- diff --git a/tests/test_internals/test_extraction_tools.py b/tests/test_internals/test_extraction_tools.py index 44e73ff4..01f82881 100644 --- a/tests/test_internals/test_extraction_tools.py +++ b/tests/test_internals/test_extraction_tools.py @@ -95,7 +95,7 @@ def test_invalid_structure_with_repeated_axis(self): "^Invalid structure: (.*?) each property axis should be unique value between 0 and 3 (inclusive)?" ) with self.assertRaisesRegex(ValueError, reg_expression): - video_structure = VideoStructure( + VideoStructure( num_rows=self.num_rows, num_columns=self.num_columns, num_channels=self.num_channels, diff --git a/tests/test_scanimage_utils.py b/tests/test_scanimage_utils.py index 3d1e60c0..619a2327 100644 --- a/tests/test_scanimage_utils.py +++ b/tests/test_scanimage_utils.py @@ -1,6 +1,5 @@ import pytest from numpy.testing import assert_array_equal -from ScanImageTiffReader import ScanImageTiffReader from roiextractors.extractors.tiffimagingextractors.scanimagetiff_utils import ( _get_scanimage_reader, extract_extra_metadata, diff --git a/tests/test_suite2psegmentationextractor.py b/tests/test_suite2psegmentationextractor.py index b38265a2..62c9647e 100644 --- a/tests/test_suite2psegmentationextractor.py +++ b/tests/test_suite2psegmentationextractor.py @@ -7,7 +7,6 @@ from numpy.testing import assert_array_equal from roiextractors import Suite2pSegmentationExtractor -from roiextractors.extraction_tools import _image_mask_extractor from tests.setup_paths import OPHYS_DATA_PATH