Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate intertrack edges in the ctc get_edge_errors function #54

Merged
merged 10 commits into from
Oct 25, 2023
8 changes: 8 additions & 0 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,14 @@ def get_divisions(self):
"""
return [node for node, degree in self.graph.out_degree() if degree >= 2]

def get_merges(self):
"""Get all nodes that have at least two incoming edges from the previous time frame

Returns:
list of hashable: a list of node ids for nodes that have more than one parent
"""
return [node for node, degree in self.graph.in_degree() if degree >= 2]

def get_preds(self, node):
"""Get all predecessors of the given node.

Expand Down
10 changes: 2 additions & 8 deletions src/traccuracy/loaders/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def ctc_to_graph(df, detections):
{
"source": cellids[0:-1],
"target": cellids[1:],
"is_intertrack_edge": [0 for _ in range(len(cellids) - 1)],
}
)
)
Expand All @@ -126,11 +125,7 @@ def ctc_to_graph(df, detections):

target = "{}_{}".format(row["Cell_ID"], row["Start"])

edges.append(
pd.DataFrame(
{"source": [source], "target": [target], "is_intertrack_edge": [1]}
)
)
edges.append(pd.DataFrame({"source": [source], "target": [target]}))

# Store position attributes on nodes
detections["node_id"] = (
Expand All @@ -149,9 +144,8 @@ def ctc_to_graph(df, detections):

# Create graph
edges = pd.concat(edges)
edges["is_intertrack_edge"] = edges["is_intertrack_edge"].astype(bool)
G = nx.from_pandas_edgelist(
edges, source="source", target="target", create_using=nx.DiGraph, edge_attr=True
edges, source="source", target="target", create_using=nx.DiGraph
)

# Add all isolates to graph
Expand Down
25 changes: 22 additions & 3 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def get_edge_errors(matched_data: "Matched"):
logger.info("Edge errors already calculated. Skipping graph annotation")
return

# Node errors must already be annotated
if not comp_graph.node_errors and not gt_graph.node_errors:
logger.warning("Node errors have not been annotated. Running node annotation.")
get_vertex_errors(matched_data)

induced_graph = comp_graph.get_subgraph(
comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)
).graph

comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False)
comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.TRUE_POS, False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msschwartz21 I don't think I fully understand why we are not annotating true positive edges? You mention in the description that we can't rely on the assumption that TP and WS edges are mutually exclusive, but I think we can? I don't think a TP edge should ever be annotated with WS.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My stab at explaining this, that Morgan can elaborate on at her leisure: I admit that I am always confused byWS edges; however, it is my understanding that if there is a division and we correctly detect one of the edges to a daughter but not the other, the one we correctly detected would be labeled WS - correct? I would consider that a TP - we correctly identified the edge! Just because we missed the other daughter edge (a FN edge), that doesn't make the correctly recovered edge incorrect. In that way, something that is a WS edge in the CTC metric could be considered a TP elsewhere, making them not mutually exclusive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ Caroline's description captures my thoughts perfectly

comp_graph.set_edge_attribute(
list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False
)
Expand All @@ -94,6 +98,23 @@ def get_edge_errors(matched_data: "Matched"):
node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])

# intertrack edges = connection between parent and daughter
for graph in [comp_graph, gt_graph]:
# Set to False by default
graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False)

for parent in graph.get_divisions():
msschwartz21 marked this conversation as resolved.
Show resolved Hide resolved
for daughter in graph.get_succs(parent):
graph.set_edge_attribute(
(parent, daughter), EdgeAttr.INTERTRACK_EDGE, True
)

for merge in graph.get_merges():
for parent in graph.get_preds(merge):
graph.set_edge_attribute(
(parent, merge), EdgeAttr.INTERTRACK_EDGE, True
)

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]
Expand All @@ -110,8 +131,6 @@ def get_edge_errors(matched_data: "Matched"):
is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE]
if is_parent_gt != is_parent_comp:
comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True)
else:
comp_graph.set_edge_attribute(edge, EdgeAttr.TRUE_POS, True)

# fn edges - edges in gt_graph that aren't in induced graph
for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"):
Expand Down
3 changes: 0 additions & 3 deletions tests/metrics/test_ctc_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import networkx as nx
from traccuracy._tracking_graph import EdgeAttr
from traccuracy.matchers._ctc import CTCMatched
from traccuracy.metrics._ctc import CTCMetrics

Expand All @@ -11,7 +9,6 @@ def test_compute_mapping():
n_frames = 3
n_labels = 3
track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels)
nx.set_edge_attributes(track_graph.graph, 0, EdgeAttr.INTERTRACK_EDGE)

matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph)
metric = CTCMetrics(matched)
Expand Down
45 changes: 44 additions & 1 deletion tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ def nx_comp2():
return graph


@pytest.fixture
def nx_merge():
"""
3_0--3_1--\\
3_2--3_3
3_4--3_5--/
"""
cells = [
{"id": "3_0", "t": 0, "x": 0, "y": 0},
{"id": "3_1", "t": 1, "x": 0, "y": 0},
{"id": "3_2", "t": 2, "x": 0, "y": 0},
{"id": "3_3", "t": 3, "x": 0, "y": 0},
{"id": "3_4", "t": 0, "x": 0, "y": 0},
{"id": "3_5", "t": 1, "x": 0, "y": 0},
]

edges = [
{"source": "3_0", "target": "3_1"},
{"source": "3_1", "target": "3_2"},
{"source": "3_2", "target": "3_3"},
{"source": "3_4", "target": "3_5"},
{"source": "3_5", "target": "3_2"},
]
graph = nx.DiGraph()
graph.add_nodes_from([(cell["id"], cell) for cell in cells])
graph.add_edges_from([(edge["source"], edge["target"]) for edge in edges])
return graph


@pytest.fixture
def merge_graph(nx_merge):
return TrackingGraph(nx_merge)


@pytest.fixture
def simple_graph(nx_comp1):
return TrackingGraph(nx_comp1)
Expand Down Expand Up @@ -168,11 +202,20 @@ def test_get_divisions(complex_graph):
assert complex_graph.get_divisions() == ["1_1", "2_2"]


def test_get_preds(simple_graph):
def test_get_merges(merge_graph):
assert merge_graph.get_merges() == ["3_2"]


def test_get_preds(simple_graph, merge_graph):
# Division graph
assert simple_graph.get_preds("1_0") == []
assert simple_graph.get_preds("1_1") == ["1_0"]
assert simple_graph.get_preds("1_2") == ["1_1"]

# Merge graph
assert merge_graph.get_preds("3_3") == ["3_2"]
assert merge_graph.get_preds("3_2") == ["3_1", "3_5"]


def test_get_succs(simple_graph):
assert simple_graph.get_succs("1_0") == ["1_1"]
Expand Down
74 changes: 33 additions & 41 deletions tests/track_errors/test_ctc_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_assign_edge_errors():
comp_g.add_nodes_from(comp_ids)
comp_g.add_edges_from(comp_edges)
nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS)
nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(
comp_g,
{idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids},
Expand All @@ -90,7 +89,6 @@ def test_assign_edge_errors():
gt_g = nx.DiGraph()
gt_g.add_nodes_from(gt_ids)
gt_g.add_edges_from(gt_edges)
nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG)
nx.set_node_attributes(
gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids}
Expand All @@ -102,51 +100,45 @@ def test_assign_edge_errors():

get_edge_errors(matched_data)

assert comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS]
assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS]
assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG]


def test_assign_edge_errors_semantics():
comp_ids = [3, 7, 10]
comp_ids_2 = list(np.asarray(comp_ids) + 1)
comp_ids += comp_ids_2

gt_ids = [4, 12, 17]
gt_ids_2 = list(np.asarray(gt_ids) + 1)
gt_ids += gt_ids_2

mapping = [(4, 3), (12, 7), (17, 10), (5, 4), (18, 11), (13, 8)]

# need a tp, fp, fn
comp_edges = [(3, 4)]
comp_g = nx.DiGraph()
comp_g.add_nodes_from(comp_ids)
comp_g.add_edges_from(comp_edges)
nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS)
nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(
comp_g,
{idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids},
)
G_comp = TrackingGraph(comp_g)

gt_edges = [(4, 5), (17, 18)]
gt_g = nx.DiGraph()
gt_g.add_nodes_from(gt_ids)
gt_g.add_edges_from(gt_edges)
nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG)
gt_g.edges[(4, 5)][EdgeAttr.INTERTRACK_EDGE] = 1
nx.set_node_attributes(
gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids}
)
G_gt = TrackingGraph(gt_g)

matched_data = DummyMatched(G_gt, G_comp)
"""
gt:
1_0 -- 1_1 -- 1_2 -- 1_3

comp:
1_3
1_0 -- 1_1 -- 1_2 -<
2_3
"""

gt = nx.DiGraph()
gt.add_edge("1_0", "1_1")
gt.add_edge("1_1", "1_2")
gt.add_edge("1_2", "1_3")
# Set node attrs
attrs = {}
for node in gt.nodes:
attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
nx.set_node_attributes(gt, attrs)

comp = gt.copy()
comp.add_edge("1_2", "2_3")
# Set node attrs
attrs = {}
for node in comp.nodes:
attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
nx.set_node_attributes(comp, attrs)

# Define mapping with all nodes matching except for 2_3 in comp
mapping = [(n, n) for n in gt.nodes]

matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp))
matched_data.mapping = mapping

get_edge_errors(matched_data)

assert comp_g.edges[(3, 4)][EdgeAttr.WRONG_SEMANTIC]
assert not comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS]
assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC]