Skip to content

Commit

Permalink
Add option to use scipy optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
matham committed May 23, 2024
1 parent a50f4aa commit e1ca972
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 47 deletions.
84 changes: 81 additions & 3 deletions brainglobe_utils/cells/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
from numba import njit, objmode
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm


Expand Down Expand Up @@ -438,6 +439,7 @@ def match_cells(
other: List[Cell],
threshold: float = np.inf,
pre_match: bool = True,
use_scipy: bool = False,
) -> Tuple[List[int], List[Tuple[int, int]], List[int]]:
"""
Given two lists of cells. It finds a pairing of cells from `cells` and
Expand Down Expand Up @@ -477,8 +479,14 @@ def match_cells(
If True, we will (interenally) first efficiently find all the pairs of
`cells` and `others` which are each at the same position in space. Then
we run the optimization to find the best matching on the remaining.
This will significantly speed up the matching, if there are pairs of
cells on top of each other in each set.
use_scipy : bool, optional. Defaults to False.
Whether to use scipy `linear_sum_assignment` to find the optimal
matching. Otherwise, we use our own implementation.
See `match_points` for consideration details.
Returns
-------
Expand Down Expand Up @@ -509,7 +517,7 @@ def match_cells(
__progress_update.updater = progress.update
# for each index corresponding to c1, returns the index in c2 that matches
try:
assignment = match_points(c1, c2, threshold, pre_match)
assignment = match_points(c1, c2, threshold, pre_match, use_scipy)
progress.close()
finally:
__progress_update.updater = None
Expand Down Expand Up @@ -802,11 +810,62 @@ def _optimize_pairs(
return matches


def _optimize_pairs_scipy(
pos1: np.ndarray,
pos2: np.ndarray,
) -> np.ndarray:
"""
Implements `match_points` using scipy's `linear_sum_assignment`.
The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`,
approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k`
points, we need 16GB. For `N=M=75.7k` points, we need 128GB.
Parameters
----------
pos1 : np.ndarray
2D array of NxK.
pos2 : np.ndarray
2D array of MxK.
The relationship N <= M must be true and K must be the same for both.
Returns
-------
matches : np.ndarray
1D array of length N, where index i in matches corresponds
to index i in `pos1` and its value is the index in pos2
of its best match.
"""
# we don't check for boundary conditions, just assert because it should be
# checked by caller (match_points)
n_rows = pos1.shape[0]
n_cols = pos2.shape[0]
assert len(pos1.shape) == 2
assert len(pos2.shape) == 2
assert pos1.shape[1] == pos2.shape[1]
assert n_rows <= n_cols

# Mxk -> M1K
pos1 = pos1[:, np.newaxis, :]
# Nxk -> 1NK
pos2 = pos2[np.newaxis, :, :]
# dist is MNK
dist = pos1 - pos2
# cost is MN
cost = np.sqrt(np.sum(np.square(dist), axis=2))
# result is sorted by rows
rows, cols = linear_sum_assignment(cost)
# M <= N, so cols and rows is size M
return cols


def match_points(
pos1: np.ndarray,
pos2: np.ndarray,
threshold: float = np.inf,
pre_match: bool = True,
use_scipy: bool = False,
) -> np.ndarray:
"""
Given two arrays, each a list of position. For each point in `pos1` it
Expand Down Expand Up @@ -849,6 +908,19 @@ def match_points(
If True, it'll significantly speed up the matching, if there are pairs
of points on top of each other across the input lists.
use_scipy : bool, optional. Defaults to False.
Whether to use scipy `linear_sum_assignment` to find the optimal
matching. Otherwise, we use our own implementation.
Our implementation is very memory efficient. Using scipy, we have e.g.
a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`),
16GB for 26.8k points, and 128GB for 75.7k points.
If `pre_match` is used, and it eliminates many zero-distance points,
using numpy is feasable. Otherwise, use our implementation.
Note: When using scipy, we *don't* take the threshold into account
until after the matching is complete.
Returns
-------
Expand All @@ -873,7 +945,10 @@ def match_points(

if not pre_match:
# do optimization on full inputs
return _optimize_pairs(pos1, pos2, threshold)
if use_scipy:
return _optimize_pairs_scipy(pos1, pos2)
else:
return _optimize_pairs(pos1, pos2, threshold)

# extract the indices of zero-pairs and remaining points
unpaired1_indices, unpaired2_indices, paired_indices = (
Expand All @@ -894,7 +969,10 @@ def match_points(
pos2 = pos2[unpaired2_indices]
n_rows = pos1.shape[0]

matches = _optimize_pairs(pos1, pos2, threshold)
if use_scipy:
matches = _optimize_pairs_scipy(pos1, pos2)
else:
matches = _optimize_pairs(pos1, pos2, threshold)

# map extracted
full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64)
Expand Down
Loading

0 comments on commit e1ca972

Please sign in to comment.