Skip to content

Commit

Permalink
Merge pull request #141 from Janelia-Trackathon-2023/check_empty_subg…
Browse files Browse the repository at this point in the history
…raph

Check empty subgraph
  • Loading branch information
cmalinmayor authored Aug 6, 2024
2 parents 280308d + cfc2173 commit 599aa8e
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 16 deletions.
24 changes: 16 additions & 8 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,12 @@ def __init__(
self.edges_by_flag[edge_flag].add(edge)

# Store first and last frames for reference
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1
if len(self.nodes_by_frame) == 0:
self.start_frame = None
self.end_frame = None
else:
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1

# Record types of annotations that have been calculated
self.division_annotations = False
Expand Down Expand Up @@ -301,13 +305,13 @@ def get_connected_components(self) -> list[TrackingGraph]:
return [self.get_subgraph(g) for g in nx.weakly_connected_components(graph)]

def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph:
"""Returns a new TrackingGraph with the subgraph defined by the list of nodes
"""Returns a new TrackingGraph with the subgraph defined by the list of nodes.
Args:
nodes (list): List of node ids to use in constructing the subgraph
nodes (list): A list of node ids to use in constructing the subgraph
"""

new_graph = self.graph.subgraph(nodes).copy()

new_trackgraph = copy.deepcopy(self)
new_trackgraph.graph = new_graph
for frame, nodes_in_frame in self.nodes_by_frame.items():
Expand All @@ -324,10 +328,14 @@ def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph:
for edge_flag in EdgeFlag:
new_trackgraph.edges_by_flag[edge_flag] = self.edges_by_flag[
edge_flag
].intersection(nodes)
].intersection(new_trackgraph.edges)

new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys())
new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1
if len(new_trackgraph.nodes_by_frame) == 0:
new_trackgraph.start_frame = None
new_trackgraph.end_frame = None
else:
new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys())
new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1

return new_trackgraph

Expand Down
5 changes: 4 additions & 1 deletion src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
if mask_gt.shape != mask_pred.shape:
raise ValueError("Segmentation shapes must match between gt and pred")

mapping = []
mapping: list[tuple] = []
# Get overlaps for each frame
if gt.start_frame is None or gt.end_frame is None:
return Matched(gt_graph, pred_graph, mapping)

for i, t in enumerate(
tqdm(
range(gt.start_frame, gt.end_frame),
Expand Down
16 changes: 9 additions & 7 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ def get_edge_errors(matched_data: Matched):
logger.warning("Node errors have not been annotated. Running node annotation.")
get_vertex_errors(matched_data)

induced_graph = comp_graph.get_subgraph(
comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)
).graph
comp_tp_nodes = comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)
induced_graph = comp_graph.get_subgraph(comp_tp_nodes).graph

comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False)
comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False)
Expand Down Expand Up @@ -141,12 +140,15 @@ def get_edge_errors(matched_data: Matched):
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)
continue

source_comp_id = gt_comp_mapping[source]
target_comp_id = gt_comp_mapping[target]
source_comp_id = gt_comp_mapping.get(source, None)
target_comp_id = gt_comp_mapping.get(target, None)

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
if source_comp_id is None or target_comp_id is None:
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)
else:
expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)

gt_graph.edge_errors = True
comp_graph.edge_errors = True
19 changes: 19 additions & 0 deletions tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,25 @@ def test_get_connected_components(complex_graph, nx_comp1, nx_comp2):
assert track2.graph.edges == nx_comp2.edges


def test_get_subgraph(simple_graph):
target_nodes = ("1_0", "1_1")
subgraph = simple_graph.get_subgraph(target_nodes)
assert len(subgraph.nodes) == 2
assert len(subgraph.edges) == 1
# test that nodes_by_flag dicts are maintained
assert Counter(subgraph.nodes_by_flag[NodeFlag.TP_DIV]) == Counter(["1_1"])
assert Counter(subgraph.edges_by_flag[EdgeFlag.TRUE_POS]) == Counter(
[("1_0", "1_1")]
)
# test that start and end frame are updated
assert subgraph.start_frame == 0
assert subgraph.end_frame == 2

# test empty target nodes
empty_graph = simple_graph.get_subgraph([])
assert Counter(empty_graph.nodes) == Counter([])


def test_set_flag_on_node(simple_graph):
assert simple_graph.nodes()["1_0"] == {"id": "1_0", "t": 0, "y": 1, "x": 1}
assert simple_graph.nodes()["1_1"] == {
Expand Down
67 changes: 67 additions & 0 deletions tests/track_errors/test_ctc_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,70 @@ def test_assign_edge_errors_semantics():
get_edge_errors(matched_data)

assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeFlag.WRONG_SEMANTIC]


def test_ns_vertex_fn_edge():
"""Minimal Example of testing for FN edges with a NS Vertex
gt 1 - 2 - 3
4 - 5 - 6
comp 1 - 2
matching [ (1, 1), (4, 1), (2, 2), (5, 2) ]
"""

gt_nodes = [
(1, {"t": 0, "x": 1, "y": 1}),
(2, {"t": 1, "x": 1, "y": 1}),
(3, {"t": 2, "x": 1, "y": 1}),
(4, {"t": 0, "x": 0, "y": 1}),
(5, {"t": 1, "x": 0, "y": 1}),
(6, {"t": 2, "x": 0, "y": 1}),
]
gt_edges = [
(1, 2),
(2, 3),
(4, 5),
(5, 6),
]
gt = nx.DiGraph()
gt.add_nodes_from(gt_nodes)
gt.add_edges_from(gt_edges)

comp_nodes = [
(1, {"t": 0, "x": 0.5, "y": 1}),
(2, {"t": 1, "x": 0.5, "y": 1}),
]
comp_edges = [
(1, 2),
]
comp = nx.DiGraph()
comp.add_nodes_from(comp_nodes)
comp.add_edges_from(comp_edges)

mapping = [
(1, 1),
(5, 1),
(2, 2),
(5, 2),
]

matched_data = Matched(TrackingGraph(gt), TrackingGraph(comp), mapping)
get_vertex_errors(matched_data)
get_edge_errors(matched_data)

for node in comp.nodes:
assert comp.nodes[node][NodeFlag.NON_SPLIT]
for edge in comp_edges:
assert not comp.edges[edge][EdgeFlag.FALSE_POS]

# https://github.com/Janelia-Trackathon-2023/traccuracy/pull/141#issuecomment-2265990197
if False: # TODO: Fix this in a separate PR
for node in [1, 2, 4, 5]:
assert gt.nodes[node][NodeFlag.FALSE_NEG]

for node in [3, 6]:
assert gt.nodes[node][NodeFlag.FALSE_NEG]

for edge in gt_edges:
assert gt.edges[edge][EdgeFlag.FALSE_NEG]

0 comments on commit 599aa8e

Please sign in to comment.