From 2c66eeaa87976ccc179f09a0d332edabc3579c07 Mon Sep 17 00:00:00 2001 From: Martin Weigert Date: Tue, 15 Oct 2024 16:41:02 +0200 Subject: [PATCH] retrurn idx in nms; Add scores (#19) --- spotiflow/lib/point_nms.cpp | 2 +- spotiflow/utils/peaks.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/spotiflow/lib/point_nms.cpp b/spotiflow/lib/point_nms.cpp index d86d6fb..d49ab4a 100644 --- a/spotiflow/lib/point_nms.cpp +++ b/spotiflow/lib/point_nms.cpp @@ -145,7 +145,7 @@ static PyObject *c_point_nms_2d(PyObject *self, PyObject *args) const long k = ret_matches[j].first; const float dist = ret_matches[j].second; if ((k != i) && (dist < min_distance_squared)) { - // std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " (y,x) = " << *(float *)PyArray_GETPTR2(points, k, 0) << " distance " << dist << std::endl; + // std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " " << *(float *)PyArray_GETPTR2(points, k, 1) << " distance " << dist << std::endl; suppressed[k] = true; } } diff --git a/spotiflow/utils/peaks.py b/spotiflow/utils/peaks.py index 68d3b0b..c4017d9 100644 --- a/spotiflow/utils/peaks.py +++ b/spotiflow/utils/peaks.py @@ -60,13 +60,14 @@ def nms_points_2d( if not scores.ndim == 1: raise ValueError("scores must be a array of shape (N,)") - idx = np.argsort(scores, kind="stable") + idx = np.argsort(scores, kind="stable")[::-1] points = points[idx] scores = scores[idx] - + points = np.ascontiguousarray(points, dtype=np.float32) inds = c_point_nms_2d(points, np.float32(min_distance)) - return points[inds].copy() + inds = idx[inds] + return inds def nms_points_3d( points: np.ndarray, scores: np.ndarray = None, min_distance: int = 2 @@ -100,13 +101,14 @@ def nms_points_3d( if not scores.ndim == 1: raise ValueError("scores must be a array of shape (N,)") - idx = np.argsort(scores, kind="stable") + idx = np.argsort(scores, kind="stable")[::-1] points = points[idx] scores = scores[idx] points = np.ascontiguousarray(points, dtype=np.float32) inds = c_point_nms_3d(points, np.float32(min_distance)) - return points[inds].copy() + inds = idx[inds] + return inds def maximum_filter_2d(image: np.ndarray, kernel_size: int = 3) -> np.ndarray: @@ -398,6 +400,7 @@ def local_peaks( exclude_border=True, threshold_abs=None, threshold_rel=None, + use_score:bool=False ): if not image.ndim in [2, 3]: raise ValueError("Image must be 2D") @@ -427,8 +430,14 @@ def local_peaks( coord = np.nonzero(mask) coord = np.stack(coord, axis=1) - points = nms_fun(coord, min_distance=min_distance) - return points + if use_score: + scores = image[mask] if mask.sum() > 0 else None + else: + scores = None + + idx = nms_fun(coord, scores=scores, min_distance=min_distance) + coord = coord[idx].copy() + return coord