Skip to content

Commit

Permalink
retrurn idx in nms; Add scores (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
maweigert authored Oct 15, 2024
1 parent 87b3f4b commit 2c66eea
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion spotiflow/lib/point_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ static PyObject *c_point_nms_2d(PyObject *self, PyObject *args)
const long k = ret_matches[j].first;
const float dist = ret_matches[j].second;
if ((k != i) && (dist < min_distance_squared)) {
// std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " (y,x) = " << *(float *)PyArray_GETPTR2(points, k, 0) << " distance " << dist << std::endl;
// std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " " << *(float *)PyArray_GETPTR2(points, k, 1) << " distance " << dist << std::endl;
suppressed[k] = true;
}
}
Expand Down
23 changes: 16 additions & 7 deletions spotiflow/utils/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def nms_points_2d(
if not scores.ndim == 1:
raise ValueError("scores must be a array of shape (N,)")

idx = np.argsort(scores, kind="stable")
idx = np.argsort(scores, kind="stable")[::-1]
points = points[idx]
scores = scores[idx]

points = np.ascontiguousarray(points, dtype=np.float32)
inds = c_point_nms_2d(points, np.float32(min_distance))
return points[inds].copy()
inds = idx[inds]
return inds

def nms_points_3d(
points: np.ndarray, scores: np.ndarray = None, min_distance: int = 2
Expand Down Expand Up @@ -100,13 +101,14 @@ def nms_points_3d(
if not scores.ndim == 1:
raise ValueError("scores must be a array of shape (N,)")

idx = np.argsort(scores, kind="stable")
idx = np.argsort(scores, kind="stable")[::-1]
points = points[idx]
scores = scores[idx]

points = np.ascontiguousarray(points, dtype=np.float32)
inds = c_point_nms_3d(points, np.float32(min_distance))
return points[inds].copy()
inds = idx[inds]
return inds

def maximum_filter_2d(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:

Expand Down Expand Up @@ -398,6 +400,7 @@ def local_peaks(
exclude_border=True,
threshold_abs=None,
threshold_rel=None,
use_score:bool=False
):
if not image.ndim in [2, 3]:
raise ValueError("Image must be 2D")
Expand Down Expand Up @@ -427,8 +430,14 @@ def local_peaks(
coord = np.nonzero(mask)
coord = np.stack(coord, axis=1)

points = nms_fun(coord, min_distance=min_distance)
return points
if use_score:
scores = image[mask] if mask.sum() > 0 else None
else:
scores = None

idx = nms_fun(coord, scores=scores, min_distance=min_distance)
coord = coord[idx].copy()
return coord



Expand Down

0 comments on commit 2c66eea

Please sign in to comment.