Skip to content

Commit

Permalink
Merge pull request #54 from Janelia-Trackathon-2023/intertrack_edges
Browse files Browse the repository at this point in the history
Annotate intertrack edges in the ctc `get_edge_errors` function
  • Loading branch information
cmalinmayor authored Oct 25, 2023
2 parents 73ddb1c + 9b852a5 commit a910bc6
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 56 deletions.
8 changes: 8 additions & 0 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,14 @@ def get_divisions(self):
"""
return [node for node, degree in self.graph.out_degree() if degree >= 2]

def get_merges(self):
"""Get all nodes that have at least two incoming edges from the previous time frame
Returns:
list of hashable: a list of node ids for nodes that have more than one parent
"""
return [node for node, degree in self.graph.in_degree() if degree >= 2]

def get_preds(self, node):
"""Get all predecessors of the given node.
Expand Down
10 changes: 2 additions & 8 deletions src/traccuracy/loaders/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def ctc_to_graph(df, detections):
{
"source": cellids[0:-1],
"target": cellids[1:],
"is_intertrack_edge": [0 for _ in range(len(cellids) - 1)],
}
)
)
Expand All @@ -126,11 +125,7 @@ def ctc_to_graph(df, detections):

target = "{}_{}".format(row["Cell_ID"], row["Start"])

edges.append(
pd.DataFrame(
{"source": [source], "target": [target], "is_intertrack_edge": [1]}
)
)
edges.append(pd.DataFrame({"source": [source], "target": [target]}))

# Store position attributes on nodes
detections["node_id"] = (
Expand All @@ -149,9 +144,8 @@ def ctc_to_graph(df, detections):

# Create graph
edges = pd.concat(edges)
edges["is_intertrack_edge"] = edges["is_intertrack_edge"].astype(bool)
G = nx.from_pandas_edgelist(
edges, source="source", target="target", create_using=nx.DiGraph, edge_attr=True
edges, source="source", target="target", create_using=nx.DiGraph
)

# Add all isolates to graph
Expand Down
25 changes: 22 additions & 3 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def get_edge_errors(matched_data: "Matched"):
logger.info("Edge errors already calculated. Skipping graph annotation")
return

# Node errors must already be annotated
if not comp_graph.node_errors and not gt_graph.node_errors:
logger.warning("Node errors have not been annotated. Running node annotation.")
get_vertex_errors(matched_data)

induced_graph = comp_graph.get_subgraph(
comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)
).graph

comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False)
comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.TRUE_POS, False)
comp_graph.set_edge_attribute(
list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False
)
Expand All @@ -94,6 +98,23 @@ def get_edge_errors(matched_data: "Matched"):
node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])

# intertrack edges = connection between parent and daughter
for graph in [comp_graph, gt_graph]:
# Set to False by default
graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False)

for parent in graph.get_divisions():
for daughter in graph.get_succs(parent):
graph.set_edge_attribute(
(parent, daughter), EdgeAttr.INTERTRACK_EDGE, True
)

for merge in graph.get_merges():
for parent in graph.get_preds(merge):
graph.set_edge_attribute(
(parent, merge), EdgeAttr.INTERTRACK_EDGE, True
)

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]
Expand All @@ -110,8 +131,6 @@ def get_edge_errors(matched_data: "Matched"):
is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE]
if is_parent_gt != is_parent_comp:
comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True)
else:
comp_graph.set_edge_attribute(edge, EdgeAttr.TRUE_POS, True)

# fn edges - edges in gt_graph that aren't in induced graph
for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"):
Expand Down
3 changes: 0 additions & 3 deletions tests/metrics/test_ctc_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import networkx as nx
from traccuracy._tracking_graph import EdgeAttr
from traccuracy.matchers._ctc import CTCMatched
from traccuracy.metrics._ctc import CTCMetrics

Expand All @@ -11,7 +9,6 @@ def test_compute_mapping():
n_frames = 3
n_labels = 3
track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels)
nx.set_edge_attributes(track_graph.graph, 0, EdgeAttr.INTERTRACK_EDGE)

matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph)
metric = CTCMetrics(matched)
Expand Down
45 changes: 44 additions & 1 deletion tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ def nx_comp2():
return graph


@pytest.fixture
def nx_merge():
"""
3_0--3_1--\\
3_2--3_3
3_4--3_5--/
"""
cells = [
{"id": "3_0", "t": 0, "x": 0, "y": 0},
{"id": "3_1", "t": 1, "x": 0, "y": 0},
{"id": "3_2", "t": 2, "x": 0, "y": 0},
{"id": "3_3", "t": 3, "x": 0, "y": 0},
{"id": "3_4", "t": 0, "x": 0, "y": 0},
{"id": "3_5", "t": 1, "x": 0, "y": 0},
]

edges = [
{"source": "3_0", "target": "3_1"},
{"source": "3_1", "target": "3_2"},
{"source": "3_2", "target": "3_3"},
{"source": "3_4", "target": "3_5"},
{"source": "3_5", "target": "3_2"},
]
graph = nx.DiGraph()
graph.add_nodes_from([(cell["id"], cell) for cell in cells])
graph.add_edges_from([(edge["source"], edge["target"]) for edge in edges])
return graph


@pytest.fixture
def merge_graph(nx_merge):
return TrackingGraph(nx_merge)


@pytest.fixture
def simple_graph(nx_comp1):
return TrackingGraph(nx_comp1)
Expand Down Expand Up @@ -168,11 +202,20 @@ def test_get_divisions(complex_graph):
assert complex_graph.get_divisions() == ["1_1", "2_2"]


def test_get_preds(simple_graph):
def test_get_merges(merge_graph):
assert merge_graph.get_merges() == ["3_2"]


def test_get_preds(simple_graph, merge_graph):
# Division graph
assert simple_graph.get_preds("1_0") == []
assert simple_graph.get_preds("1_1") == ["1_0"]
assert simple_graph.get_preds("1_2") == ["1_1"]

# Merge graph
assert merge_graph.get_preds("3_3") == ["3_2"]
assert merge_graph.get_preds("3_2") == ["3_1", "3_5"]


def test_get_succs(simple_graph):
assert simple_graph.get_succs("1_0") == ["1_1"]
Expand Down
74 changes: 33 additions & 41 deletions tests/track_errors/test_ctc_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_assign_edge_errors():
comp_g.add_nodes_from(comp_ids)
comp_g.add_edges_from(comp_edges)
nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS)
nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(
comp_g,
{idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids},
Expand All @@ -90,7 +89,6 @@ def test_assign_edge_errors():
gt_g = nx.DiGraph()
gt_g.add_nodes_from(gt_ids)
gt_g.add_edges_from(gt_edges)
nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG)
nx.set_node_attributes(
gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids}
Expand All @@ -102,51 +100,45 @@ def test_assign_edge_errors():

get_edge_errors(matched_data)

assert comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS]
assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS]
assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG]


def test_assign_edge_errors_semantics():
comp_ids = [3, 7, 10]
comp_ids_2 = list(np.asarray(comp_ids) + 1)
comp_ids += comp_ids_2

gt_ids = [4, 12, 17]
gt_ids_2 = list(np.asarray(gt_ids) + 1)
gt_ids += gt_ids_2

mapping = [(4, 3), (12, 7), (17, 10), (5, 4), (18, 11), (13, 8)]

# need a tp, fp, fn
comp_edges = [(3, 4)]
comp_g = nx.DiGraph()
comp_g.add_nodes_from(comp_ids)
comp_g.add_edges_from(comp_edges)
nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS)
nx.set_edge_attributes(comp_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(
comp_g,
{idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids},
)
G_comp = TrackingGraph(comp_g)

gt_edges = [(4, 5), (17, 18)]
gt_g = nx.DiGraph()
gt_g.add_nodes_from(gt_ids)
gt_g.add_edges_from(gt_edges)
nx.set_edge_attributes(gt_g, 0, EdgeAttr.INTERTRACK_EDGE)
nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG)
gt_g.edges[(4, 5)][EdgeAttr.INTERTRACK_EDGE] = 1
nx.set_node_attributes(
gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids}
)
G_gt = TrackingGraph(gt_g)

matched_data = DummyMatched(G_gt, G_comp)
"""
gt:
1_0 -- 1_1 -- 1_2 -- 1_3
comp:
1_3
1_0 -- 1_1 -- 1_2 -<
2_3
"""

gt = nx.DiGraph()
gt.add_edge("1_0", "1_1")
gt.add_edge("1_1", "1_2")
gt.add_edge("1_2", "1_3")
# Set node attrs
attrs = {}
for node in gt.nodes:
attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
nx.set_node_attributes(gt, attrs)

comp = gt.copy()
comp.add_edge("1_2", "2_3")
# Set node attrs
attrs = {}
for node in comp.nodes:
attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
nx.set_node_attributes(comp, attrs)

# Define mapping with all nodes matching except for 2_3 in comp
mapping = [(n, n) for n in gt.nodes]

matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp))
matched_data.mapping = mapping

get_edge_errors(matched_data)

assert comp_g.edges[(3, 4)][EdgeAttr.WRONG_SEMANTIC]
assert not comp_g.edges[(3, 4)][EdgeAttr.TRUE_POS]
assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC]

2 comments on commit a910bc6

@msschwartz21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmalinmayor I'm not sure why merging this in led to the change in the CTC benchmarking results, but for now I'm leaning towards just ignoring it and resetting the assertion to the new value. My understanding is that the candidate graph we are using for testing is pretty wacky so it's possible that there is a strange edge case buried in there.

@cmalinmayor
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmalinmayor I'm not sure why merging this in led to the change in the CTC benchmarking results, but for now I'm leaning towards just ignoring it and resetting the assertion to the new value. My understanding is that the candidate graph we are using for testing is pretty wacky so it's possible that there is a strange edge case buried in there.

@DragaDoncila Just FYI - since you were seeing a similar change in the number of WS edges when we merged the annotation on edges branch. Not sure if knowing the benchmark also changed helps you debug or not. Let us know if you think it's an issue we should look into further, or if we can assume it was a prior bug.

Please sign in to comment.