diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 85fb8333..a2c2cc70 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -210,14 +210,15 @@ def nodes(self, limit_to=None): Will raise KeyError if any of these node_ids are not present. Returns: - dict[hashable, dict]: A dictionary from node ids to node attributes + NodeView: Provides set-like operations on the nodes as well as node attribute lookup. """ - nodes = self.graph.nodes.items() if limit_to is None: - return dict(nodes) + return self.graph.nodes else: - limited_nodes = {_id: data for _id, data in nodes if _id in limit_to} - return limited_nodes + for node in limit_to: + if not self.graph.has_node(node): + raise KeyError(f"Queried node {node} not present in graph.") + return self.graph.subgraph(limit_to).nodes def edges(self, limit_to=None): """Get all the edges in the graph, along with their attributes. @@ -228,14 +229,16 @@ def edges(self, limit_to=None): Will raise KeyError if any of these edge ids are not present. Returns: - dict[tuple[hashable], dict]: A dictionary from edge ids to edge attributes + OutEdgeView: Provides set-like operations on the edge-tuples as well as edge attribute + lookup. """ - edges = self.graph.edges.items() if limit_to is None: - return dict(edges) + return self.graph.edges else: - limited_edges = {_id: data for _id, data in edges if _id in limit_to} - return limited_edges + for edge in limit_to: + if not self.graph.has_edge(*edge): + raise KeyError(f"Queried edge {edge} not present in graph.") + return self.graph.edge_subgraph(limit_to).edges def get_nodes_in_frame(self, frame): """Get the node ids of all nodes in the given frame. diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index ea7071e9..642399df 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -2,7 +2,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np from tqdm import tqdm from traccuracy import EdgeAttr, NodeAttr @@ -95,8 +94,8 @@ def get_edge_errors(matched_data: "Matched"): ) gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False) - node_mapping_first = np.array([mp[0] for mp in node_mapping]) - node_mapping_second = np.array([mp[1] for mp in node_mapping]) + gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} + comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: @@ -119,8 +118,8 @@ def get_edge_errors(matched_data: "Matched"): for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] - source_gt_id = node_mapping[np.where(node_mapping_second == source)[0][0]][0] - target_gt_id = node_mapping[np.where(node_mapping_second == target)[0][0]][0] + source_gt_id = comp_gt_mapping[source] + target_gt_id = comp_gt_mapping[target] expected_gt_edge = (source_gt_id, target_gt_id) if expected_gt_edge not in gt_graph.edges(): @@ -143,8 +142,8 @@ def get_edge_errors(matched_data: "Matched"): gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) continue - source_comp_id = node_mapping[np.where(node_mapping_first == source)[0][0]][1] - target_comp_id = node_mapping[np.where(node_mapping_first == target)[0][0]][1] + source_comp_id = gt_comp_mapping[source] + target_comp_id = gt_comp_mapping[target] expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: