diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index db60a204..93c8198d 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -205,8 +205,12 @@ def __init__( self.edges_by_flag[edge_flag].add(edge) # Store first and last frames for reference - self.start_frame = min(self.nodes_by_frame.keys()) - self.end_frame = max(self.nodes_by_frame.keys()) + 1 + if len(self.nodes_by_frame) == 0: + self.start_frame = None + self.end_frame = None + else: + self.start_frame = min(self.nodes_by_frame.keys()) + self.end_frame = max(self.nodes_by_frame.keys()) + 1 # Record types of annotations that have been calculated self.division_annotations = False @@ -301,13 +305,13 @@ def get_connected_components(self) -> list[TrackingGraph]: return [self.get_subgraph(g) for g in nx.weakly_connected_components(graph)] def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph: - """Returns a new TrackingGraph with the subgraph defined by the list of nodes + """Returns a new TrackingGraph with the subgraph defined by the list of nodes. Args: - nodes (list): List of node ids to use in constructing the subgraph + nodes (list): A list of node ids to use in constructing the subgraph """ - new_graph = self.graph.subgraph(nodes).copy() + new_trackgraph = copy.deepcopy(self) new_trackgraph.graph = new_graph for frame, nodes_in_frame in self.nodes_by_frame.items(): @@ -324,10 +328,14 @@ def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph: for edge_flag in EdgeFlag: new_trackgraph.edges_by_flag[edge_flag] = self.edges_by_flag[ edge_flag - ].intersection(nodes) + ].intersection(new_trackgraph.edges) - new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys()) - new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1 + if len(new_trackgraph.nodes_by_frame) == 0: + new_trackgraph.start_frame = None + new_trackgraph.end_frame = None + else: + new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys()) + new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1 return new_trackgraph diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 6712e70b..084f50a1 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -51,8 +51,11 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): if mask_gt.shape != mask_pred.shape: raise ValueError("Segmentation shapes must match between gt and pred") - mapping = [] + mapping: list[tuple] = [] # Get overlaps for each frame + if gt.start_frame is None or gt.end_frame is None: + return Matched(gt_graph, pred_graph, mapping) + for i, t in enumerate( tqdm( range(gt.start_frame, gt.end_frame), diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 01a7c7bb..49e8476d 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -87,9 +87,8 @@ def get_edge_errors(matched_data: Matched): logger.warning("Node errors have not been annotated. Running node annotation.") get_vertex_errors(matched_data) - induced_graph = comp_graph.get_subgraph( - comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS) - ).graph + comp_tp_nodes = comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS) + induced_graph = comp_graph.get_subgraph(comp_tp_nodes).graph comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False) comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False) @@ -141,12 +140,15 @@ def get_edge_errors(matched_data: Matched): gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) continue - source_comp_id = gt_comp_mapping[source] - target_comp_id = gt_comp_mapping[target] + source_comp_id = gt_comp_mapping.get(source, None) + target_comp_id = gt_comp_mapping.get(target, None) - expected_comp_edge = (source_comp_id, target_comp_id) - if expected_comp_edge not in induced_graph.edges: + if source_comp_id is None or target_comp_id is None: gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) + else: + expected_comp_edge = (source_comp_id, target_comp_id) + if expected_comp_edge not in induced_graph.edges: + gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) gt_graph.edge_errors = True comp_graph.edge_errors = True diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 6ed67c3d..02739070 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -182,6 +182,25 @@ def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): assert track2.graph.edges == nx_comp2.edges +def test_get_subgraph(simple_graph): + target_nodes = ("1_0", "1_1") + subgraph = simple_graph.get_subgraph(target_nodes) + assert len(subgraph.nodes) == 2 + assert len(subgraph.edges) == 1 + # test that nodes_by_flag dicts are maintained + assert Counter(subgraph.nodes_by_flag[NodeFlag.TP_DIV]) == Counter(["1_1"]) + assert Counter(subgraph.edges_by_flag[EdgeFlag.TRUE_POS]) == Counter( + [("1_0", "1_1")] + ) + # test that start and end frame are updated + assert subgraph.start_frame == 0 + assert subgraph.end_frame == 2 + + # test empty target nodes + empty_graph = simple_graph.get_subgraph([]) + assert Counter(empty_graph.nodes) == Counter([]) + + def test_set_flag_on_node(simple_graph): assert simple_graph.nodes()["1_0"] == {"id": "1_0", "t": 0, "y": 1, "x": 1} assert simple_graph.nodes()["1_1"] == { diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index a66517d0..323fc4a3 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -134,3 +134,70 @@ def test_assign_edge_errors_semantics(): get_edge_errors(matched_data) assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeFlag.WRONG_SEMANTIC] + + +def test_ns_vertex_fn_edge(): + """Minimal Example of testing for FN edges with a NS Vertex + gt 1 - 2 - 3 + 4 - 5 - 6 + + comp 1 - 2 + + matching [ (1, 1), (4, 1), (2, 2), (5, 2) ] + """ + + gt_nodes = [ + (1, {"t": 0, "x": 1, "y": 1}), + (2, {"t": 1, "x": 1, "y": 1}), + (3, {"t": 2, "x": 1, "y": 1}), + (4, {"t": 0, "x": 0, "y": 1}), + (5, {"t": 1, "x": 0, "y": 1}), + (6, {"t": 2, "x": 0, "y": 1}), + ] + gt_edges = [ + (1, 2), + (2, 3), + (4, 5), + (5, 6), + ] + gt = nx.DiGraph() + gt.add_nodes_from(gt_nodes) + gt.add_edges_from(gt_edges) + + comp_nodes = [ + (1, {"t": 0, "x": 0.5, "y": 1}), + (2, {"t": 1, "x": 0.5, "y": 1}), + ] + comp_edges = [ + (1, 2), + ] + comp = nx.DiGraph() + comp.add_nodes_from(comp_nodes) + comp.add_edges_from(comp_edges) + + mapping = [ + (1, 1), + (5, 1), + (2, 2), + (5, 2), + ] + + matched_data = Matched(TrackingGraph(gt), TrackingGraph(comp), mapping) + get_vertex_errors(matched_data) + get_edge_errors(matched_data) + + for node in comp.nodes: + assert comp.nodes[node][NodeFlag.NON_SPLIT] + for edge in comp_edges: + assert not comp.edges[edge][EdgeFlag.FALSE_POS] + + # https://github.com/Janelia-Trackathon-2023/traccuracy/pull/141#issuecomment-2265990197 + if False: # TODO: Fix this in a separate PR + for node in [1, 2, 4, 5]: + assert gt.nodes[node][NodeFlag.FALSE_NEG] + + for node in [3, 6]: + assert gt.nodes[node][NodeFlag.FALSE_NEG] + + for edge in gt_edges: + assert gt.edges[edge][EdgeFlag.FALSE_NEG]