Skip to content

Commit

Permalink
added docstrings to testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Sep 5, 2023
1 parent c3529d6 commit 295f0c1
Showing 1 changed file with 58 additions and 27 deletions.
85 changes: 58 additions & 27 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@


def generate_dummy_video(size: Tuple[int], dtype: DtypeType = "uint16"):
"""Generate a dummy video of a given size and dtype.
Parameters
----------
size : Tuple[int]
Size of the video to generate.
dtype : DtypeType, optional
Dtype of the video to generate, by default "uint16".
Returns
-------
video : np.ndarray
A dummy video of the given size and dtype.
"""
dtype = np.dtype(dtype)
number_of_bytes = dtype.itemsize

Expand All @@ -39,6 +53,30 @@ def generate_dummy_imaging_extractor(
sampling_frequency: float = 30,
dtype: DtypeType = "uint16",
):
"""Generate a dummy imaging extractor for testing.
The imaging extractor is built by feeding random data into the `NumpyImagingExtractor`.
Parameters
----------
num_frames : int, optional
number of frames in the video, by default 30.
num_rows : int, optional
number of rows in the video, by default 10.
num_columns : int, optional
number of columns in the video, by default 10.
num_channels : int, optional
number of channels in the video, by default 1.
sampling_frequency : float, optional
sampling frequency of the video, by default 30.
dtype : DtypeType, optional
dtype of the video, by default "uint16".
Returns
-------
ImagingExtractor
An imaging extractor with random data fed into `NumpyImagingExtractor`.
"""
channel_names = [f"channel_num_{num}" for num in range(num_channels)]

size = (num_frames, num_rows, num_columns, num_channels)
Expand All @@ -64,13 +102,10 @@ def generate_dummy_segmentation_extractor(
has_neuropil_signal: bool = True,
rejected_list: Optional[list] = None,
) -> SegmentationExtractor:
"""
A dummy segmentation extractor for testing. The segmentation extractor is built by feeding random data into the
`NumpySegmentationExtractor`.
"""Generate a dummy segmentation extractor for testing.
Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not
contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal
is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc.
The segmentation extractor is built by feeding random data into the
`NumpySegmentationExtractor`.
Parameters
----------
Expand Down Expand Up @@ -101,8 +136,13 @@ def generate_dummy_segmentation_extractor(
-------
SegmentationExtractor
A segmentation extractor with random data fed into `NumpySegmentationExtractor`
"""
Notes
-----
Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not
contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal
is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc.
"""
# Create dummy image masks
image_masks = np.random.rand(num_rows, num_columns, num_rois)
movie_dims = (num_rows, num_columns)
Expand Down Expand Up @@ -150,20 +190,23 @@ def generate_dummy_segmentation_extractor(


def _assert_iterable_shape(iterable, shape):
"""Assert that the iterable has the given shape. If the iterable is a numpy array, the shape is checked directly."""
ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable)
for ar_shape, given_shape in zip(ar.shape, shape):
if isinstance(given_shape, int):
assert ar_shape == given_shape, f"Expected {given_shape}, received {ar_shape}!"


def _assert_iterable_shape_max(iterable, shape_max):
"""Assert that the iterable has a shape less than or equal to the given maximum shape."""
ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable)
for ar_shape, given_shape in zip(ar.shape, shape_max):
if isinstance(given_shape, int):
assert ar_shape <= given_shape


def _assert_iterable_element_dtypes(iterable, dtypes):
"""Assert that the iterable has elements of the given dtypes."""
if isinstance(iterable, Iterable) and not isinstance(iterable, str):
for iter in iterable:
_assert_iterable_element_dtypes(iter, dtypes)
Expand All @@ -172,6 +215,7 @@ def _assert_iterable_element_dtypes(iterable, dtypes):


def _assert_iterable_complete(iterable, dtypes=None, element_dtypes=None, shape=None, shape_max=None):
"""Assert that the iterable is complete, i.e. it is not None and has the given dtypes, element_dtypes, shape and shape_max."""
assert isinstance(iterable, dtypes), f"iterable {type(iterable)} is none of the types {dtypes}"
if not isinstance(iterable, NoneType):
if shape is not None:
Expand All @@ -185,6 +229,7 @@ def _assert_iterable_complete(iterable, dtypes=None, element_dtypes=None, shape=
def check_segmentations_equal(
segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor
):
"""Check that two segmentation extractors have equal fields."""
check_segmentation_return_types(segmentation_extractor1)
check_segmentation_return_types(segmentation_extractor2)
# assert equality:
Expand Down Expand Up @@ -224,9 +269,7 @@ def check_segmentations_images(
segmentation_extractor1: SegmentationExtractor,
segmentation_extractor2: SegmentationExtractor,
):
"""
Check that the segmentation images are equal for the given segmentation extractors.
"""
"""Check that the segmentation images are equal for the given segmentation extractors."""
images_in_extractor1 = segmentation_extractor1.get_images_dict()
images_in_extractor2 = segmentation_extractor2.get_images_dict()

Expand All @@ -243,11 +286,7 @@ def check_segmentations_images(


def check_segmentation_return_types(seg: SegmentationExtractor):
"""
Parameters
----------
seg:SegmentationExtractor
"""
"""Check that the return types of the segmentation extractor are correct."""
assert isinstance(seg.get_num_rois(), int)
assert isinstance(seg.get_num_frames(), int)
assert isinstance(seg.get_num_channels(), int)
Expand Down Expand Up @@ -317,6 +356,7 @@ def check_segmentation_return_types(seg: SegmentationExtractor):
def check_imaging_equal(
imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor, exclude_channel_comparison: bool = False
):
"""Check that two imaging extractors have equal fields."""
# assert equality:
assert imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames()
assert imaging_extractor1.get_num_channels() == imaging_extractor2.get_num_channels()
Expand All @@ -337,15 +377,10 @@ def check_imaging_equal(


def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor):
"""Utiliy to check whether an ImagingExtractor get_frames function behaves as expected. We aim for the function to
behave as numpy slicing and indexing as much as possible
"""Check whether an ImagingExtractor get_frames function behaves as expected.
Parameters
----------
imaging_extractor : ImagingExtractor
An image extractor
We aim for the function to behave as numpy slicing and indexing as much as possible.
"""

image_size = imaging_extractor.get_image_size()

frame_idxs = 0
Expand All @@ -369,11 +404,7 @@ def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor):


def check_imaging_return_types(img_ex: ImagingExtractor):
"""
Parameters
----------
img_ex:ImagingExtractor
"""
"""Check that the return types of the imaging extractor are correct."""
assert isinstance(img_ex.get_num_frames(), inttype)
assert isinstance(img_ex.get_num_channels(), inttype)
assert isinstance(img_ex.get_sampling_frequency(), floattype)
Expand Down

0 comments on commit 295f0c1

Please sign in to comment.