Skip to content

Commit

Permalink
Merge branch 'iou-tests' into div-test-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 authored Dec 17, 2024
2 parents 2cf71c8 + dd1d438 commit 11a3e70
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 217 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:

- name: Coverage
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}


benchmark:
name: Benchmark
Expand Down
106 changes: 69 additions & 37 deletions examples/test-cases.ipynb

Large diffs are not rendered by default.

56 changes: 18 additions & 38 deletions src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,46 +79,27 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
if pred_label_key in G_pred.graph.nodes[node]
}

(
overlapping_gt_labels,
overlapping_pred_labels,
intersection,
) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt")

for i in range(len(overlapping_gt_labels)):
gt_label = overlapping_gt_labels[i]
pred_label = overlapping_pred_labels[i]
# CTC metrics only match comp IDs to a single GT ID if there is majority overlap
if intersection[i] > 0.5:
mapping.append(
(gt_label_to_id[gt_label], pred_label_to_id[pred_label])
)
frame_map = match_frame_majority(gt_frame, pred_frame)
# Switch from segmentation ids to node ids
for gt_label, pred_label in frame_map:
mapping.append((gt_label_to_id[gt_label], pred_label_to_id[pred_label]))

return mapping


def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int:
"""Check if computed marker overlaps majority of the reference marker.
def match_frame_majority(gt_frame, pred_frame):
mapping = []
(
overlapping_gt_labels,
overlapping_pred_labels,
intersection,
) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt")

Given a reference marker and computer marker in original coordinates,
return True if the computed marker overlaps strictly more than half
of the reference marker's pixels, otherwise False.
for gt_label, pred_label, iogt in zip(
overlapping_gt_labels, overlapping_pred_labels, intersection
):
# CTC metrics only match comp IDs to a single GT ID if there is majority overlap
if iogt > 0.5:
mapping.append((gt_label, pred_label))

Parameters
----------
gt_blob : np.ndarray
2D or 3D boolean mask representing the pixels of the ground truth
marker
comp_blob : np.ndarray
2D or 3D boolean mask representing the pixels of the computed
marker
Returns
-------
bool
True if computed marker majority overlaps reference marker, else False.
"""
n_gt_pixels = np.sum(gt_blob)
intersection = np.logical_and(gt_blob, comp_blob)
comp_blob_matches_gt_blob = int(np.sum(intersection) > 0.5 * n_gt_pixels)
return comp_blob_matches_gt_blob
return mapping
64 changes: 63 additions & 1 deletion tests/examples/segs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def make_split_cell_3d(
mask = sphere(center, radius, shape=arr_shape)
im[mask] = labels[0]
# get indices where y value greater than center
mask[:, 0 : center[1]] = 0
mask[:, 0 : center[1] + 1] = 0
im[mask] = labels[1]
return im

Expand Down Expand Up @@ -195,6 +195,37 @@ def undersegmentation_2d() -> tuple[np.ndarray, np.ndarray]:
return gt, pred


def no_overlap_2d() -> tuple[np.ndarray, np.ndarray]:
"""Two cells with no overlap in 2d."""
gt = make_one_cell_2d(label=1, center=(5, 5), radius=7)
pred = make_one_cell_2d(label=2, center=(17, 17), radius=7)
return gt, pred


def multicell_2d() -> tuple[np.ndarray, np.ndarray]:
"""Two cells in each image, one that overlaps and one that doesn't"""
arr_shape = (32, 32)
radius = 5

gt = np.zeros(arr_shape, dtype="int32")
pred = np.zeros(arr_shape, dtype="int32")

# Overlap cell
rr, cc = disk((5, 5), radius, shape=arr_shape)
gt[rr, cc] = 1
pred[rr, cc] = 3

# Unique gt
rr, cc = disk((17, 17), radius, shape=arr_shape)
gt[rr, cc] = 2

# Unique pred
rr, cc = disk((25, 7), radius, shape=arr_shape)
pred[rr, cc] = 4

return gt, pred


### CANONICAL 3D SEGMENTATION EXAMPLES ###
def good_segmentation_3d() -> tuple[np.ndarray, np.ndarray]:
"""A pretty good (but not perfect) pair of segmentations in 3d.
Expand Down Expand Up @@ -263,6 +294,37 @@ def undersegmentation_3d() -> tuple[np.ndarray, np.ndarray]:
return gt, pred


def no_overlap_3d() -> tuple[np.ndarray, np.ndarray]:
"""3D segmentations with no overlap"""
gt = make_one_cell_3d(label=1, center=(5, 5, 5), radius=5)
pred = make_one_cell_3d(label=2, center=(17, 17, 17), radius=6)
return gt, pred


def multicell_3d() -> tuple[np.ndarray, np.ndarray]:
"""Two cells in each image, one that overlaps and one that doesn't"""
arr_shape = (32, 32, 32)
radius = 5

gt = np.zeros(arr_shape, dtype="int32")
pred = np.zeros(arr_shape, dtype="int32")

# Overlap cell
mask = sphere((5, 5, 5), radius, shape=arr_shape)
gt[mask] = 1
pred[mask] = 3

# Unique gt
mask = sphere((17, 17, 17), radius, shape=arr_shape)
gt[mask] = 2

# Unique pred
mask = sphere((25, 7, 7), radius, shape=arr_shape)
pred[mask] = 4

return gt, pred


def nodes_from_segmentation(
seg: np.ndarray,
frame: int,
Expand Down
85 changes: 83 additions & 2 deletions tests/matchers/test_ctc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections import Counter

import networkx as nx
import numpy as np
import pytest

import tests.examples.segs as ex_segs
from tests.test_utils import get_annotated_movie
from traccuracy._tracking_graph import TrackingGraph
from traccuracy.matchers._ctc import CTCMatcher
from traccuracy.matchers._ctc import CTCMatcher, match_frame_majority


def test_match_ctc():
def test_CTCMatcher():
matcher = CTCMatcher()

# shapes don't match
Expand Down Expand Up @@ -46,3 +49,81 @@ def test_match_ctc():
# gt and pred node should be the same
for pair in matched.mapping:
assert pair[0] == pair[1]


class Test_match_frame_majority:
@pytest.mark.parametrize(
"data",
[ex_segs.good_segmentation_2d(), ex_segs.good_segmentation_3d()],
ids=["2D", "3D"],
)
def test_good_seg(self, data):
ex_match = [(1, 2)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[
ex_segs.false_positive_segmentation_2d(),
ex_segs.false_positive_segmentation_3d(),
],
ids=["2D", "3D"],
)
def test_false_pos_seg(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[
ex_segs.false_negative_segmentation_2d(),
ex_segs.false_negative_segmentation_3d(),
],
ids=["2D", "3D"],
)
def test_false_neg_seg(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.oversegmentation_2d(), ex_segs.oversegmentation_3d()],
ids=["2D", "3D"],
)
def test_split(self, data):
ex_match = [(1, 2)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.undersegmentation_2d(), ex_segs.undersegmentation_3d()],
ids=["2D", "3D"],
)
def test_merge(self, data):
ex_match = [(1, 3), (2, 3)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.no_overlap_2d(), ex_segs.no_overlap_3d()],
ids=["2D", "3D"],
)
def test_no_overlap(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.multicell_2d(), ex_segs.multicell_3d()],
ids=["2D", "3D"],
)
def test_multicell(self, data):
ex_match = [(1, 3)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)
Loading

1 comment on commit 11a3e70

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE dd1d438 Mean (s) HEAD 11a3e70 Percent Change
test_load_gt_ctc_data[2d] 5.43773 5.45835 0.38
test_load_gt_ctc_data[3d] 19.465 15.4373 -20.69
test_load_pred_ctc_data[2d] 1.08051 1.09899 1.71
test_ctc_checks[2d] 0.73922 0.73112 -1.1
test_ctc_checks[3d] 9.32699 9.32012 -0.07
test_ctc_matcher[2d] 1.4971 1.47187 -1.69
test_ctc_matcher[3d] 16.6937 16.6288 -0.39
test_ctc_metrics[2d] 0.26305 0.2621 -0.36
test_ctc_metrics[3d] 4.03157 2.20822 -45.23
test_iou_matcher[2d] 1.57971 1.54575 -2.15
test_iou_matcher[3d] 17.6789 17.5727 -0.6
test_iou_div_metrics[2d] 0.07305 0.07168 -1.88
test_iou_div_metrics[3d] 0.69094 0.72019 4.23

Please sign in to comment.