diff --git a/src/traccuracy/__init__.py b/src/traccuracy/__init__.py index 6fbbeeb8..662c5f31 100644 --- a/src/traccuracy/__init__.py +++ b/src/traccuracy/__init__.py @@ -7,6 +7,6 @@ __version__ = "uninstalled" from ._run_metrics import run_metrics -from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +from ._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph -__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"] +__all__ = ["TrackingGraph", "run_metrics", "NodeFlag", "EdgeFlag"] diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 2a038b2f..f6dd21fc 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -1,21 +1,26 @@ +from __future__ import annotations + import copy import enum import logging +from typing import TYPE_CHECKING, Hashable, Iterable import networkx as nx +if TYPE_CHECKING: + import numpy as np + from networkx.classes.reportviews import NodeView, OutEdgeView + logger = logging.getLogger(__name__) @enum.unique -class NodeAttr(str, enum.Enum): - """An enum containing all valid attributes that can be used to - annotate the nodes of a TrackingGraph. If new metrics require new - annotations, they should be added here to ensure strings do not overlap and - are standardized. Note that the user specified frame and location +class NodeFlag(str, enum.Enum): + """An enum containing standard flags that are used to annotate the nodes + of a TrackingGraph. Note that the user specified frame and location attributes are also valid node attributes that will be stored on the graph and should not overlap with these values. Additionally, if a graph already - has annotations using these strings before becoming a TrackGraph, + has annotations using these strings before becoming a TrackingGraph, this will likely ruin metrics computation! """ @@ -52,12 +57,10 @@ def has_value(cls, value): @enum.unique -class EdgeAttr(str, enum.Enum): - """An enum containing all valid attributes that can be used to - annotate the edges of a TrackingGraph. If new metrics require new - annotations, they should be added here to ensure strings do not overlap and - are standardized. Additionally, if a graph already - has annotations using these strings before becoming a TrackGraph, +class EdgeFlag(str, enum.Enum): + """An enum containing standard flags that are used to + annotate the edges of a TrackingGraph. If a graph already has + annotations using these strings before becoming a TrackingGraph, this will likely ruin metrics computation! """ @@ -105,25 +108,25 @@ class TrackingGraph: def __init__( self, - graph, - segmentation=None, - frame_key="t", - label_key="segmentation_id", - location_keys=("x", "y"), - name=None, + graph: nx.DiGraph, + segmentation: np.ndarray | None = None, + frame_key: str = "t", + label_key: str = "segmentation_id", + location_keys: tuple[str, ...] = ("x", "y"), + name: str | None = None, ): """A directed graph representing a tracking solution where edges go forward in time. If the provided graph already has annotations that are strings - included in NodeAttrs or EdgeAttrs, this will likely ruin + included in NodeFlags or EdgeFlags, this will likely ruin metric computation! Args: graph (networkx.DiGraph): A directed graph representing a tracking solution where edges go forward in time. If the graph already - has annotations that are strings included in NodeAttrs or - EdgeAttrs, this will likely ruin metrics computation! + has annotations that are strings included in NodeFlags or + EdgeFlags, this will likely ruin metrics computation! segmentation (numpy-like array, optional): A numpy-like array of segmentations. The location of each node in tracking_graph is assumed to be inside the area of the corresponding segmentation. Defaults to None. @@ -142,20 +145,20 @@ def __init__( outputs associated with this object """ self.segmentation = segmentation - if NodeAttr.has_value(frame_key): + if NodeFlag.has_value(frame_key): raise ValueError( f"Specified frame key {frame_key} is reserved for graph " "annotation. Please change the frame key." ) self.frame_key = frame_key - if label_key is not None and NodeAttr.has_value(label_key): + if label_key is not None and NodeFlag.has_value(label_key): raise ValueError( f"Specified label key {label_key} is reserved for graph" "annotation. Please change the label key." ) self.label_key = label_key for loc_key in location_keys: - if NodeAttr.has_value(loc_key): + if NodeFlag.has_value(loc_key): raise ValueError( f"Specified location key {loc_key} is reserved for graph" "annotation. Please change the location key." @@ -166,9 +169,13 @@ def __init__( self.graph = graph # construct dictionaries from attributes to nodes/edges for easy lookup - self.nodes_by_frame = {} - self.nodes_by_flag = {flag: set() for flag in NodeAttr} - self.edges_by_flag = {flag: set() for flag in EdgeAttr} + self.nodes_by_frame: dict[int, set[Hashable]] = {} + self.nodes_by_flag: dict[NodeFlag, set[Hashable]] = { + flag: set() for flag in NodeFlag + } + self.edges_by_flag: dict[EdgeFlag, set[tuple[Hashable, Hashable]]] = { + flag: set() for flag in EdgeFlag + } for node, attrs in self.graph.nodes.items(): # check that every node has the time frame and location specified assert ( @@ -186,15 +193,15 @@ def __init__( else: self.nodes_by_frame[frame].add(node) # store node id in nodes_by_flag mapping - for flag in NodeAttr: - if flag in attrs and attrs[flag]: - self.nodes_by_flag[flag].add(node) + for node_flag in NodeFlag: + if node_flag in attrs and attrs[node_flag]: + self.nodes_by_flag[node_flag].add(node) # store edge id in edges_by_flag for edge, attrs in self.graph.edges.items(): - for flag in EdgeAttr: - if flag in attrs and attrs[flag]: - self.edges_by_flag[flag].add(edge) + for edge_flag in EdgeFlag: + if edge_flag in attrs and attrs[edge_flag]: + self.edges_by_flag[edge_flag].add(edge) # Store first and last frames for reference self.start_frame = min(self.nodes_by_frame.keys()) @@ -205,62 +212,42 @@ def __init__( self.node_errors = False self.edge_errors = False - def nodes(self, limit_to=None): + @property + def nodes(self) -> NodeView: """Get all the nodes in the graph, along with their attributes. - Args: - limit_to (list[hashable], optional): Limit returned dictionary - to nodes with the provided ids. Defaults to None. - Will raise KeyError if any of these node_ids are not present. - Returns: NodeView: Provides set-like operations on the nodes as well as node attribute lookup. """ - if limit_to is None: - return self.graph.nodes - else: - for node in limit_to: - if not self.graph.has_node(node): - raise KeyError(f"Queried node {node} not present in graph.") - return self.graph.subgraph(limit_to).nodes + return self.graph.nodes - def edges(self, limit_to=None): + @property + def edges(self) -> OutEdgeView: """Get all the edges in the graph, along with their attributes. - Args: - limit_to (list[tuple[hashable]], optional): Limit returned dictionary - to edges with the provided ids. Defaults to None. - Will raise KeyError if any of these edge ids are not present. - Returns: OutEdgeView: Provides set-like operations on the edge-tuples as well as edge attribute lookup. """ - if limit_to is None: - return self.graph.edges - else: - for edge in limit_to: - if not self.graph.has_edge(*edge): - raise KeyError(f"Queried edge {edge} not present in graph.") - return self.graph.edge_subgraph(limit_to).edges + return self.graph.edges - def get_nodes_in_frame(self, frame): + def get_nodes_in_frame(self, frame: int) -> set[Hashable]: """Get the node ids of all nodes in the given frame. Args: frame (int): The frame to return all node ids for. If the provided frame is outside of the range - (self.start_frame, self.end_frame), returns an empty list. + (self.start_frame, self.end_frame), returns an empty iterable. Returns: - list of node_ids: A list of node ids for all nodes in frame. + Iterable[Hashable]: An iterable of node ids for all nodes in frame. """ if frame in self.nodes_by_frame.keys(): - return list(self.nodes_by_frame[frame]) + return self.nodes_by_frame[frame] else: - return [] + return set() - def get_location(self, node_id): + def get_location(self, node_id: Hashable) -> list[float]: """Get the spatial location of the node with node_id using self.location_keys. Args: @@ -271,156 +258,35 @@ def get_location(self, node_id): """ return [self.graph.nodes[node_id][key] for key in self.location_keys] - def get_nodes_with_flag(self, attr): - """Get all nodes with specified NodeAttr set to True. + def get_nodes_with_flag(self, flag: NodeFlag) -> set[Hashable]: + """Get all nodes with specified NodeFlag set to True. Args: - attr (traccuracy.NodeAttr): the node attribute to query for + flag (traccuracy.NodeFlag): the node flag to query for Returns: - (List(hashable)): A list of node_ids which have the given attribute + (List(hashable)): An iterable of node_ids which have the given flag and the value is True. """ - if not isinstance(attr, NodeAttr): - raise ValueError(f"Function takes NodeAttr arguments, not {type(attr)}.") - return list(self.nodes_by_flag[attr]) + if not isinstance(flag, NodeFlag): + raise ValueError(f"Function takes NodeFlag arguments, not {type(flag)}.") + return self.nodes_by_flag[flag] - def get_edges_with_flag(self, attr): - """Get all edges with specified EdgeAttr set to True. + def get_edges_with_flag(self, flag: EdgeFlag) -> set[tuple[Hashable, Hashable]]: + """Get all edges with specified EdgeFlag set to True. Args: - attr (traccuracy.EdgeAttr): the edge attribute to query for + flag (traccuracy.EdgeFlag): the edge flag to query for Returns: - (List(hashable)): A list of edge ids which have the given attribute + (List(hashable)): An iterable of edge ids which have the given flag and the value is True. """ - if not isinstance(attr, EdgeAttr): - raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.") - return list(self.edges_by_flag[attr]) - - def get_nodes_by_roi(self, **kwargs): - """Gets the nodes in a given region of interest (ROI). The ROI is - defined by keyword arguments that correspond to the frame key and - location keys, where each argument should be a (start, end) tuple - (the end is exclusive). Dimensions that are not passed as arguments - are unbounded. None can be passed as an element of the tuple to - signify an unbounded ROI on that side. - - For example, if frame_key='t' and location_keys=('x', 'y'): - `graph.get_nodes_by_roi(t=(10, None), x=(0, 100))` - would return all nodes with time >= 10, and 0 <= x < 100, with no limit - on the y values. - - Returns: - list of hashable: A list of node_ids for all nodes in the ROI. - """ - frames = None - dimensions = [] - for dim, limit in kwargs.items(): - if not (dim == self.frame_key or dim in self.location_keys): - raise ValueError( - f"Provided argument {dim} is neither the frame key" - f" {self.frame_key} or one of the location keys" - f" {self.location_keys}." - ) - if dim == self.frame_key: - frames = list(limit) - else: - dimensions.append((dim, limit[0], limit[1])) - nodes = [] - if frames: - if frames[0] is None: - frames[0] = self.start_frame - if frames[1] is None: - frames[1] = self.end_frame - possible_nodes = [] - for frame in range(frames[0], frames[1]): - if frame in self.nodes_by_frame: - possible_nodes.extend(self.nodes_by_frame[frame]) - else: - possible_nodes = self.graph.nodes() - - for node in possible_nodes: - attrs = self.graph.nodes[node] - inside = True - for dim, start, end in dimensions: - if start is not None and attrs[dim] < start: - inside = False - break - if end is not None and attrs[dim] >= end: - inside = False - break - if inside: - nodes.append(node) - return nodes - - def get_nodes_with_attribute(self, attr, criterion=None, limit_to=None): - """Get the node_ids of all nodes who have an attribute, optionally - limiting to nodes whose value at that attribute meet a given criteria. - - For example, get all nodes that have an attribute called "division", - or where the value for "division" == True. - This also works on location keys, for example to get all nodes with y > 100. - - Args: - attr (str): the name of the attribute to search for in the node metadata - criterion ((any)->bool, optional): A function that takes a value and returns - a boolean. If provided, nodes will only be returned if the value at - node[attr] meets this criterion. Defaults to None. - limit_to (list[hashable], optional): If provided the function will only - return node ids in this list. Will raise KeyError if ids provided here - are not present. + if not isinstance(flag, EdgeFlag): + raise ValueError(f"Function takes EdgeFlag arguments, not {type(flag)}.") + return self.edges_by_flag[flag] - Returns: - list of hashable: A list of node_ids which have the given attribute - (and optionally have values at that attribute that meet the given criterion, - and/or are in the list of node ids.) - """ - if not limit_to: - limit_to = self.graph.nodes.keys() - - nodes = [] - for node in limit_to: - attributes = self.graph.nodes[node] - if attr in attributes.keys(): - if criterion is None or criterion(attributes[attr]): - nodes.append(node) - return nodes - - def get_edges_with_attribute(self, attr, criterion=None, limit_to=None): - """Get the edge_ids of all edges who have an attribute, optionally - limiting to edges whose value at that attribute meet a given criteria. - - For example, get all edges that have an attribute called "fp", - or where the value for "fp" == True. - - Args: - attr (str): the name of the attribute to search for in the edge metadata - criterion ((any)->bool, optional): A function that takes a value and returns - a boolean. If provided, edges will only be returned if the value at - edge[attr] meets this criterion. Defaults to None. - limit_to (list[hashable], optional): If provided the function will only - return edge ids in this list. Will raise KeyError if ids provided here - are not present. - - Returns: - list of hashable: A list of edge_ids which have the given attribute - (and optionally have values at that attribute that meet the given criterion, - and/or are in the list of edge ids.) - """ - if not limit_to: - limit_to = self.graph.edges.keys() - - edges = [] - for edge in limit_to: - attributes = self.graph.edges[edge] - if attr in attributes.keys(): - if criterion is None or criterion(attributes[attr]): - edges.append(edge) - return edges - - def get_divisions(self): + def get_divisions(self) -> list[Hashable]: """Get all nodes that have at least two edges pointing to the next time frame Returns: @@ -428,7 +294,7 @@ def get_divisions(self): """ return [node for node, degree in self.graph.out_degree() if degree >= 2] - def get_merges(self): + def get_merges(self) -> list[Hashable]: """Get all nodes that have at least two incoming edges from the previous time frame Returns: @@ -436,39 +302,7 @@ def get_merges(self): """ return [node for node, degree in self.graph.in_degree() if degree >= 2] - def get_preds(self, node): - """Get all predecessors of the given node. - - A predecessor node is any node from a previous time point that has an edge to - the given node. In a case where merges are not allowed, each node will have a - maximum of one predecessor. - - Args: - node (hashable): A node id - - Returns: - list of hashable: A list of node ids containing all nodes that - have an edge to the given node. - """ - return [pred for pred, _ in self.graph.in_edges(node)] - - def get_succs(self, node): - """Get all successor nodes of the given node. - - A successor node is any node from a later time point that has an edge - from the given node. In a case where divisions are not allowed, - a node will have a maximum of one successor. - - Args: - node (hashable): A node id - - Returns: - list of hashable: A list of node ids containing all nodes that have - an edge from the given node. - """ - return [succ for _, succ in self.graph.out_edges(node)] - - def get_connected_components(self): + def get_connected_components(self) -> list[TrackingGraph]: """Get a list of TrackingGraphs, each corresponding to one track (i.e., a connected component in the track graph). @@ -481,7 +315,7 @@ def get_connected_components(self): return [self.get_subgraph(g) for g in nx.weakly_connected_components(graph)] - def get_subgraph(self, nodes): + def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph: """Returns a new TrackingGraph with the subgraph defined by the list of nodes Args: @@ -498,129 +332,133 @@ def get_subgraph(self, nodes): else: del new_trackgraph.nodes_by_frame[frame] - for attr in NodeAttr: - new_trackgraph.nodes_by_flag[attr] = self.nodes_by_flag[attr].intersection( - nodes - ) - for attr in EdgeAttr: - new_trackgraph.edges_by_flag[attr] = self.edges_by_flag[attr].intersection( - nodes - ) + for node_flag in NodeFlag: + new_trackgraph.nodes_by_flag[node_flag] = self.nodes_by_flag[ + node_flag + ].intersection(nodes) + for edge_flag in EdgeFlag: + new_trackgraph.edges_by_flag[edge_flag] = self.edges_by_flag[ + edge_flag + ].intersection(nodes) new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys()) new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1 return new_trackgraph - def set_node_attribute(self, ids, attr, value=True): - """Set an attribute flag for a set of nodes specified by - ids. If an id is not found in the graph, a KeyError will be raised. - If the key already exists, the existing value will be overwritten. + def set_flag_on_node( + self, _id: Hashable, flag: NodeFlag, value: bool = True + ) -> None: + """Set an attribute flag for a single node. + If the id is not found in the graph, a KeyError will be raised. + If the flag already exists, the existing value will be overwritten. Args: - ids (hashable | list[hashable]): The node id or list of node ids - to set the attribute for. - attr (traccuracy.NodeAttr): The node attribute to set. Must be - of type NodeAttr - you may not not pass strings, even if they - are included in the NodeAttr enum values. - value (bool, optional): Attributes are flags and can only be set to + _id (Hashable): The node id on which to set the flag. + flag (traccuracy.NodeFlag): The node flag to set. Must be + of type NodeFlag - you may not not pass strings, even if they + are included in the NodeFlag enum values. + value (bool, optional): Flags can only be set to True or False. Defaults to True. + + Raises: + KeyError if the provided id is not in the graph. + ValueError if the provided flag is not a NodeFlag """ - if not isinstance(ids, list): - ids = [ids] - if not isinstance(attr, NodeAttr): + if not isinstance(flag, NodeFlag): raise ValueError( - f"Provided attribute {attr} is not of type NodeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided flag {flag} is not of type NodeFlag. " + "Please use the enum instead of passing string values." ) - for _id in ids: - self.graph.nodes[_id][attr] = value - if value: - self.nodes_by_flag[attr].add(_id) - else: - self.nodes_by_flag[attr].discard(_id) + self.graph.nodes[_id][flag] = value + if value: + self.nodes_by_flag[flag].add(_id) + else: + self.nodes_by_flag[flag].discard(_id) - def set_edge_attribute(self, ids, attr, value=True): - """Set an attribute flag for a set of edges specified by - ids. If an edge is not found in the graph, a KeyError will be raised. - If the key already exists, the existing value will be overwritten. + def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True) -> None: + """Set an attribute flag for all nodes in the graph. + If the flag already exists, the existing values will be overwritten. Args: - ids (tuple(hashable) | list[tuple(hashable)]): The edge id or list of edge ids - to set the attribute for. Edge ids are a 2-tuple of node ids. - attr (traccuracy.EdgeAttr): The edge attribute to set. Must be - of type EdgeAttr - you may not pass strings, even if they are - included in the EdgeAttr enum values. - value (bool): Attributes are flags and can only be set to - True or False. Defaults to True. + flag (traccuracy.NodeFlag): The node flag to set. Must be + of type NodeFlag - you may not not pass strings, even if they + are included in the NodeFlag enum values. + value (bool, optional): Flags can only be set to True or False. + Defaults to True. + + Raises: + ValueError if the provided flag is not a NodeFlag. """ - if not isinstance(ids, list): - ids = [ids] - if not isinstance(attr, EdgeAttr): + if not isinstance(flag, NodeFlag): raise ValueError( - f"Provided attribute {attr} is not of type EdgeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided flag {flag} is not of type NodeFlag. " + "Please use the enum instead of passing string values." ) - for _id in ids: - self.graph.edges[_id][attr] = value - if value: - self.edges_by_flag[attr].add(_id) - else: - self.edges_by_flag[attr].discard(_id) + nx.set_node_attributes(self.graph, value, name=flag) + if value: + self.nodes_by_flag[flag].update(self.graph.nodes) + else: + self.nodes_by_flag[flag] = set() - def get_node_attribute(self, _id, attr): - """Get the boolean value of a given attribute for a given node. + def set_flag_on_edge( + self, _id: tuple[Hashable, Hashable], flag: EdgeFlag, value: bool = True + ) -> None: + """Set an attribute flag for an edge. + If the flag already exists, the existing value will be overwritten. Args: - _id (hashable): node id - attr (NodeAttr): Node attribute to fetch the value of + ids (tuple[Hashable]): The edge id or list of edge ids + to set the attribute for. Edge ids are a 2-tuple of node ids. + flag (traccuracy.EdgeFlag): The edge flag to set. Must be + of type EdgeFlag - you may not pass strings, even if they are + included in the EdgeFlag enum values. + value (bool): Flags can only be set to True or False. + Defaults to True. Raises: - ValueError: if attr is not a NodeAttr - - Returns: - bool: The value of the attribute for that node. If the attribute - is not present on the graph, the value is presumed False. + KeyError if edge with _id not in graph. """ - if not isinstance(attr, NodeAttr): + if not isinstance(flag, EdgeFlag): raise ValueError( - f"Provided attribute {attr} is not of type NodeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided attribute {flag} is not of type EdgeFlag. " + "Please use the enum instead of passing string values." ) + self.graph.edges[_id][flag] = value + if value: + self.edges_by_flag[flag].add(_id) + else: + self.edges_by_flag[flag].discard(_id) - if attr not in self.graph.nodes[_id]: - return False - return self.graph.nodes[_id][attr] - - def get_edge_attribute(self, _id, attr): - """Get the boolean value of a given attribute for a given edge. + def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True) -> None: + """Set an attribute flag for all edges in the graph. + If the flag already exists, the existing values will be overwritten. Args: - _id (hashable): node id - attr (EdgeAttr): Edge attribute to fetch the value of + flag (traccuracy.EdgeFlag): The edge flag to set. Must be + of type EdgeFlag - you may not not pass strings, even if they + are included in the EdgeFlag enum values. + value (bool, optional): Flags can only be set to True or False. + Defaults to True. Raises: - ValueError: if attr is not a EdgeAttr - - Returns: - bool: The value of the attribute for that edge. If the attribute - is not present on the graph, the value is presumed False. + ValueError if the provided flag is not an EdgeFlag. """ - if not isinstance(attr, EdgeAttr): + if not isinstance(flag, EdgeFlag): raise ValueError( - f"Provided attribute {attr} is not of type EdgeAttr. " + f"Provided flag {flag} is not of type EdgeFlag. " "Please use the enum instead of passing string values, " "and add new attributes to the class to avoid key collision." ) + nx.set_edge_attributes(self.graph, value, name=flag) + if value: + self.edges_by_flag[flag].update(self.graph.edges) + else: + self.edges_by_flag[flag] = set() - if attr not in self.graph.edges[_id]: - return False - return self.graph.edges[_id][attr] - - def get_tracklets(self, include_division_edges: bool = False): + def get_tracklets( + self, include_division_edges: bool = False + ) -> list[TrackingGraph]: """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next @@ -636,7 +474,7 @@ def get_tracklets(self, include_division_edges: bool = False): # Remove all intertrack edges from a copy of the original graph removed_edges = [] for parent in self.get_divisions(): - for daughter in self.get_succs(parent): + for daughter in self.graph.successors(parent): graph_copy.remove_edge(parent, daughter) removed_edges.append((parent, daughter)) diff --git a/src/traccuracy/matchers/_base.py b/src/traccuracy/matchers/_base.py index f075aa96..d3056356 100644 --- a/src/traccuracy/matchers/_base.py +++ b/src/traccuracy/matchers/_base.py @@ -50,9 +50,9 @@ def compute_mapping( matched.matcher_info = self.info # Report matching performance - total_gt = len(matched.gt_graph.nodes()) + total_gt = len(matched.gt_graph.nodes) matched_gt = len({m[0] for m in matched.mapping}) - total_pred = len(matched.pred_graph.nodes()) + total_pred = len(matched.pred_graph.nodes) matched_pred = len({m[1] for m in matched.mapping}) logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.") logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.") diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 23f02be8..43c33f8b 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -37,7 +37,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): traccuracy.matchers.Matched: Matched data object containing the CTC mapping Raises: - ValueError: GT and pred segmentations must be the same shape + ValueError: if GT and pred segmentations are None or are not the same shape """ gt = gt_graph pred = pred_graph @@ -46,6 +46,9 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): G_gt, mask_gt = gt, gt.segmentation G_pred, mask_pred = pred, pred.segmentation + if mask_gt is None or mask_pred is None: + raise ValueError("Segmentation is None, cannot perform matching") + if mask_gt.shape != mask_pred.shape: raise ValueError("Segmentation shapes must match between gt and pred") diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 7a48fbc4..5bcd7c39 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Hashable + import numpy as np from tqdm import tqdm @@ -46,6 +48,35 @@ def _match_nodes(gt, res, threshold=1): return gtcells, rescells +def _construct_time_to_seg_id_map( + graph: TrackingGraph, +) -> dict[int, dict[Hashable, Hashable]]: + """For each time frame in the graph, create a mapping from segmentation ids + (the ids in the segmentation array, stored in graph.label_key) to the + node ids (the ids of the TrackingGraph nodes). + + Args: + graph(TrackingGraph): a tracking graph with a label_key on each node + + Returns: + dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}} + + Raises: + AssertionError: If two nodes in a time frame have the same segmentation_id + """ + time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {} + for node_id, data in graph.nodes(data=True): + time = data[graph.frame_key] + seg_id = data[graph.label_key] + seg_id_to_node_id_map = time_to_seg_id_map.get(time, {}) + assert ( + seg_id not in seg_id_to_node_id_map + ), f"Segmentation ID {seg_id} occurred twice in frame {time}." + seg_id_to_node_id_map[seg_id] = node_id + time_to_seg_id_map[time] = seg_id_to_node_id_map + return time_to_seg_id_map + + def match_iou(gt, pred, threshold=0.6): """Identifies pairs of cells between gt and pred that have iou > threshold @@ -78,24 +109,19 @@ def match_iou(gt, pred, threshold=0.6): # Get overlaps for each frame frame_range = range(gt.start_frame, gt.end_frame) total = len(list(frame_range)) + + gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt) + pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred) + for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total): matches = _match_nodes( gt.segmentation[i], pred.segmentation[i], threshold=threshold ) - # Construct node id tuple for each match - for gt_id, pred_id in zip(*matches): + for gt_seg_id, pred_seg_id in zip(*matches): # Find node id based on time and segmentation label - gt_node = gt.get_nodes_with_attribute( - gt.label_key, - criterion=lambda x: x == gt_id, # noqa - limit_to=gt.get_nodes_in_frame(t), - )[0] - pred_node = pred.get_nodes_with_attribute( - pred.label_key, - criterion=lambda x: x == pred_id, # noqa - limit_to=pred.get_nodes_in_frame(t), - )[0] + gt_node = gt_time_to_seg_id_map[t][gt_seg_id] + pred_node = pred_time_to_seg_id_map[t][pred_seg_id] mapper.append((gt_node, pred_node)) return mapper diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index 71079469..9d40d4c0 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from traccuracy._tracking_graph import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeFlag, NodeFlag from traccuracy.track_errors._ctc import evaluate_ctc_events from ._base import Metric @@ -38,14 +38,14 @@ def _compute(self, data: Matched): evaluate_ctc_events(data) vertex_error_counts = { - "ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), - "fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), - "fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), + "ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)), + "fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)), + "fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)), } edge_error_counts = { - "ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)), - "fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), - "fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), + "ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)), + "fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.FALSE_POS)), + "fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)), } error_sum = get_weighted_error_sum( vertex_error_counts, diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 9dd275e3..86636727 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -36,7 +36,7 @@ from typing import TYPE_CHECKING -from traccuracy._tracking_graph import NodeAttr +from traccuracy._tracking_graph import NodeFlag from traccuracy.track_errors.divisions import _evaluate_division_events from ._base import Metric @@ -90,15 +90,9 @@ def _compute(self, data: Matched): } def _calculate_metrics(self, g_gt, g_pred): - tp_division_count = len( - g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x) - ) - fn_division_count = len( - g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x) - ) - fp_division_count = len( - g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x) - ) + tp_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) + fn_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.FN_DIV)) + fp_division_count = len(g_pred.get_nodes_with_flag(NodeFlag.FP_DIV)) try: recall = tp_division_count / (tp_division_count + fn_division_count) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index bee72a8c..01a7c7bb 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -6,7 +6,7 @@ from tqdm import tqdm -from traccuracy._tracking_graph import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeFlag, NodeFlag if TYPE_CHECKING: from traccuracy.matchers import Matched @@ -39,12 +39,12 @@ def get_vertex_errors(matched_data: Matched): logger.info("Node errors already calculated. Skipping graph annotation") return - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.TRUE_POS, False) - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.NON_SPLIT, False) + comp_graph.set_flag_on_all_nodes(NodeFlag.TRUE_POS, False) + comp_graph.set_flag_on_all_nodes(NodeFlag.NON_SPLIT, False) # will flip this when we come across the vertex in the mapping - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.FALSE_POS, True) - gt_graph.set_node_attribute(list(gt_graph.nodes()), NodeAttr.FALSE_NEG, True) + comp_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, True) + gt_graph.set_flag_on_all_nodes(NodeFlag.FALSE_NEG, True) # we need to know how many computed vertices are "non-split", so we make # a mapping of gt vertices to their matched comp vertices @@ -57,15 +57,16 @@ def get_vertex_errors(matched_data: Matched): gt_ids = dict_mapping[pred_id] if len(gt_ids) == 1: gid = gt_ids[0] - comp_graph.set_node_attribute(pred_id, NodeAttr.TRUE_POS, True) - comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False) - gt_graph.set_node_attribute(gid, NodeAttr.FALSE_NEG, False) + comp_graph.set_flag_on_node(pred_id, NodeFlag.TRUE_POS, True) + comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False) + gt_graph.set_flag_on_node(gid, NodeFlag.FALSE_NEG, False) elif len(gt_ids) > 1: - comp_graph.set_node_attribute(pred_id, NodeAttr.NON_SPLIT, True) - comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False) + comp_graph.set_flag_on_node(pred_id, NodeFlag.NON_SPLIT, True) + comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False) # number of split operations that would be required to correct the vertices ns_count += len(gt_ids) - 1 - gt_graph.set_node_attribute(gt_ids, NodeAttr.FALSE_NEG, False) + for gt_id in gt_ids: + gt_graph.set_flag_on_node(gt_id, NodeFlag.FALSE_NEG, False) # Record presence of annotations on the TrackingGraph comp_graph.node_errors = True @@ -87,14 +88,12 @@ def get_edge_errors(matched_data: Matched): get_vertex_errors(matched_data) induced_graph = comp_graph.get_subgraph( - comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS) + comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS) ).graph - comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False) - comp_graph.set_edge_attribute( - list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False - ) - gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False) + comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False) + comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False) + gt_graph.set_flag_on_all_edges(EdgeFlag.FALSE_NEG, False) gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} @@ -102,19 +101,17 @@ def get_edge_errors(matched_data: Matched): # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: # Set to False by default - graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) + graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False) for parent in graph.get_divisions(): - for daughter in graph.get_succs(parent): - graph.set_edge_attribute( - (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True + for daughter in graph.graph.successors(parent): + graph.set_flag_on_edge( + (parent, daughter), EdgeFlag.INTERTRACK_EDGE, True ) for merge in graph.get_merges(): - for parent in graph.get_preds(merge): - graph.set_edge_attribute( - (parent, merge), EdgeAttr.INTERTRACK_EDGE, True - ) + for parent in graph.graph.predecessors(merge): + graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True) # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): @@ -124,24 +121,24 @@ def get_edge_errors(matched_data: Matched): target_gt_id = comp_gt_mapping[target] expected_gt_edge = (source_gt_id, target_gt_id) - if expected_gt_edge not in gt_graph.edges(): - comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True) + if expected_gt_edge not in gt_graph.edges: + comp_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_POS, True) else: # check if semantics are correct - is_parent_gt = gt_graph.edges()[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE] - is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE] + is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeFlag.INTERTRACK_EDGE] + is_parent_comp = comp_graph.edges[edge][EdgeFlag.INTERTRACK_EDGE] if is_parent_gt != is_parent_comp: - comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) + comp_graph.set_flag_on_edge(edge, EdgeFlag.WRONG_SEMANTIC, True) # fn edges - edges in gt_graph that aren't in induced graph - for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"): + for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn if ( - gt_graph.nodes()[source][NodeAttr.FALSE_NEG] - or gt_graph.nodes()[target][NodeAttr.FALSE_NEG] + gt_graph.nodes[source][NodeFlag.FALSE_NEG] + or gt_graph.nodes[target][NodeFlag.FALSE_NEG] ): - gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) continue source_comp_id = gt_comp_mapping[source] @@ -149,7 +146,7 @@ def get_edge_errors(matched_data: Matched): expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: - gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) gt_graph.edge_errors = True comp_graph.edge_errors = True diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 7ab3fc6b..a0c41b0e 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -6,7 +6,7 @@ from collections import Counter from typing import TYPE_CHECKING -from traccuracy._tracking_graph import NodeAttr +from traccuracy._tracking_graph import NodeFlag from traccuracy._utils import find_gt_node_matches, find_pred_node_matches if TYPE_CHECKING: @@ -62,29 +62,30 @@ def _find_pred_node_matches(pred_node): pred_node = _find_gt_node_matches(gt_node) # No matching node so division missed if pred_node is None: - g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True) # Check if the division has the correct daughters else: - succ_gt = g_gt.get_succs(gt_node) + succ_gt = g_gt.graph.successors(gt_node) # Map pred succ nodes onto gt, unmapped nodes will return as None succ_pred = [ - _find_pred_node_matches(n) for n in g_pred.get_succs(pred_node) + _find_pred_node_matches(n) for n in g_pred.graph.successors(pred_node) ] # If daughters are same, division is correct if Counter(succ_gt) == Counter(succ_pred): - g_gt.set_node_attribute(gt_node, NodeAttr.TP_DIV, True) - g_pred.set_node_attribute(pred_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.TP_DIV, True) + g_pred.set_flag_on_node(pred_node, NodeFlag.TP_DIV, True) # If daughters are at all mismatched, division is false negative else: - g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True) # Remove res division to record that we have classified it if pred_node in div_pred: div_pred.remove(pred_node) # Any remaining pred divisions are false positives - g_pred.set_node_attribute(div_pred, NodeAttr.FP_DIV, True) + for fp_div in div_pred: + g_pred.set_flag_on_node(fp_div, NodeFlag.FP_DIV, True) # Set division annotation flag g_gt.division_annotations = True @@ -106,7 +107,7 @@ def _get_pred_by_t(g, node, delta_frames): hashable: Node key of predecessor in target frame """ for _ in range(delta_frames): - nodes = g.get_preds(node) + nodes = list(g.graph.predecessors(node)) # Exit if there are no predecessors if len(nodes) == 0: return None @@ -132,7 +133,7 @@ def _get_succ_by_t(g, node, delta_frames): hashable: Node id of successor """ for _ in range(delta_frames): - nodes = g.get_succs(node) + nodes = list(g.graph.successors(node)) # Exit if there are no successors another division if len(nodes) == 0 or len(nodes) >= 2: return None @@ -170,8 +171,8 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): ): raise ValueError("Mapping must be one-to-one") - fp_divs = g_pred.get_nodes_with_flag(NodeAttr.FP_DIV) - fn_divs = g_gt.get_nodes_with_flag(NodeAttr.FN_DIV) + fp_divs = g_pred.get_nodes_with_flag(NodeFlag.FP_DIV) + fn_divs = g_gt.get_nodes_with_flag(NodeFlag.FN_DIV) # Compare all pairs of fp and fn for fp_node, fn_node in itertools.product(fp_divs, fn_divs): @@ -195,9 +196,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): # Check if daughters match fp_succ = [ _get_succ_by_t(g_pred, node, t_fn - t_fp) - for node in g_pred.get_succs(fp_node) + for node in g_pred.graph.successors(fp_node) ] - fn_succ = g_gt.get_succs(fn_node) + fn_succ = g_gt.graph.successors(fn_node) if Counter(fp_succ) != Counter(fn_succ): # Daughters don't match so division cannot match continue @@ -216,9 +217,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): # Check if daughters match fn_succ = [ _get_succ_by_t(g_gt, node, t_fp - t_fn) - for node in g_gt.get_succs(fn_node) + for node in g_gt.graph.successors(fn_node) ] - fp_succ = g_pred.get_succs(fp_node) + fp_succ = g_pred.graph.successors(fp_node) if Counter(fp_succ) != Counter(fn_succ): # Daughters don't match so division cannot match continue @@ -228,12 +229,12 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): if correct: # Remove error annotations from pred graph - g_pred.set_node_attribute(fp_node, NodeAttr.FP_DIV, False) - g_gt.set_node_attribute(fn_node, NodeAttr.FN_DIV, False) + g_pred.set_flag_on_node(fp_node, NodeFlag.FP_DIV, False) + g_gt.set_flag_on_node(fn_node, NodeFlag.FN_DIV, False) # Add the tp divisions annotations - g_gt.set_node_attribute(fn_node, NodeAttr.TP_DIV, True) - g_pred.set_node_attribute(fp_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(fn_node, NodeFlag.TP_DIV, True) + g_pred.set_flag_on_node(fp_node, NodeFlag.TP_DIV, True) return new_matched diff --git a/tests/matchers/test_iou.py b/tests/matchers/test_iou.py index edf942b7..d1133a0d 100644 --- a/tests/matchers/test_iou.py +++ b/tests/matchers/test_iou.py @@ -2,7 +2,12 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._iou import IOUMatcher, _match_nodes, match_iou +from traccuracy.matchers._iou import ( + IOUMatcher, + _construct_time_to_seg_id_map, + _match_nodes, + match_iou, +) from tests.test_utils import get_annotated_image, get_movie_with_graph @@ -21,6 +26,24 @@ def test__match_nodes(): gtcells, rescells = _match_nodes(y1, y2) +def test__construct_time_to_seg_id_map(): + # Test 2d data + n_frames = 3 + n_labels = 3 + track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) + time_to_seg_id_map = _construct_time_to_seg_id_map(track_graph) + for t in range(n_frames): + for i in range(1, n_labels): + assert time_to_seg_id_map[t][i] == f"{i}_{t}" + + # Test 3d data + track_graph = get_movie_with_graph(ndims=4, n_frames=n_frames, n_labels=n_labels) + time_to_seg_id_map = _construct_time_to_seg_id_map(track_graph) + for t in range(n_frames): + for i in range(1, n_labels): + assert time_to_seg_id_map[t][i] == f"{i}_{t}" + + def test_match_iou(): # Bad input with pytest.raises(ValueError): diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 2ee56e01..045584a4 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -2,7 +2,7 @@ import networkx as nx import pytest -from traccuracy import EdgeAttr, NodeAttr, TrackingGraph +from traccuracy import EdgeFlag, NodeFlag, TrackingGraph @pytest.fixture @@ -127,83 +127,37 @@ def test_constructor(nx_comp1): with pytest.raises(AssertionError): TrackingGraph(nx_comp1, frame_key="f") with pytest.raises(ValueError): - TrackingGraph(nx_comp1, frame_key=NodeAttr.FALSE_NEG) + TrackingGraph(nx_comp1, frame_key=NodeFlag.FALSE_NEG) with pytest.raises(AssertionError): TrackingGraph(nx_comp1, location_keys=["x", "y", "z"]) with pytest.raises(ValueError): - TrackingGraph(nx_comp1, location_keys=["x", NodeAttr.FALSE_NEG]) + TrackingGraph(nx_comp1, location_keys=["x", NodeFlag.FALSE_NEG]) def test_get_cells_by_frame(simple_graph): - assert simple_graph.get_nodes_in_frame(0) == ["1_0"] + assert Counter(simple_graph.get_nodes_in_frame(0)) == Counter({"1_0"}) assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"]) - assert simple_graph.get_nodes_in_frame(5) == [] - - -def test_get_nodes_by_roi(simple_graph): - assert simple_graph.get_nodes_by_roi(t=(0, 1)) == ["1_0"] - assert Counter(simple_graph.get_nodes_by_roi(x=(1, None))) == Counter( - ["1_0", "1_1", "1_3", "1_4"] - ) - assert Counter(simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None))) == Counter( - ["1_1", "1_2"] - ) - - -def test_get_location(nx_comp1): - graph1 = TrackingGraph(nx_comp1, location_keys=["x", "y"]) - assert graph1.get_location("1_2") == [0, 1] - assert graph1.get_location("1_4") == [2, 1] - graph2 = TrackingGraph(nx_comp1, location_keys=["y", "x"]) - assert graph2.get_location("1_2") == [1, 0] - assert graph2.get_location("1_4") == [1, 2] + assert Counter(simple_graph.get_nodes_in_frame(5)) == Counter([]) def test_get_nodes_with_flag(simple_graph): - assert simple_graph.get_nodes_with_flag(NodeAttr.TP_DIV) == ["1_1"] - assert simple_graph.get_nodes_with_flag(NodeAttr.FP_DIV) == [] + assert Counter(simple_graph.get_nodes_with_flag(NodeFlag.TP_DIV)) == Counter( + ["1_1"] + ) + assert Counter(simple_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == Counter([]) with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp_division") def test_get_edges_with_flag(simple_graph): - assert simple_graph.get_edges_with_flag(EdgeAttr.TRUE_POS) == [("1_0", "1_1")] - assert simple_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG) == [] + assert Counter(simple_graph.get_edges_with_flag(EdgeFlag.TRUE_POS)) == Counter( + [("1_0", "1_1")] + ) + assert Counter(simple_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)) == Counter([]) with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp") -def test_get_nodes_with_attribute(simple_graph): - assert simple_graph.get_nodes_with_attribute("is_tp_division") == ["1_1"] - assert simple_graph.get_nodes_with_attribute("null") == [] - assert simple_graph.get_nodes_with_attribute( - "is_tp_division", criterion=lambda x: x - ) == ["1_1"] - assert ( - simple_graph.get_nodes_with_attribute( - "is_tp_division", criterion=lambda x: not x - ) - == [] - ) - assert simple_graph.get_nodes_with_attribute("x", criterion=lambda x: x > 1) == [ - "1_3", - "1_4", - ] - assert simple_graph.get_nodes_with_attribute( - "x", criterion=lambda x: x > 1, limit_to=["1_3"] - ) == [ - "1_3", - ] - assert ( - simple_graph.get_nodes_with_attribute( - "x", criterion=lambda x: x > 1, limit_to=["1_0"] - ) - == [] - ) - with pytest.raises(KeyError): - simple_graph.get_nodes_with_attribute("x", limit_to=["5"]) - - def test_get_divisions(complex_graph): assert complex_graph.get_divisions() == ["1_1", "2_2"] @@ -212,23 +166,6 @@ def test_get_merges(merge_graph): assert merge_graph.get_merges() == ["3_2"] -def test_get_preds(simple_graph, merge_graph): - # Division graph - assert simple_graph.get_preds("1_0") == [] - assert simple_graph.get_preds("1_1") == ["1_0"] - assert simple_graph.get_preds("1_2") == ["1_1"] - - # Merge graph - assert merge_graph.get_preds("3_3") == ["3_2"] - assert merge_graph.get_preds("3_2") == ["3_1", "3_5"] - - -def test_get_succs(simple_graph): - assert simple_graph.get_succs("1_0") == ["1_1"] - assert Counter(simple_graph.get_succs("1_1")) == Counter(["1_2", "1_3"]) - assert simple_graph.get_succs("1_2") == [] - - def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): tracks = complex_graph.get_connected_components() assert len(tracks) == 2 @@ -244,7 +181,7 @@ def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): assert track2.graph.edges == nx_comp2.edges -def test_get_and_set_node_attributes(simple_graph): +def test_set_flag_on_node(simple_graph): assert simple_graph.nodes()["1_0"] == {"id": "1_0", "t": 0, "y": 1, "x": 1} assert simple_graph.nodes()["1_1"] == { "id": "1_1", @@ -254,26 +191,68 @@ def test_get_and_set_node_attributes(simple_graph): "is_tp_division": True, } - simple_graph.set_node_attribute("1_0", NodeAttr.FALSE_POS, value=False) + simple_graph.set_flag_on_node("1_0", NodeFlag.FALSE_POS, value=True) assert simple_graph.nodes()["1_0"] == { "id": "1_0", "t": 0, "y": 1, "x": 1, - NodeAttr.FALSE_POS: False, + NodeFlag.FALSE_POS: True, } + assert "1_0" in simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] + + simple_graph.set_flag_on_node("1_0", NodeFlag.FALSE_POS, value=False) + assert simple_graph.nodes()["1_0"] == { + "id": "1_0", + "t": 0, + "y": 1, + "x": 1, + NodeFlag.FALSE_POS: False, + } + assert "1_0" not in simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] + + simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=True) + for node in simple_graph.nodes: + assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is True + assert Counter(simple_graph.nodes_by_flag[NodeFlag.FALSE_POS]) == Counter( + list(simple_graph.nodes()) + ) + + simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=False) + for node in simple_graph.nodes: + assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is False + assert not simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] + with pytest.raises(ValueError): - simple_graph.set_node_attribute("1_0", "x", 2) + simple_graph.set_flag_on_node("1_0", "x", 2) + +def test_set_flag_on_edge(simple_graph): + edge_id = ("1_1", "1_3") + assert EdgeFlag.TRUE_POS not in simple_graph.edges()[edge_id] + + simple_graph.set_flag_on_edge(edge_id, EdgeFlag.TRUE_POS, value=True) + assert simple_graph.edges()[edge_id][EdgeFlag.TRUE_POS] is True + assert edge_id in simple_graph.edges_by_flag[EdgeFlag.TRUE_POS] + + simple_graph.set_flag_on_edge(edge_id, EdgeFlag.TRUE_POS, value=False) + assert simple_graph.edges()[edge_id][EdgeFlag.TRUE_POS] is False + assert edge_id not in simple_graph.edges_by_flag[EdgeFlag.TRUE_POS] + + simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=True) + for edge in simple_graph.edges: + assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is True + assert Counter(simple_graph.edges_by_flag[EdgeFlag.FALSE_POS]) == Counter( + list(simple_graph.edges) + ) -def test_get_and_set_edge_attributes(simple_graph): - print(simple_graph.edges()) - assert EdgeAttr.TRUE_POS not in simple_graph.edges()[("1_1", "1_3")] + simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=False) + for edge in simple_graph.edges: + assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is False + assert not simple_graph.edges_by_flag[EdgeFlag.FALSE_POS] - simple_graph.set_edge_attribute(("1_1", "1_3"), EdgeAttr.TRUE_POS, value=False) - assert simple_graph.edges()[("1_1", "1_3")][EdgeAttr.TRUE_POS] is False with pytest.raises(ValueError): - simple_graph.set_edge_attribute(("1_1", "1_3"), "x", 2) + simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) def test_get_tracklets(simple_graph): diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 53d4a7f7..a66517d0 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,6 +1,6 @@ import networkx as nx import numpy as np -from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +from traccuracy._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph from traccuracy.matchers import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors @@ -38,22 +38,22 @@ def test_get_vertex_errors(): get_vertex_errors(matched_data) - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)) == 1 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)) == 2 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)) == 3 - assert matched_data.gt_graph.graph.nodes[15][NodeAttr.FALSE_NEG] - assert not matched_data.gt_graph.graph.nodes[17][NodeAttr.FALSE_NEG] + assert matched_data.gt_graph.nodes[15][NodeFlag.FALSE_NEG] + assert not matched_data.gt_graph.nodes[17][NodeFlag.FALSE_NEG] - assert matched_data.pred_graph.graph.nodes[3][NodeAttr.NON_SPLIT] - assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.NON_SPLIT] + assert matched_data.pred_graph.nodes[3][NodeFlag.NON_SPLIT] + assert not matched_data.pred_graph.nodes[7][NodeFlag.NON_SPLIT] - assert matched_data.pred_graph.graph.nodes[7][NodeAttr.TRUE_POS] - assert not matched_data.pred_graph.graph.nodes[3][NodeAttr.TRUE_POS] + assert matched_data.pred_graph.nodes[7][NodeFlag.TRUE_POS] + assert not matched_data.pred_graph.nodes[3][NodeFlag.TRUE_POS] - assert matched_data.pred_graph.graph.nodes[10][NodeAttr.FALSE_POS] - assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.FALSE_POS] + assert matched_data.pred_graph.nodes[10][NodeFlag.FALSE_POS] + assert not matched_data.pred_graph.nodes[7][NodeFlag.FALSE_POS] def test_assign_edge_errors(): @@ -72,7 +72,7 @@ def test_assign_edge_errors(): comp_g = nx.DiGraph() comp_g.add_nodes_from(comp_ids) comp_g.add_edges_from(comp_edges) - nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) + nx.set_node_attributes(comp_g, True, NodeFlag.TRUE_POS) nx.set_node_attributes( comp_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, @@ -83,7 +83,7 @@ def test_assign_edge_errors(): gt_g = nx.DiGraph() gt_g.add_nodes_from(gt_ids) gt_g.add_edges_from(gt_edges) - nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) + nx.set_node_attributes(gt_g, False, NodeFlag.FALSE_NEG) nx.set_node_attributes( gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} ) @@ -93,8 +93,8 @@ def test_assign_edge_errors(): get_edge_errors(matched_data) - assert matched_data.pred_graph.graph.edges[(7, 8)][EdgeAttr.FALSE_POS] - assert matched_data.gt_graph.graph.edges[(17, 18)][EdgeAttr.FALSE_NEG] + assert matched_data.pred_graph.edges[(7, 8)][EdgeFlag.FALSE_POS] + assert matched_data.gt_graph.edges[(17, 18)][EdgeFlag.FALSE_NEG] def test_assign_edge_errors_semantics(): @@ -133,4 +133,4 @@ def test_assign_edge_errors_semantics(): get_edge_errors(matched_data) - assert matched_data.pred_graph.graph.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] + assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeFlag.WRONG_SEMANTIC] diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 6538e644..609c8ed5 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -1,7 +1,7 @@ import networkx as nx import numpy as np import pytest -from traccuracy import NodeAttr, TrackingGraph +from traccuracy import NodeFlag, TrackingGraph from traccuracy.matchers import Matched from traccuracy.track_errors.divisions import ( _classify_divisions, @@ -51,10 +51,10 @@ def test_classify_divisions_tp(g): # Test true positive _classify_divisions(matched_data) - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] - assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 + assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] + assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] # Check division flag assert matched_data.gt_graph.division_annotations @@ -80,10 +80,10 @@ def test_classify_divisions_fp(g): _classify_divisions(matched_data) - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert NodeAttr.FP_DIV in matched_data.pred_graph.nodes()["1_2"] - assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] - assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 + assert NodeFlag.FP_DIV in matched_data.pred_graph.nodes["1_2"] + assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] + assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] def test_classify_divisions_fn(g): @@ -100,9 +100,9 @@ def test_classify_divisions_fn(g): _classify_divisions(matched_data) - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 - assert NodeAttr.FN_DIV in matched_data.gt_graph.nodes()["2_2"] + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 + assert NodeFlag.FN_DIV in matched_data.gt_graph.nodes["2_2"] @pytest.fixture @@ -160,8 +160,8 @@ class Test_correct_shifted_divisions: def test_no_change(self): # Early division in gt g_pred, g_gt, mapper = get_division_graphs() - g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True - g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True + g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -170,15 +170,15 @@ def test_no_change(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is True - assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is True - assert len(ng_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 + assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is True + assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is True + assert len(ng_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 def test_fn_early(self): # Early division in gt g_pred, g_gt, mapper = get_division_graphs() - g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True - g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True + g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -187,16 +187,16 @@ def test_fn_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes()["1_3"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes()["1_1"][NodeAttr.TP_DIV] is True + assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is False + assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is False + assert ng_pred.nodes["1_3"][NodeFlag.TP_DIV] is True + assert ng_gt.nodes["1_1"][NodeFlag.TP_DIV] is True def test_fp_early(self): # Early division in pred g_gt, g_pred, mapper = get_division_graphs() - g_pred.nodes["1_1"][NodeAttr.FP_DIV] = True - g_gt.nodes["1_3"][NodeAttr.FN_DIV] = True + g_pred.nodes["1_1"][NodeFlag.FP_DIV] = True + g_gt.nodes["1_3"][NodeFlag.FN_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -205,10 +205,10 @@ def test_fp_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_1"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes()["1_3"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes()["1_1"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes()["1_3"][NodeAttr.TP_DIV] is True + assert ng_pred.nodes["1_1"][NodeFlag.FP_DIV] is False + assert ng_gt.nodes["1_3"][NodeFlag.FN_DIV] is False + assert ng_pred.nodes["1_1"][NodeFlag.TP_DIV] is True + assert ng_gt.nodes["1_3"][NodeFlag.TP_DIV] is True def test_evaluate_division_events():