diff --git a/tweakwcs/matchutils.py b/tweakwcs/matchutils.py index 19bc7d0..ad4aa04 100644 --- a/tweakwcs/matchutils.py +++ b/tweakwcs/matchutils.py @@ -16,6 +16,7 @@ from astropy.utils.exceptions import AstropyDeprecationWarning from stsci.stimage import xyxymatch +from scipy import spatial from . import __version__ # noqa: F401 @@ -279,8 +280,20 @@ def __call__(self, refcat, imcat, tp_pscale=1.0, tp_units=None, **kwargs): def _xy_2dhist(imgxy, refxy, r): # This code replaces the C version (arrxyzero) from carrutils.c # It is about 5-8 times slower than the C version. - dx = np.subtract.outer(imgxy[:, 0], refxy[:, 0]).ravel() - dy = np.subtract.outer(imgxy[:, 1], refxy[:, 1]).ravel() + + # trim to only pairs within (r+0.5) * np.sqrt(2) using a kdtree + # to avoid computing differences for many widely separated pairs. + kdtree = spatial.KDTree(refxy) + neighbors = kdtree.query_ball_point(imgxy, (r + 0.5) * np.sqrt(2)) + lens = [len(n) for n in neighbors] + mi = np.repeat(np.arange(imgxy.shape[0]), lens) + if len(mi) > 0: + mr = np.concatenate([n for n in neighbors if len(n) > 0]) + else: + mr = mi.copy() + + dx = imgxy[mi, 0] - refxy[mr, 0] + dy = imgxy[mi, 1] - refxy[mr, 1] idx = np.where((dx < r + 0.5) & (dx >= -r - 0.5) & (dy < r + 0.5) & (dy >= -r - 0.5)) r = int(np.ceil(r))