Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up accessing nodes/edges by attribute and node/edge attributes #63

Merged
merged 8 commits into from
Oct 27, 2023
186 changes: 133 additions & 53 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,42 @@ def __init__(
)
self.location_keys = location_keys

# Define empty attributes that will be set by update_graph
self.graph = None
self.nodes_by_frame = None
self.start_frame = None
self.end_frame = None
self.graph = graph

self._update_graph(graph)
# construct dictionaries from attributes to nodes/edges for easy lookup
self.nodes_by_frame = {}
self.nodes_by_flag = {flag: set() for flag in NodeAttr}
self.edges_by_flag = {flag: set() for flag in EdgeAttr}
for node, attrs in self.graph.nodes.items():
# check that every node has the time frame and location specified
assert (
self.frame_key in attrs.keys()
), f"Frame key {self.frame_key} not present for node {node}."
for key in self.location_keys:
assert (
key in attrs.keys()
), f"Location key {key} not present for node {node}."

# store node id in nodes_by_frame mapping
frame = attrs[self.frame_key]
if frame not in self.nodes_by_frame.keys():
self.nodes_by_frame[frame] = {node}
else:
self.nodes_by_frame[frame].add(node)
# store node id in nodes_by_flag mapping
for flag in NodeAttr:
if flag in attrs and attrs[flag]:
self.nodes_by_flag[flag].add(node)

# store edge id in edges_by_flag
for edge, attrs in self.graph.edges.items():
for flag in EdgeAttr:
if flag in attrs and attrs[flag]:
self.edges_by_flag[flag].add(edge)

# Store first and last frames for reference
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1

# Record types of annotations that have been calculated
self.division_annotations = False
Expand Down Expand Up @@ -220,7 +249,7 @@ def get_nodes_in_frame(self, frame):
list of node_ids: A list of node ids for all nodes in frame.
"""
if frame in self.nodes_by_frame.keys():
return self.nodes_by_frame[frame]
return list(self.nodes_by_frame[frame])
else:
return []

Expand All @@ -247,12 +276,7 @@ def get_nodes_with_flag(self, attr):
"""
if not isinstance(attr, NodeAttr):
raise ValueError(f"Function takes NodeAttr arguments, not {type(attr)}.")
nodes_with_flag = [
node
for node, attrs in self.nodes().items()
if attr in attrs.keys() and attrs[attr] is True
]
return nodes_with_flag
return list(self.nodes_by_flag[attr])

def get_edges_with_flag(self, attr):
"""Get all edges with specified EdgeAttr set to True.
Expand All @@ -266,12 +290,7 @@ def get_edges_with_flag(self, attr):
"""
if not isinstance(attr, EdgeAttr):
raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.")
edges_with_flag = [
edge
for edge, attrs in self.edges().items()
if attr in attrs.keys() and attrs[attr] is True
]
return edges_with_flag
return list(self.edges_by_flag[attr])

def get_nodes_by_roi(self, **kwargs):
"""Gets the nodes in a given region of interest (ROI). The ROI is
Expand All @@ -289,6 +308,7 @@ def get_nodes_by_roi(self, **kwargs):
Returns:
list of hashable: A list of node_ids for all nodes in the ROI.
"""
frames = None
dimensions = []
for dim, limit in kwargs.items():
if not (dim == self.frame_key or dim in self.location_keys):
Expand All @@ -297,9 +317,25 @@ def get_nodes_by_roi(self, **kwargs):
f" {self.frame_key} or one of the location keys"
f" {self.location_keys}."
)
dimensions.append((dim, limit[0], limit[1]))
if dim == self.frame_key:
frames = list(limit)
else:
dimensions.append((dim, limit[0], limit[1]))
nodes = []
for node, attrs in self.graph.nodes().items():
if frames:
if frames[0] is None:
frames[0] = self.start_frame
if frames[1] is None:
frames[1] = self.end_frame
possible_nodes = []
for frame in range(frames[0], frames[1]):
if frame in self.nodes_by_frame:
possible_nodes.extend(self.nodes_by_frame[frame])
else:
possible_nodes = self.graph.nodes()

for node in possible_nodes:
attrs = self.graph.nodes[node]
inside = True
for dim, start, end in dimensions:
if start is not None and attrs[dim] < start:
Expand Down Expand Up @@ -447,41 +483,27 @@ def get_subgraph(self, nodes):

new_graph = self.graph.subgraph(nodes).copy()
new_trackgraph = copy.deepcopy(self)
new_trackgraph._update_graph(new_graph)

return new_trackgraph

def _update_graph(self, graph):
"""Given a new graph, which is expected to be a subgraph of the current graph,
update attributes which are dependent on the graph.

Args:
graph (nx.DiGraph): A networkx graph that is a subgraph of the original graph
"""
self.graph = graph
new_trackgraph.graph = new_graph
for frame, nodes_in_frame in self.nodes_by_frame.items():
new_nodes_in_frame = nodes_in_frame.intersection(nodes)
if new_nodes_in_frame:
new_trackgraph.nodes_by_frame[frame] = new_nodes_in_frame
else:
del new_trackgraph.nodes_by_frame[frame]

# construct a dictionary from frames to node ids for easy lookup
self.nodes_by_frame = {}
for node, attrs in self.graph.nodes.items():
# check that every node has the time frame and location specified
assert (
self.frame_key in attrs.keys()
), f"Frame key {self.frame_key} not present for node {node}."
for key in self.location_keys:
assert (
key in attrs.keys()
), f"Location key {key} not present for node {node}."
for attr in NodeAttr:
new_trackgraph.nodes_by_flag[attr] = self.nodes_by_flag[attr].intersection(
nodes
)
for attr in EdgeAttr:
new_trackgraph.edges_by_flag[attr] = self.edges_by_flag[attr].intersection(
nodes
)

# store node id in nodes_by_frame mapping
frame = attrs[self.frame_key]
if frame not in self.nodes_by_frame.keys():
self.nodes_by_frame[frame] = [node]
else:
self.nodes_by_frame[frame].append(node)
new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys())
new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1

# Store first and last frames for reference
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1
return new_trackgraph

def set_node_attribute(self, ids, attr, value=True):
"""Set an attribute flag for a set of nodes specified by
Expand All @@ -507,6 +529,10 @@ def set_node_attribute(self, ids, attr, value=True):
)
for _id in ids:
self.graph.nodes[_id][attr] = value
if value:
self.nodes_by_flag[attr].add(_id)
else:
self.nodes_by_flag[attr].discard(_id)

def set_edge_attribute(self, ids, attr, value=True):
"""Set an attribute flag for a set of edges specified by
Expand All @@ -532,6 +558,60 @@ def set_edge_attribute(self, ids, attr, value=True):
)
for _id in ids:
self.graph.edges[_id][attr] = value
if value:
self.edges_by_flag[attr].add(_id)
else:
self.edges_by_flag[attr].discard(_id)

def get_node_attribute(self, _id, attr):
"""Get the boolean value of a given attribute for a given node.

Args:
_id (hashable): node id
attr (NodeAttr): Node attribute to fetch the value of

Raises:
ValueError: if attr is not a NodeAttr

Returns:
bool: The value of the attribute for that node. If the attribute
is not present on the graph, the value is presumed False.
"""
if not isinstance(attr, NodeAttr):
raise ValueError(
f"Provided attribute {attr} is not of type NodeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
)

if attr not in self.graph.nodes[_id]:
return False
return self.graph.nodes[_id][attr]

def get_edge_attribute(self, _id, attr):
"""Get the boolean value of a given attribute for a given edge.

Args:
_id (hashable): node id
attr (EdgeAttr): Edge attribute to fetch the value of

Raises:
ValueError: if attr is not a EdgeAttr

Returns:
bool: The value of the attribute for that edge. If the attribute
is not present on the graph, the value is presumed False.
"""
if not isinstance(attr, EdgeAttr):
raise ValueError(
f"Provided attribute {attr} is not of type EdgeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
)

if attr not in self.graph.edges[_id]:
return False
return self.graph.edges[_id][attr]

def get_tracklets(self):
"""Gets a list of new TrackingGraph objects containing all tracklets of the current graph.
Expand Down
22 changes: 14 additions & 8 deletions tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

import networkx as nx
import pytest
from traccuracy import EdgeAttr, NodeAttr, TrackingGraph
Expand Down Expand Up @@ -114,10 +116,10 @@ def test_constructor(nx_comp1):
assert tracking_graph.start_frame == 0
assert tracking_graph.end_frame == 4
assert tracking_graph.nodes_by_frame == {
0: ["1_0"],
1: ["1_1"],
2: ["1_2", "1_3"],
3: ["1_4"],
0: {"1_0"},
1: {"1_1"},
2: {"1_2", "1_3"},
3: {"1_4"},
}

# raise AssertionError if frame key not present or ValueError if overlaps
Expand All @@ -134,14 +136,18 @@ def test_constructor(nx_comp1):

def test_get_cells_by_frame(simple_graph):
assert simple_graph.get_nodes_in_frame(0) == ["1_0"]
assert simple_graph.get_nodes_in_frame(2) == ["1_2", "1_3"]
assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"])
assert simple_graph.get_nodes_in_frame(5) == []


def test_get_nodes_by_roi(simple_graph):
assert simple_graph.get_nodes_by_roi(t=(0, 1)) == ["1_0"]
assert simple_graph.get_nodes_by_roi(x=(1, None)) == ["1_0", "1_1", "1_3", "1_4"]
assert simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None)) == ["1_1", "1_2"]
assert Counter(simple_graph.get_nodes_by_roi(x=(1, None))) == Counter(
["1_0", "1_1", "1_3", "1_4"]
)
assert Counter(simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None))) == Counter(
["1_1", "1_2"]
)


def test_get_location(nx_comp1):
Expand Down Expand Up @@ -219,7 +225,7 @@ def test_get_preds(simple_graph, merge_graph):

def test_get_succs(simple_graph):
assert simple_graph.get_succs("1_0") == ["1_1"]
assert simple_graph.get_succs("1_1") == ["1_2", "1_3"]
assert Counter(simple_graph.get_succs("1_1")) == Counter(["1_2", "1_3"])
assert simple_graph.get_succs("1_2") == []


Expand Down