Skip to content

Commit

Permalink
We don't need max. We can just set to threshold.
Browse files Browse the repository at this point in the history
  • Loading branch information
matham committed May 21, 2024
1 parent ccd70d9 commit 83bb745
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
21 changes: 5 additions & 16 deletions brainglobe_utils/cells/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,10 @@ def match_points(
# based on https://en.wikipedia.org/wiki/Hungarian_algorithm
pos1 = pos1.astype(np.float64)
pos2 = pos2.astype(np.float64)
# numba pre-checks that arrays are at least 2-dims. Us checking would be
# too late and never invoked

if len(pos1.shape) != 2 or len(pos2.shape) != 2:
if pos1.ndim != 2 or pos2.ndim != 2:
raise ValueError("The input arrays must have exactly 2 dimensions")

n_rows = pos1.shape[0]
Expand All @@ -594,20 +596,7 @@ def match_points(
if pos1.shape[1] != pos2.shape[1]:
raise ValueError("The two inputs have different number of columns")

inf_dist = 0
have_threshold = threshold != np.inf
# If we use a threshold, find the largest enclosing (hyper) cube and use
# the distance between two opposing corners as the maximum distance we
# can ever see. Use that as dist of points further than threshold
if have_threshold:
# for each col, find the range of points and pick greatest col
largest_side = 0
for i in range(pos1.shape[1]):
bottom = min(np.min(pos1[:, i]), np.min(pos2[:, i]))
top = max(np.max(pos1[:, i]), np.max(pos2[:, i]))
largest_side = max(largest_side, top - bottom)
# make cube using the largest col range
inf_dist = math.sqrt(pos1.shape[1]) * (largest_side + 1)

potentials_rows = np.zeros(n_rows)
potentials_cols = np.zeros(n_cols + 1)
Expand Down Expand Up @@ -643,8 +632,8 @@ def match_points(
raise ValueError(
"The distance between point is too large"
)
if have_threshold and dist >= threshold:
dist = inf_dist
if have_threshold and dist > threshold:
dist = threshold

cur = (
dist
Expand Down
7 changes: 6 additions & 1 deletion tests/tests/test_cells/test_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ def test_unequal_inputs_shape():


def test_bad_input_shape():
with pytest.raises(ValueError):
# we want to check that a 1-dim array is not accepted. But, numba checks
# the inputs for at least 2-dims because it knows we access the 2dn dim.
# So we have no chance to raise an error ourself. So check numba's error
import numba.core.errors

with pytest.raises(numba.core.errors.TypingError):
match_points(np.zeros(5), np.zeros(5))

with pytest.raises(ValueError):
Expand Down

0 comments on commit 83bb745

Please sign in to comment.