From 83bb74550e7ad20e77361cddc7bbf225cb5495ca Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 20 May 2024 22:17:29 -0400 Subject: [PATCH] We don't need max. We can just set to threshold. --- brainglobe_utils/cells/cells.py | 21 +++++---------------- tests/tests/test_cells/test_matches.py | 7 ++++++- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 75cdc28..58616ba 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -581,8 +581,10 @@ def match_points( # based on https://en.wikipedia.org/wiki/Hungarian_algorithm pos1 = pos1.astype(np.float64) pos2 = pos2.astype(np.float64) + # numba pre-checks that arrays are at least 2-dims. Us checking would be + # too late and never invoked - if len(pos1.shape) != 2 or len(pos2.shape) != 2: + if pos1.ndim != 2 or pos2.ndim != 2: raise ValueError("The input arrays must have exactly 2 dimensions") n_rows = pos1.shape[0] @@ -594,20 +596,7 @@ def match_points( if pos1.shape[1] != pos2.shape[1]: raise ValueError("The two inputs have different number of columns") - inf_dist = 0 have_threshold = threshold != np.inf - # If we use a threshold, find the largest enclosing (hyper) cube and use - # the distance between two opposing corners as the maximum distance we - # can ever see. Use that as dist of points further than threshold - if have_threshold: - # for each col, find the range of points and pick greatest col - largest_side = 0 - for i in range(pos1.shape[1]): - bottom = min(np.min(pos1[:, i]), np.min(pos2[:, i])) - top = max(np.max(pos1[:, i]), np.max(pos2[:, i])) - largest_side = max(largest_side, top - bottom) - # make cube using the largest col range - inf_dist = math.sqrt(pos1.shape[1]) * (largest_side + 1) potentials_rows = np.zeros(n_rows) potentials_cols = np.zeros(n_cols + 1) @@ -643,8 +632,8 @@ def match_points( raise ValueError( "The distance between point is too large" ) - if have_threshold and dist >= threshold: - dist = inf_dist + if have_threshold and dist > threshold: + dist = threshold cur = ( dist diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index 8543f8c..a7a1aa3 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -131,7 +131,12 @@ def test_unequal_inputs_shape(): def test_bad_input_shape(): - with pytest.raises(ValueError): + # we want to check that a 1-dim array is not accepted. But, numba checks + # the inputs for at least 2-dims because it knows we access the 2dn dim. + # So we have no chance to raise an error ourself. So check numba's error + import numba.core.errors + + with pytest.raises(numba.core.errors.TypingError): match_points(np.zeros(5), np.zeros(5)) with pytest.raises(ValueError):