diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 46a08ec..2062d0c 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -23,7 +23,6 @@ import numpy as np from numba import njit, objmode -from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -439,7 +438,6 @@ def match_cells( other: List[Cell], threshold: float = np.inf, pre_match: bool = True, - use_scipy: bool = False, ) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: """ Given two lists of cells. It finds a pairing of cells from `cells` and @@ -482,11 +480,6 @@ def match_cells( This will significantly speed up the matching, if there are pairs of cells on top of each other in each set. - use_scipy : bool, optional. Defaults to False. - Whether to use scipy `linear_sum_assignment` to find the optimal - matching. Otherwise, we use our own implementation. - - See `match_points` for consideration details. Returns ------- @@ -513,15 +506,12 @@ def match_cells( if flip: c1, c2 = c2, c1 - progress = None - if not use_scipy: - # with scipy we don't have callbacks so no updates - progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") - __progress_update.updater = progress.update + progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") + __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches try: - assignment = match_points(c1, c2, threshold, pre_match, use_scipy) + assignment = match_points(c1, c2, threshold, pre_match) finally: __progress_update.updater = None @@ -816,62 +806,11 @@ def _optimize_pairs( return matches -def _optimize_pairs_scipy( - pos1: np.ndarray, - pos2: np.ndarray, -) -> np.ndarray: - """ - Implements `match_points` using scipy's `linear_sum_assignment`. - - The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`, - approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k` - points, we need 16GB. For `N=M=75.7k` points, we need 128GB. - - Parameters - ---------- - pos1 : np.ndarray - 2D array of NxK. - pos2 : np.ndarray - 2D array of MxK. - - The relationship N <= M must be true and K must be the same for both. - - Returns - ------- - matches : np.ndarray - 1D array of length N, where index i in matches corresponds - to index i in `pos1` and its value is the index in pos2 - of its best match. - """ - # we don't check for boundary conditions, just assert because it should be - # checked by caller (match_points) - n_rows = pos1.shape[0] - n_cols = pos2.shape[0] - assert len(pos1.shape) == 2 - assert len(pos2.shape) == 2 - assert pos1.shape[1] == pos2.shape[1] - assert n_rows <= n_cols - - # Mxk -> M1K - pos1 = pos1[:, np.newaxis, :] - # Nxk -> 1NK - pos2 = pos2[np.newaxis, :, :] - # dist is MNK - dist = pos1 - pos2 - # cost is MN - cost = np.sqrt(np.sum(np.square(dist), axis=2)) - # result is sorted by rows - rows, cols = linear_sum_assignment(cost) - # M <= N, so cols and rows is size M - return cols - - def match_points( pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf, pre_match: bool = True, - use_scipy: bool = False, ) -> np.ndarray: """ Given two arrays, each a list of position. For each point in `pos1` it @@ -914,19 +853,6 @@ def match_points( If True, it'll significantly speed up the matching, if there are pairs of points on top of each other across the input lists. - use_scipy : bool, optional. Defaults to False. - Whether to use scipy `linear_sum_assignment` to find the optimal - matching. Otherwise, we use our own implementation. - - Our implementation is very memory efficient. Using scipy, we have e.g. - a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`), - 16GB for 26.8k points, and 128GB for 75.7k points. - - If `pre_match` is used, and it eliminates many zero-distance points, - using numpy is feasable. Otherwise, use our implementation. - - Note: When using scipy, we *don't* take the threshold into account - until after the matching is complete. Returns ------- @@ -951,10 +877,7 @@ def match_points( if not pre_match: # do optimization on full inputs - if use_scipy: - return _optimize_pairs_scipy(pos1, pos2) - else: - return _optimize_pairs(pos1, pos2, threshold) + return _optimize_pairs(pos1, pos2, threshold) # extract the indices of zero-pairs and remaining points unpaired1_indices, unpaired2_indices, paired_indices = ( @@ -975,10 +898,7 @@ def match_points( pos2 = pos2[unpaired2_indices] n_rows = pos1.shape[0] - if use_scipy: - matches = _optimize_pairs_scipy(pos1, pos2) - else: - matches = _optimize_pairs(pos1, pos2, threshold) + matches = _optimize_pairs(pos1, pos2, threshold) # map extracted full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64) diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index 6bb7657..e14d447 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -3,13 +3,16 @@ import numpy as np import pytest +from scipy.optimize import linear_sum_assignment import brainglobe_utils.cells.cells as cell_utils from brainglobe_utils.cells.cells import ( Cell, + analyze_point_matches, from_numpy_pos, match_cells, match_points, + to_numpy_pos, ) from brainglobe_utils.IO.cells import get_cells @@ -36,7 +39,7 @@ def cells_and_other_cells(test_data_registry): "cellfinder/cells-z-1000-1050.xml" ) other_cell_data_path = test_data_registry.fetch( - "cellfinder/other_cells-z-1000-1050.xml" + "cellfinder/other-cells-z-1000-1050.xml" ) cell_data = get_cells(cell_data_path) other_cell_data = get_cells(other_cell_data_path) @@ -49,73 +52,107 @@ def as_cell(x: List[float]): return cells -@pytest.mark.xfail def test_cell_matching_regression(cells_and_other_cells): cells, other_cells = cells_and_other_cells - # TODO implement cell matching regression test here, then remove xfail - assert False + np_cells = to_numpy_pos(cells) + np_other = to_numpy_pos(other_cells) + + # only run matching on unpaired to reduce computation + unpaired1_indices, unpaired2_indices, paired_indices = ( + cell_utils._find_identical_points(np_cells, np_other) + ) + np_cells = np_cells[unpaired1_indices] + np_other = np_other[unpaired2_indices] + + # happens to be true for this dataset + assert len(np_cells) < len(np_other), "must be true to pass to match" + + # get matches + matches = match_points(np_cells, np_other, pre_match=False) + missing_cells, good, missing_other = analyze_point_matches( + np_cells, np_other, matches + ) + good = np.array(good) + assert not len(missing_cells), "all cells must be matched" + + # get cost + a = np_cells[good[:, 0], :] + b = np_other[good[:, 1], :] + cost_our = np.sum(np.sqrt(np.sum(np.square(a - b), axis=1))) + + # get scipy cost + # Mxk -> M1K + pos1 = np_cells[:, np.newaxis, :] + # Nxk -> 1NK + pos2 = np_other[np.newaxis, :, :] + # dist is MNK + dist = pos1 - pos2 + # cost is MN + cost_mat = np.sqrt(np.sum(np.square(dist), axis=2)) + # result is sorted by rows + rows, cols = linear_sum_assignment(cost_mat) + + cost_scipy = cost_mat[rows, cols].sum() + + assert np.isclose(cost_scipy, cost_our) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_equal_size(pre_match, use_scipy): +def test_cell_matches_equal_size(pre_match): a = as_cell([10, 20, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_other(pre_match, use_scipy): +def test_cell_matches_larger_other(pre_match): a = as_cell([1, 12, 100, 80]) b = as_cell([5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [2] assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42, 41]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [3] assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_cells(pre_match, use_scipy): +def test_cell_matches_larger_cells(pre_match): a = as_cell([5, 15, 25, 35, 100]) b = as_cell([1, 12, 100, 80]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert a_ == [2] assert not b_ assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_threshold(pre_match, use_scipy): +def test_cell_matches_threshold(pre_match): a = as_cell([10, 12, 100, 80]) b = as_cell([0, 5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab @@ -125,16 +162,14 @@ def test_cell_matches_threshold(pre_match, use_scipy): b, threshold=math.sqrt(3) * 11, pre_match=pre_match, - use_scipy=use_scipy, ) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): +def test_global_optimum_with_threshold_original_pr(pre_match): cells1 = [ Cell((0, 0, 0), Cell.UNKNOWN), Cell((12, 0, 0), Cell.UNKNOWN), @@ -151,7 +186,6 @@ def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): cells2, threshold=np.inf, pre_match=pre_match, - use_scipy=use_scipy, ) assert not missing_c1 assert not missing_c2 @@ -160,115 +194,92 @@ def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): # with threshold, the previous pairing should not be considered good. # Instead, only (12, 10) is a good match. So while total cost is 24, # we only care about the cost of 2 during the matching algorithm - # BUT, only when not using scipy. With scipy it doesn't account for - # threshold during the matching, so it'll do it after missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=5, pre_match=pre_match, use_scipy=use_scipy + cells1, cells2, threshold=5, pre_match=pre_match ) - # before we added the threshold to match_points, the following applies to - # both scipy and our own. After the fix, this is only True for scipy - if use_scipy: - assert missing_c1 == [0, 1] - assert missing_c2 == [0, 1] - assert not good_matches - else: - # with threshold in match_points, this is true - as wanted - assert missing_c1 == [ - 0, - ] - assert missing_c2 == [ - 1, - ] - assert good_matches == [[1, 0]] - - -@pytest.mark.parametrize("use_scipy", [True, False]) + assert missing_c1 == [ + 0, + ] + assert missing_c2 == [ + 1, + ] + assert good_matches == [[1, 0]] + + @pytest.mark.parametrize("pre_match", [True, False]) -def test_rows_greater_than_cols(pre_match, use_scipy): +def test_rows_greater_than_cols(pre_match): with pytest.raises(ValueError): match_points( np.zeros((5, 3)), np.zeros((4, 3)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_unequal_inputs_shape(pre_match, use_scipy): +def test_unequal_inputs_shape(pre_match): with pytest.raises(ValueError): match_points( np.zeros((5, 3)), np.zeros((5, 2)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_bad_input_shape(pre_match, use_scipy): +def test_bad_input_shape(pre_match): # has to be 2 dims with pytest.raises(ValueError): - match_points( - np.zeros(5), np.zeros(5), pre_match=pre_match, use_scipy=use_scipy - ) + match_points(np.zeros(5), np.zeros(5), pre_match=pre_match) with pytest.raises(ValueError): match_points( np.zeros((5, 4, 6)), np.zeros((5, 4, 6)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_progress_already_running(pre_match, use_scipy): +def test_progress_already_running(pre_match): a = as_cell([10, 12]) b = as_cell([10, 12]) cell_utils.__progress_update.updater = 1 try: with pytest.raises(TypeError): - match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + match_cells(a, b, pre_match=pre_match) finally: cell_utils.__progress_update.updater = None -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_distance_too_large(pre_match, use_scipy): +def test_distance_too_large(pre_match): a = np.array([[1, 2, 3]]).T b = np.array([[1, 2, np.inf]]).T with pytest.raises(ValueError): - match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + match_points(a, b, pre_match=pre_match) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_identical_points(pre_match, use_scipy): +def test_contains_identical_points(pre_match): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[5, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [1, 0, 2]) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_only_identical_points(pre_match, use_scipy): +def test_contains_only_identical_points(pre_match): a = np.array([[1, 2, 3]]).T b = np.array([[2, 3, 5, 1]]).T - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [3, 0, 1]) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_no_identical_points(pre_match, use_scipy): +def test_contains_no_identical_points(pre_match): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[6, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [1, 0, 2])