Skip to content

Commit

Permalink
Merge branch 'main' into check_empty_subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor authored Aug 5, 2024
2 parents f493cff + 280308d commit a140746
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def _match_nodes(gt, res, threshold=0.5, one_to_one=False):
gtcells (np arr): Array of overlapping ids in the gt frame.
rescells (np arr): Array of overlapping ids in the res frame.
"""
iou = np.zeros((np.max(gt) + 1, np.max(res) + 1))
if threshold == 0.0 and not one_to_one:
raise ValueError("Threshold of 0 is not valid unless one_to_one is True")
# casting to int to avoid issue #152 (result is float with numpy<2, dtype=uint64)
iou = np.zeros((int(np.max(gt) + 1), int(np.max(res) + 1)))

overlapping_gt_labels, overlapping_res_labels, _ = get_labels_with_overlap(gt, res)

Expand All @@ -39,12 +42,14 @@ def _match_nodes(gt, res, threshold=0.5, one_to_one=False):
iou_res_idx = overlapping_res_labels[index]
intersection = np.logical_and(gt == iou_gt_idx, res == iou_res_idx)
union = np.logical_or(gt == iou_gt_idx, res == iou_res_idx)
iou[iou_gt_idx, iou_res_idx] = intersection.sum() / union.sum()
iou_value = intersection.sum() / union.sum()
if iou_value >= threshold:
iou[iou_gt_idx, iou_res_idx] = iou_value

if one_to_one:
pairs = _one_to_one_assignment(iou)
else:
pairs = np.where(iou >= threshold)
pairs = np.where(iou)

# Catch the case where there are no overlaps
if len(pairs) < 2:
Expand Down Expand Up @@ -76,6 +81,8 @@ def _one_to_one_assignment(iou, unmapped_cost=4):

# Assign 1 - iou to top left and bottom right
cost = 1 - iou[1:, 1:]
# increase the cost for those with no IOU to higher than the unmapped cost
cost[cost == 1] = unmapped_cost + 1
matrix[:n0, :n1] = cost
matrix[n_obj - n1 :, n_obj - n0 :] = cost.T

Expand Down
58 changes: 58 additions & 0 deletions tests/matchers/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def get_two_to_one(w, h, imw, imh):
return merge.astype("int"), split.astype("int")


def get_no_overlap(imw, imh):
"""two non-overlapping segmentations that start at high number labels
Mapping []
"""

im1 = np.zeros((imw, imh))
im1[0:2, 0:2] = 100

im2 = np.zeros((imw, imh))
im2[4:6, 4:6] = 222

return im1.astype("int"), im2.astype("int")


def test__match_nodes():
# creat dummy image to test against
num_labels = 5
Expand Down Expand Up @@ -62,6 +78,48 @@ def test__match_nodes():
# Check that only one of the merge matches is present
assert ((2, 4) in matches) != ((2, 5) in matches)

with pytest.raises(ValueError):
# Test that threshold 0 is not valid when not one-to-one
gtcells, rescells = _match_nodes(im1, im2, threshold=0.0)


def test__match_nodes_threshold():
im1, im2 = get_two_to_one(10, 10, 30, 30)
# Test high threshold
gtcells, rescells = _match_nodes(im1, im2, threshold=1)
# Create match tuples
matches = list(zip(gtcells, rescells))
# Check that nothing is matched
assert len(matches) == 1

# Test for high threshold and one to one
gtcells, rescells = _match_nodes(im1, im2, threshold=0.7, one_to_one=True)
# Create match tuples
matches = list(zip(gtcells, rescells))
# Check that nothing is matched
assert len(matches) == 1


def test__match_nodes_non_sequential():
# test when the segmentation ids are high numbers (the lower numbers should never appear)

im1, im2 = get_no_overlap(30, 30)

# Test that phantom segmentations are not matched
gtcells, rescells = _match_nodes(im1, im2, threshold=0.1)
# Create match tuples
matches = list(zip(gtcells, rescells))
# Check that nothing is matched
assert len(matches) == 0

# Test that with one-to-one, phantom segmentations are not matched,
# even with threshold 0
gtcells, rescells = _match_nodes(im1, im2, threshold=0.0, one_to_one=True)
# Create match tuples
matches = list(zip(gtcells, rescells))
# Check that nothing is matched
assert len(matches) == 0


def test__construct_time_to_seg_id_map():
# Test 2d data
Expand Down

1 comment on commit a140746

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 280308d Mean (s) HEAD a140746 Percent Change
test_load_gt_data 1.26417 1.21836 -3.62
test_load_pred_data 1.15023 1.1144 -3.12
test_ctc_checks 0.43034 0.43516 1.12
test_ctc_matched 1.74139 1.77198 1.76
test_ctc_metrics 0.54313 0.47859 -11.88
test_ctc_div_metrics 0.28762 0.29413 2.26
test_iou_matched 8.99149 8.52445 -5.19
test_iou_div_metrics 0.27955 0.27606 -1.25

Please sign in to comment.