Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up CTC edge errors #59

Merged
13 changes: 6 additions & 7 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from traccuracy import EdgeAttr, NodeAttr
Expand Down Expand Up @@ -91,15 +90,15 @@ def get_edge_errors(matched_data: "Matched"):
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)

node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])
node_mapping_first = {mp[0]: mp[1] for mp in node_mapping}
node_mapping_second = {mp[1]: mp[0] for mp in node_mapping}

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]

source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0]
target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0]
source_gt_id = node_mapping_second[source]
target_gt_id = node_mapping_second[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
Expand All @@ -124,8 +123,8 @@ def get_edge_errors(matched_data: "Matched"):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
continue

source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1]
target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1]
source_comp_id = node_mapping_first[source]
target_comp_id = node_mapping_first[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
Expand Down