diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 606af67..64be3ab 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -23,6 +23,7 @@ import numpy as np from numba import njit, objmode +from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -438,6 +439,7 @@ def match_cells( other: List[Cell], threshold: float = np.inf, pre_match: bool = True, + use_scipy: bool = False, ) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: """ Given two lists of cells. It finds a pairing of cells from `cells` and @@ -477,8 +479,14 @@ def match_cells( If True, we will (interenally) first efficiently find all the pairs of `cells` and `others` which are each at the same position in space. Then we run the optimization to find the best matching on the remaining. + This will significantly speed up the matching, if there are pairs of cells on top of each other in each set. + use_scipy : bool, optional. Defaults to False. + Whether to use scipy `linear_sum_assignment` to find the optimal + matching. Otherwise, we use our own implementation. + + See `match_points` for consideration details. Returns ------- @@ -509,7 +517,7 @@ def match_cells( __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches try: - assignment = match_points(c1, c2, threshold, pre_match) + assignment = match_points(c1, c2, threshold, pre_match, use_scipy) progress.close() finally: __progress_update.updater = None @@ -802,11 +810,62 @@ def _optimize_pairs( return matches +def _optimize_pairs_scipy( + pos1: np.ndarray, + pos2: np.ndarray, +) -> np.ndarray: + """ + Implements `match_points` using scipy's `linear_sum_assignment`. + + The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`, + approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k` + points, we need 16GB. For `N=M=75.7k` points, we need 128GB. + + Parameters + ---------- + pos1 : np.ndarray + 2D array of NxK. + pos2 : np.ndarray + 2D array of MxK. + + The relationship N <= M must be true and K must be the same for both. + + Returns + ------- + matches : np.ndarray + 1D array of length N, where index i in matches corresponds + to index i in `pos1` and its value is the index in pos2 + of its best match. + """ + # we don't check for boundary conditions, just assert because it should be + # checked by caller (match_points) + n_rows = pos1.shape[0] + n_cols = pos2.shape[0] + assert len(pos1.shape) == 2 + assert len(pos2.shape) == 2 + assert pos1.shape[1] == pos2.shape[1] + assert n_rows <= n_cols + + # Mxk -> M1K + pos1 = pos1[:, np.newaxis, :] + # Nxk -> 1NK + pos2 = pos2[np.newaxis, :, :] + # dist is MNK + dist = pos1 - pos2 + # cost is MN + cost = np.sqrt(np.sum(np.square(dist), axis=2)) + # result is sorted by rows + rows, cols = linear_sum_assignment(cost) + # M <= N, so cols and rows is size M + return cols + + def match_points( pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf, pre_match: bool = True, + use_scipy: bool = False, ) -> np.ndarray: """ Given two arrays, each a list of position. For each point in `pos1` it @@ -849,6 +908,19 @@ def match_points( If True, it'll significantly speed up the matching, if there are pairs of points on top of each other across the input lists. + use_scipy : bool, optional. Defaults to False. + Whether to use scipy `linear_sum_assignment` to find the optimal + matching. Otherwise, we use our own implementation. + + Our implementation is very memory efficient. Using scipy, we have e.g. + a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`), + 16GB for 26.8k points, and 128GB for 75.7k points. + + If `pre_match` is used, and it eliminates many zero-distance points, + using numpy is feasable. Otherwise, use our implementation. + + Note: When using scipy, we *don't* take the threshold into account + until after the matching is complete. Returns ------- @@ -873,7 +945,10 @@ def match_points( if not pre_match: # do optimization on full inputs - return _optimize_pairs(pos1, pos2, threshold) + if use_scipy: + return _optimize_pairs_scipy(pos1, pos2) + else: + return _optimize_pairs(pos1, pos2, threshold) # extract the indices of zero-pairs and remaining points unpaired1_indices, unpaired2_indices, paired_indices = ( @@ -894,7 +969,10 @@ def match_points( pos2 = pos2[unpaired2_indices] n_rows = pos1.shape[0] - matches = _optimize_pairs(pos1, pos2, threshold) + if use_scipy: + matches = _optimize_pairs_scipy(pos1, pos2) + else: + matches = _optimize_pairs(pos1, pos2, threshold) # map extracted full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64) diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index f0ba023..add83bd 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -19,76 +19,85 @@ def as_cell(x: List[float]): return cells +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_equal_size(pre_match): +def test_cell_matches_equal_size(pre_match, use_scipy): a = as_cell([10, 20, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_other(pre_match): +def test_cell_matches_larger_other(pre_match, use_scipy): a = as_cell([1, 12, 100, 80]) b = as_cell([5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [2] assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42, 41]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [3] assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_cells(pre_match): +def test_cell_matches_larger_cells(pre_match, use_scipy): a = as_cell([5, 15, 25, 35, 100]) b = as_cell([1, 12, 100, 80]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert a_ == [2] assert not b_ assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_threshold(pre_match): +def test_cell_matches_threshold(pre_match, use_scipy): a = as_cell([10, 12, 100, 80]) b = as_cell([0, 5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab a_, ab, b_ = match_cells( - a, b, threshold=math.sqrt(3) * 11, pre_match=pre_match + a, + b, + threshold=math.sqrt(3) * 11, + pre_match=pre_match, + use_scipy=use_scipy, ) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_global_optimum_with_threshold_original_pr(pre_match): +def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): cells1 = [ Cell((0, 0, 0), Cell.UNKNOWN), Cell((12, 0, 0), Cell.UNKNOWN), @@ -101,7 +110,11 @@ def test_global_optimum_with_threshold_original_pr(pre_match): # without threshold, the global optimum pars points (0, 10), (12, 22) at a # global cost of 20. The other pairing would have cost of 24 missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=np.inf, pre_match=pre_match + cells1, + cells2, + threshold=np.inf, + pre_match=pre_match, + use_scipy=use_scipy, ) assert not missing_c1 assert not missing_c2 @@ -110,79 +123,115 @@ def test_global_optimum_with_threshold_original_pr(pre_match): # with threshold, the previous pairing should not be considered good. # Instead, only (12, 10) is a good match. So while total cost is 24, # we only care about the cost of 2 during the matching algorithm + # BUT, only when not using scipy. With scipy it doesn't account for + # threshold during the matching, so it'll do it after missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=5, pre_match=pre_match + cells1, cells2, threshold=5, pre_match=pre_match, use_scipy=use_scipy ) - # before we added the threshold to match_points, the following applies - # assert missing_c1 == [0, 1] - # assert missing_c2 == [0, 1] - # assert not good_matches - # with threshold in match_points, this is true - as wanted - assert missing_c1 == [ - 0, - ] - assert missing_c2 == [ - 1, - ] - assert good_matches == [[1, 0]] - - + # before we added the threshold to match_points, the following applies to + # both scipy and our own. After the fix, this is only True for scipy + if use_scipy: + assert missing_c1 == [0, 1] + assert missing_c2 == [0, 1] + assert not good_matches + else: + # with threshold in match_points, this is true - as wanted + assert missing_c1 == [ + 0, + ] + assert missing_c2 == [ + 1, + ] + assert good_matches == [[1, 0]] + + +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_rows_greater_than_cols(pre_match): +def test_rows_greater_than_cols(pre_match, use_scipy): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((4, 3)), pre_match=pre_match) + match_points( + np.zeros((5, 3)), + np.zeros((4, 3)), + pre_match=pre_match, + use_scipy=use_scipy, + ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_unequal_inputs_shape(pre_match): +def test_unequal_inputs_shape(pre_match, use_scipy): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((5, 2)), pre_match=pre_match) + match_points( + np.zeros((5, 3)), + np.zeros((5, 2)), + pre_match=pre_match, + use_scipy=use_scipy, + ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_bad_input_shape(pre_match): +def test_bad_input_shape(pre_match, use_scipy): # has to be 2 dims with pytest.raises(ValueError): - match_points(np.zeros(5), np.zeros(5), pre_match=pre_match) + match_points( + np.zeros(5), np.zeros(5), pre_match=pre_match, use_scipy=use_scipy + ) with pytest.raises(ValueError): match_points( - np.zeros((5, 4, 6)), np.zeros((5, 4, 6)), pre_match=pre_match + np.zeros((5, 4, 6)), + np.zeros((5, 4, 6)), + pre_match=pre_match, + use_scipy=use_scipy, ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_progress_already_running(pre_match): +def test_progress_already_running(pre_match, use_scipy): a = as_cell([10, 12]) b = as_cell([10, 12]) cell_utils.__progress_update.updater = 1 try: with pytest.raises(TypeError): - match_cells(a, b, pre_match=pre_match) + match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) finally: cell_utils.__progress_update.updater = None +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_distance_too_large(pre_match): +def test_distance_too_large(pre_match, use_scipy): a = np.array([[1, 2, 3]]).T b = np.array([[1, 2, np.inf]]).T with pytest.raises(ValueError): - match_points(a, b, pre_match=pre_match) + match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_identical_points(pre_match): +def test_contains_identical_points(pre_match, use_scipy): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[5, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) assert np.array_equal(matching, [1, 0, 2]) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_only_identical_points(pre_match): +def test_contains_only_identical_points(pre_match, use_scipy): a = np.array([[1, 2, 3]]).T b = np.array([[2, 3, 5, 1]]).T - matching = match_points(a, b, pre_match=pre_match) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) assert np.array_equal(matching, [3, 0, 1]) + + +@pytest.mark.parametrize("use_scipy", [True, False]) +@pytest.mark.parametrize("pre_match", [True, False]) +def test_contains_no_identical_points(pre_match, use_scipy): + a = np.array([[1, 10], [5, 7], [22, 12]]) + b = np.array([[6, 7], [7, 1], [21, 10]]) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + assert np.array_equal(matching, [1, 0, 2])