Skip to content

Commit

Permalink
Remove scipy as an option and fill in scipy as a test.
Browse files Browse the repository at this point in the history
  • Loading branch information
matham committed Jun 3, 2024
1 parent 4e73c40 commit baea2ad
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 155 deletions.
90 changes: 5 additions & 85 deletions brainglobe_utils/cells/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
from numba import njit, objmode
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm


Expand Down Expand Up @@ -439,7 +438,6 @@ def match_cells(
other: List[Cell],
threshold: float = np.inf,
pre_match: bool = True,
use_scipy: bool = False,
) -> Tuple[List[int], List[Tuple[int, int]], List[int]]:
"""
Given two lists of cells. It finds a pairing of cells from `cells` and
Expand Down Expand Up @@ -482,11 +480,6 @@ def match_cells(
This will significantly speed up the matching, if there are pairs of
cells on top of each other in each set.
use_scipy : bool, optional. Defaults to False.
Whether to use scipy `linear_sum_assignment` to find the optimal
matching. Otherwise, we use our own implementation.
See `match_points` for consideration details.
Returns
-------
Expand All @@ -513,15 +506,12 @@ def match_cells(
if flip:
c1, c2 = c2, c1

progress = None
if not use_scipy:
# with scipy we don't have callbacks so no updates
progress = tqdm(desc="Matching cells", total=len(c1), unit="cells")
__progress_update.updater = progress.update
progress = tqdm(desc="Matching cells", total=len(c1), unit="cells")
__progress_update.updater = progress.update

# for each index corresponding to c1, returns the index in c2 that matches
try:
assignment = match_points(c1, c2, threshold, pre_match, use_scipy)
assignment = match_points(c1, c2, threshold, pre_match)
finally:
__progress_update.updater = None

Expand Down Expand Up @@ -816,62 +806,11 @@ def _optimize_pairs(
return matches


def _optimize_pairs_scipy(
pos1: np.ndarray,
pos2: np.ndarray,
) -> np.ndarray:
"""
Implements `match_points` using scipy's `linear_sum_assignment`.
The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`,
approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k`
points, we need 16GB. For `N=M=75.7k` points, we need 128GB.
Parameters
----------
pos1 : np.ndarray
2D array of NxK.
pos2 : np.ndarray
2D array of MxK.
The relationship N <= M must be true and K must be the same for both.
Returns
-------
matches : np.ndarray
1D array of length N, where index i in matches corresponds
to index i in `pos1` and its value is the index in pos2
of its best match.
"""
# we don't check for boundary conditions, just assert because it should be
# checked by caller (match_points)
n_rows = pos1.shape[0]
n_cols = pos2.shape[0]
assert len(pos1.shape) == 2
assert len(pos2.shape) == 2
assert pos1.shape[1] == pos2.shape[1]
assert n_rows <= n_cols

# Mxk -> M1K
pos1 = pos1[:, np.newaxis, :]
# Nxk -> 1NK
pos2 = pos2[np.newaxis, :, :]
# dist is MNK
dist = pos1 - pos2
# cost is MN
cost = np.sqrt(np.sum(np.square(dist), axis=2))
# result is sorted by rows
rows, cols = linear_sum_assignment(cost)
# M <= N, so cols and rows is size M
return cols


def match_points(
pos1: np.ndarray,
pos2: np.ndarray,
threshold: float = np.inf,
pre_match: bool = True,
use_scipy: bool = False,
) -> np.ndarray:
"""
Given two arrays, each a list of position. For each point in `pos1` it
Expand Down Expand Up @@ -914,19 +853,6 @@ def match_points(
If True, it'll significantly speed up the matching, if there are pairs
of points on top of each other across the input lists.
use_scipy : bool, optional. Defaults to False.
Whether to use scipy `linear_sum_assignment` to find the optimal
matching. Otherwise, we use our own implementation.
Our implementation is very memory efficient. Using scipy, we have e.g.
a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`),
16GB for 26.8k points, and 128GB for 75.7k points.
If `pre_match` is used, and it eliminates many zero-distance points,
using numpy is feasable. Otherwise, use our implementation.
Note: When using scipy, we *don't* take the threshold into account
until after the matching is complete.
Returns
-------
Expand All @@ -951,10 +877,7 @@ def match_points(

if not pre_match:
# do optimization on full inputs
if use_scipy:
return _optimize_pairs_scipy(pos1, pos2)
else:
return _optimize_pairs(pos1, pos2, threshold)
return _optimize_pairs(pos1, pos2, threshold)

# extract the indices of zero-pairs and remaining points
unpaired1_indices, unpaired2_indices, paired_indices = (
Expand All @@ -975,10 +898,7 @@ def match_points(
pos2 = pos2[unpaired2_indices]
n_rows = pos1.shape[0]

if use_scipy:
matches = _optimize_pairs_scipy(pos1, pos2)
else:
matches = _optimize_pairs(pos1, pos2, threshold)
matches = _optimize_pairs(pos1, pos2, threshold)

# map extracted
full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64)
Expand Down
Loading

0 comments on commit baea2ad

Please sign in to comment.