diff --git a/src/traccuracy/matchers/_compute_overlap.py b/src/traccuracy/matchers/_compute_overlap.py index ee904736..d85b860e 100644 --- a/src/traccuracy/matchers/_compute_overlap.py +++ b/src/traccuracy/matchers/_compute_overlap.py @@ -19,18 +19,24 @@ def _union_slice(a: Tuple[slice], b: Tuple[slice]): return tuple(slice(start, stop) for start, stop in zip(starts, stops)) -def get_labels_with_overlap(gt_frame, res_frame): +def get_labels_with_overlap( + gt_frame, + res_frame, + overlap="iou", +): """Get all labels IDs in gt_frame and res_frame whose bounding boxes overlap. Args: gt_frame (np.ndarray): ground truth segmentation for a single frame res_frame (np.ndarray): result segmentation for a given frame + overlap (str, optional): Choose between intersection-over-ground-truth (``iogt``) + or intersection-over-union (``iou``). Returns: overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes - intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area) + overlaps: List[float], list of IoGT/IoU values for each overlapping pair """ gt_frame = gt_frame.astype(np.uint16, copy=False) res_frame = res_frame.astype(np.uint16, copy=False) @@ -64,16 +70,23 @@ def get_labels_with_overlap(gt_frame, res_frame): overlapping_gt_labels = gt_box_labels[ind_gt] overlapping_res_labels = res_box_labels[ind_res] - intersections_over_gt = [] + overlaps = [] for i, j in zip(ind_gt, ind_res): sslice = _union_slice(gt_props[i].slice, res_props[j].slice) gt_mask = gt_frame[sslice] == gt_box_labels[i] res_mask = res_frame[sslice] == res_box_labels[j] area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask)) - area_gt = np.count_nonzero(gt_mask) - intersections_over_gt.append(area_inter / area_gt) - return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt + if overlap == "iou": + area_union = np.count_nonzero(np.logical_or(gt_mask, res_mask)) + overlaps.append(area_inter / area_union) + elif overlap == "iogt": + area_gt = np.count_nonzero(gt_mask) + overlaps.append(area_inter / area_gt) + else: + raise ValueError(f"Unknown overlap type: {overlap}") + + return overlapping_gt_labels, overlapping_res_labels, overlaps def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray: diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 084f50a1..ffb0d742 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -84,7 +84,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): overlapping_gt_labels, overlapping_pred_labels, intersection, - ) = get_labels_with_overlap(gt_frame, pred_frame) + ) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt") for i in range(len(overlapping_gt_labels)): gt_label = overlapping_gt_labels[i] diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 973e8ef4..28bb0e4f 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -35,16 +35,15 @@ def _match_nodes(gt, res, threshold=0.5, one_to_one=False): # casting to int to avoid issue #152 (result is float with numpy<2, dtype=uint64) iou = np.zeros((int(np.max(gt) + 1), int(np.max(res) + 1))) - overlapping_gt_labels, overlapping_res_labels, _ = get_labels_with_overlap(gt, res) - - for index in range(len(overlapping_gt_labels)): - iou_gt_idx = overlapping_gt_labels[index] - iou_res_idx = overlapping_res_labels[index] - intersection = np.logical_and(gt == iou_gt_idx, res == iou_res_idx) - union = np.logical_or(gt == iou_gt_idx, res == iou_res_idx) - iou_value = intersection.sum() / union.sum() - if iou_value >= threshold: - iou[iou_gt_idx, iou_res_idx] = iou_value + overlapping_gt_labels, overlapping_res_labels, ious = get_labels_with_overlap( + gt, res, overlap="iou" + ) + + for gt_label, res_label, iou_val in zip( + overlapping_gt_labels, overlapping_res_labels, ious + ): + if iou_val >= threshold: + iou[gt_label, res_label] = iou_val if one_to_one: pairs = _one_to_one_assignment(iou) diff --git a/tests/matchers/test_compute_overlap.py b/tests/matchers/test_compute_overlap.py index 18ec557a..78d5a530 100644 --- a/tests/matchers/test_compute_overlap.py +++ b/tests/matchers/test_compute_overlap.py @@ -1,3 +1,4 @@ +import pytest from traccuracy.matchers._compute_overlap import ( get_labels_with_overlap, ) @@ -5,19 +6,8 @@ from tests.test_utils import get_annotated_image -def test_get_labels_with_overlap(): - """Get all labels IDs in gt_frame and res_frame whose bounding boxes - overlap. - - Args: - gt_frame (np.ndarray): ground truth segmentation for a single frame - res_frame (np.ndarray): result segmentation for a given frame - - Returns: - overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes - overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes - intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area) - """ +@pytest.mark.parametrize("overlap", ["iou", "iogt"]) +def test_get_labels_with_overlap(overlap): n_labels = 3 image1 = get_annotated_image( img_size=256, num_labels=n_labels, sequential=True, seed=1 @@ -29,13 +19,19 @@ def test_get_labels_with_overlap(): img_size=256, num_labels=0, sequential=True, seed=1 ) - perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap(image1, image1) + perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap( + image1, image1, overlap + ) assert list(perfect_gt) == list(range(1, n_labels + 1)) assert list(perfect_res) == list(range(1, n_labels + 1)) assert list(perfect_ious) == [1.0] * n_labels - get_labels_with_overlap(image1, image2) + + get_labels_with_overlap(image1, image2, overlap) + # Test empty labels array - empty_gt, empty_res, empty_ious = get_labels_with_overlap(image1, empty_image) + empty_gt, empty_res, empty_ious = get_labels_with_overlap( + image1, empty_image, overlap + ) assert empty_gt == [] assert empty_res == [] assert empty_ious == []