Skip to content

Commit

Permalink
Merge branch 'main' into release-notes
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 authored Sep 10, 2024
2 parents d0731d6 + 4160403 commit 8c0781f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 33 deletions.
25 changes: 19 additions & 6 deletions src/traccuracy/matchers/_compute_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,24 @@ def _union_slice(a: Tuple[slice], b: Tuple[slice]):
return tuple(slice(start, stop) for start, stop in zip(starts, stops))


def get_labels_with_overlap(gt_frame, res_frame):
def get_labels_with_overlap(
gt_frame,
res_frame,
overlap="iou",
):
"""Get all labels IDs in gt_frame and res_frame whose bounding boxes
overlap.
Args:
gt_frame (np.ndarray): ground truth segmentation for a single frame
res_frame (np.ndarray): result segmentation for a given frame
overlap (str, optional): Choose between intersection-over-ground-truth (``iogt``)
or intersection-over-union (``iou``).
Returns:
overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes
overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes
intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area)
overlaps: List[float], list of IoGT/IoU values for each overlapping pair
"""
gt_frame = gt_frame.astype(np.uint16, copy=False)
res_frame = res_frame.astype(np.uint16, copy=False)
Expand Down Expand Up @@ -64,16 +70,23 @@ def get_labels_with_overlap(gt_frame, res_frame):
overlapping_gt_labels = gt_box_labels[ind_gt]
overlapping_res_labels = res_box_labels[ind_res]

intersections_over_gt = []
overlaps = []
for i, j in zip(ind_gt, ind_res):
sslice = _union_slice(gt_props[i].slice, res_props[j].slice)
gt_mask = gt_frame[sslice] == gt_box_labels[i]
res_mask = res_frame[sslice] == res_box_labels[j]
area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask))
area_gt = np.count_nonzero(gt_mask)
intersections_over_gt.append(area_inter / area_gt)

return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt
if overlap == "iou":
area_union = np.count_nonzero(np.logical_or(gt_mask, res_mask))
overlaps.append(area_inter / area_union)
elif overlap == "iogt":
area_gt = np.count_nonzero(gt_mask)
overlaps.append(area_inter / area_gt)
else:
raise ValueError(f"Unknown overlap type: {overlap}")

return overlapping_gt_labels, overlapping_res_labels, overlaps


def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
overlapping_gt_labels,
overlapping_pred_labels,
intersection,
) = get_labels_with_overlap(gt_frame, pred_frame)
) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt")

for i in range(len(overlapping_gt_labels)):
gt_label = overlapping_gt_labels[i]
Expand Down
19 changes: 9 additions & 10 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ def _match_nodes(gt, res, threshold=0.5, one_to_one=False):
# casting to int to avoid issue #152 (result is float with numpy<2, dtype=uint64)
iou = np.zeros((int(np.max(gt) + 1), int(np.max(res) + 1)))

overlapping_gt_labels, overlapping_res_labels, _ = get_labels_with_overlap(gt, res)

for index in range(len(overlapping_gt_labels)):
iou_gt_idx = overlapping_gt_labels[index]
iou_res_idx = overlapping_res_labels[index]
intersection = np.logical_and(gt == iou_gt_idx, res == iou_res_idx)
union = np.logical_or(gt == iou_gt_idx, res == iou_res_idx)
iou_value = intersection.sum() / union.sum()
if iou_value >= threshold:
iou[iou_gt_idx, iou_res_idx] = iou_value
overlapping_gt_labels, overlapping_res_labels, ious = get_labels_with_overlap(
gt, res, overlap="iou"
)

for gt_label, res_label, iou_val in zip(
overlapping_gt_labels, overlapping_res_labels, ious
):
if iou_val >= threshold:
iou[gt_label, res_label] = iou_val

if one_to_one:
pairs = _one_to_one_assignment(iou)
Expand Down
28 changes: 12 additions & 16 deletions tests/matchers/test_compute_overlap.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
import pytest
from traccuracy.matchers._compute_overlap import (
get_labels_with_overlap,
)

from tests.test_utils import get_annotated_image


def test_get_labels_with_overlap():
"""Get all labels IDs in gt_frame and res_frame whose bounding boxes
overlap.
Args:
gt_frame (np.ndarray): ground truth segmentation for a single frame
res_frame (np.ndarray): result segmentation for a given frame
Returns:
overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes
overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes
intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area)
"""
@pytest.mark.parametrize("overlap", ["iou", "iogt"])
def test_get_labels_with_overlap(overlap):
n_labels = 3
image1 = get_annotated_image(
img_size=256, num_labels=n_labels, sequential=True, seed=1
Expand All @@ -29,13 +19,19 @@ def test_get_labels_with_overlap():
img_size=256, num_labels=0, sequential=True, seed=1
)

perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap(image1, image1)
perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap(
image1, image1, overlap
)
assert list(perfect_gt) == list(range(1, n_labels + 1))
assert list(perfect_res) == list(range(1, n_labels + 1))
assert list(perfect_ious) == [1.0] * n_labels
get_labels_with_overlap(image1, image2)

get_labels_with_overlap(image1, image2, overlap)

# Test empty labels array
empty_gt, empty_res, empty_ious = get_labels_with_overlap(image1, empty_image)
empty_gt, empty_res, empty_ious = get_labels_with_overlap(
image1, empty_image, overlap
)
assert empty_gt == []
assert empty_res == []
assert empty_ious == []

1 comment on commit 8c0781f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 4160403 Mean (s) HEAD 8c0781f Percent Change
test_load_gt_data 1.24147 1.22478 -1.34
test_load_pred_data 1.15828 1.14635 -1.03
test_ctc_checks 0.39794 0.39075 -1.81
test_ctc_matched 1.70852 1.71165 0.18
test_ctc_metrics 0.52704 0.51124 -3
test_ctc_div_metrics 0.27161 0.26364 -2.93
test_iou_matched 1.77892 1.76681 -0.68
test_iou_div_metrics 0.26688 0.26844 0.59

Please sign in to comment.