diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index f6dd21fc..45dd65de 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,6 +3,7 @@ import copy import enum import logging +from collections import defaultdict from typing import TYPE_CHECKING, Hashable, Iterable import networkx as nx @@ -169,7 +170,7 @@ def __init__( self.graph = graph # construct dictionaries from attributes to nodes/edges for easy lookup - self.nodes_by_frame: dict[int, set[Hashable]] = {} + self.nodes_by_frame: defaultdict[int, set[Hashable]] = defaultdict(set) self.nodes_by_flag: dict[NodeFlag, set[Hashable]] = { flag: set() for flag in NodeFlag } @@ -231,22 +232,6 @@ def edges(self) -> OutEdgeView: """ return self.graph.edges - def get_nodes_in_frame(self, frame: int) -> set[Hashable]: - """Get the node ids of all nodes in the given frame. - - Args: - frame (int): The frame to return all node ids for. - If the provided frame is outside of the range - (self.start_frame, self.end_frame), returns an empty iterable. - - Returns: - Iterable[Hashable]: An iterable of node ids for all nodes in frame. - """ - if frame in self.nodes_by_frame.keys(): - return self.nodes_by_frame[frame] - else: - return set() - def get_location(self, node_id: Hashable) -> list[float]: """Get the spatial location of the node with node_id using self.location_keys. diff --git a/src/traccuracy/matchers/_compute_overlap.py b/src/traccuracy/matchers/_compute_overlap.py index e0eb1784..abd04085 100644 --- a/src/traccuracy/matchers/_compute_overlap.py +++ b/src/traccuracy/matchers/_compute_overlap.py @@ -46,6 +46,8 @@ def get_labels_with_overlap(gt_frame, res_frame): res_box_labels = np.asarray( [int(res_prop.label) for res_prop in res_props], dtype=np.uint16 ) + if len(gt_props) == 0 or len(res_props) == 0: + return [], [], [] if gt_frame.ndim == 3: overlaps = compute_overlap_3D(gt_boxes, res_boxes) diff --git a/tests/matchers/test_compute_overlap.py b/tests/matchers/test_compute_overlap.py new file mode 100644 index 00000000..18ec557a --- /dev/null +++ b/tests/matchers/test_compute_overlap.py @@ -0,0 +1,41 @@ +from traccuracy.matchers._compute_overlap import ( + get_labels_with_overlap, +) + +from tests.test_utils import get_annotated_image + + +def test_get_labels_with_overlap(): + """Get all labels IDs in gt_frame and res_frame whose bounding boxes + overlap. + + Args: + gt_frame (np.ndarray): ground truth segmentation for a single frame + res_frame (np.ndarray): result segmentation for a given frame + + Returns: + overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes + overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes + intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area) + """ + n_labels = 3 + image1 = get_annotated_image( + img_size=256, num_labels=n_labels, sequential=True, seed=1 + ) + image2 = get_annotated_image( + img_size=256, num_labels=n_labels + 1, sequential=True, seed=2 + ) + empty_image = get_annotated_image( + img_size=256, num_labels=0, sequential=True, seed=1 + ) + + perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap(image1, image1) + assert list(perfect_gt) == list(range(1, n_labels + 1)) + assert list(perfect_res) == list(range(1, n_labels + 1)) + assert list(perfect_ious) == [1.0] * n_labels + get_labels_with_overlap(image1, image2) + # Test empty labels array + empty_gt, empty_res, empty_ious = get_labels_with_overlap(image1, empty_image) + assert empty_gt == [] + assert empty_res == [] + assert empty_ious == [] diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 045584a4..6ed67c3d 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -135,9 +135,10 @@ def test_constructor(nx_comp1): def test_get_cells_by_frame(simple_graph): - assert Counter(simple_graph.get_nodes_in_frame(0)) == Counter({"1_0"}) - assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"]) - assert Counter(simple_graph.get_nodes_in_frame(5)) == Counter([]) + assert Counter(simple_graph.nodes_by_frame[0]) == Counter({"1_0"}) + assert Counter(simple_graph.nodes_by_frame[2]) == Counter(["1_2", "1_3"]) + # Test non-existent frame + assert Counter(simple_graph.nodes_by_frame[5]) == Counter([]) def test_get_nodes_with_flag(simple_graph): diff --git a/tests/test_utils.py b/tests/test_utils.py index 1dbf6d7c..6f854ad1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,10 @@ def get_annotated_image(img_size=256, num_labels=3, sequential=True, seed=1): np.random.seed(seed) + if num_labels == 0: + im = np.zeros((img_size, img_size)) + return im.astype("int32") + num_labels_act = False trial = 0 while num_labels != num_labels_act: