Skip to content

Commit

Permalink
Merge pull request #111 from Janelia-Trackathon-2023/prune-tracking-g…
Browse files Browse the repository at this point in the history
…raph

Prune tracking graph API
  • Loading branch information
cmalinmayor authored Jan 10, 2024
2 parents 6f6406e + f50eca2 commit 3d85480
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 546 deletions.
4 changes: 2 additions & 2 deletions src/traccuracy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
__version__ = "uninstalled"

from ._run_metrics import run_metrics
from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph
from ._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph

__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"]
__all__ = ["TrackingGraph", "run_metrics", "NodeFlag", "EdgeFlag"]
484 changes: 161 additions & 323 deletions src/traccuracy/_tracking_graph.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/traccuracy/matchers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def compute_mapping(
matched.matcher_info = self.info

# Report matching performance
total_gt = len(matched.gt_graph.nodes())
total_gt = len(matched.gt_graph.nodes)
matched_gt = len({m[0] for m in matched.mapping})
total_pred = len(matched.pred_graph.nodes())
total_pred = len(matched.pred_graph.nodes)
matched_pred = len({m[1] for m in matched.mapping})
logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.")
logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.")
Expand Down
5 changes: 4 additions & 1 deletion src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
traccuracy.matchers.Matched: Matched data object containing the CTC mapping
Raises:
ValueError: GT and pred segmentations must be the same shape
ValueError: if GT and pred segmentations are None or are not the same shape
"""
gt = gt_graph
pred = pred_graph
Expand All @@ -46,6 +46,9 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
G_gt, mask_gt = gt, gt.segmentation
G_pred, mask_pred = pred, pred.segmentation

if mask_gt is None or mask_pred is None:
raise ValueError("Segmentation is None, cannot perform matching")

if mask_gt.shape != mask_pred.shape:
raise ValueError("Segmentation shapes must match between gt and pred")

Expand Down
50 changes: 38 additions & 12 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Hashable

import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -46,6 +48,35 @@ def _match_nodes(gt, res, threshold=1):
return gtcells, rescells


def _construct_time_to_seg_id_map(
graph: TrackingGraph,
) -> dict[int, dict[Hashable, Hashable]]:
"""For each time frame in the graph, create a mapping from segmentation ids
(the ids in the segmentation array, stored in graph.label_key) to the
node ids (the ids of the TrackingGraph nodes).
Args:
graph(TrackingGraph): a tracking graph with a label_key on each node
Returns:
dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}}
Raises:
AssertionError: If two nodes in a time frame have the same segmentation_id
"""
time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {}
for node_id, data in graph.nodes(data=True):
time = data[graph.frame_key]
seg_id = data[graph.label_key]
seg_id_to_node_id_map = time_to_seg_id_map.get(time, {})
assert (
seg_id not in seg_id_to_node_id_map
), f"Segmentation ID {seg_id} occurred twice in frame {time}."
seg_id_to_node_id_map[seg_id] = node_id
time_to_seg_id_map[time] = seg_id_to_node_id_map
return time_to_seg_id_map


def match_iou(gt, pred, threshold=0.6):
"""Identifies pairs of cells between gt and pred that have iou > threshold
Expand Down Expand Up @@ -78,24 +109,19 @@ def match_iou(gt, pred, threshold=0.6):
# Get overlaps for each frame
frame_range = range(gt.start_frame, gt.end_frame)
total = len(list(frame_range))

gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt)
pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred)

for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total):
matches = _match_nodes(
gt.segmentation[i], pred.segmentation[i], threshold=threshold
)

# Construct node id tuple for each match
for gt_id, pred_id in zip(*matches):
for gt_seg_id, pred_seg_id in zip(*matches):
# Find node id based on time and segmentation label
gt_node = gt.get_nodes_with_attribute(
gt.label_key,
criterion=lambda x: x == gt_id, # noqa
limit_to=gt.get_nodes_in_frame(t),
)[0]
pred_node = pred.get_nodes_with_attribute(
pred.label_key,
criterion=lambda x: x == pred_id, # noqa
limit_to=pred.get_nodes_in_frame(t),
)[0]
gt_node = gt_time_to_seg_id_map[t][gt_seg_id]
pred_node = pred_time_to_seg_id_map[t][pred_seg_id]
mapper.append((gt_node, pred_node))
return mapper

Expand Down
14 changes: 7 additions & 7 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
from traccuracy._tracking_graph import EdgeFlag, NodeFlag
from traccuracy.track_errors._ctc import evaluate_ctc_events

from ._base import Metric
Expand Down Expand Up @@ -38,14 +38,14 @@ def _compute(self, data: Matched):
evaluate_ctc_events(data)

vertex_error_counts = {
"ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)),
"ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)),
}
edge_error_counts = {
"ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)),
"ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)),
}
error_sum = get_weighted_error_sum(
vertex_error_counts,
Expand Down
14 changes: 4 additions & 10 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import NodeAttr
from traccuracy._tracking_graph import NodeFlag
from traccuracy.track_errors.divisions import _evaluate_division_events

from ._base import Metric
Expand Down Expand Up @@ -90,15 +90,9 @@ def _compute(self, data: Matched):
}

def _calculate_metrics(self, g_gt, g_pred):
tp_division_count = len(
g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x)
)
fn_division_count = len(
g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x)
)
fp_division_count = len(
g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x)
)
tp_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.TP_DIV))
fn_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.FN_DIV))
fp_division_count = len(g_pred.get_nodes_with_flag(NodeFlag.FP_DIV))

try:
recall = tp_division_count / (tp_division_count + fn_division_count)
Expand Down
67 changes: 32 additions & 35 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tqdm import tqdm

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
from traccuracy._tracking_graph import EdgeFlag, NodeFlag

if TYPE_CHECKING:
from traccuracy.matchers import Matched
Expand Down Expand Up @@ -39,12 +39,12 @@ def get_vertex_errors(matched_data: Matched):
logger.info("Node errors already calculated. Skipping graph annotation")
return

comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.TRUE_POS, False)
comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.NON_SPLIT, False)
comp_graph.set_flag_on_all_nodes(NodeFlag.TRUE_POS, False)
comp_graph.set_flag_on_all_nodes(NodeFlag.NON_SPLIT, False)

# will flip this when we come across the vertex in the mapping
comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.FALSE_POS, True)
gt_graph.set_node_attribute(list(gt_graph.nodes()), NodeAttr.FALSE_NEG, True)
comp_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, True)
gt_graph.set_flag_on_all_nodes(NodeFlag.FALSE_NEG, True)

# we need to know how many computed vertices are "non-split", so we make
# a mapping of gt vertices to their matched comp vertices
Expand All @@ -57,15 +57,16 @@ def get_vertex_errors(matched_data: Matched):
gt_ids = dict_mapping[pred_id]
if len(gt_ids) == 1:
gid = gt_ids[0]
comp_graph.set_node_attribute(pred_id, NodeAttr.TRUE_POS, True)
comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False)
gt_graph.set_node_attribute(gid, NodeAttr.FALSE_NEG, False)
comp_graph.set_flag_on_node(pred_id, NodeFlag.TRUE_POS, True)
comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False)
gt_graph.set_flag_on_node(gid, NodeFlag.FALSE_NEG, False)
elif len(gt_ids) > 1:
comp_graph.set_node_attribute(pred_id, NodeAttr.NON_SPLIT, True)
comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False)
comp_graph.set_flag_on_node(pred_id, NodeFlag.NON_SPLIT, True)
comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False)
# number of split operations that would be required to correct the vertices
ns_count += len(gt_ids) - 1
gt_graph.set_node_attribute(gt_ids, NodeAttr.FALSE_NEG, False)
for gt_id in gt_ids:
gt_graph.set_flag_on_node(gt_id, NodeFlag.FALSE_NEG, False)

# Record presence of annotations on the TrackingGraph
comp_graph.node_errors = True
Expand All @@ -87,34 +88,30 @@ def get_edge_errors(matched_data: Matched):
get_vertex_errors(matched_data)

induced_graph = comp_graph.get_subgraph(
comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)
comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)
).graph

comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False)
comp_graph.set_edge_attribute(
list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)
comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False)
comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False)
gt_graph.set_flag_on_all_edges(EdgeFlag.FALSE_NEG, False)

gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph}
comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph}

# intertrack edges = connection between parent and daughter
for graph in [comp_graph, gt_graph]:
# Set to False by default
graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False)
graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False)

for parent in graph.get_divisions():
for daughter in graph.get_succs(parent):
graph.set_edge_attribute(
(parent, daughter), EdgeAttr.INTERTRACK_EDGE, True
for daughter in graph.graph.successors(parent):
graph.set_flag_on_edge(
(parent, daughter), EdgeFlag.INTERTRACK_EDGE, True
)

for merge in graph.get_merges():
for parent in graph.get_preds(merge):
graph.set_edge_attribute(
(parent, merge), EdgeAttr.INTERTRACK_EDGE, True
)
for parent in graph.graph.predecessors(merge):
graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True)

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
Expand All @@ -124,32 +121,32 @@ def get_edge_errors(matched_data: Matched):
target_gt_id = comp_gt_mapping[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True)
if expected_gt_edge not in gt_graph.edges:
comp_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_POS, True)
else:
# check if semantics are correct
is_parent_gt = gt_graph.edges()[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE]
is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE]
is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeFlag.INTERTRACK_EDGE]
is_parent_comp = comp_graph.edges[edge][EdgeFlag.INTERTRACK_EDGE]
if is_parent_gt != is_parent_comp:
comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True)
comp_graph.set_flag_on_edge(edge, EdgeFlag.WRONG_SEMANTIC, True)

# fn edges - edges in gt_graph that aren't in induced graph
for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"):
for edge in tqdm(gt_graph.edges, "Evaluating FN edges"):
source, target = edge[0], edge[1]
# this edge is adjacent to an edge we didn't detect, so it definitely is an fn
if (
gt_graph.nodes()[source][NodeAttr.FALSE_NEG]
or gt_graph.nodes()[target][NodeAttr.FALSE_NEG]
gt_graph.nodes[source][NodeFlag.FALSE_NEG]
or gt_graph.nodes[target][NodeFlag.FALSE_NEG]
):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)
continue

source_comp_id = gt_comp_mapping[source]
target_comp_id = gt_comp_mapping[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)

gt_graph.edge_errors = True
comp_graph.edge_errors = True
Loading

0 comments on commit 3d85480

Please sign in to comment.