From 7ba43be165df441ac5eb3fba98d2e81db0c05180 Mon Sep 17 00:00:00 2001 From: Benjamin Gallusser Date: Fri, 22 Sep 2023 18:07:43 +0200 Subject: [PATCH] Speed up CTC edge errors When calculating edge errors, there is a 1-to-1 mapping between computed graph nodes and GT graph nodes, see details below. It is faster to use a dictionary for doing the node mapping with O(n log(n)) runtime, compared to repeatedly iterating over a list with np.where (O(n^2)). Details: Potential 1-to-many matches, namely computed nodes that match multiple GT nodes (called non-split), are not part of the induced graph. Therefore, all edges incident to such nodes are also not part of the induced graph, leaving us with the desired 1-to-1 mapping for all nodes that are incident to existing edges in the induced graph, which we iterate over in the loop for finding FP edges. Conversely, each GT node is only matched to at most one computed node, directly yielding the 1-1 matching for finding FN edges. --- src/traccuracy/track_errors/_ctc.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 7acba9b0..f80449be 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -2,7 +2,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np from tqdm import tqdm from traccuracy import EdgeAttr, NodeAttr @@ -91,15 +90,15 @@ def get_edge_errors(matched_data: "Matched"): ) gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False) - node_mapping_first = np.array([mp[0] for mp in node_mapping]) - node_mapping_second = np.array([mp[1] for mp in node_mapping]) + node_mapping_first = {mp[0]: mp[1] for mp in node_mapping} + node_mapping_second = {mp[1]: mp[0] for mp in node_mapping} # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] - source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0] - target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0] + source_gt_id = node_mapping_second[source] + target_gt_id = node_mapping_second[target] expected_gt_edge = (source_gt_id, target_gt_id) if expected_gt_edge not in gt_graph.edges(): @@ -124,8 +123,8 @@ def get_edge_errors(matched_data: "Matched"): gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) continue - source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1] - target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1] + source_comp_id = node_mapping_first[source] + target_comp_id = node_mapping_first[target] expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: