Skip to content

Commit

Permalink
added tests for assert tools
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Oct 1, 2024
1 parent 898bd64 commit af443be
Show file tree
Hide file tree
Showing 2 changed files with 465 additions and 92 deletions.
236 changes: 144 additions & 92 deletions src/roiextractors/tools/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from roiextractors import NumpyImagingExtractor, NumpySegmentationExtractor

from roiextractors.tools.typing import DtypeType, NoneType, FloatType, IntType
from roiextractors.tools.typing import DtypeType, ArrayType


def generate_mock_video(size: Tuple[int], dtype: DtypeType = "uint16", seed: int = 0):
Expand Down Expand Up @@ -86,6 +86,79 @@ def generate_mock_imaging_extractor(
return imaging_extractor


def assert_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor):
"""Assert that two ImagingExtractor objects are equal by comparing their attributes and data.
Parameters
----------
imaging_extractor1 : ImagingExtractor
The first ImagingExtractor object to compare.
imaging_extractor2 : ImagingExtractor
The second ImagingExtractor object to compare.
Raises
------
AssertionError
If any of the following attributes or data do not match between the two ImagingExtractor objects:
- Image size
- Number of frames
- Sampling frequency
- Data type (dtype)
- Video data
- Time points (_times)
"""
assert (
imaging_extractor1.get_image_size() == imaging_extractor2.get_image_size()
), "ImagingExtractors are not equal: image_sizes do not match."
assert (
imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames()
), "ImagingExtractors are not equal: num_frames do not match."
assert np.isclose(
imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency()
), "ImagingExtractors are not equal: sampling_frequencies do not match."
assert (
imaging_extractor1.get_dtype() == imaging_extractor2.get_dtype()
), "ImagingExtractors are not equal: dtypes do not match."
assert_array_equal(
imaging_extractor1.get_video(),
imaging_extractor2.get_video(),
err_msg="ImagingExtractors are not equal: videos do not match.",
)
assert_array_equal(
imaging_extractor1._times,
imaging_extractor2._times,
err_msg="ImagingExtractors are not equal: _times do not match.",
)


def imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor) -> bool:
"""Return True if two ImagingExtractors are equal, False otherwise.
Parameters
----------
imaging_extractor1 : ImagingExtractor
The first ImagingExtractor object to compare.
imaging_extractor2 : ImagingExtractor
The second ImagingExtractor object to compare.
Returns
-------
bool
True if all of the following fields match between the two ImagingExtractor objects:
- Image size
- Number of frames
- Sampling frequency
- Data type (dtype)
- Video data
- Time points (_times)
"""
try:
assert_imaging_equal(imaging_extractor1, imaging_extractor2)
return True
except AssertionError:
return False


def generate_mock_segmentation_extractor(
num_rois: int = 10,
num_frames: int = 30,
Expand All @@ -97,6 +170,12 @@ def generate_mock_segmentation_extractor(
roi_response_names: List[str] = ["raw", "dff", "deconvolved", "denoised"],
background_response_names: List[str] = ["background"],
rejected_roi_ids: Optional[list] = None,
roi_locations: Optional[ArrayType] = None,
image_masks: Optional[ArrayType] = None,
roi_response_traces: Optional[dict] = None,
background_image_masks: Optional[ArrayType] = None,
background_response_traces: Optional[dict] = None,
summary_images: Optional[dict] = None,
seed: int = 0,
) -> NumpySegmentationExtractor:
"""Generate a mock segmentation extractor for testing.
Expand Down Expand Up @@ -126,6 +205,8 @@ def generate_mock_segmentation_extractor(
names of background response traces, by default ["background"].
rejected_roi_ids: Optional[list], optional
A list of rejected rois, None by default.
roi_locations : Optional[ArrayType], optional
A 2D array of shape (2, num_rois) containing the locations of the rois, None by default.
seed : int, default 0
seed for the random number generator, by default 0.
Expand All @@ -143,37 +224,78 @@ def generate_mock_segmentation_extractor(
rng = np.random.default_rng(seed)

# Create dummy image masks
image_masks = rng.random((num_rows, num_columns, num_rois))
background_image_masks = rng.random((num_rows, num_columns, num_background_components))
if image_masks is None:
image_masks = rng.random((num_rows, num_columns, num_rois))
else:
assert image_masks.shape == (
num_rows,
num_columns,
num_rois,
), f"image_masks should have shape (num_rows, num_columns, num_rois) but got {image_masks.shape}."
if background_image_masks is None:
background_image_masks = rng.random((num_rows, num_columns, num_background_components))
else:
assert background_image_masks.shape == (
num_rows,
num_columns,
num_background_components,
), f"background_image_masks should have shape (num_rows, num_columns, num_background_components) but got {background_image_masks.shape}."

# Create signals
roi_response_traces = {name: rng.random((num_frames, num_rois)) for name in roi_response_names}
background_response_traces = {
name: rng.random((num_frames, num_background_components)) for name in background_response_names
}
if roi_response_traces is None:
roi_response_traces = {name: rng.random((num_frames, num_rois)) for name in roi_response_names}
else:
for name, trace in roi_response_traces.items():
assert trace.shape == (
num_frames,
num_rois,
), f"roi_response_traces[{name}] should have shape (num_frames, num_rois) but got {trace.shape}."
if background_response_traces is None:
background_response_traces = {
name: rng.random((num_frames, num_background_components)) for name in background_response_names
}
else:
for name, trace in background_response_traces.items():
assert trace.shape == (
num_frames,
num_background_components,
), f"background_response_traces[{name}] should have shape (num_frames, num_background_components) but got {trace.shape}."

# Summary images
summary_images = {name: rng.random((num_rows, num_columns)) for name in summary_image_names}
if summary_images is None:
summary_images = {name: rng.random((num_rows, num_columns)) for name in summary_image_names}
else:
for name, image in summary_images.items():
assert image.shape == (
num_rows,
num_columns,
), f"summary_images[{name}] should have shape (num_rows, num_columns) but got {image.shape}."

# Rois
roi_ids = [id for id in range(num_rois)]
roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois)
roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois)
roi_locations = np.vstack((roi_locations_rows, roi_locations_columns))
if roi_locations is None:
roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois)
roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois)
roi_locations = np.vstack((roi_locations_rows, roi_locations_columns))
else:
assert roi_locations.shape == (
2,
num_rois,
), f"roi_locations should have shape (2, num_rois) but got {roi_locations.shape}."
background_ids = [i for i in range(num_background_components)]

rejected_roi_ids = rejected_roi_ids if rejected_roi_ids else None

accepeted_list = roi_ids
accepted_roi_ids = roi_ids
if rejected_roi_ids is not None:
accepeted_list = list(set(accepeted_list).difference(rejected_roi_ids))
accepted_roi_ids = list(set(accepted_roi_ids).difference(rejected_roi_ids))

dummy_segmentation_extractor = NumpySegmentationExtractor(
image_masks=image_masks,
roi_response_traces=roi_response_traces,
sampling_frequency=sampling_frequency,
roi_ids=roi_ids,
accepted_roi_ids=accepeted_list,
accepted_roi_ids=accepted_roi_ids,
rejected_roi_ids=rejected_roi_ids,
roi_locations=roi_locations,
summary_images=summary_images,
Expand Down Expand Up @@ -261,11 +383,14 @@ def assert_segmentation_equal(
segmentation_extractor2.get_roi_image_masks(),
err_msg="SegmentationExtractors are not equal: roi_image_masks do not match.",
)
assert_array_equal(
segmentation_extractor1.get_roi_pixel_masks(),
segmentation_extractor2.get_roi_pixel_masks(),
err_msg="SegmentationExtractors are not equal: roi_pixel_masks do not match.",
)
for pixel_mask1, pixel_mask2 in zip(
segmentation_extractor1.get_roi_pixel_masks(), segmentation_extractor2.get_roi_pixel_masks()
):
assert_array_equal(
pixel_mask1,
pixel_mask2,
err_msg="SegmentationExtractors are not equal: roi_pixel_masks do not match.",
)
roi_response_traces1 = segmentation_extractor1.get_roi_response_traces()
roi_response_traces2 = segmentation_extractor2.get_roi_response_traces()
for name, trace1 in roi_response_traces1.items():
Expand Down Expand Up @@ -369,76 +494,3 @@ def segmentation_equal(
return True
except AssertionError:
return False


def assert_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor):
"""Assert that two ImagingExtractor objects are equal by comparing their attributes and data.
Parameters
----------
imaging_extractor1 : ImagingExtractor
The first ImagingExtractor object to compare.
imaging_extractor2 : ImagingExtractor
The second ImagingExtractor object to compare.
Raises
------
AssertionError
If any of the following attributes or data do not match between the two ImagingExtractor objects:
- Image size
- Number of frames
- Sampling frequency
- Data type (dtype)
- Video data
- Time points (_times)
"""
assert (
imaging_extractor1.get_image_size() == imaging_extractor2.get_image_size()
), "ImagingExtractors are not equal: image_sizes do not match."
assert (
imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames()
), "ImagingExtractors are not equal: num_frames do not match."
assert np.isclose(
imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency()
), "ImagingExtractors are not equal: sampling_frequencies do not match."
assert (
imaging_extractor1.get_dtype() == imaging_extractor2.get_dtype()
), "ImagingExtractors are not equal: dtypes do not match."
assert_array_equal(
imaging_extractor1.get_video(),
imaging_extractor2.get_video(),
err_msg="ImagingExtractors are not equal: videos do not match.",
)
assert_array_equal(
imaging_extractor1._times,
imaging_extractor2._times,
err_msg="ImagingExtractors are not equal: _times do not match.",
)


def imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor) -> bool:
"""Return True if two ImagingExtractors are equal, False otherwise.
Parameters
----------
imaging_extractor1 : ImagingExtractor
The first ImagingExtractor object to compare.
imaging_extractor2 : ImagingExtractor
The second ImagingExtractor object to compare.
Returns
-------
bool
True if all of the following fields match between the two ImagingExtractor objects:
- Image size
- Number of frames
- Sampling frequency
- Data type (dtype)
- Video data
- Time points (_times)
"""
try:
assert_imaging_equal(imaging_extractor1, imaging_extractor2)
return True
except AssertionError:
return False
Loading

0 comments on commit af443be

Please sign in to comment.