Skip to content

Commit

Permalink
Remove get_preds and get_succs
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Jan 10, 2024
1 parent 528323b commit f036f28
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 60 deletions.
34 changes: 1 addition & 33 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,38 +302,6 @@ def get_merges(self) -> list[Hashable]:
"""
return [node for node, degree in self.graph.in_degree() if degree >= 2]

def get_preds(self, node: Hashable) -> list[Hashable]:
"""Get all predecessors of the given node.
A predecessor node is any node from a previous time point that has an edge to
the given node. In a case where merges are not allowed, each node will have a
maximum of one predecessor.
Args:
node (hashable): A node id
Returns:
list of hashable: A list of node ids containing all nodes that
have an edge to the given node.
"""
return [pred for pred, _ in self.graph.in_edges(node)]

def get_succs(self, node: Hashable) -> list[Hashable]:
"""Get all successor nodes of the given node.
A successor node is any node from a later time point that has an edge
from the given node. In a case where divisions are not allowed,
a node will have a maximum of one successor.
Args:
node (hashable): A node id
Returns:
list of hashable: A list of node ids containing all nodes that have
an edge from the given node.
"""
return [succ for _, succ in self.graph.out_edges(node)]

def get_connected_components(self) -> list[TrackingGraph]:
"""Get a list of TrackingGraphs, each corresponding to one track
(i.e., a connected component in the track graph).
Expand Down Expand Up @@ -506,7 +474,7 @@ def get_tracklets(
# Remove all intertrack edges from a copy of the original graph
removed_edges = []
for parent in self.get_divisions():
for daughter in self.get_succs(parent):
for daughter in self.graph.successors(parent):
graph_copy.remove_edge(parent, daughter)
removed_edges.append((parent, daughter))

Expand Down
4 changes: 2 additions & 2 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def get_edge_errors(matched_data: Matched):
graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False)

for parent in graph.get_divisions():
for daughter in graph.get_succs(parent):
for daughter in graph.graph.successors(parent):
graph.set_flag_on_edge(
(parent, daughter), EdgeFlag.INTERTRACK_EDGE, True
)

for merge in graph.get_merges():
for parent in graph.get_preds(merge):
for parent in graph.graph.predecessors(merge):
graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True)

# fp edges - edges in induced_graph that aren't in gt_graph
Expand Down
16 changes: 8 additions & 8 deletions src/traccuracy/track_errors/divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def _find_pred_node_matches(pred_node):
g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True)
# Check if the division has the correct daughters
else:
succ_gt = g_gt.get_succs(gt_node)
succ_gt = g_gt.graph.successors(gt_node)
# Map pred succ nodes onto gt, unmapped nodes will return as None
succ_pred = [
_find_pred_node_matches(n) for n in g_pred.get_succs(pred_node)
_find_pred_node_matches(n) for n in g_pred.graph.successors(pred_node)
]

# If daughters are same, division is correct
Expand Down Expand Up @@ -107,7 +107,7 @@ def _get_pred_by_t(g, node, delta_frames):
hashable: Node key of predecessor in target frame
"""
for _ in range(delta_frames):
nodes = g.get_preds(node)
nodes = list(g.graph.predecessors(node))
# Exit if there are no predecessors
if len(nodes) == 0:
return None
Expand All @@ -133,7 +133,7 @@ def _get_succ_by_t(g, node, delta_frames):
hashable: Node id of successor
"""
for _ in range(delta_frames):
nodes = g.get_succs(node)
nodes = list(g.graph.successors(node))
# Exit if there are no successors another division
if len(nodes) == 0 or len(nodes) >= 2:
return None
Expand Down Expand Up @@ -196,9 +196,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
# Check if daughters match
fp_succ = [
_get_succ_by_t(g_pred, node, t_fn - t_fp)
for node in g_pred.get_succs(fp_node)
for node in g_pred.graph.successors(fp_node)
]
fn_succ = g_gt.get_succs(fn_node)
fn_succ = g_gt.graph.successors(fn_node)
if Counter(fp_succ) != Counter(fn_succ):
# Daughters don't match so division cannot match
continue
Expand All @@ -217,9 +217,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
# Check if daughters match
fn_succ = [
_get_succ_by_t(g_gt, node, t_fp - t_fn)
for node in g_gt.get_succs(fn_node)
for node in g_gt.graph.successors(fn_node)
]
fp_succ = g_pred.get_succs(fp_node)
fp_succ = g_pred.graph.successors(fp_node)
if Counter(fp_succ) != Counter(fn_succ):
# Daughters don't match so division cannot match
continue
Expand Down
17 changes: 0 additions & 17 deletions tests/test_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,6 @@ def test_get_merges(merge_graph):
assert merge_graph.get_merges() == ["3_2"]


def test_get_preds(simple_graph, merge_graph):
# Division graph
assert simple_graph.get_preds("1_0") == []
assert simple_graph.get_preds("1_1") == ["1_0"]
assert simple_graph.get_preds("1_2") == ["1_1"]

# Merge graph
assert merge_graph.get_preds("3_3") == ["3_2"]
assert merge_graph.get_preds("3_2") == ["3_1", "3_5"]


def test_get_succs(simple_graph):
assert simple_graph.get_succs("1_0") == ["1_1"]
assert Counter(simple_graph.get_succs("1_1")) == Counter(["1_2", "1_3"])
assert simple_graph.get_succs("1_2") == []


def test_get_connected_components(complex_graph, nx_comp1, nx_comp2):
tracks = complex_graph.get_connected_components()
assert len(tracks) == 2
Expand Down

0 comments on commit f036f28

Please sign in to comment.