Skip to content

Commit

Permalink
Merge pull request #190 from catalystneuro/testing_images_return_types
Browse files Browse the repository at this point in the history
[refactor] Testing SegmentationExtractor images
  • Loading branch information
CodyCBakerPhD authored Aug 12, 2022
2 parents e0aecbd + 7207769 commit 8a7729e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat
### Improvements
* Add `frame_to_time` to `SegmentationExtractor`, `get_roi_ids` is now a class method. [PR #187](https://github.com/catalystneuro/roiextractors/pull/187)
* Add `set_times` to `SegmentationExtractor`. [PR #188](https://github.com/catalystneuro/roiextractors/pull/188)
* Updated the test for segmentation images to check all images for the given segmentation extractors. [PR #190](https://github.com/catalystneuro/roiextractors/pull/190)

### Fixes

Expand Down
40 changes: 32 additions & 8 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def check_segmentations_equal(
) == set(
segmentation_extractor2.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten()
)
assert_array_equal(segmentation_extractor1.get_image(), segmentation_extractor2.get_image())

check_segmentations_images(segmentation_extractor1, segmentation_extractor2)

assert_array_equal(segmentation_extractor1.get_accepted_list(), segmentation_extractor2.get_accepted_list())
assert_array_equal(segmentation_extractor1.get_rejected_list(), segmentation_extractor2.get_rejected_list())
assert_array_equal(segmentation_extractor1.get_roi_locations(), segmentation_extractor2.get_roi_locations())
Expand All @@ -216,6 +218,28 @@ def check_segmentations_equal(
)


def check_segmentations_images(
segmentation_extractor1: SegmentationExtractor,
segmentation_extractor2: SegmentationExtractor,
):
"""
Check that the segmentation images are equal for the given segmentation extractors.
"""
images_in_extractor1 = segmentation_extractor1.get_images_dict()
images_in_extractor2 = segmentation_extractor2.get_images_dict()

assert len(images_in_extractor1) == len(images_in_extractor2)

image_names_are_equal = all(image_name in images_in_extractor1.keys() for image_name in images_in_extractor2.keys())
assert image_names_are_equal, "The names of segmentation images in the segmentation extractors are not the same."

for image_name in images_in_extractor1.keys():
assert_array_equal(
images_in_extractor1[image_name],
images_in_extractor2[image_name],
), f"The segmentation images for {image_name} are not equal."


def check_segmentation_return_types(seg: SegmentationExtractor):
"""
Parameters
Expand Down Expand Up @@ -252,12 +276,13 @@ def check_segmentation_return_types(seg: SegmentationExtractor):
element_dtypes=floattype,
shape_max=(np.prod(seg.get_image_size()), 3),
)
_assert_iterable_complete(
seg.get_image(),
dtypes=(np.ndarray, NoneType),
element_dtypes=floattype,
shape_max=(*seg.get_image_size(),),
)
for image_name in seg.get_images_dict():
_assert_iterable_complete(
seg.get_image(image_name),
dtypes=(np.ndarray, NoneType),
element_dtypes=floattype,
shape_max=(*seg.get_image_size(),),
)
_assert_iterable_complete(
seg.get_accepted_list(),
dtypes=(list, NoneType),
Expand Down Expand Up @@ -285,7 +310,6 @@ def check_segmentation_return_types(seg: SegmentationExtractor):
assert isinstance(seg.get_traces_dict(), dict)
assert isinstance(seg.get_images_dict(), dict)
assert {"raw", "dff", "neuropil", "deconvolved"} == set(seg.get_traces_dict().keys())
assert {"mean", "correlation"} == set(seg.get_images_dict().keys())


def check_imaging_equal(
Expand Down

0 comments on commit 8a7729e

Please sign in to comment.