Skip to content

Commit

Permalink
Speed up CTC edge errors
Browse files Browse the repository at this point in the history
When calculating edge errors, there is a 1-to-1 mapping between computed
graph nodes and GT graph nodes, see details below.

It is faster to use a dictionary for doing the node mapping with O(n
log(n)) runtime, compared to repeatedly iterating over a list with
np.where (O(n^2)).

Details: Potential 1-to-many matches, namely computed nodes that match multiple
GT nodes (called non-split), are not part of the induced graph.
Therefore, all edges incident to such nodes are also not part of the induced
graph, leaving us with the desired 1-to-1 mapping for all nodes that are
incident to existing edges in the induced graph, which we iterate over
in the loop for finding FP edges.

Conversely, each GT node is only matched to at most one computed node,
directly yielding the 1-1 matching for finding FN edges.
  • Loading branch information
bentaculum committed Sep 22, 2023
1 parent c18408d commit 7ba43be
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from traccuracy import EdgeAttr, NodeAttr
Expand Down Expand Up @@ -91,15 +90,15 @@ def get_edge_errors(matched_data: "Matched"):
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)

node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])
node_mapping_first = {mp[0]: mp[1] for mp in node_mapping}
node_mapping_second = {mp[1]: mp[0] for mp in node_mapping}

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]

source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0]
target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0]
source_gt_id = node_mapping_second[source]
target_gt_id = node_mapping_second[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
Expand All @@ -124,8 +123,8 @@ def get_edge_errors(matched_data: "Matched"):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
continue

source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1]
target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1]
source_comp_id = node_mapping_first[source]
target_comp_id = node_mapping_first[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
Expand Down

0 comments on commit 7ba43be

Please sign in to comment.