diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 3508c10b..68074776 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -163,6 +163,8 @@ def __init__( self._update_graph(graph) + # Record types of annotations that have been calculated + self.division_annotations = False self.node_errors = False self.edge_errors = False diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 27a4bc85..c1b5ce77 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -1,4 +1,4 @@ -"""This submodule classifies division erros in tracking graphs +"""This submodule classifies division errors in tracking graphs Each division is classifed as one of the following: - true positive @@ -32,57 +32,26 @@ as the late division daughters. """ -import itertools -from collections import Counter -from traccuracy._tracking_graph import TrackingGraph -from traccuracy._utils import find_gt_node_matches, find_pred_node_matches -from traccuracy.track_errors._division_events import DivisionEvents +from traccuracy._tracking_graph import NodeAttr +from traccuracy.track_errors.divisions import _evaluate_division_events from ._base import Metric -def _calculate_metrics(event: DivisionEvents): - try: - recall = event.tp_division_count / ( - event.tp_division_count + event.fn_division_count - ) - except ZeroDivisionError: - recall = 0 - - try: - precision = event.tp_division_count / ( - event.tp_division_count + event.fp_division_count - ) - except ZeroDivisionError: - precision = 0 - - try: - f1 = 2 * (recall * precision) / (recall + precision) - except ZeroDivisionError: - f1 = 0 - - try: - mbc = event.tp_division_count / ( - event.tp_division_count + event.fn_division_count + event.fp_division_count - ) - except ZeroDivisionError: - mbc = 0 - - return { - "Division Recall": recall, - "Division Precision": precision, - "Division F1": f1, - "Mitotic Branching Correctness": mbc, - **event.count_dict, - } - - class DivisionMetrics(Metric): needs_one_to_one = True def __init__(self, matched_data, frame_buffer=(0,)): - """Classify division events and provide summary metrics + """Classify division events and provide the following summary metrics + + - Division Recall + - Division Precision + - Divison F1 Score + - Mitotic Branching Correctness: TP / (TP + FP + FN) as defined by Ulicna, K., + Vallardi, G., Charras, G. & Lowe, A. R. Automated deep lineage tree analysis + using a Bayesian single cell tracking approach. Frontiers in Computer Science + 3, 734559 (2021). Args: matched_data (Matched): Matched object for set of GT and Pred data @@ -99,282 +68,58 @@ def compute(self): Returns: dict: Returns a nested dictionary with one dictionary per frame buffer value """ - events = _evaluate_division_events( - self.data.gt_data.tracking_graph, - self.data.pred_data.tracking_graph, - self.data.mapping, + div_annotations = _evaluate_division_events( + self.data, frame_buffer=self.frame_buffer, ) return { - f"Frame Buffer {event.frame_buffer}": _calculate_metrics(event) - for event in events + f"Frame Buffer {fb}": self._calculate_metrics( + matched_data.gt_data.tracking_graph, + matched_data.pred_data.tracking_graph, + ) + for fb, matched_data in div_annotations.items() } + def _calculate_metrics(self, g_gt, g_pred): + tp_division_count = len( + g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x) + ) + fn_division_count = len( + g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x) + ) + fp_division_count = len( + g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x) + ) -def _classify_divisions(G_gt, G_pred, mapper): - """Identify each division as a true positive, false positive or false negative - - This function only works on node mappers that are one-to-one - - Args: - G_gt (TrackingGraph): `TrackingGraph` of GT data - G_pred (TrackingGraph): `TrackingGraph` of pred data - mapper ([(gt_node, pred_node)]): List of tuples with pairs of gt and pred nodes - - Raises: - TypeError: G_gt and G_pred must be TrackingGraph objects - ValueError: mapper must contain a one-to-one mapping of nodes - - Returns: - DivisionEvents: Counts of gt_divisions, tp_divisions, fp_divisions and fn_divisions - TrackingGraph: G_gt with division annotations - TrackingGraph: G_pred with division annotations - """ - if not isinstance(G_gt, TrackingGraph) or not isinstance(G_pred, TrackingGraph): - raise TypeError("G_gt and G_pred must be TrackingGraph objects") - - # Check that mapper is one to one - if len(mapper) != len({pair[0] for pair in mapper}) or len(mapper) != len( - {pair[1] for pair in mapper} - ): - raise ValueError("Mapping must be one-to-one") - - def _find_gt_node_matches(gt_node): - match = find_gt_node_matches(mapper, gt_node) - if len(match) > 0: - return match[0] - - def _find_pred_node_matches(pred_node): - match = find_pred_node_matches(mapper, pred_node) - if len(match) > 0: - return match[0] - - # Collect list of divisions - div_gt = G_gt.get_divisions() - div_pred = G_pred.get_divisions() - - counts = DivisionEvents() - counts.gt_divisions.extend(div_gt) - - for gt_node in div_gt: - # Find possible matching nodes - pred_node = _find_gt_node_matches(gt_node) - # No matching node so division missed - if pred_node is None: - counts.fn_divisions.append(gt_node) - G_gt.set_node_attribute(gt_node, "is_fn_division", True) - # Check if the division has the corret daughters - else: - succ_gt = G_gt.get_succs(gt_node) - # Map pred succ nodes onto gt, unmapped nodes will return as None - succ_pred = [ - _find_pred_node_matches(n) for n in G_pred.get_succs(pred_node) - ] - - # If daughters are same, division is correct - if Counter(succ_gt) == Counter(succ_pred): - counts.tp_divisions.append(gt_node) - G_gt.set_node_attribute(gt_node, "is_tp_division", True) - G_pred.set_node_attribute(pred_node, "is_tp_division", True) - # If daughters are at all mismatched, division is false negative - else: - counts.fn_divisions.append(gt_node) - G_gt.set_node_attribute(gt_node, "is_fn_division", True) - - # Remove res division to record that we have classified it - if pred_node in div_pred: - div_pred.remove(pred_node) - - # Any remaining pred divisions are false positives - counts.fp_divisions.extend(div_pred) - G_pred.set_node_attribute(div_pred, "is_fp_division", True) - - return counts, G_gt, G_pred - - -def _get_pred_by_t(G, node, delta_frames): - """For a given graph and node, traverses back by predecessor until target_frame - - Args: - G (TrackingGraph): TrackingGraph to search on - node (hashable): Key of starting node - target_frame (int): Frame of the predecessor target node - - Raises: - ValueError: Cannot operate on graphs with merges - - Returns: - hashable: Node key of predecessor in target frame - """ - for _ in range(delta_frames): - nodes = G.get_preds(node) - # Exit if there are no predecessors - if len(nodes) == 0: - return None - # Fail if finding merges - elif len(nodes) > 1: - raise ValueError("Cannot operate on graphs with merges") - node = nodes[0] - - return node - - -def _get_succ_by_t(G, node, delta_frames): - """For a given node, find the successors after delta frames - - If a division event is discovered, returns None - - Args: - G (TrackingGraph): TrackingGraph to search on - node (hashable): Key of starting node - target_frame (int): Frame of the successor target node - - Returns: - hashable: Node id of successor - """ - for _ in range(delta_frames): - nodes = G.get_succs(node) - # Exit if there are no successors another division - if len(nodes) == 0 or len(nodes) >= 2: - return None - node = nodes[0] - - return node - - -def _correct_shifted_divisions(G_gt, G_pred, mapper, n_frames=1): - """Allows for divisions to occur within a frame buffer and still be correct - - This implementation asserts that the parent lineages and daughter lineages must match. - Matching is determined based on the provided mapper - Does not support merges - - Args: - G_gt (TrackingGraph): GT tracking graph with FN division annotations - G_pred (TrackningGraph): Pred tracking graph with FP division annotations - mapper ([(gt_node, pred_node)]): List of tuples with pairs of gt and pred nodes - Must be a one-to-one mapping - n_frames (int): Number of frames to include in the frame buffer - - Returns: - DivisionEvents: Corrected counts of gt_divisions, tp_divisions, fp_divisions - and fn_divisions - """ - - if not isinstance(G_gt, TrackingGraph) or not isinstance(G_pred, TrackingGraph): - raise TypeError("G_gt and G_pred must be TrackingGraph objects") - - # Check that mapper is one to one - if len(mapper) != len({pair[0] for pair in mapper}) or len(mapper) != len( - {pair[1] for pair in mapper} - ): - raise ValueError("Mapping must be one-to-one") - - fp_divs = G_pred.get_nodes_with_attribute("is_fp_division") - fn_divs = G_gt.get_nodes_with_attribute("is_fn_division") - - # Create counts object for collecting error classifications - counts = DivisionEvents( - gt_divisions=G_gt.get_divisions(), - fp_divisions=fp_divs, - fn_divisions=fn_divs, - tp_divisions=G_gt.get_nodes_with_attribute("is_tp_division"), - frame_buffer=n_frames, - ) - - # Compare all pairs of fp and fn - for fp_node, fn_node in itertools.product(fp_divs, fn_divs): - correct = False - t_fp = G_pred.graph.nodes[fp_node][G_pred.frame_key] - t_fn = G_gt.graph.nodes[fn_node][G_gt.frame_key] - - # Move on if nodes are not within frame buffer or within same frame - if abs(t_fp - t_fn) > n_frames or t_fp == t_fn: - continue - - # False positive in pred occurs before false negative in gt - if t_fp < t_fn: - # Check if fp node matches prececessor of fn - fn_pred = _get_pred_by_t(G_gt, fn_node, t_fn - t_fp) - # Check if the match exists - if (fn_pred, fp_node) not in mapper: - # Match does not exist so divisions cannot match - continue - - # Check if daughters match - fp_succ = [ - _get_succ_by_t(G_pred, node, t_fn - t_fp) - for node in G_pred.get_succs(fp_node) - ] - fn_succ = G_gt.get_succs(fn_node) - if Counter(fp_succ) != Counter(fn_succ): - # Daughters don't match so division cannot match - continue - - # At this point daughters and parents match so division is correct - correct = True - # False negative in gt occurs before false positive in pred - else: - # Check if fp node matches fn predecessor - fp_pred = _get_pred_by_t(G_pred, fp_node, t_fp - t_fn) - # Check if match exists - if (fn_node, fp_pred) not in mapper: - # Match does not exist so divisions cannot match - continue - - # Check if daughters match - fn_succ = [ - _get_succ_by_t(G_gt, node, t_fp - t_fn) - for node in G_gt.get_succs(fn_node) - ] - fp_succ = G_pred.get_succs(fp_node) - if Counter(fp_succ) != Counter(fn_succ): - # Daughters don't match so division cannot match - continue - - # At this point daughters and parents match so division is correct - correct = True - - if correct: - # Remove node from error lists - counts.fp_divisions.remove(fp_node) - counts.fn_divisions.remove(fn_node) - - # Add gt node to tp list - counts.tp_divisions.append(fn_node) - - return counts - - -def _evaluate_division_events(G_gt, G_pred, mapper, frame_buffer=(0)): - """Classify division errors and correct shifted divisions according to frame_buffer - One DivisionEvent object will be returned for each value in frame_buffer - - Args: - G_gt (TrackingGraph): TrackingGraph of GT data - G_pred (TrackingGraph): TrackingGraph of pred data - mapper ([(gt_node, pred_node)]): List of tuples with pairs of gt and pred nodes - frame_buffer (tuple, optional): Tuple of integers. Value used as n_frames - to tolerate in correct_shifted_divisions. Defaults to (0). - - Returns: - list[DivisionEvents]: List of one DivisionEvent object per value in frame_buffer - """ - - events = [] + try: + recall = tp_division_count / (tp_division_count + fn_division_count) + except ZeroDivisionError: + recall = 0 - # Baseline division classification - event, G_gt, G_pred = _classify_divisions(G_gt, G_pred, mapper) - events.append(event) + try: + precision = tp_division_count / (tp_division_count + fp_division_count) + except ZeroDivisionError: + precision = 0 - # Correct shifted divisions for each nonzero value in frame_buffer - for delta in frame_buffer: - # Skip 0 because we used that in baseline classification - if delta == 0: - continue + try: + f1 = 2 * (recall * precision) / (recall + precision) + except ZeroDivisionError: + f1 = 0 - event = _correct_shifted_divisions(G_gt, G_pred, mapper, n_frames=delta) - events.append(event) + try: + mbc = tp_division_count / ( + tp_division_count + fn_division_count + fp_division_count + ) + except ZeroDivisionError: + mbc = 0 - return events + return { + "Division Recall": recall, + "Division Precision": precision, + "Division F1": f1, + "Mitotic Branching Correctness": mbc, + "True Positive Divisions": tp_division_count, + "False Positive Divisions": fp_division_count, + "False Negative Divisions": fn_division_count, + } diff --git a/src/traccuracy/track_errors/_division_events.py b/src/traccuracy/track_errors/_division_events.py deleted file mode 100644 index 11fbcdfd..00000000 --- a/src/traccuracy/track_errors/_division_events.py +++ /dev/null @@ -1,81 +0,0 @@ -class DivisionEvents: - """A class to hold counts of tracking events or errors. - - Counts are generated based on the output of a matching - (gt TrackingGraph, predicted TrackingGraph, matched nodes). - - This class provides a set of standard events that our library - keeps track of. - To add custom fields, you can create a subclass of this class. - - Fields: - gt_divisions (list): The number of divisions in the ground truth graph. - Defaults to None. - fp_divisions (list): The number of divisions in the predicted graph - that are not matched to a division in the ground truth graph. - Defaults to None. - fn_divisions (list): The number of divisions in the ground truth graph - that are not matched to a division in the predicted graph. - Defaults to None. - tp_divisions (list): The number of divisions in the ground truth graph - that are correctly identified in the predicted graph. - Defaults to None. - frame_buffer (int): A predicted division can be matched with a ground - truth division within this many frames. Defaults to 0. - - """ - - def __init__( - self, - gt_divisions=None, - tp_divisions=None, - fp_divisions=None, - fn_divisions=None, - frame_buffer=0, - ): - if isinstance(gt_divisions, list): - self.gt_divisions = gt_divisions - else: - self.gt_divisions = [] - - if isinstance(tp_divisions, list): - self.tp_divisions = tp_divisions - else: - self.tp_divisions = [] - - if isinstance(fp_divisions, list): - self.fp_divisions = fp_divisions - else: - self.fp_divisions = [] - - if isinstance(fn_divisions, list): - self.fn_divisions = fn_divisions - else: - self.fn_divisions = [] - - self.frame_buffer = frame_buffer - - @property - def gt_division_count(self): - return len(self.gt_divisions) - - @property - def tp_division_count(self): - return len(self.tp_divisions) - - @property - def fp_division_count(self): - return len(self.fp_divisions) - - @property - def fn_division_count(self): - return len(self.fn_divisions) - - @property - def count_dict(self): - return { - "Total GT Divisions": self.gt_division_count, - "True Positive Divisions": self.tp_division_count, - "False Positive Divisions": self.fp_division_count, - "False Negative Divisions": self.fn_division_count, - } diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py new file mode 100644 index 00000000..9e952aee --- /dev/null +++ b/src/traccuracy/track_errors/divisions.py @@ -0,0 +1,271 @@ +import copy +import itertools +import logging +from collections import Counter +from typing import TYPE_CHECKING + +from traccuracy._tracking_graph import NodeAttr +from traccuracy._utils import find_gt_node_matches, find_pred_node_matches + +if TYPE_CHECKING: + from traccuracy.matchers._matched import Matched + +logger = logging.getLogger(__name__) + + +def _classify_divisions(matched_data: "Matched"): + """Identify each division as a true positive, false positive or false negative + + This function only works on node mappers that are one-to-one + + Graphs are annotated in place and therefore not returned + + Args: + matched_data (Matched): Matched data object containing gt and pred graphs + with their associated mapping + + Raises: + ValueError: mapper must contain a one-to-one mapping of nodes + """ + g_gt = matched_data.gt_data.tracking_graph + g_pred = matched_data.pred_data.tracking_graph + mapper = matched_data.mapping + + if g_gt.division_annotations and g_pred.division_annotations: + logger.info("Divison annotations already present. Skipping graph annotation.") + return + + # Check that mapper is one to one + if len(mapper) != len({pair[0] for pair in mapper}) or len(mapper) != len( + {pair[1] for pair in mapper} + ): + raise ValueError("Mapping must be one-to-one") + + def _find_gt_node_matches(gt_node): + match = find_gt_node_matches(mapper, gt_node) + if len(match) > 0: + return match[0] + + def _find_pred_node_matches(pred_node): + match = find_pred_node_matches(mapper, pred_node) + if len(match) > 0: + return match[0] + + # Collect list of divisions + div_gt = g_gt.get_divisions() + div_pred = g_pred.get_divisions() + + for gt_node in div_gt: + # Find possible matching nodes + pred_node = _find_gt_node_matches(gt_node) + # No matching node so division missed + if pred_node is None: + g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + # Check if the division has the corret daughters + else: + succ_gt = g_gt.get_succs(gt_node) + # Map pred succ nodes onto gt, unmapped nodes will return as None + succ_pred = [ + _find_pred_node_matches(n) for n in g_pred.get_succs(pred_node) + ] + + # If daughters are same, division is correct + if Counter(succ_gt) == Counter(succ_pred): + g_gt.set_node_attribute(gt_node, NodeAttr.TP_DIV, True) + g_pred.set_node_attribute(pred_node, NodeAttr.TP_DIV, True) + # If daughters are at all mismatched, division is false negative + else: + g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + + # Remove res division to record that we have classified it + if pred_node in div_pred: + div_pred.remove(pred_node) + + # Any remaining pred divisions are false positives + g_pred.set_node_attribute(div_pred, NodeAttr.FP_DIV, True) + + # Set division annotation flag + g_gt.division_annotations = True + g_pred.division_annotations = True + + +def _get_pred_by_t(g, node, delta_frames): + """For a given graph and node, traverses back by predecessor until target_frame + + Args: + G (TrackingGraph): TrackingGraph to search on + node (hashable): Key of starting node + target_frame (int): Frame of the predecessor target node + + Raises: + ValueError: Cannot operate on graphs with merges + + Returns: + hashable: Node key of predecessor in target frame + """ + for _ in range(delta_frames): + nodes = g.get_preds(node) + # Exit if there are no predecessors + if len(nodes) == 0: + return None + # Fail if finding merges + elif len(nodes) > 1: + raise ValueError("Cannot operate on graphs with merges") + node = nodes[0] + + return node + + +def _get_succ_by_t(g, node, delta_frames): + """For a given node, find the successors after delta frames + + If a division event is discovered, returns None + + Args: + G (TrackingGraph): TrackingGraph to search on + node (hashable): Key of starting node + target_frame (int): Frame of the successor target node + + Returns: + hashable: Node id of successor + """ + for _ in range(delta_frames): + nodes = g.get_succs(node) + # Exit if there are no successors another division + if len(nodes) == 0 or len(nodes) >= 2: + return None + node = nodes[0] + + return node + + +def _correct_shifted_divisions(matched_data: "Matched", n_frames=1): + """Allows for divisions to occur within a frame buffer and still be correct + + This implementation asserts that the parent lineages and daughter lineages must match. + Matching is determined based on the provided mapper + Does not support merges + + Copies matched_data before modifying node annotations and returns the new versions + + Args: + matched_data (Matched): Matched data object containing gt and pred graphs + with their associated mapping + n_frames (int): Number of frames to include in the frame buffer + + Returns: + Matched: copy of matched_data with corrected division annotations + """ + # Create copies of the graphs to modify during correction of divisions + new_matched = copy.deepcopy(matched_data) + g_gt = new_matched.gt_data.tracking_graph + g_pred = new_matched.pred_data.tracking_graph + mapper = new_matched.mapping + + # Check that mapper is one to one + if len(mapper) != len({pair[0] for pair in mapper}) or len(mapper) != len( + {pair[1] for pair in mapper} + ): + raise ValueError("Mapping must be one-to-one") + + fp_divs = g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV) + fn_divs = g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV) + + # Compare all pairs of fp and fn + for fp_node, fn_node in itertools.product(fp_divs, fn_divs): + correct = False + t_fp = g_pred.graph.nodes[fp_node][g_pred.frame_key] + t_fn = g_gt.graph.nodes[fn_node][g_gt.frame_key] + + # Move on if nodes are not within frame buffer or within same frame + if abs(t_fp - t_fn) > n_frames or t_fp == t_fn: + continue + + # False positive in pred occurs before false negative in gt + if t_fp < t_fn: + # Check if fp node matches prececessor of fn + fn_pred = _get_pred_by_t(g_gt, fn_node, t_fn - t_fp) + # Check if the match exists + if (fn_pred, fp_node) not in mapper: + # Match does not exist so divisions cannot match + continue + + # Check if daughters match + fp_succ = [ + _get_succ_by_t(g_pred, node, t_fn - t_fp) + for node in g_pred.get_succs(fp_node) + ] + fn_succ = g_gt.get_succs(fn_node) + if Counter(fp_succ) != Counter(fn_succ): + # Daughters don't match so division cannot match + continue + + # At this point daughters and parents match so division is correct + correct = True + # False negative in gt occurs before false positive in pred + else: + # Check if fp node matches fn predecessor + fp_pred = _get_pred_by_t(g_pred, fp_node, t_fp - t_fn) + # Check if match exists + if (fn_node, fp_pred) not in mapper: + # Match does not exist so divisions cannot match + continue + + # Check if daughters match + fn_succ = [ + _get_succ_by_t(g_gt, node, t_fp - t_fn) + for node in g_gt.get_succs(fn_node) + ] + fp_succ = g_pred.get_succs(fp_node) + if Counter(fp_succ) != Counter(fn_succ): + # Daughters don't match so division cannot match + continue + + # At this point daughters and parents match so division is correct + correct = True + + if correct: + # Remove error annotations from pred graph + g_pred.set_node_attribute(fp_node, NodeAttr.FP_DIV, False) + g_gt.set_node_attribute(fn_node, NodeAttr.FN_DIV, False) + + # Add the tp divisions annotations + g_gt.set_node_attribute(fn_node, NodeAttr.TP_DIV, True) + g_pred.set_node_attribute(fp_node, NodeAttr.TP_DIV, True) + + return new_matched + + +def _evaluate_division_events(matched_data: "Matched", frame_buffer=(0)): + """Classify division errors and correct shifted divisions according to frame_buffer + + Note: A copy of matched_data will be created for each frame_buffer other than 0. + For large graphs, creating copies may introduce memory problems. + + Args: + matched_data (Matched): Matched data object containing gt and pred graphs + with their associated mapping + frame_buffer (tuple, optional): Tuple of integers. Value used as n_frames + to tolerate in correct_shifted_divisions. Defaults to (0). + + Returns: + dict {frame_buffer: matched_data}: A dictionary where each key corresponds to a frame + buffer with a tuple of the corresponding ground truth and predicted TrackingGraphs + after division annotations and correction by frame buffer + """ + div_annotations = {} + + # Baseline division classification + _classify_divisions(matched_data) + div_annotations[0] = matched_data + + # Correct shifted divisions for each nonzero value in frame_buffer + for delta in frame_buffer: + # Skip 0 because we used that in baseline classification + if delta == 0: + continue + + corrected_matched = _correct_shifted_divisions(matched_data, n_frames=delta) + div_annotations[delta] = corrected_matched + + return div_annotations diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index e4f3a368..7f3f73af 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,264 +1,8 @@ -import networkx as nx -import pytest from traccuracy import TrackingData, TrackingGraph from traccuracy.matchers._matched import Matched -from traccuracy.metrics._divisions import ( - DivisionMetrics, - _classify_divisions, - _correct_shifted_divisions, - _evaluate_division_events, - _get_pred_by_t, - _get_succ_by_t, -) +from traccuracy.metrics._divisions import DivisionMetrics - -@pytest.fixture -def G(): - """ - 1_0 -- 1_1 -- 1_2 -- 1_3 - 3_3 - 2_0 -- 2_1 -- 2_2 -< - 4_3 - """ - G = nx.DiGraph() - G.add_edge("1_0", "1_1") - G.add_edge("1_1", "1_2") - G.add_edge("1_2", "1_3") - - G.add_edge("2_0", "2_1") - G.add_edge("2_1", "2_2") - - # node 2 divides into 3 and 4 in frame 3 - G.add_edge("2_2", "3_3") - G.add_edge("2_2", "4_3") - - # Set node attributes - attrs = {} - for node in G.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(G, attrs) - - return G - - -def test_classify_divisions_tp(G): - # Define mapper assuming all nodes match - mapper = [(n, n) for n in G.nodes] - - # Test true positive - counts, G_gt, G_pred = _classify_divisions( - TrackingGraph(G), TrackingGraph(G), mapper - ) - assert len(counts.tp_divisions) == 1 - assert len(counts.fn_divisions) == 0 - assert len(counts.fp_divisions) == 0 - assert "is_tp_division" in G_gt.nodes()["2_2"] - assert "is_tp_division" in G_pred.nodes()["2_2"] - - -def test_classify_divisions_fp(G): - """ - 5_3 - 1_0 -- 1_1 -- 1_2 -< - 1_3 - 3_3 - 2_0 -- 2_1 -- 2_2 -< - 4_3 - """ - H = G.copy() - # Add false positive division edge - H.add_edge("1_2", "5_3") - nx.set_node_attributes(H, {"5_3": {"t": 3, "x": 0, "y": 0}}) - mapper = [(n, n) for n in H.nodes] - - counts, G_gt, G_pred = _classify_divisions( - TrackingGraph(G), TrackingGraph(H), mapper - ) - assert len(counts.fp_divisions) == 1 - assert len(counts.tp_divisions) == 1 - assert len(counts.fn_divisions) == 0 - assert "is_fp_division" in G_pred.nodes()["1_2"] - - -def test_classify_divisions_fn(G): - """ - 1_0 -- 1_1 -- 1_2 -- 1_3 - 2_0 -- 2_1 -- 2_2 - """ - # Remove daughters to create false negative - H = G.copy() - H.remove_nodes_from(["3_3", "4_3"]) - mapper = [(n, n) for n in H.nodes] - - counts, G_gt, G_pred = _classify_divisions( - TrackingGraph(G), TrackingGraph(H), mapper - ) - assert len(counts.fp_divisions) == 0 - assert len(counts.tp_divisions) == 0 - assert len(counts.fn_divisions) == 1 - assert "is_fn_division" in G_gt.nodes()["2_2"] - - -@pytest.fixture -def straight_graph(): - G = nx.DiGraph() - for t in range(2, 10): - G.add_edge(f"1_{t}", f"1_{t+1}") - - # Set node attributes - attrs = {} - for node in G.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(G, attrs) - - return G - - -def test__get_pred_by_t(straight_graph): - # Linear graph with node id 1 from frame 2-10 - G = TrackingGraph(straight_graph) - - # Predecessor available - start_frame = 10 - target_frame = 5 - node = _get_pred_by_t(G, f"1_{start_frame}", start_frame - target_frame) - assert node == f"1_{target_frame}" - - # Predecessor does not exist - start_frame = 10 - target_frame = 1 - node = _get_pred_by_t(G, f"1_{start_frame}", start_frame - target_frame) - assert node is None - - -def get_division_graphs(): - """ - G1 - 2_4 - 1_0 -- 1_1 -- 1_2 -- 1_3 -< - 3_4 - G2 - 2_2 -- 2_3 -- 2_4 - 1_0 -- 1_1 -< - 3_2 -- 3_3 -- 3_4 - """ - - G1 = nx.DiGraph() - G1.add_edge("1_0", "1_1") - G1.add_edge("1_1", "1_2") - G1.add_edge("1_2", "1_3") - G1.add_edge("1_3", "2_4") - G1.add_edge("1_3", "3_4") - - attrs = {} - for node in G1.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(G1, attrs) - - G2 = nx.DiGraph() - G2.add_edge("1_0", "1_1") - # Divide to generate 2 lineage - G2.add_edge("1_1", "2_2") - G2.add_edge("2_2", "2_3") - G2.add_edge("2_3", "2_4") - # Divide to generate 3 lineage - G2.add_edge("1_1", "3_2") - G2.add_edge("3_2", "3_3") - G2.add_edge("3_3", "3_4") - - attrs = {} - for node in G2.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(G2, attrs) - - mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")] - - return G1, G2, mapper - - -def test__get_succ_by_t(): - _, G2, _ = get_division_graphs() - G2 = TrackingGraph(G2) - - # Find 2 frames forward correctly - start_node = "2_2" - delta_t = 2 - end_node = "2_4" - node = _get_succ_by_t(G2, start_node, delta_t) - assert node == end_node - - # 3 frames forward returns None - start_node = "2_2" - delta_t = 3 - end_node = None - node = _get_succ_by_t(G2, start_node, delta_t) - assert node == end_node - - -class Test_correct_shifted_divisions: - def test_no_change(self): - # Early division in gt - G_pred, G_gt, mapper = get_division_graphs() - G_gt.nodes["1_1"]["is_fn_division"] = True - G_pred.nodes["1_3"]["is_fp_division"] = True - - # buffer of 1, no change - counts = _correct_shifted_divisions( - TrackingGraph(G_gt), TrackingGraph(G_pred), mapper, n_frames=1 - ) - assert len(counts.fp_divisions) == 1 - assert len(counts.fn_divisions) == 1 - assert len(counts.tp_divisions) == 0 - - def test_fn_early(self): - # Early division in gt - G_pred, G_gt, mapper = get_division_graphs() - G_gt.nodes["1_1"]["is_fn_division"] = True - G_pred.nodes["1_3"]["is_fp_division"] = True - - # buffer of 3, corrections - counts = _correct_shifted_divisions( - TrackingGraph(G_gt), TrackingGraph(G_pred), mapper, n_frames=3 - ) - assert len(counts.tp_divisions) == 1 - assert len(counts.fp_divisions) == 0 - assert len(counts.fn_divisions) == 0 - - def test_fp_early(self): - # Early division in pred - G_gt, G_pred, mapper = get_division_graphs() - G_pred.nodes["1_1"]["is_fp_division"] = True - G_gt.nodes["1_3"]["is_fn_division"] = True - - # buffer of 3, corrections - counts = _correct_shifted_divisions( - TrackingGraph(G_gt), TrackingGraph(G_pred), mapper, n_frames=3 - ) - assert len(counts.tp_divisions) == 1 - assert len(counts.fp_divisions) == 0 - assert len(counts.fn_divisions) == 0 - - -def test_evaluate_division_events(): - G_gt, G_pred, mapper = get_division_graphs() - frame_buffer = (0, 1, 2) - - events = _evaluate_division_events( - TrackingGraph(G_gt), TrackingGraph(G_pred), mapper, frame_buffer=frame_buffer - ) - - for e in events: - assert e.frame_buffer in frame_buffer - if e.frame_buffer in (0, 1): - # No corrections - assert len(e.tp_divisions) == 0 - assert len(e.fp_divisions) == 1 - assert len(e.fn_divisions) == 1 - else: - # Correction - assert len(e.tp_divisions) == 1 - assert len(e.fp_divisions) == 0 - assert len(e.fn_divisions) == 0 +from ..test_utils import get_division_graphs class DummyMatched(Matched): @@ -271,10 +15,10 @@ def compute_mapping(self): def test_DivisionMetrics(): - G_gt, G_pred, mapper = get_division_graphs() + g_gt, g_pred, mapper = get_division_graphs() matched = DummyMatched( - TrackingData(TrackingGraph(G_gt)), - TrackingData(TrackingGraph(G_pred)), + TrackingData(TrackingGraph(g_gt)), + TrackingData(TrackingGraph(g_pred)), mapper=mapper, ) frame_buffer = (0, 1, 2) diff --git a/tests/test_utils.py b/tests/test_utils.py index 804be766..6bb03cb8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -103,3 +103,48 @@ def get_movie_with_graph(ndims=3, n_frames=3, n_labels=3): nx.set_node_attributes(G, attrs) return TrackingGraph(G), movie + + +def get_division_graphs(): + """ + G1 + 2_4 + 1_0 -- 1_1 -- 1_2 -- 1_3 -< + 3_4 + G2 + 2_2 -- 2_3 -- 2_4 + 1_0 -- 1_1 -< + 3_2 -- 3_3 -- 3_4 + """ + + G1 = nx.DiGraph() + G1.add_edge("1_0", "1_1") + G1.add_edge("1_1", "1_2") + G1.add_edge("1_2", "1_3") + G1.add_edge("1_3", "2_4") + G1.add_edge("1_3", "3_4") + + attrs = {} + for node in G1.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G1, attrs) + + G2 = nx.DiGraph() + G2.add_edge("1_0", "1_1") + # Divide to generate 2 lineage + G2.add_edge("1_1", "2_2") + G2.add_edge("2_2", "2_3") + G2.add_edge("2_3", "2_4") + # Divide to generate 3 lineage + G2.add_edge("1_1", "3_2") + G2.add_edge("3_2", "3_3") + G2.add_edge("3_3", "3_4") + + attrs = {} + for node in G2.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(G2, attrs) + + mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")] + + return G1, G2, mapper diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py new file mode 100644 index 00000000..926a27fd --- /dev/null +++ b/tests/track_errors/test_divisions.py @@ -0,0 +1,255 @@ +import networkx as nx +import numpy as np +import pytest +from traccuracy import NodeAttr, TrackingData, TrackingGraph +from traccuracy.matchers._matched import Matched +from traccuracy.track_errors.divisions import ( + _classify_divisions, + _correct_shifted_divisions, + _evaluate_division_events, + _get_pred_by_t, + _get_succ_by_t, +) + +from ..test_utils import get_division_graphs + + +class DummyMatched(Matched): + def compute_mapping(self): + return [] + + +@pytest.fixture +def g(): + """ + 1_0 -- 1_1 -- 1_2 -- 1_3 + 3_3 + 2_0 -- 2_1 -- 2_2 -< + 4_3 + """ + g = nx.DiGraph() + g.add_edge("1_0", "1_1") + g.add_edge("1_1", "1_2") + g.add_edge("1_2", "1_3") + + g.add_edge("2_0", "2_1") + g.add_edge("2_1", "2_2") + + # node 2 divides into 3 and 4 in frame 3 + g.add_edge("2_2", "3_3") + g.add_edge("2_2", "4_3") + + # Set node attributes + attrs = {} + for node in g.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(g, attrs) + + return g + + +def test_classify_divisions_tp(g): + # Define mapper assuming all nodes match + mapper = [(n, n) for n in g.nodes] + g_gt = TrackingGraph(g.copy()) + g_pred = TrackingGraph(g.copy()) + matched_data = DummyMatched(TrackingData(g_gt), TrackingData(g_pred)) + matched_data.mapping = mapper + + # Test true positive + _classify_divisions(matched_data) + + assert len(g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x is True)) == 0 + assert ( + len(g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x is True)) == 0 + ) + assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] + assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + + # Check division flag + assert g_gt.division_annotations + assert g_pred.division_annotations + + +def test_classify_divisions_fp(g): + """ + 5_3 + 1_0 -- 1_1 -- 1_2 -< + 1_3 + 3_3 + 2_0 -- 2_1 -- 2_2 -< + 4_3 + """ + h = g.copy() + # Add false positive division edge + h.add_edge("1_2", "5_3") + nx.set_node_attributes(h, {"5_3": {"t": 3, "x": 0, "y": 0}}) + mapper = [(n, n) for n in h.nodes] + + g_gt = TrackingGraph(g) + g_pred = TrackingGraph(h) + matched_data = DummyMatched(TrackingData(g_gt), TrackingData(g_pred)) + matched_data.mapping = mapper + + _classify_divisions(matched_data) + + assert len(g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x is True)) == 0 + assert NodeAttr.FP_DIV in g_pred.nodes()["1_2"] + assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] + assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + + +def test_classify_divisions_fn(g): + """ + 1_0 -- 1_1 -- 1_2 -- 1_3 + 2_0 -- 2_1 -- 2_2 + """ + # Remove daughters to create false negative + h = g.copy() + h.remove_nodes_from(["3_3", "4_3"]) + mapper = [(n, n) for n in h.nodes] + + g_gt = TrackingGraph(g) + g_pred = TrackingGraph(h) + matched_data = DummyMatched(TrackingData(g_gt), TrackingData(g_pred)) + matched_data.mapping = mapper + + _classify_divisions(matched_data) + + assert ( + len(g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x is True)) == 0 + ) + assert len(g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x is True)) == 0 + assert NodeAttr.FN_DIV in g_gt.nodes()["2_2"] + + +@pytest.fixture +def straight_graph(): + g = nx.DiGraph() + for t in range(2, 10): + g.add_edge(f"1_{t}", f"1_{t+1}") + + # Set node attributes + attrs = {} + for node in g.nodes: + attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} + nx.set_node_attributes(g, attrs) + + return g + + +def test__get_pred_by_t(straight_graph): + # Linear graph with node id 1 from frame 2-10 + g = TrackingGraph(straight_graph) + + # Predecessor available + start_frame = 10 + target_frame = 5 + node = _get_pred_by_t(g, f"1_{start_frame}", start_frame - target_frame) + assert node == f"1_{target_frame}" + + # Predecessor does not exist + start_frame = 10 + target_frame = 1 + node = _get_pred_by_t(g, f"1_{start_frame}", start_frame - target_frame) + assert node is None + + +def test__get_succ_by_t(): + _, g2, _ = get_division_graphs() + g2 = TrackingGraph(g2) + + # Find 2 frames forward correctly + start_node = "2_2" + delta_t = 2 + end_node = "2_4" + node = _get_succ_by_t(g2, start_node, delta_t) + assert node == end_node + + # 3 frames forward returns None + start_node = "2_2" + delta_t = 3 + end_node = None + node = _get_succ_by_t(g2, start_node, delta_t) + assert node == end_node + + +class Test_correct_shifted_divisions: + def test_no_change(self): + # Early division in gt + g_pred, g_gt, mapper = get_division_graphs() + g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True + g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + + matched_data = DummyMatched( + TrackingData(TrackingGraph(g_gt)), TrackingData(TrackingGraph(g_pred)) + ) + matched_data.mapping = mapper + + # buffer of 1, no change + new_matched = _correct_shifted_divisions(matched_data, n_frames=1) + ng_pred = new_matched.pred_data.tracking_graph + ng_gt = new_matched.gt_data.tracking_graph + + assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is True + assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is True + assert ( + len(ng_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x is True)) + == 0 + ) + + def test_fn_early(self): + # Early division in gt + g_pred, g_gt, mapper = get_division_graphs() + g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True + g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + + matched_data = DummyMatched( + TrackingData(TrackingGraph(g_gt)), TrackingData(TrackingGraph(g_pred)) + ) + matched_data.mapping = mapper + + # buffer of 3, corrections + new_matched = _correct_shifted_divisions(matched_data, n_frames=3) + ng_pred = new_matched.pred_data.tracking_graph + ng_gt = new_matched.gt_data.tracking_graph + + assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is False + assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is False + assert ng_pred.nodes()["1_3"][NodeAttr.TP_DIV] is True + assert ng_gt.nodes()["1_1"][NodeAttr.TP_DIV] is True + + def test_fp_early(self): + # Early division in pred + g_gt, g_pred, mapper = get_division_graphs() + g_pred.nodes["1_1"][NodeAttr.FP_DIV] = True + g_gt.nodes["1_3"][NodeAttr.FN_DIV] = True + + matched_data = DummyMatched( + TrackingData(TrackingGraph(g_gt)), TrackingData(TrackingGraph(g_pred)) + ) + matched_data.mapping = mapper + + # buffer of 3, corrections + new_matched = _correct_shifted_divisions(matched_data, n_frames=3) + ng_pred = new_matched.pred_data.tracking_graph + ng_gt = new_matched.gt_data.tracking_graph + + assert ng_pred.nodes()["1_1"][NodeAttr.FP_DIV] is False + assert ng_gt.nodes()["1_3"][NodeAttr.FN_DIV] is False + assert ng_pred.nodes()["1_1"][NodeAttr.TP_DIV] is True + assert ng_gt.nodes()["1_3"][NodeAttr.TP_DIV] is True + + +def test_evaluate_division_events(): + g_gt, g_pred, mapper = get_division_graphs() + frame_buffer = (0, 1, 2) + + matched_data = DummyMatched( + TrackingData(TrackingGraph(g_gt)), TrackingData(TrackingGraph(g_pred)) + ) + matched_data.mapping = mapper + + results = _evaluate_division_events(matched_data, frame_buffer=frame_buffer) + + assert np.all([isinstance(k, int) for k in results.keys()])