diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index eb33fccf..3508c10b 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -161,7 +161,7 @@ def __init__( self.start_frame = None self.end_frame = None - self.update_graph(graph) + self._update_graph(graph) self.node_errors = False self.edge_errors = False @@ -356,11 +356,17 @@ def get_subgraph(self, nodes): new_graph = self.graph.subgraph(nodes).copy() new_trackgraph = copy.deepcopy(self) - new_trackgraph.update_graph(new_graph) + new_trackgraph._update_graph(new_graph) return new_trackgraph - def update_graph(self, graph): + def _update_graph(self, graph): + """Given a new graph, which is expected to be a subgraph of the current graph, + update attributes which are dependent on the graph. + + Args: + graph (nx.DiGraph): A networkx graph that is a subgraph of the original graph + """ self.graph = graph # construct a dictionary from frames to node ids for easy lookup diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index 6ddd7a67..0aadaa64 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING -from .._tracking_graph import EdgeAttr, NodeAttr -from ..track_errors._ctc import evaluate_ctc_events +from traccuracy._tracking_graph import EdgeAttr, NodeAttr +from traccuracy.track_errors._ctc import evaluate_ctc_events + from ._base import Metric if TYPE_CHECKING: