diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 5dd9d280..ab280c2f 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -388,6 +388,14 @@ def get_divisions(self): """ return [node for node, degree in self.graph.out_degree() if degree >= 2] + def get_merges(self): + """Get all nodes that have at least two incoming edges from the previous time frame + + Returns: + list of hashable: a list of node ids for nodes that have more than one parent + """ + return [node for node, degree in self.graph.in_degree() if degree >= 2] + def get_preds(self, node): """Get all predecessors of the given node. diff --git a/src/traccuracy/loaders/_ctc.py b/src/traccuracy/loaders/_ctc.py index eac346d5..863601ce 100644 --- a/src/traccuracy/loaders/_ctc.py +++ b/src/traccuracy/loaders/_ctc.py @@ -113,7 +113,6 @@ def ctc_to_graph(df, detections): { "source": cellids[0:-1], "target": cellids[1:], - "is_intertrack_edge": [0 for _ in range(len(cellids) - 1)], } ) ) @@ -126,11 +125,7 @@ def ctc_to_graph(df, detections): target = "{}_{}".format(row["Cell_ID"], row["Start"]) - edges.append( - pd.DataFrame( - {"source": [source], "target": [target], "is_intertrack_edge": [1]} - ) - ) + edges.append(pd.DataFrame({"source": [source], "target": [target]})) # Store position attributes on nodes detections["node_id"] = ( @@ -149,9 +144,8 @@ def ctc_to_graph(df, detections): # Create graph edges = pd.concat(edges) - edges["is_intertrack_edge"] = edges["is_intertrack_edge"].astype(bool) G = nx.from_pandas_edgelist( - edges, source="source", target="target", create_using=nx.DiGraph, edge_attr=True + edges, source="source", target="target", create_using=nx.DiGraph ) # Add all isolates to graph diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index f1bf3a38..642399df 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -79,12 +79,16 @@ def get_edge_errors(matched_data: "Matched"): logger.info("Edge errors already calculated. Skipping graph annotation") return + # Node errors must already be annotated + if not comp_graph.node_errors and not gt_graph.node_errors: + logger.warning("Node errors have not been annotated. Running node annotation.") + get_vertex_errors(matched_data) + induced_graph = comp_graph.get_subgraph( comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS) ).graph comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False) - comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.TRUE_POS, False) comp_graph.set_edge_attribute( list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False ) @@ -93,6 +97,23 @@ def get_edge_errors(matched_data: "Matched"): gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} + # intertrack edges = connection between parent and daughter + for graph in [comp_graph, gt_graph]: + # Set to False by default + graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) + + for parent in graph.get_divisions(): + for daughter in graph.get_succs(parent): + graph.set_edge_attribute( + (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True + ) + + for merge in graph.get_merges(): + for parent in graph.get_preds(merge): + graph.set_edge_attribute( + (parent, merge), EdgeAttr.INTERTRACK_EDGE, True + ) + # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] @@ -109,8 +130,6 @@ def get_edge_errors(matched_data: "Matched"): is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE] if is_parent_gt != is_parent_comp: comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) - else: - comp_graph.set_edge_attribute(edge, EdgeAttr.TRUE_POS, True) # fn edges - edges in gt_graph that aren't in induced graph for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"): diff --git a/tests/bench.py b/tests/bench.py index b1df152b..41451bad 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -103,7 +103,7 @@ def run_compute(): assert ctc_results["fp_edges"] == 60 assert ctc_results["fp_nodes"] == 0 assert ctc_results["ns_nodes"] == 0 - assert ctc_results["ws_edges"] == 51 + assert ctc_results["ws_edges"] == 47 def test_ctc_div_metrics(benchmark, ctc_matched): diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index 557bf0d4..fb5bd601 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,5 +1,3 @@ -import networkx as nx -from traccuracy._tracking_graph import EdgeAttr from traccuracy.matchers._ctc import CTCMatched from traccuracy.metrics._ctc import CTCMetrics @@ -11,7 +9,6 @@ def test_compute_mapping(): n_frames = 3 n_labels = 3 track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) - nx.set_edge_attributes(track_graph.graph, 0, EdgeAttr.INTERTRACK_EDGE) matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph) metric = CTCMetrics(matched) diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index cb0d989c..71df1a36 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -65,6 +65,40 @@ def nx_comp2(): return graph +@pytest.fixture +def nx_merge(): + """ + 3_0--3_1--\\ + 3_2--3_3 + 3_4--3_5--/ + """ + cells = [ + {"id": "3_0", "t": 0, "x": 0, "y": 0}, + {"id": "3_1", "t": 1, "x": 0, "y": 0}, + {"id": "3_2", "t": 2, "x": 0, "y": 0}, + {"id": "3_3", "t": 3, "x": 0, "y": 0}, + {"id": "3_4", "t": 0, "x": 0, "y": 0}, + {"id": "3_5", "t": 1, "x": 0, "y": 0}, + ] + + edges = [ + {"source": "3_0", "target": "3_1"}, + {"source": "3_1", "target": "3_2"}, + {"source": "3_2", "target": "3_3"}, + {"source": "3_4", "target": "3_5"}, + {"source": "3_5", "target": "3_2"}, + ] + graph = nx.DiGraph() + graph.add_nodes_from([(cell["id"], cell) for cell in cells]) + graph.add_edges_from([(edge["source"], edge["target"]) for edge in edges]) + return graph + + +@pytest.fixture +def merge_graph(nx_merge): + return TrackingGraph(nx_merge) + + @pytest.fixture def simple_graph(nx_comp1): return TrackingGraph(nx_comp1) @@ -168,11 +202,20 @@ def test_get_divisions(complex_graph): assert complex_graph.get_divisions() == ["1_1", "2_2"] -def test_get_preds(simple_graph): +def test_get_merges(merge_graph): + assert merge_graph.get_merges() == ["3_2"] + + +def test_get_preds(simple_graph, merge_graph): + # Division graph assert simple_graph.get_preds("1_0") == [] assert simple_graph.get_preds("1_1") == ["1_0"] assert simple_graph.get_preds("1_2") == ["1_1"] + # Merge graph + assert merge_graph.get_preds("3_3") == ["3_2"] + assert merge_graph.get_preds("3_2") == ["3_1", "3_5"] + def test_get_succs(simple_graph): assert simple_graph.get_succs("1_0") == ["1_1"] diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 56060a81..ceba4989 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -79,7 +79,6 @@ def test_assign_edge_errors(): comp_g.add_nodes_from(comp_ids) comp_g.add_edges_from(comp_edges) nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) - nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE) nx.set_node_attributes( comp_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, @@ -90,7 +89,6 @@ def test_assign_edge_errors(): gt_g = nx.DiGraph() gt_g.add_nodes_from(gt_ids) gt_g.add_edges_from(gt_edges) - nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE) nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) nx.set_node_attributes( gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} @@ -102,51 +100,45 @@ def test_assign_edge_errors(): get_edge_errors(matched_data) - assert comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS] assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS] assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG] def test_assign_edge_errors_semantics(): - comp_ids = [3, 7, 10] - comp_ids_2 = list(np.asarray(comp_ids) + 1) - comp_ids += comp_ids_2 - - gt_ids = [4, 12, 17] - gt_ids_2 = list(np.asarray(gt_ids) + 1) - gt_ids += gt_ids_2 - - mapping = [(4, 3), (12, 7), (17, 10), (5, 4), (18, 11), (13, 8)] - - # need a tp, fp, fn - comp_edges = [(3, 4)] - comp_g = nx.DiGraph() - comp_g.add_nodes_from(comp_ids) - comp_g.add_edges_from(comp_edges) - nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) - nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE) - nx.set_node_attributes( - comp_g, - {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, - ) - G_comp = TrackingGraph(comp_g) - - gt_edges = [(4, 5), (17, 18)] - gt_g = nx.DiGraph() - gt_g.add_nodes_from(gt_ids) - gt_g.add_edges_from(gt_edges) - nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE) - nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) - gt_g.edges[(4, 5)][EdgeAttr.INTERTRACK_EDGE] = 1 - nx.set_node_attributes( - gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} - ) - G_gt = TrackingGraph(gt_g) - - matched_data = DummyMatched(G_gt, G_comp) + """ + gt: + 1_0 -- 1_1 -- 1_2 -- 1_3 + + comp: + 1_3 + 1_0 -- 1_1 -- 1_2 -< + 2_3 + """ + + gt = nx.DiGraph() + gt.add_edge("1_0", "1_1") + gt.add_edge("1_1", "1_2") + gt.add_edge("1_2", "1_3") + # Set node attrs + attrs = {} + for node in gt.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(gt, attrs) + + comp = gt.copy() + comp.add_edge("1_2", "2_3") + # Set node attrs + attrs = {} + for node in comp.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(comp, attrs) + + # Define mapping with all nodes matching except for 2_3 in comp + mapping = [(n, n) for n in gt.nodes] + + matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp)) matched_data.mapping = mapping get_edge_errors(matched_data) - assert comp_g.edges[(3, 4)][EdgeAttr.WRONG_SEMANTIC] - assert not comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS] + assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC]