Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up CTC edge errors #59

Merged
14 changes: 8 additions & 6 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def nodes(self, limit_to=None):
Returns:
dict[hashable, dict]: A dictionary from node ids to node attributes
"""
nodes = self.graph.nodes.items()
if limit_to is None:
return dict(nodes)
return self.graph.nodes
else:
limited_nodes = {_id: data for _id, data in nodes if _id in limit_to}
limited_nodes = {
_id: data for _id, data in self.graph.nodes if _id in limit_to
}
return limited_nodes

def edges(self, limit_to=None):
Expand All @@ -201,11 +202,12 @@ def edges(self, limit_to=None):
Returns:
dict[tuple[hashable], dict]: A dictionary from edge ids to edge attributes
"""
edges = self.graph.edges.items()
if limit_to is None:
return dict(edges)
return self.graph.edges
else:
limited_edges = {_id: data for _id, data in edges if _id in limit_to}
limited_edges = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be self.graph.edge_subgraph(limit_to).edges and then the typing would be the same for both return statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, put this in here 7e72be3. networkx.DiGraph.edge_subgraph and networkx.DiGraph.subgraph unfortunately don't throw an error when you ask for non-existing things. I do think that this is a reasonable check and added it. This way we also keep the API that you had before.

Copy link
Contributor Author

@bentaculum bentaculum Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about a good way to make docstrings that include classes from third-party packages, like networkx OutEdgeView here. Writing networkx.classes.reportviews.OutEdgeView in the docstring seems excessive.

Might be a good segway into typing traccuracy ;), or at least the core parts of it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah would really love to have this package better typed especially once we've sorta settled on the API

_id: data for _id, data in self.graph.edges if _id in limit_to
}
return limited_edges

def get_nodes_in_frame(self, frame):
Expand Down
17 changes: 10 additions & 7 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from traccuracy import EdgeAttr, NodeAttr
Expand Down Expand Up @@ -91,15 +90,19 @@ def get_edge_errors(matched_data: "Matched"):
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)

node_mapping_first = np.array([mp[0] for mp in node_mapping])
node_mapping_second = np.array([mp[1] for mp in node_mapping])
node_mapping_first = {
DragaDoncila marked this conversation as resolved.
Show resolved Hide resolved
gt: comp for gt, comp in node_mapping if comp in induced_graph
}
node_mapping_second = {
comp: gt for gt, comp in node_mapping if comp in induced_graph
}

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
source, target = edge[0], edge[1]

source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0]
target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0]
source_gt_id = node_mapping_second[source]
target_gt_id = node_mapping_second[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
Expand All @@ -124,8 +127,8 @@ def get_edge_errors(matched_data: "Matched"):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
continue

source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1]
target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1]
source_comp_id = node_mapping_first[source]
target_comp_id = node_mapping_first[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
Expand Down