Skip to content

Commit

Permalink
clean: (#273)
Browse files Browse the repository at this point in the history
* clean:
typos
fix docstrings
rmv unused import

* rmv unused imports

* add typehints

* moved get_video_shape from extraction_tools to numpyextractors where it is used

* reordered documentation and names to match usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make get_video_shape a staticmethod

---------

Co-authored-by: pauladkisson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 20, 2024
1 parent eb07c0a commit 0a223a5
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/roiextractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Python-based module for extracting from, converting between, and handling recorded and optical imaging data from several file formats."""

# Keeping __version__ accessible only to maintain backcompatability.
# Modern appraoch (Python >= 3.8) is to use importlib
# Modern approach (Python >= 3.8) is to use importlib
try:
from importlib.metadata import version

Expand Down
28 changes: 3 additions & 25 deletions src/roiextractors/extraction_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _validate_video_structure(self) -> None:
"each property axis should be unique value between 0 and 3 (inclusive)"
)

axis_values = set((self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis))
axis_values = {self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis}
axis_values_are_not_unique = len(axis_values) != 4
if axis_values_are_not_unique:
raise ValueError(exception_message)
Expand Down Expand Up @@ -273,7 +273,7 @@ def read_numpy_memmap_video(
return video_memap


def _pixel_mask_extractor(image_mask_, _roi_ids):
def _pixel_mask_extractor(image_mask_, _roi_ids) -> list:
"""Convert image mask to pixel mask.
Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images.
Expand Down Expand Up @@ -302,7 +302,7 @@ def _pixel_mask_extractor(image_mask_, _roi_ids):
return pixel_mask_list


def _image_mask_extractor(pixel_mask, _roi_ids, image_shape):
def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray:
"""Convert a pixel mask to image mask.
Parameters
Expand All @@ -326,28 +326,6 @@ def _image_mask_extractor(pixel_mask, _roi_ids, image_shape):
return image_mask


def get_video_shape(video):
"""Get the shape of a video (num_channels, num_frames, size_x, size_y).
Parameters
----------
video: numpy.ndarray
The video to get the shape of.
Returns
-------
video_shape: tuple
The shape of the video (num_channels, num_frames, size_x, size_y).
"""
if len(video.shape) == 3:
# 1 channel
num_channels = 1
num_frames, size_x, size_y = video.shape
else:
num_channels, num_frames, size_x, size_y = video.shape
return num_channels, num_frames, size_x, size_y


def check_get_frames_args(func):
"""Check the arguments of the get_frames function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
import numpy as np

from ...extraction_tools import PathType, FloatType, ArrayType
from ...extraction_tools import (
get_video_shape,
write_to_h5_dataset_format,
)
from ...extraction_tools import write_to_h5_dataset_format
from ...imagingextractor import ImagingExtractor
from lazy_ops import DatasetView

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@
from tqdm import tqdm

from ...imagingextractor import ImagingExtractor
from typing import Tuple, Dict, Optional
from typing import Tuple, Optional

from ...extraction_tools import (
PathType,
DtypeType,
NumpyArray,
)
from ...extraction_tools import PathType, DtypeType


class MemmapImagingExtractor(ImagingExtractor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@

import os
from pathlib import Path
from typing import Tuple, Dict

import numpy as np
from tqdm import tqdm

from ...imagingextractor import ImagingExtractor
from typing import Tuple, Dict
from roiextractors.extraction_tools import read_numpy_memmap_video, VideoStructure, DtypeType, PathType

from .memmapextractors import MemmapImagingExtractor


Expand Down
25 changes: 23 additions & 2 deletions src/roiextractors/extractors/numpyextractors/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np

from ...extraction_tools import PathType, FloatType, ArrayType
from ...extraction_tools import get_video_shape
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import SegmentationExtractor

Expand Down Expand Up @@ -78,7 +77,7 @@ def __init__(
self._num_rows,
self._num_columns,
self._num_channels,
) = get_video_shape(self._video)
) = self.get_video_shape(self._video)

if len(self._video.shape) == 3:
# check if this converts to np.ndarray
Expand All @@ -91,6 +90,28 @@ def __init__(
else:
self._channel_names = [f"channel_{ch}" for ch in range(self._num_channels)]

@staticmethod
def get_video_shape(video) -> Tuple[int, int, int, int]:
"""Get the shape of a video (num_frames, num_rows, num_columns, num_channels).
Parameters
----------
video: numpy.ndarray
The video to get the shape of.
Returns
-------
video_shape: tuple
The shape of the video (num_frames, num_rows, num_columns, num_channels).
"""
if len(video.shape) == 3:
# 1 channel
num_channels = 1
num_frames, num_rows, num_columns = video.shape
else:
num_frames, num_rows, num_columns, num_channels = video.shape
return num_frames, num_rows, num_columns, num_channels

def get_frames(self, frame_idxs=None, channel: Optional[int] = 0) -> np.ndarray:
if frame_idxs is None:
frame_idxs = [frame for frame in range(self.get_num_frames())]
Expand Down
2 changes: 0 additions & 2 deletions src/roiextractors/extractors/nwbextractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
FloatType,
IntType,
ArrayType,
check_get_frames_args,
check_get_videos_args,
raise_multi_channel_or_depth_not_implemented,
)
from ...imagingextractor import ImagingExtractor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from multiprocessing.sharedctypes import Value
import os
from pathlib import Path
from warnings import warn
from typing import Tuple, Optional

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path
from typing import Optional
from warnings import warn
import os
import numpy as np

from ...extraction_tools import PathType
Expand All @@ -34,7 +33,7 @@ def get_available_channels(cls, folder_path: PathType):
Parameters
----------
file_path : PathType
folder_path : PathType
Path to Suite2p output path.
Returns
Expand Down
6 changes: 3 additions & 3 deletions src/roiextractors/multiimagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from .extraction_tools import ArrayType, NumpyArray, check_get_frames_args
from .extraction_tools import ArrayType, NumpyArray
from .imagingextractor import ImagingExtractor


Expand Down Expand Up @@ -81,7 +81,7 @@ def _check_consistency_between_imaging_extractors(self):
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def _get_times(self):
def _get_times(self) -> np.ndarray:
"""Get all the times from the imaging extractors and combine them into a single array.
Returns
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_video(

return video

def get_image_size(self) -> Tuple:
def get_image_size(self) -> Tuple[int, int]:
return self._imaging_extractors[0].get_image_size()

def get_num_frames(self) -> int:
Expand Down
20 changes: 13 additions & 7 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,13 @@ def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int
"""
return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame)

def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"):
def get_traces(
self,
roi_ids: ArrayType = None,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
name: str = "raw",
):
"""Get the traces of each ROI specified by roi_ids.
Parameters
Expand Down Expand Up @@ -237,7 +243,7 @@ def get_images_dict(self):
"""
return dict(mean=self._image_mean, correlation=self._image_correlation)

def get_image(self, name="correlation"):
def get_image(self, name: str = "correlation") -> ArrayType:
"""Get specific images: mean or correlation.
Parameters
Expand All @@ -253,7 +259,7 @@ def get_image(self, name="correlation"):
raise ValueError(f"could not find {name} image, enter one of {list(self.get_images_dict().keys())}")
return self.get_images_dict().get(name)

def get_sampling_frequency(self):
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
Expand All @@ -266,7 +272,7 @@ def get_sampling_frequency(self):

return self._sampling_frequency

def get_num_rois(self):
def get_num_rois(self) -> int:
"""Get total number of Regions of Interest (ROIs) in the acquired images.
Returns
Expand All @@ -278,7 +284,7 @@ def get_num_rois(self):
if trace is not None and len(trace.shape) > 0:
return trace.shape[1]

def get_channel_names(self):
def get_channel_names(self) -> List[str]:
"""Get names of channels in the pipeline.
Returns
Expand All @@ -288,7 +294,7 @@ def get_channel_names(self):
"""
return self._channel_names

def get_num_channels(self):
def get_num_channels(self) -> int:
"""Get number of channels in the pipeline.
Returns
Expand Down Expand Up @@ -342,7 +348,7 @@ def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, A
Parameters
----------
frame_indices: int or array-like
frames: int or array-like
The frame or frames to be converted to times
Returns
Expand Down
2 changes: 2 additions & 0 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def generate_dummy_imaging_extractor(
sampling frequency of the video, by default 30.
dtype : DtypeType, optional
dtype of the video, by default "uint16".
channel_names : list, optional
list of channel names.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion tests/test_internals/test_extraction_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_invalid_structure_with_repeated_axis(self):
"^Invalid structure: (.*?) each property axis should be unique value between 0 and 3 (inclusive)?"
)
with self.assertRaisesRegex(ValueError, reg_expression):
video_structure = VideoStructure(
VideoStructure(
num_rows=self.num_rows,
num_columns=self.num_columns,
num_channels=self.num_channels,
Expand Down
1 change: 0 additions & 1 deletion tests/test_scanimage_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from numpy.testing import assert_array_equal
from ScanImageTiffReader import ScanImageTiffReader
from roiextractors.extractors.tiffimagingextractors.scanimagetiff_utils import (
_get_scanimage_reader,
extract_extra_metadata,
Expand Down
1 change: 0 additions & 1 deletion tests/test_suite2psegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from numpy.testing import assert_array_equal

from roiextractors import Suite2pSegmentationExtractor
from roiextractors.extraction_tools import _image_mask_extractor
from tests.setup_paths import OPHYS_DATA_PATH


Expand Down

0 comments on commit 0a223a5

Please sign in to comment.