diff --git a/README.md b/README.md index 0b0b3adc..1e84692c 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ The `traccuracy` library has three main components: loaders, matchers, and metri Loaders load tracking graphs from other formats, such as the CTC format, into a [TrackingGraph](https://traccuracy.readthedocs.io/en/latest/autoapi/traccuracy/index.html#traccuracy.TrackingGraph) object. A TrackingGraph is a spatiotemporal graph. Nodes represent a single cell in a given time point, and are annotated with a time and a location. -Edges point from a node representing a cell in time point `t` to the same cell or its daughter in `t+1`. +Edges point forward in time from a node representing a cell in time point `t` to the same cell or its daughter in frame `t+1` (or beyond, to represent gap-closing). To load TrackingGraphs from a custom format, you will likely need to implement a loader: see documentation [here](https://traccuracy.readthedocs.io/en/latest/autoapi/traccuracy/loaders/index.html#module-traccuracy.loaders) for more information. @@ -55,4 +55,7 @@ pipelines, [documented here](https://traccuracy.readthedocs.io/en/latest/cli.htm : A single non-dividing cell tracked over time. In graph terms, this is the connected component of a track between divisions (daughter to next parent). Tracklets can also start or end with a non-dividing cell at the beginning and end of the captured time or if the track leaves the field of view. **Track** -: A single cell and all of its progeny. In graph terms, a connected component including divisions. \ No newline at end of file +: A single cell and all of its progeny. In graph terms, a connected component including divisions. + +**Gap-Closing** +: Also known as *frame-skipping*, these are edges that connect non-consecutive frames to signify a cell being occluded or missing for some frames, before the track continues. \ No newline at end of file diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 082e0ac3..8f21bc62 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -81,6 +81,9 @@ class TrackingGraph: location (defaults to 'x' and 'y'). As in networkx, every cell must have a unique id, but these can be of any (hashable) type. + Edges typically connect nodes across consecutive frames, but gap closing or frame + skipping edges are valid, which connect nodes in frame t to nodes in frames beyond t+1. + We provide common functions for accessing parts of the track graph, for example all nodes in a certain frame, or all previous or next edges for a given node. Additional functionality can be accessed by querying the stored networkx graph diff --git a/tests/loaders/test_ctc.py b/tests/loaders/test_ctc.py index 38dcc06f..a621fa4d 100644 --- a/tests/loaders/test_ctc.py +++ b/tests/loaders/test_ctc.py @@ -66,6 +66,23 @@ def test_ctc_single_nodes(): TrackingGraph(G) +def test_ctc_with_gap_closing(): + data = [ + {"Cell_ID": 1, "Start": 0, "End": 1, "Parent_ID": 0}, + {"Cell_ID": 2, "Start": 0, "End": 1, "Parent_ID": 0}, + # Connecting frame 1 to frame 3 + {"Cell_ID": 3, "Start": 3, "End": 5, "Parent_ID": 1}, + # Connecting frame 1 to frame 6 + {"Cell_ID": 4, "Start": 6, "End": 8, "Parent_ID": 2}, + ] + df = pd.DataFrame(data) + G = _ctc.ctc_to_graph( + df, pd.DataFrame({"segmentation_id": [], "x": [], "y": [], "z": [], "t": []}) + ) + assert G.has_edge("1_1", "3_3") + assert G.has_edge("2_1", "4_6") + + def test_load_data(): test_dir = os.path.abspath(__file__) data_dir = os.path.abspath( diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index bd92fc95..e8885b8a 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,7 +1,9 @@ +from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +from traccuracy.matchers._base import Matched from traccuracy.matchers._ctc import CTCMatcher from traccuracy.metrics._ctc import CTCMetrics -from tests.test_utils import get_movie_with_graph +from tests.test_utils import get_gap_close_graphs, get_movie_with_graph def test_compute_mapping(): @@ -17,3 +19,19 @@ def test_compute_mapping(): assert "DET" in results assert results["TRA"] == 1 assert results["DET"] == 1 + + +def test_compute_metrics_gap_close(): + g_gt, g_pred, mapper = get_gap_close_graphs() + matched = Matched( + gt_graph=TrackingGraph(g_gt), pred_graph=TrackingGraph(g_pred), mapping=mapper + ) + CTCMetrics().compute(matched) + + # check that missing gap closing edge is false negative + assert g_gt.edges[("1_1", "2_3")][EdgeAttr.FALSE_NEG] + # check that "extra" node is FP + assert g_pred.nodes["1_2"][NodeAttr.FALSE_POS] + # check that correct edge is not annotated with errors + for error_attr in [EdgeAttr.FALSE_POS, EdgeAttr.WRONG_SEMANTIC]: + assert not g_pred.edges[("2_6", "4_10")][error_attr] diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index f70a9304..8a2da45e 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -6,6 +6,8 @@ from traccuracy.matchers import Matched from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict +from tests.test_utils import get_gap_close_graphs + def add_frame(tree): attrs = {} @@ -183,6 +185,19 @@ def test_track_overlap_metrics(data, inverse) -> None: assert results == expected, f"{data['name']} failed without division edges" +def test_track_overlap_gap_close(): + g_gt, g_pred, mapping = get_gap_close_graphs() + matched = Matched( + TrackingGraph(g_gt), + TrackingGraph(g_pred), + mapping, + ) + metric = TrackOverlapMetrics() + results = metric.compute(matched) + assert results["track_purity"] == 7 / 9 + assert results["target_effectiveness"] == 7 / 8 + + def test_mapping_to_dict(): mapping = [("1", "2"), ("2", "3"), ("1", "3"), ("2", "3")] mapping_dict = _mapping_to_dict(mapping) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1dbf6d7c..9d98cdb9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -148,3 +148,124 @@ def get_division_graphs(): mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")] return G1, G2, mapper + + +def get_gap_close_graphs(): + """ + G1 + 3_5 -- 3_6 -- -- -- 5_10 + 1_0 -- 1_1 -- -- -- 2_3 -- 2_4 -< + 4_5 -- 4_6 + G2 + 2_5 -- 2_6 -- -- -- 4_10 + 1_0 -- 1_1 -- 1_2 -- 1_3 -- 1_4 -< + 3_5 -- 3_6 + """ + G1 = nx.DiGraph() + G1.add_edge("1_0", "1_1") + # gap closing edge + G1.add_edge("1_1", "2_3") + G1.add_edge("2_3", "2_4") + # Divide to generate 3 lineage + G1.add_edge("2_4", "3_5") + G1.add_edge("3_5", "3_6") + # gap closing edge + G1.add_edge("3_6", "5_10") + # Divide to generate 4 lineage + G1.add_edge("2_4", "4_5") + G1.add_edge("4_5", "4_6") + + attrs = {} + for node in G1.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G1, attrs) + + G2 = nx.DiGraph() + G2.add_edge("1_0", "1_1") + # missing gap closing edge + G2.add_edge("1_1", "1_2") + G2.add_edge("1_2", "1_3") + G2.add_edge("1_3", "1_4") + # Divide to generate 2 lineage + G2.add_edge("1_4", "2_5") + G2.add_edge("2_5", "2_6") + # correct gap closing edge + G2.add_edge("2_6", "4_10") + # Divide to generate 3 lineage + G2.add_edge("1_4", "3_5") + G2.add_edge("3_5", "3_6") + + attrs = {} + for node in G2.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G2, attrs) + + # G1, G2 mapper + mapper = [ + ("1_0", "1_0"), + ("1_1", "1_1"), + ("2_3", "1_3"), + ("2_4", "1_4"), + ("3_5", "2_5"), + ("3_6", "2_6"), + ("5_10", "4_10"), + ("4_5", "3_5"), + ("4_6", "3_6"), + ] + + return G1, G2, mapper + + +def get_division_gap_close_graphs(): + """ + G1 + -- -- 2_3 -- 2_4 + 1_0 -- 1_1 -< + 3_2 -- 3_3 -- 3_4 + G2 + 2_2 -- 2_3 -- 2_4 + 1_0 -- 1_1 -< + 3_2 -- -- -- 4_4 + """ + + G1 = nx.DiGraph() + G1.add_edge("1_0", "1_1") + # gap division + G1.add_edge("1_1", "2_3") + G1.add_edge("2_3", "2_4") + # divide into 3 lineage + G1.add_edge("1_1", "3_2") + G1.add_edge("3_2", "3_3") + G1.add_edge("3_3", "3_4") + + attrs = {} + for node in G1.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G1, attrs) + + G2 = nx.DiGraph() + G2.add_edge("1_0", "1_1") + # Divide to generate 2 lineage + G2.add_edge("1_1", "2_2") + G2.add_edge("2_2", "2_3") + G2.add_edge("2_3", "2_4") + # Divide to generate 3 lineage + G2.add_edge("1_1", "3_2") + # incorrect gap closing edge + G2.add_edge("3_2", "4_4") + + attrs = {} + for node in G2.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G2, attrs) + + mapper = [ + ("1_0", "1_0"), + ("1_1", "1_1"), + ("2_3", "2_3"), + ("2_4", "2_4"), + ("3_2", "3_2"), + ("3_4", "4_4"), + ] + + return G1, G2, mapper diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 6538e644..8822c75e 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -11,7 +11,7 @@ _get_succ_by_t, ) -from tests.test_utils import get_division_graphs +from tests.test_utils import get_division_gap_close_graphs, get_division_graphs @pytest.fixture @@ -220,3 +220,30 @@ def test_evaluate_division_events(): results = _evaluate_division_events(matched_data, frame_buffer=frame_buffer) assert np.all([isinstance(k, int) for k in results.keys()]) + + +def test_gap_close_divisions(): + g_gt, g_pred, mapper = get_division_gap_close_graphs() + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) + _classify_divisions(matched_data) + + # missing gap close div edge so FN DIV + assert g_gt.nodes["1_1"][NodeAttr.FN_DIV] + + # fix division, assert it's identified correctly + g_pred.remove_node("2_2") + g_pred.add_edge("1_1", "2_3") + # mapper doesn't need to change as removed node was always missing + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) + _classify_divisions(matched_data) + assert g_gt.nodes["1_1"][NodeAttr.TP_DIV] + assert g_pred.nodes["1_1"][NodeAttr.TP_DIV] + + g_gt, g_pred, mapper = get_division_gap_close_graphs() + # remove gt division + g_gt.remove_edge("1_1", "2_3") + g_gt.remove_edge("1_1", "3_2") + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) + _classify_divisions(matched_data) + # assert fp division classified correctly + assert g_pred.nodes["1_1"][NodeAttr.FP_DIV]