Skip to content

Commit

Permalink
reset testing file
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Oct 4, 2023
1 parent 3aa2ac2 commit 58dfd93
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,19 @@ def generate_dummy_segmentation_extractor(


def _assert_iterable_shape(iterable, shape):
"""Assert that the iterable has the given shape. If the iterable is a numpy array, the shape is checked directly."""
"""Assert that the iterable has the given shape.
If the iterable is a numpy array, the shape is checked directly.
"""
ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable)
for ar_shape, given_shape in zip(ar.shape, shape):
if isinstance(given_shape, int):
assert ar_shape == given_shape, f"Expected {given_shape}, received {ar_shape}!"


def _assert_iterable_shape_max(iterable, shape_max):
"""Assert that the iterable has a shape less than or equal to the given maximum shape."""
"""Assert that the iterable has a shape less than or equal to the given
maximum shape."""
ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable)
for ar_shape, given_shape in zip(ar.shape, shape_max):
if isinstance(given_shape, int):
Expand All @@ -216,7 +220,8 @@ def _assert_iterable_element_dtypes(iterable, dtypes):


def _assert_iterable_complete(iterable, dtypes=None, element_dtypes=None, shape=None, shape_max=None):
"""Assert that the iterable is complete, i.e. it is not None and has the given dtypes, element_dtypes, shape and shape_max."""
"""Assert that the iterable is complete, i.e. it is not None and has the
given dtypes, element_dtypes, shape and shape_max."""
assert isinstance(iterable, dtypes), f"iterable {type(iterable)} is none of the types {dtypes}"
if not isinstance(iterable, NoneType):
if shape is not None:
Expand Down Expand Up @@ -270,7 +275,8 @@ def check_segmentations_images(
segmentation_extractor1: SegmentationExtractor,
segmentation_extractor2: SegmentationExtractor,
):
"""Check that the segmentation images are equal for the given segmentation extractors."""
"""Check that the segmentation images are equal for the given segmentation
extractors."""
images_in_extractor1 = segmentation_extractor1.get_images_dict()
images_in_extractor2 = segmentation_extractor2.get_images_dict()

Expand All @@ -287,7 +293,8 @@ def check_segmentations_images(


def check_segmentation_return_types(seg: SegmentationExtractor):
"""Check that the return types of the segmentation extractor are correct."""
"""Check that the return types of the segmentation extractor are
correct."""
assert isinstance(seg.get_num_rois(), int)
assert isinstance(seg.get_num_frames(), int)
assert isinstance(seg.get_num_channels(), int)
Expand Down Expand Up @@ -378,9 +385,11 @@ def check_imaging_equal(


def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor):
"""Check whether an ImagingExtractor get_frames function behaves as expected.
"""Check whether an ImagingExtractor get_frames function behaves as
expected.
We aim for the function to behave as numpy slicing and indexing as much as possible.
We aim for the function to behave as numpy slicing and indexing as
much as possible.
"""
image_size = imaging_extractor.get_image_size()

Expand Down

0 comments on commit 58dfd93

Please sign in to comment.