diff --git a/CHANGELOG.md b/CHANGELOG.md index c3bdebbf..9e0a9ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat ### Improvements * Add `frame_to_time` to `SegmentationExtractor`, `get_roi_ids` is now a class method. [PR #187](https://github.com/catalystneuro/roiextractors/pull/187) * Add `set_times` to `SegmentationExtractor`. [PR #188](https://github.com/catalystneuro/roiextractors/pull/188) +* Updated the test for segmentation images to check all images for the given segmentation extractors. [PR #190](https://github.com/catalystneuro/roiextractors/pull/190) ### Fixes diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index 6780ccb6..cf1b9470 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -203,7 +203,9 @@ def check_segmentations_equal( ) == set( segmentation_extractor2.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten() ) - assert_array_equal(segmentation_extractor1.get_image(), segmentation_extractor2.get_image()) + + check_segmentations_images(segmentation_extractor1, segmentation_extractor2) + assert_array_equal(segmentation_extractor1.get_accepted_list(), segmentation_extractor2.get_accepted_list()) assert_array_equal(segmentation_extractor1.get_rejected_list(), segmentation_extractor2.get_rejected_list()) assert_array_equal(segmentation_extractor1.get_roi_locations(), segmentation_extractor2.get_roi_locations()) @@ -216,6 +218,28 @@ def check_segmentations_equal( ) +def check_segmentations_images( + segmentation_extractor1: SegmentationExtractor, + segmentation_extractor2: SegmentationExtractor, +): + """ + Check that the segmentation images are equal for the given segmentation extractors. + """ + images_in_extractor1 = segmentation_extractor1.get_images_dict() + images_in_extractor2 = segmentation_extractor2.get_images_dict() + + assert len(images_in_extractor1) == len(images_in_extractor2) + + image_names_are_equal = all(image_name in images_in_extractor1.keys() for image_name in images_in_extractor2.keys()) + assert image_names_are_equal, "The names of segmentation images in the segmentation extractors are not the same." + + for image_name in images_in_extractor1.keys(): + assert_array_equal( + images_in_extractor1[image_name], + images_in_extractor2[image_name], + ), f"The segmentation images for {image_name} are not equal." + + def check_segmentation_return_types(seg: SegmentationExtractor): """ Parameters @@ -252,12 +276,13 @@ def check_segmentation_return_types(seg: SegmentationExtractor): element_dtypes=floattype, shape_max=(np.prod(seg.get_image_size()), 3), ) - _assert_iterable_complete( - seg.get_image(), - dtypes=(np.ndarray, NoneType), - element_dtypes=floattype, - shape_max=(*seg.get_image_size(),), - ) + for image_name in seg.get_images_dict(): + _assert_iterable_complete( + seg.get_image(image_name), + dtypes=(np.ndarray, NoneType), + element_dtypes=floattype, + shape_max=(*seg.get_image_size(),), + ) _assert_iterable_complete( seg.get_accepted_list(), dtypes=(list, NoneType), @@ -285,7 +310,6 @@ def check_segmentation_return_types(seg: SegmentationExtractor): assert isinstance(seg.get_traces_dict(), dict) assert isinstance(seg.get_images_dict(), dict) assert {"raw", "dff", "neuropil", "deconvolved"} == set(seg.get_traces_dict().keys()) - assert {"mean", "correlation"} == set(seg.get_images_dict().keys()) def check_imaging_equal(