Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/codecov/codecov-ac…
Browse files Browse the repository at this point in the history
…tion-4
  • Loading branch information
cmalinmayor authored Feb 5, 2024
2 parents 19a558f + 894a5b6 commit 7706f09
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 20 deletions.
19 changes: 2 additions & 17 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import enum
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Hashable, Iterable

import networkx as nx
Expand Down Expand Up @@ -169,7 +170,7 @@ def __init__(
self.graph = graph

# construct dictionaries from attributes to nodes/edges for easy lookup
self.nodes_by_frame: dict[int, set[Hashable]] = {}
self.nodes_by_frame: defaultdict[int, set[Hashable]] = defaultdict(set)
self.nodes_by_flag: dict[NodeFlag, set[Hashable]] = {
flag: set() for flag in NodeFlag
}
Expand Down Expand Up @@ -231,22 +232,6 @@ def edges(self) -> OutEdgeView:
"""
return self.graph.edges

def get_nodes_in_frame(self, frame: int) -> set[Hashable]:
"""Get the node ids of all nodes in the given frame.
Args:
frame (int): The frame to return all node ids for.
If the provided frame is outside of the range
(self.start_frame, self.end_frame), returns an empty iterable.
Returns:
Iterable[Hashable]: An iterable of node ids for all nodes in frame.
"""
if frame in self.nodes_by_frame.keys():
return self.nodes_by_frame[frame]
else:
return set()

def get_location(self, node_id: Hashable) -> list[float]:
"""Get the spatial location of the node with node_id using self.location_keys.
Expand Down
2 changes: 2 additions & 0 deletions src/traccuracy/matchers/_compute_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_labels_with_overlap(gt_frame, res_frame):
res_box_labels = np.asarray(
[int(res_prop.label) for res_prop in res_props], dtype=np.uint16
)
if len(gt_props) == 0 or len(res_props) == 0:
return [], [], []

if gt_frame.ndim == 3:
overlaps = compute_overlap_3D(gt_boxes, res_boxes)
Expand Down
41 changes: 41 additions & 0 deletions tests/matchers/test_compute_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from traccuracy.matchers._compute_overlap import (
get_labels_with_overlap,
)

from tests.test_utils import get_annotated_image


def test_get_labels_with_overlap():
"""Get all labels IDs in gt_frame and res_frame whose bounding boxes
overlap.
Args:
gt_frame (np.ndarray): ground truth segmentation for a single frame
res_frame (np.ndarray): result segmentation for a given frame
Returns:
overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes
overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes
intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area)
"""
n_labels = 3
image1 = get_annotated_image(
img_size=256, num_labels=n_labels, sequential=True, seed=1
)
image2 = get_annotated_image(
img_size=256, num_labels=n_labels + 1, sequential=True, seed=2
)
empty_image = get_annotated_image(
img_size=256, num_labels=0, sequential=True, seed=1
)

perfect_gt, perfect_res, perfect_ious = get_labels_with_overlap(image1, image1)
assert list(perfect_gt) == list(range(1, n_labels + 1))
assert list(perfect_res) == list(range(1, n_labels + 1))
assert list(perfect_ious) == [1.0] * n_labels
get_labels_with_overlap(image1, image2)
# Test empty labels array
empty_gt, empty_res, empty_ious = get_labels_with_overlap(image1, empty_image)
assert empty_gt == []
assert empty_res == []
assert empty_ious == []
7 changes: 4 additions & 3 deletions tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def test_constructor(nx_comp1):


def test_get_cells_by_frame(simple_graph):
assert Counter(simple_graph.get_nodes_in_frame(0)) == Counter({"1_0"})
assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"])
assert Counter(simple_graph.get_nodes_in_frame(5)) == Counter([])
assert Counter(simple_graph.nodes_by_frame[0]) == Counter({"1_0"})
assert Counter(simple_graph.nodes_by_frame[2]) == Counter(["1_2", "1_3"])
# Test non-existent frame
assert Counter(simple_graph.nodes_by_frame[5]) == Counter([])


def test_get_nodes_with_flag(simple_graph):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

def get_annotated_image(img_size=256, num_labels=3, sequential=True, seed=1):
np.random.seed(seed)
if num_labels == 0:
im = np.zeros((img_size, img_size))
return im.astype("int32")

num_labels_act = False
trial = 0
while num_labels != num_labels_act:
Expand Down

1 comment on commit 7706f09

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 894a5b6 Mean (s) HEAD 7706f09 Percent Change
test_load_gt_data 1.24436 1.22954 -1.19
test_load_pred_data 1.17267 1.14096 -2.7
test_ctc_checks 0.42199 0.41242 -2.27
test_ctc_matched 2.27439 2.26281 -0.51
test_ctc_metrics 0.46854 0.49811 6.31
test_ctc_div_metrics 0.27987 0.28591 2.16
test_iou_matched 8.58173 8.54167 -0.47
test_iou_div_metrics 0.26712 0.27547 3.13

Please sign in to comment.