diff --git a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py index 2cd0560f..77242b75 100644 --- a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py +++ b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py @@ -14,7 +14,7 @@ from ...tools.typing import PathType from ...multisegmentationextractor import MultiSegmentationExtractor -from ...segmentationextractor import SegmentationExtractor, _image_mask_extractor +from ...segmentationextractor import SegmentationExtractor, convert_pixel_masks_to_image_masks class Suite2pSegmentationExtractor(SegmentationExtractor): @@ -180,7 +180,7 @@ def __init__( image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2" self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None roi_indices = list(range(self.get_num_rois())) - self._image_masks = _image_mask_extractor( + self._image_masks = convert_pixel_masks_to_image_masks( self.get_roi_pixel_masks(), roi_indices, self.get_image_size(), diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 998d7cde..e74fc369 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -429,13 +429,13 @@ def convert_image_masks_to_pixel_masks(image_masks: np.ndarray) -> list: Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of the pixel. """ - pixel_mask_list = [] + pixel_masks = [] for i in range(image_masks.shape[2]): image_mask = image_masks[:, :, i] locs = np.where(image_mask > 0) pix_values = image_mask[image_mask > 0] - pixel_mask_list.append(np.vstack((locs[0], locs[1], pix_values)).T) - return pixel_mask_list + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + return pixel_masks def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shape: tuple) -> np.ndarray: @@ -457,16 +457,16 @@ def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shap image_masks = np.zeros(shape=shape) for i, pixel_mask in enumerate(pixel_masks): for row, column, wt in pixel_mask: - image_masks[row, column, i] = wt + image_masks[int(row), int(column), i] = wt return image_masks def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray: """Calculate the default ROI locations from given image masks. - This function takes a 3D numpy array of image masks and computes the median - coordinates of the maximum values in each 2D mask. The result is a 2D numpy - array where each column represents the (x, y) coordinates of the ROI for + This function takes a 3D numpy array of image masks and computes the coordinates (row, column) + of the maximum values in each 2D mask. In the case of a tie, the integer median of the coordinates is used. + The result is a 2D numpy array where each column represents the (row, column) coordinates of the ROI for each mask. Parameters @@ -478,12 +478,12 @@ def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.nd ------- np.ndarray A 2D numpy array of shape (2, num_rois) where each column contains the - (x, y) coordinates of the ROI for each mask. + (row, column) coordinates of the ROI for each mask. """ num_rois = image_masks.shape[2] roi_locations = np.zeros([2, num_rois], dtype="int") for i in range(num_rois): image_mask = image_masks[:, :, i] max_value_indices = np.where(image_mask == np.amax(image_mask)) - roi_locations[:, i] = np.array([np.median(max_value_indices[0]), np.median(max_value_indices[1])]).T + roi_locations[:, i] = np.array([int(np.median(max_value_indices[0])), int(np.median(max_value_indices[1]))]).T return roi_locations diff --git a/tests/test_minimal/test_segmentation_extractor_functions.py b/tests/test_minimal/test_segmentation_extractor_functions.py new file mode 100644 index 00000000..30403236 --- /dev/null +++ b/tests/test_minimal/test_segmentation_extractor_functions.py @@ -0,0 +1,132 @@ +import pytest +import numpy as np + +from roiextractors.segmentationextractor import ( + convert_image_masks_to_pixel_masks, + convert_pixel_masks_to_image_masks, + get_default_roi_locations_from_image_masks, +) + + +@pytest.fixture(scope="module") +def rng(): + seed = 1728084845 # int(datetime.datetime.now().timestamp()) at the time of writing + return np.random.default_rng(seed=seed) + + +@pytest.fixture(scope="function") +def image_masks(rng): + return rng.random((3, 3, 3)) + + +def test_convert_image_masks_to_pixel_masks(image_masks): + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + for i, pixel_mask in enumerate(pixel_masks): + assert pixel_mask.shape == (image_masks.shape[0] * image_masks.shape[1], 3) + for row, column, wt in pixel_mask: + assert row == int(row) + assert column == int(column) + assert image_masks[int(row), int(column), i] == wt + + +def test_convert_image_masks_to_pixel_masks_with_zeros(image_masks): + image_masks[0, 0, 0] = 0 + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + assert pixel_masks[0].shape == (image_masks.shape[0] * image_masks.shape[1] - 1, 3) + for i, pixel_mask in enumerate(pixel_masks): + for row, column, wt in pixel_mask: + assert row == int(row) + assert column == int(column) + assert image_masks[int(row), int(column), i] == wt + + +def test_convert_image_masks_to_pixel_masks_all_zeros(image_masks): + image_masks = np.zeros(image_masks.shape) + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + for pixel_mask in pixel_masks: + assert pixel_mask.shape == (0, 3) + + +def test_convert_pixel_masks_to_image_masks(image_masks): + pixel_masks = [] + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + + image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + indices = np.ndindex(image_mask.shape) + for row, column in indices: + pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column) + assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2] + + +def test_convert_pixel_masks_to_image_masks_with_zeros(image_masks): + pixel_masks = [] + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + + pixel_masks[0] = pixel_masks[0][1:] + image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + indices = np.ndindex(image_mask.shape) + for row, column in indices: + pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column) + if i == 0 and row == 0 and column == 0: + assert np.all(np.logical_not(pixel_mask_mask)) + else: + assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2] + + +def test_convert_pixel_masks_to_image_masks_all_zeros(image_masks): + pixel_masks = [np.zeros((0, 0)) for _ in range(image_masks.shape[2])] + output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + assert output_image_masks.shape == image_masks.shape + for image_mask in output_image_masks: + assert np.all(image_mask == 0) + + +def test_convert_masks_roundtrip(image_masks): + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + np.testing.assert_array_equal(image_masks, output_image_masks) + + +def test_get_default_roi_locations_from_image_masks(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + +def test_get_default_roi_locations_from_image_masks_tie1(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[0, 1, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + +def test_get_default_roi_locations_from_image_masks_tie2(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[0, 1, 0] = 1 + image_masks[1, 1, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 1], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations)