From 54a99f2a4c3fc8aa55724e10e03cd13a314f6e82 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 20 Sep 2023 11:07:54 -0700 Subject: [PATCH 1/7] Annotate intertrack edges in ctc edge errors function --- src/traccuracy/track_errors/_ctc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 322711d2..ec736557 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -131,5 +131,16 @@ def get_edge_errors(matched_data: "Matched"): if expected_comp_edge not in induced_graph.edges: gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) + # intertrack edges = connection between parent and daughter + for graph in [comp_graph, gt_graph]: + # Set to False by default + graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) + + for parent in graph.get_divisions(): + for daughter in graph.get_succs(parent): + graph.set_edge_attribute( + (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True + ) + gt_graph.edge_errors = True comp_graph.edge_errors = True From 2b2d005ed28bd67be6fd84ea3192b5910416944f Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 20 Sep 2023 12:25:37 -0700 Subject: [PATCH 2/7] CTC loader does not need to annotated intertrack edges --- src/traccuracy/loaders/_ctc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/loaders/_ctc.py b/src/traccuracy/loaders/_ctc.py index eac346d5..965b3d29 100644 --- a/src/traccuracy/loaders/_ctc.py +++ b/src/traccuracy/loaders/_ctc.py @@ -113,7 +113,6 @@ def ctc_to_graph(df, detections): { "source": cellids[0:-1], "target": cellids[1:], - "is_intertrack_edge": [0 for _ in range(len(cellids) - 1)], } ) ) @@ -128,7 +127,7 @@ def ctc_to_graph(df, detections): edges.append( pd.DataFrame( - {"source": [source], "target": [target], "is_intertrack_edge": [1]} + {"source": [source], "target": [target]} ) ) @@ -149,9 +148,8 @@ def ctc_to_graph(df, detections): # Create graph edges = pd.concat(edges) - edges["is_intertrack_edge"] = edges["is_intertrack_edge"].astype(bool) G = nx.from_pandas_edgelist( - edges, source="source", target="target", create_using=nx.DiGraph, edge_attr=True + edges, source="source", target="target", create_using=nx.DiGraph ) # Add all isolates to graph From ba74e9977c22a5dccb30dcf2761a034a5765257a Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 20 Sep 2023 12:27:49 -0700 Subject: [PATCH 3/7] Check for node error annotations before calculating edge errors --- src/traccuracy/track_errors/_ctc.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index ec736557..c02e4686 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -80,6 +80,11 @@ def get_edge_errors(matched_data: "Matched"): logger.info("Edge errors already calculated. Skipping graph annotation") return + # Node errors must already be annotated + if not comp_graph.node_errors and not gt_graph.node_errors: + logger.warning("Node errors have not been annotated. Running node annotation.") + get_vertex_errors(matched_data) + induced_graph = comp_graph.get_subgraph( comp_graph.get_nodes_with_attribute(NodeAttr.TRUE_POS, criterion=lambda x: x) ).graph @@ -94,6 +99,17 @@ def get_edge_errors(matched_data: "Matched"): node_mapping_first = np.array([mp[0] for mp in node_mapping]) node_mapping_second = np.array([mp[1] for mp in node_mapping]) + # intertrack edges = connection between parent and daughter + for graph in [comp_graph, gt_graph]: + # Set to False by default + graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) + + for parent in graph.get_divisions(): + for daughter in graph.get_succs(parent): + graph.set_edge_attribute( + (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True + ) + # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] @@ -131,16 +147,5 @@ def get_edge_errors(matched_data: "Matched"): if expected_comp_edge not in induced_graph.edges: gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) - # intertrack edges = connection between parent and daughter - for graph in [comp_graph, gt_graph]: - # Set to False by default - graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) - - for parent in graph.get_divisions(): - for daughter in graph.get_succs(parent): - graph.set_edge_attribute( - (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True - ) - gt_graph.edge_errors = True comp_graph.edge_errors = True From d531bcef35fe0acea35b76ab17168757b3c55ce6 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 20 Sep 2023 12:28:04 -0700 Subject: [PATCH 4/7] Update tests --- tests/metrics/test_ctc_metrics.py | 3 -- tests/track_errors/test_ctc_errors.py | 73 ++++++++++++--------------- 2 files changed, 33 insertions(+), 43 deletions(-) diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index 557bf0d4..fb5bd601 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,5 +1,3 @@ -import networkx as nx -from traccuracy._tracking_graph import EdgeAttr from traccuracy.matchers._ctc import CTCMatched from traccuracy.metrics._ctc import CTCMetrics @@ -11,7 +9,6 @@ def test_compute_mapping(): n_frames = 3 n_labels = 3 track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) - nx.set_edge_attributes(track_graph.graph, 0, EdgeAttr.INTERTRACK_EDGE) matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph) metric = CTCMetrics(matched) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index f80cc87d..2bcc75dc 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -82,7 +82,6 @@ def test_assign_edge_errors(): comp_g.add_nodes_from(comp_ids) comp_g.add_edges_from(comp_edges) nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) - nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE) nx.set_node_attributes( comp_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, @@ -93,7 +92,6 @@ def test_assign_edge_errors(): gt_g = nx.DiGraph() gt_g.add_nodes_from(gt_ids) gt_g.add_edges_from(gt_edges) - nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE) nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) nx.set_node_attributes( gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} @@ -111,45 +109,40 @@ def test_assign_edge_errors(): def test_assign_edge_errors_semantics(): - comp_ids = [3, 7, 10] - comp_ids_2 = list(np.asarray(comp_ids) + 1) - comp_ids += comp_ids_2 - - gt_ids = [4, 12, 17] - gt_ids_2 = list(np.asarray(gt_ids) + 1) - gt_ids += gt_ids_2 - - mapping = [(4, 3), (12, 7), (17, 10), (5, 4), (18, 11), (13, 8)] - - # need a tp, fp, fn - comp_edges = [(3, 4)] - comp_g = nx.DiGraph() - comp_g.add_nodes_from(comp_ids) - comp_g.add_edges_from(comp_edges) - nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) - nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE) - nx.set_node_attributes( - comp_g, - {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, - ) - G_comp = TrackingGraph(comp_g) - - gt_edges = [(4, 5), (17, 18)] - gt_g = nx.DiGraph() - gt_g.add_nodes_from(gt_ids) - gt_g.add_edges_from(gt_edges) - nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE) - nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) - gt_g.edges[(4, 5)][EdgeAttr.INTERTRACK_EDGE] = 1 - nx.set_node_attributes( - gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} - ) - G_gt = TrackingGraph(gt_g) - - matched_data = DummyMatched(G_gt, G_comp) + """ + gt: + 1_0 -- 1_1 -- 1_2 -- 1_3 + + comp: + 1_3 + 1_0 -- 1_1 -- 1_2 -< + 2_3 + """ + + gt = nx.DiGraph() + gt.add_edge("1_0", "1_1") + gt.add_edge("1_1", "1_2") + gt.add_edge("1_2", "1_3") + # Set node attrs + attrs = {} + for node in gt.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(gt, attrs) + + comp = gt.copy() + comp.add_edge("1_2", "2_3") + # Set node attrs + attrs = {} + for node in comp.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(comp, attrs) + + # Define mapping with all nodes matching except for 2_3 in comp + mapping = [(n, n) for n in gt.nodes] + + matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp)) matched_data.mapping = mapping get_edge_errors(matched_data) - assert comp_g.edges[(3, 4)][EdgeAttr.WRONG_SEMANTIC] - assert not comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS] + assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] From ae4159d811639be503e8012455ae04b398b6cf30 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 20 Sep 2023 12:35:28 -0700 Subject: [PATCH 5/7] Remove annotation of true positive edges since its somewhat subjective --- src/traccuracy/track_errors/_ctc.py | 3 --- tests/track_errors/test_ctc_errors.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index c02e4686..39a942c2 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -90,7 +90,6 @@ def get_edge_errors(matched_data: "Matched"): ).graph comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False) - comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.TRUE_POS, False) comp_graph.set_edge_attribute( list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False ) @@ -126,8 +125,6 @@ def get_edge_errors(matched_data: "Matched"): is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE] if is_parent_gt != is_parent_comp: comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) - else: - comp_graph.set_edge_attribute(edge, EdgeAttr.TRUE_POS, True) # fn edges - edges in gt_graph that aren't in induced graph for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"): diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 2bcc75dc..f79dd614 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -103,7 +103,6 @@ def test_assign_edge_errors(): get_edge_errors(matched_data) - assert comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS] assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS] assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG] From 404caf26b05afd1c09f4892f6f78f4667c8dadd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Sep 2023 19:42:11 +0000 Subject: [PATCH 6/7] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/loaders/_ctc.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/traccuracy/loaders/_ctc.py b/src/traccuracy/loaders/_ctc.py index 965b3d29..863601ce 100644 --- a/src/traccuracy/loaders/_ctc.py +++ b/src/traccuracy/loaders/_ctc.py @@ -125,11 +125,7 @@ def ctc_to_graph(df, detections): target = "{}_{}".format(row["Cell_ID"], row["Start"]) - edges.append( - pd.DataFrame( - {"source": [source], "target": [target]} - ) - ) + edges.append(pd.DataFrame({"source": [source], "target": [target]})) # Store position attributes on nodes detections["node_id"] = ( From a3968e6dfa4d34380543c8db1a463704ac5853a9 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 26 Sep 2023 16:32:02 -0700 Subject: [PATCH 7/7] Add merge detection to tracking graph and assign intertrack edges for merges --- src/traccuracy/_tracking_graph.py | 8 +++++ src/traccuracy/track_errors/_ctc.py | 6 ++++ tests/test_tracking_graph.py | 45 ++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index c67bbc0e..4fd905f1 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -308,6 +308,14 @@ def get_divisions(self): """ return [node for node, degree in self.graph.out_degree() if degree >= 2] + def get_merges(self): + """Get all nodes that have at least two incoming edges from the previous time frame + + Returns: + list of hashable: a list of node ids for nodes that have more than one parent + """ + return [node for node, degree in self.graph.in_degree() if degree >= 2] + def get_preds(self, node): """Get all predecessors of the given node. diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 39a942c2..ec99d131 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -109,6 +109,12 @@ def get_edge_errors(matched_data: "Matched"): (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True ) + for merge in graph.get_merges(): + for parent in graph.get_preds(merge): + graph.set_edge_attribute( + (parent, merge), EdgeAttr.INTERTRACK_EDGE, True + ) + # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 0d597046..f994d400 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -65,6 +65,40 @@ def nx_comp2(): return graph +@pytest.fixture +def nx_merge(): + """ + 3_0--3_1--\\ + 3_2--3_3 + 3_4--3_5--/ + """ + cells = [ + {"id": "3_0", "t": 0, "x": 0, "y": 0}, + {"id": "3_1", "t": 1, "x": 0, "y": 0}, + {"id": "3_2", "t": 2, "x": 0, "y": 0}, + {"id": "3_3", "t": 3, "x": 0, "y": 0}, + {"id": "3_4", "t": 0, "x": 0, "y": 0}, + {"id": "3_5", "t": 1, "x": 0, "y": 0}, + ] + + edges = [ + {"source": "3_0", "target": "3_1"}, + {"source": "3_1", "target": "3_2"}, + {"source": "3_2", "target": "3_3"}, + {"source": "3_4", "target": "3_5"}, + {"source": "3_5", "target": "3_2"}, + ] + graph = nx.DiGraph() + graph.add_nodes_from([(cell["id"], cell) for cell in cells]) + graph.add_edges_from([(edge["source"], edge["target"]) for edge in edges]) + return graph + + +@pytest.fixture +def merge_graph(nx_merge): + return TrackingGraph(nx_merge) + + @pytest.fixture def simple_graph(nx_comp1): return TrackingGraph(nx_comp1) @@ -146,11 +180,20 @@ def test_get_divisions(complex_graph): assert complex_graph.get_divisions() == ["1_1", "2_2"] -def test_get_preds(simple_graph): +def test_get_merges(merge_graph): + assert merge_graph.get_merges() == ["3_2"] + + +def test_get_preds(simple_graph, merge_graph): + # Division graph assert simple_graph.get_preds("1_0") == [] assert simple_graph.get_preds("1_1") == ["1_0"] assert simple_graph.get_preds("1_2") == ["1_1"] + # Merge graph + assert merge_graph.get_preds("3_3") == ["3_2"] + assert merge_graph.get_preds("3_2") == ["3_1", "3_5"] + def test_get_succs(simple_graph): assert simple_graph.get_succs("1_0") == ["1_1"]