Skip to content

Commit

Permalink
Merge pull request #63 from Janelia-Trackathon-2023/access_by_attr
Browse files Browse the repository at this point in the history
Speed up accessing nodes/edges by attribute and node/edge attributes
  • Loading branch information
cmalinmayor authored Oct 27, 2023
2 parents 082e010 + a680f4c commit 674c0f8
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 61 deletions.
186 changes: 133 additions & 53 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,42 @@ def __init__(
)
self.location_keys = location_keys

# Define empty attributes that will be set by update_graph
self.graph = None
self.nodes_by_frame = None
self.start_frame = None
self.end_frame = None
self.graph = graph

self._update_graph(graph)
# construct dictionaries from attributes to nodes/edges for easy lookup
self.nodes_by_frame = {}
self.nodes_by_flag = {flag: set() for flag in NodeAttr}
self.edges_by_flag = {flag: set() for flag in EdgeAttr}
for node, attrs in self.graph.nodes.items():
# check that every node has the time frame and location specified
assert (
self.frame_key in attrs.keys()
), f"Frame key {self.frame_key} not present for node {node}."
for key in self.location_keys:
assert (
key in attrs.keys()
), f"Location key {key} not present for node {node}."

# store node id in nodes_by_frame mapping
frame = attrs[self.frame_key]
if frame not in self.nodes_by_frame.keys():
self.nodes_by_frame[frame] = {node}
else:
self.nodes_by_frame[frame].add(node)
# store node id in nodes_by_flag mapping
for flag in NodeAttr:
if flag in attrs and attrs[flag]:
self.nodes_by_flag[flag].add(node)

# store edge id in edges_by_flag
for edge, attrs in self.graph.edges.items():
for flag in EdgeAttr:
if flag in attrs and attrs[flag]:
self.edges_by_flag[flag].add(edge)

# Store first and last frames for reference
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1

# Record types of annotations that have been calculated
self.division_annotations = False
Expand Down Expand Up @@ -220,7 +249,7 @@ def get_nodes_in_frame(self, frame):
list of node_ids: A list of node ids for all nodes in frame.
"""
if frame in self.nodes_by_frame.keys():
return self.nodes_by_frame[frame]
return list(self.nodes_by_frame[frame])
else:
return []

Expand All @@ -247,12 +276,7 @@ def get_nodes_with_flag(self, attr):
"""
if not isinstance(attr, NodeAttr):
raise ValueError(f"Function takes NodeAttr arguments, not {type(attr)}.")
nodes_with_flag = [
node
for node, attrs in self.nodes().items()
if attr in attrs.keys() and attrs[attr] is True
]
return nodes_with_flag
return list(self.nodes_by_flag[attr])

def get_edges_with_flag(self, attr):
"""Get all edges with specified EdgeAttr set to True.
Expand All @@ -266,12 +290,7 @@ def get_edges_with_flag(self, attr):
"""
if not isinstance(attr, EdgeAttr):
raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.")
edges_with_flag = [
edge
for edge, attrs in self.edges().items()
if attr in attrs.keys() and attrs[attr] is True
]
return edges_with_flag
return list(self.edges_by_flag[attr])

def get_nodes_by_roi(self, **kwargs):
"""Gets the nodes in a given region of interest (ROI). The ROI is
Expand All @@ -289,6 +308,7 @@ def get_nodes_by_roi(self, **kwargs):
Returns:
list of hashable: A list of node_ids for all nodes in the ROI.
"""
frames = None
dimensions = []
for dim, limit in kwargs.items():
if not (dim == self.frame_key or dim in self.location_keys):
Expand All @@ -297,9 +317,25 @@ def get_nodes_by_roi(self, **kwargs):
f" {self.frame_key} or one of the location keys"
f" {self.location_keys}."
)
dimensions.append((dim, limit[0], limit[1]))
if dim == self.frame_key:
frames = list(limit)
else:
dimensions.append((dim, limit[0], limit[1]))
nodes = []
for node, attrs in self.graph.nodes().items():
if frames:
if frames[0] is None:
frames[0] = self.start_frame
if frames[1] is None:
frames[1] = self.end_frame
possible_nodes = []
for frame in range(frames[0], frames[1]):
if frame in self.nodes_by_frame:
possible_nodes.extend(self.nodes_by_frame[frame])
else:
possible_nodes = self.graph.nodes()

for node in possible_nodes:
attrs = self.graph.nodes[node]
inside = True
for dim, start, end in dimensions:
if start is not None and attrs[dim] < start:
Expand Down Expand Up @@ -447,41 +483,27 @@ def get_subgraph(self, nodes):

new_graph = self.graph.subgraph(nodes).copy()
new_trackgraph = copy.deepcopy(self)
new_trackgraph._update_graph(new_graph)

return new_trackgraph

def _update_graph(self, graph):
"""Given a new graph, which is expected to be a subgraph of the current graph,
update attributes which are dependent on the graph.
Args:
graph (nx.DiGraph): A networkx graph that is a subgraph of the original graph
"""
self.graph = graph
new_trackgraph.graph = new_graph
for frame, nodes_in_frame in self.nodes_by_frame.items():
new_nodes_in_frame = nodes_in_frame.intersection(nodes)
if new_nodes_in_frame:
new_trackgraph.nodes_by_frame[frame] = new_nodes_in_frame
else:
del new_trackgraph.nodes_by_frame[frame]

# construct a dictionary from frames to node ids for easy lookup
self.nodes_by_frame = {}
for node, attrs in self.graph.nodes.items():
# check that every node has the time frame and location specified
assert (
self.frame_key in attrs.keys()
), f"Frame key {self.frame_key} not present for node {node}."
for key in self.location_keys:
assert (
key in attrs.keys()
), f"Location key {key} not present for node {node}."
for attr in NodeAttr:
new_trackgraph.nodes_by_flag[attr] = self.nodes_by_flag[attr].intersection(
nodes
)
for attr in EdgeAttr:
new_trackgraph.edges_by_flag[attr] = self.edges_by_flag[attr].intersection(
nodes
)

# store node id in nodes_by_frame mapping
frame = attrs[self.frame_key]
if frame not in self.nodes_by_frame.keys():
self.nodes_by_frame[frame] = [node]
else:
self.nodes_by_frame[frame].append(node)
new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys())
new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1

# Store first and last frames for reference
self.start_frame = min(self.nodes_by_frame.keys())
self.end_frame = max(self.nodes_by_frame.keys()) + 1
return new_trackgraph

def set_node_attribute(self, ids, attr, value=True):
"""Set an attribute flag for a set of nodes specified by
Expand All @@ -507,6 +529,10 @@ def set_node_attribute(self, ids, attr, value=True):
)
for _id in ids:
self.graph.nodes[_id][attr] = value
if value:
self.nodes_by_flag[attr].add(_id)
else:
self.nodes_by_flag[attr].discard(_id)

def set_edge_attribute(self, ids, attr, value=True):
"""Set an attribute flag for a set of edges specified by
Expand All @@ -532,6 +558,60 @@ def set_edge_attribute(self, ids, attr, value=True):
)
for _id in ids:
self.graph.edges[_id][attr] = value
if value:
self.edges_by_flag[attr].add(_id)
else:
self.edges_by_flag[attr].discard(_id)

def get_node_attribute(self, _id, attr):
"""Get the boolean value of a given attribute for a given node.
Args:
_id (hashable): node id
attr (NodeAttr): Node attribute to fetch the value of
Raises:
ValueError: if attr is not a NodeAttr
Returns:
bool: The value of the attribute for that node. If the attribute
is not present on the graph, the value is presumed False.
"""
if not isinstance(attr, NodeAttr):
raise ValueError(
f"Provided attribute {attr} is not of type NodeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
)

if attr not in self.graph.nodes[_id]:
return False
return self.graph.nodes[_id][attr]

def get_edge_attribute(self, _id, attr):
"""Get the boolean value of a given attribute for a given edge.
Args:
_id (hashable): node id
attr (EdgeAttr): Edge attribute to fetch the value of
Raises:
ValueError: if attr is not a EdgeAttr
Returns:
bool: The value of the attribute for that edge. If the attribute
is not present on the graph, the value is presumed False.
"""
if not isinstance(attr, EdgeAttr):
raise ValueError(
f"Provided attribute {attr} is not of type EdgeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
)

if attr not in self.graph.edges[_id]:
return False
return self.graph.edges[_id][attr]

def get_tracklets(self):
"""Gets a list of new TrackingGraph objects containing all tracklets of the current graph.
Expand Down
22 changes: 14 additions & 8 deletions tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

import networkx as nx
import pytest
from traccuracy import EdgeAttr, NodeAttr, TrackingGraph
Expand Down Expand Up @@ -114,10 +116,10 @@ def test_constructor(nx_comp1):
assert tracking_graph.start_frame == 0
assert tracking_graph.end_frame == 4
assert tracking_graph.nodes_by_frame == {
0: ["1_0"],
1: ["1_1"],
2: ["1_2", "1_3"],
3: ["1_4"],
0: {"1_0"},
1: {"1_1"},
2: {"1_2", "1_3"},
3: {"1_4"},
}

# raise AssertionError if frame key not present or ValueError if overlaps
Expand All @@ -134,14 +136,18 @@ def test_constructor(nx_comp1):

def test_get_cells_by_frame(simple_graph):
assert simple_graph.get_nodes_in_frame(0) == ["1_0"]
assert simple_graph.get_nodes_in_frame(2) == ["1_2", "1_3"]
assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"])
assert simple_graph.get_nodes_in_frame(5) == []


def test_get_nodes_by_roi(simple_graph):
assert simple_graph.get_nodes_by_roi(t=(0, 1)) == ["1_0"]
assert simple_graph.get_nodes_by_roi(x=(1, None)) == ["1_0", "1_1", "1_3", "1_4"]
assert simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None)) == ["1_1", "1_2"]
assert Counter(simple_graph.get_nodes_by_roi(x=(1, None))) == Counter(
["1_0", "1_1", "1_3", "1_4"]
)
assert Counter(simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None))) == Counter(
["1_1", "1_2"]
)


def test_get_location(nx_comp1):
Expand Down Expand Up @@ -219,7 +225,7 @@ def test_get_preds(simple_graph, merge_graph):

def test_get_succs(simple_graph):
assert simple_graph.get_succs("1_0") == ["1_1"]
assert simple_graph.get_succs("1_1") == ["1_2", "1_3"]
assert Counter(simple_graph.get_succs("1_1")) == Counter(["1_2", "1_3"])
assert simple_graph.get_succs("1_2") == []


Expand Down

0 comments on commit 674c0f8

Please sign in to comment.