Skip to content

Commit

Permalink
Merge pull request #59 from bentaculum/faster_edge_errors
Browse files Browse the repository at this point in the history
Speed up CTC edge errors
  • Loading branch information
cmalinmayor authored Oct 30, 2023
2 parents 674c0f8 + 2a7ca53 commit ed2b7b1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
23 changes: 13 additions & 10 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,15 @@ def nodes(self, limit_to=None):
Will raise KeyError if any of these node_ids are not present.
Returns:
dict[hashable, dict]: A dictionary from node ids to node attributes
NodeView: Provides set-like operations on the nodes as well as node attribute lookup.
"""
nodes = self.graph.nodes.items()
if limit_to is None:
return dict(nodes)
return self.graph.nodes
else:
limited_nodes = {_id: data for _id, data in nodes if _id in limit_to}
return limited_nodes
for node in limit_to:
if not self.graph.has_node(node):
raise KeyError(f"Queried node {node} not present in graph.")
return self.graph.subgraph(limit_to).nodes

def edges(self, limit_to=None):
"""Get all the edges in the graph, along with their attributes.
Expand All @@ -228,14 +229,16 @@ def edges(self, limit_to=None):
Will raise KeyError if any of these edge ids are not present.
Returns:
dict[tuple[hashable], dict]: A dictionary from edge ids to edge attributes
OutEdgeView: Provides set-like operations on the edge-tuples as well as edge attribute
lookup.
"""
edges = self.graph.edges.items()
if limit_to is None:
return dict(edges)
return self.graph.edges
else:
limited_edges = {_id: data for _id, data in edges if _id in limit_to}
return limited_edges
for edge in limit_to:
if not self.graph.has_edge(*edge):
raise KeyError(f"Queried edge {edge} not present in graph.")
return self.graph.edge_subgraph(limit_to).edges

def get_nodes_in_frame(self, frame):
"""Get the node ids of all nodes in the given frame.
Expand Down
13 changes: 6 additions & 7 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from traccuracy import EdgeAttr, NodeAttr
Expand Down Expand Up @@ -95,8 +94,8 @@ def get_edge_errors(matched_data: "Matched"):
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)

node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])
gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph}
comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph}

# intertrack edges = connection between parent and daughter
for graph in [comp_graph, gt_graph]:
Expand All @@ -119,8 +118,8 @@ def get_edge_errors(matched_data: "Matched"):
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]

source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0]
target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0]
source_gt_id = comp_gt_mapping[source]
target_gt_id = comp_gt_mapping[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
Expand All @@ -143,8 +142,8 @@ def get_edge_errors(matched_data: "Matched"):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
continue

source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1]
target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1]
source_comp_id = gt_comp_mapping[source]
target_comp_id = gt_comp_mapping[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
Expand Down

0 comments on commit ed2b7b1

Please sign in to comment.