Skip to content

Commit

Permalink
Add support for pre-extracting zero distance pairs.
Browse files Browse the repository at this point in the history
  • Loading branch information
matham committed May 23, 2024
1 parent 1bf364a commit a50f4aa
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 82 deletions.
324 changes: 276 additions & 48 deletions brainglobe_utils/cells/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,10 @@ def from_numpy_pos(pos: np.ndarray, cell_type: int) -> List[Cell]:


def match_cells(
cells: List[Cell], other: List[Cell], threshold: float = np.inf
cells: List[Cell],
other: List[Cell],
threshold: float = np.inf,
pre_match: bool = True,
) -> Tuple[List[int], List[Tuple[int, int]], List[int]]:
"""
Given two lists of cells. It finds a pairing of cells from `cells` and
Expand Down Expand Up @@ -470,6 +473,12 @@ def match_cells(
The threshold to use to remove bad matches. Any match pair whose
distance is greater than the threshold will be exluded from the
matching.
pre_match : bool, optional. Defaults to True.
If True, we will (interenally) first efficiently find all the pairs of
`cells` and `others` which are each at the same position in space. Then
we run the optimization to find the best matching on the remaining.
This will significantly speed up the matching, if there are pairs of
cells on top of each other in each set.
Returns
-------
Expand Down Expand Up @@ -500,7 +509,7 @@ def match_cells(
__progress_update.updater = progress.update
# for each index corresponding to c1, returns the index in c2 that matches
try:
assignment = match_points(c1, c2, threshold)
assignment = match_points(c1, c2, threshold, pre_match)
progress.close()
finally:
__progress_update.updater = None
Expand All @@ -526,75 +535,189 @@ def match_cells(
__progress_update.updater = None


def __compare_progress():
def __compare_progress(n: int = 1) -> None:
"""Updates the progress bar by `n`, if there's one set."""
if __progress_update.updater is not None:
__progress_update.updater()
__progress_update.updater(n)


@njit
def match_points(
pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf
) -> np.ndarray:
def _find_pairs_sorted(
pos1: np.ndarray, pos2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Given two arrays, each a list of position. For each point in `pos1` it
finds a point in `pos2` such that the distance between the assigned
matches across all `pos1` is minimized.
Given two sorted arrays, returns all the pairs in the arrays (without
replacement) that are at a `np.isclose` distance of each other.
E.g.::
This is computed in O(N) time.
>>> pos1 = np.array([[20, 10, 30, 40]]).T
>>> pos2 = np.array([[5, 15, 25, 35, 50]]).T
>>> matches = match_points(pos1, pos2)
>>> matches
array([1, 0, 2, 3])
Parameters
----------
pos1 : Sorted (1st axis) np.ndarray of shape `MxK`.
pos2 : Sorted (1st axis) np.ndarray of shape `NxK`.
Returns
-------
tuple :
used_mask_1 : Bool np.ndarray of size `M`.
It's True at indices of the rows in `pos1` used up in the pairing.
used_mask_2 : Bool np.ndarray of size `N`.
It's True at indices of the rows in `pos2` used up in the pairing.
paired_indices: np.ndarray of shape `Rx2`.
Each row is a pair of indices to `pos1` and `pos2`, respectively.
Indicating a pair that is `close` to each other.
"""
# mask of pos1/pos2 for the elements used in a identical par
used_mask_1 = np.zeros(pos1.shape[0], dtype=np.bool_)
used_mask_2 = np.zeros(pos2.shape[0], dtype=np.bool_)
n_cols = pos1.shape[1]

# the pos1/pos2 indices for each pair - at most this many pairs
max_n = min(pos1.shape[0], pos2.shape[0])
paired_indices = np.zeros((max_n, 2), dtype=np.int64)

# how many pairs found
used_n = 0
# next index to check for pair for pos1/2
pos1_i = 0
pos2_i = 0

# do this in O(N), until we reach end of either array
while pos1_i < max_n and pos2_i < max_n:
# are the two points the same
same = True
for i in range(n_cols):
same = same and np.isclose(pos1[pos1_i, i], pos2[pos2_i, i])

# they match
if same:
used_mask_1[pos1_i] = True
used_mask_2[pos2_i] = True
paired_indices[used_n, 0] = pos1_i
paired_indices[used_n, 1] = pos2_i
used_n += 1
pos1_i += 1
pos2_i += 1
else:
# the points are not the same in at least one dim, which is less?
one_is_less = True
for i in range(n_cols):
# for dims until this one (if any), they are the same
if pos1[pos1_i, i] < pos2[pos2_i, i]: # first is less
break
elif pos1[pos1_i, i] > pos2[pos2_i, i]: # second is less
one_is_less = False
break
# they were the same in this axis as well
else: # pragma: no cover
assert False, "at least in one dim it should be different"

if one_is_less:
# first is less than second, advance first by one
pos1_i += 1
else:
# second is less than first, advance second by one
pos2_i += 1

return used_mask_1, used_mask_2, paired_indices[:used_n, :]


def _find_identical_points(
pos1: np.ndarray, pos2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Given two arrays, returns the set of pairs of points in the arrays
(without replacement) that are at a `np.isclose` distance of each other.
This is computed in O(NlogN) time, dominated by sorting (internally).
Parameters
----------
pos1 : np.ndarray of shape `MxK`.
pos2 : np.ndarray of shape `NxK`.
Returns
-------
tuple :
unpaired1_indices : np.ndarray of size `M - R`.
Array of indices of `pos1` that were not used in any pairs
unpaired2_indices : np.ndarray of size `N - R`.
Array of indices of `pos2` that were not used in any pairs.
paired_indices: np.ndarray of shape `Rx2`.
Each row is a pair of indices to `pos1` and `pos2`, respectively.
Indicating a pair that is `close` to each other.
"""
# sort pos1 and pos2 rows, ordered by columns 1..N.
# lexsort uses rows as keys and sorts the keys from last to first.
# So flip order of rows and then transpose
indices = np.lexsort(np.flip(pos1, axis=1).transpose())
pos1_sorted = pos1[indices, :]
# original pos1 indices of the sorted elements
orig1_indices = np.arange(len(pos1), dtype=np.int64)[indices]

indices = np.lexsort(np.flip(pos2, axis=1).transpose())
pos2_sorted = pos2[indices, :]
orig2_indices = np.arange(len(pos2), dtype=np.int64)[indices]

# get the zero distance pairs
used_mask_1, used_mask_2, paired_indices = _find_pairs_sorted(
pos1_sorted, pos2_sorted
)

# convert the indices back to the original unsorted indices
unpaired1_indices = orig1_indices[np.logical_not(used_mask_1)]
unpaired2_indices = orig2_indices[np.logical_not(used_mask_2)]
paired_indices[:, 0] = orig1_indices[paired_indices[:, 0]]
paired_indices[:, 1] = orig2_indices[paired_indices[:, 1]]

return unpaired1_indices, unpaired2_indices, paired_indices


@njit
def _optimize_pairs(
pos1: np.ndarray,
pos2: np.ndarray,
threshold: float,
) -> np.ndarray:
"""
Implements `match_points` using
https://en.wikipedia.org/wiki/Hungarian_algorithm.
Parameters
----------
pos1 : np.ndarray
2D array of NxK. Where N is number of positions and K is the number
of dimensions (e.g. 3 for x, y, z).
2D array of NxK.
pos2 : np.ndarray
2D array of MxK. Where M is number of positions and K is the number
of dimensions (e.g. 3 for x, y, z).
2D array of MxK.
The relationship N <= M must be true.
The relationship N <= M must be true and K must be the same for both.
threshold : float, optional. Defaults to np.inf.
The threshold to use to consider a pair a bad match. Any match pair
whose distance is greater or equal to the threshold will be considered
to be at great distance to each other.
to be at threshold distance (i.e. max distance).
It'll still show up in the matching, but it will have the least
priority for a match because that match will not reduce the overall
cost across all points.
Use `analyze_point_matches` subsequently to remove the "bad" matches.
Returns
-------
matches : np.ndarray
1D array of length N. Each index i in matches corresponds
to index i in `pos1`. The value of index i in matches is the index
j in pos2 that is the best match for that pos1.
I.e. the match is (pos1[i], pos2[matches[i]]).
1D array of length N, where index i in matches corresponds
to index i in `pos1` and its value is the index in pos2
of its best match.
"""
# based on https://en.wikipedia.org/wiki/Hungarian_algorithm
pos1 = pos1.astype(np.float64)
pos2 = pos2.astype(np.float64)
# numba pre-checks that arrays are at least 2-dims. Us checking would be
# too late and never invoked

if pos1.ndim != 2 or pos2.ndim != 2:
raise ValueError("The input arrays must have exactly 2 dimensions")

# we don't check for boundary conditions, just assert because it should be
# checked by caller (match_points)
n_rows = pos1.shape[0]
n_cols = pos2.shape[0]
if n_rows > n_cols:
raise ValueError(
"The length of pos1 must be less than or equal to length of pos2"
)
if pos1.shape[1] != pos2.shape[1]:
raise ValueError("The two inputs have different number of columns")
assert len(pos1.shape) == 2
assert len(pos2.shape) == 2
assert pos1.shape[1] == pos2.shape[1]
assert n_rows <= n_cols

pos1 = pos1.astype(np.float64)
pos2 = pos2.astype(np.float64)

have_threshold = threshold != np.inf

Expand Down Expand Up @@ -626,9 +749,10 @@ def match_points(
for col_i in range(n_cols):
if not col_used[col_i]:
# use sqrt to match threshold which is in actual distance
dist = np.sqrt(
np.sum(np.square(pos1[row_cur, :] - pos2[col_i, :]))
)
dist = 0.0
for i in range(pos1.shape[1]):
dist += math.pow(pos1[row_cur, i] - pos2[col_i, i], 2)
dist = math.sqrt(dist)
if dist == np.inf:
raise ValueError(
"The distance between point is too large"
Expand Down Expand Up @@ -670,14 +794,118 @@ def match_points(
__compare_progress()

# compute match from assignment
matches = np.empty(n_rows, dtype=np.int_)
matches = np.empty(n_rows, dtype=np.int64)
for i in range(n_cols):
if assignment_row[i] != -1:
matches[assignment_row[i]] = i

return matches


def match_points(
pos1: np.ndarray,
pos2: np.ndarray,
threshold: float = np.inf,
pre_match: bool = True,
) -> np.ndarray:
"""
Given two arrays, each a list of position. For each point in `pos1` it
finds a point in `pos2` such that the distance between the assigned
matches across all `pos1` is minimized.
E.g.::
>>> pos1 = np.array([[20, 10, 30, 40]]).T
>>> pos2 = np.array([[5, 15, 25, 35, 50]]).T
>>> matches = match_points(pos1, pos2)
>>> matches
array([1, 0, 2, 3])
Parameters
----------
pos1 : np.ndarray
2D array of NxK. Where N is number of positions and K is the number
of dimensions (e.g. 3 for x, y, z).
pos2 : np.ndarray
2D array of MxK. Where M is number of positions and K is the number
of dimensions (e.g. 3 for x, y, z).
The relationship N <= M must be true.
threshold : float, optional. Defaults to np.inf.
The threshold to use to consider a pair a bad match. Any match pair
whose distance is greater or equal to the threshold will be considered
to be at threshold distance (i.e. the max distance).
It'll still show up in the matching, but it will have the least
priority for a match because that match will not reduce the overall
cost across all points.
Use `analyze_point_matches` with the same threshold subsequently to
remove the "bad" matches.
pre_match : bool, optional. Defaults to True.
If True, we will (interenally) first efficiently find all the pairs of
`pos1` and `pos2` which are each at the same position in space. Then
we run the optimization to find the best matching on the remaining.
If True, it'll significantly speed up the matching, if there are pairs
of points on top of each other across the input lists.
Returns
-------
matches : np.ndarray
1D array of length N. Each index i in matches corresponds
to index i in `pos1`. The value of index i in matches is the index
j in pos2 that is the best match for that pos1.
I.e. the match is (pos1[i], pos2[matches[i]]).
"""
if len(pos1.shape) != 2 or len(pos2.shape) != 2:
raise ValueError("The input arrays must have exactly 2 dimensions")

n_rows = pos1.shape[0]
n_cols = pos2.shape[0]
if n_rows > n_cols:
raise ValueError(
"The length of pos1 must be less than or equal to length of pos2"
)
if pos1.shape[1] != pos2.shape[1]:
raise ValueError("The two inputs have different number of columns")

if not pre_match:
# do optimization on full inputs
return _optimize_pairs(pos1, pos2, threshold)

# extract the indices of zero-pairs and remaining points
unpaired1_indices, unpaired2_indices, paired_indices = (
_find_identical_points(pos1, pos2)
)
# the number of zero pairs found are done!
__compare_progress(len(paired_indices))

# if everything was zero pairs, we're done!
if not len(unpaired1_indices):
# sort by pos1 and then return the corresponding pos2 indices matching
# those pos1 points
pos1_sorted_indices = np.argsort(paired_indices[:, 0])
return paired_indices[pos1_sorted_indices, 1]

# extract remaining pos1/post2 and run optimization
pos1 = pos1[unpaired1_indices]
pos2 = pos2[unpaired2_indices]
n_rows = pos1.shape[0]

matches = _optimize_pairs(pos1, pos2, threshold)

# map extracted
full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64)
# set pos1 optimized matches to their pos2 indices
full_matches[unpaired1_indices] = unpaired2_indices[matches]
# set the zero pairs pos1 to corresponding pos2 match indices
full_matches[paired_indices[:, 0]] = paired_indices[:, 1]

return full_matches


@njit
def analyze_point_matches(
pos1: np.ndarray,
Expand Down
Loading

0 comments on commit a50f4aa

Please sign in to comment.