-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add module for correcting merge errors based on subcompartment consis…
…tency. PiperOrigin-RevId: 597099188
- Loading branch information
1 parent
ca60d17
commit ba18cc3
Showing
2 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Library for computing skeleton consistency and optimal cuts.""" | ||
import collections | ||
from typing import Optional, Sequence | ||
import networkx as nx | ||
import numpy as np | ||
|
||
Node = int | ||
Edge = tuple[Node, Node] | ||
|
||
|
||
class NodeCounter: | ||
"""ABC for counters used by CentripetalSkeletonConsistency.""" | ||
|
||
def total_counts(self) -> np.ndarray: | ||
raise NotImplementedError | ||
|
||
def add_node(self, node: Node, counts: np.ndarray): | ||
raise NotImplementedError | ||
|
||
|
||
class IndexCounter(NodeCounter): | ||
"""Counter where node label is used to index into the counter vector.""" | ||
|
||
def __init__(self, nx_skeleton: nx.Graph, node_property_name: str): | ||
self._nx_skeleton = nx_skeleton | ||
self._node_property_name = node_property_name | ||
|
||
def total_counts(self) -> np.ndarray: | ||
"""Count node labels for entire nx_skeleton. | ||
Returns: | ||
ndarray of counts with length equal to max label + 1. | ||
""" | ||
node_labels = [ | ||
v for _, v in self._nx_skeleton.nodes(data=self._node_property_name) | ||
] | ||
labels, counts = np.unique(node_labels, return_counts=True) | ||
total_counts = np.zeros(max(labels) + 1, dtype=np.int64) | ||
for label, count in zip(labels, counts): | ||
total_counts[label] = count | ||
return total_counts | ||
|
||
def add_node(self, node: Node, counts: np.ndarray): | ||
"""Increment count for label of given node.""" | ||
label = self._nx_skeleton.nodes[node][self._node_property_name] | ||
counts[label] += 1 | ||
|
||
|
||
class VectorCounter(NodeCounter): | ||
"""Counter for nodes that hold vector counts / probabilities already.""" | ||
|
||
def __init__(self, nx_skeleton: nx.Graph, node_property_name: str): | ||
self._nx_skeleton = nx_skeleton | ||
self._node_property_name = node_property_name | ||
|
||
def total_counts(self) -> np.ndarray: | ||
node_counts_or_probabilities = [ | ||
v for _, v in self._nx_skeleton.nodes(data=self._node_property_name) | ||
] | ||
return np.sum(node_counts_or_probabilities, axis=0) | ||
|
||
def add_node(self, node: Node, counts: np.ndarray): | ||
counts += self._nx_skeleton.nodes[node][self._node_property_name] | ||
|
||
|
||
class UnitCounter(NodeCounter): | ||
"""Counter where every node just gets a value of 1. | ||
This is used for just counting leaving nodes rather than actually computing | ||
consistencies. | ||
""" | ||
|
||
def __init__(self, nx_skeleton: nx.Graph): | ||
self._nx_skeleton = nx_skeleton | ||
|
||
def total_counts(self) -> np.ndarray: | ||
return np.array([len(self._nx_skeleton)]) | ||
|
||
def add_node(self, unused_node: Node, counts: np.ndarray): | ||
counts[0] += 1 | ||
|
||
|
||
class CentripetalSkeletonConsistency(object): | ||
"""Visits edges from leaf nodes in, tallying the consistency for each cut.""" | ||
|
||
def __init__(self, nx_skeleton: nx.Graph, counter: NodeCounter, | ||
remain_sources: Sequence[Node] = ()): | ||
"""Constructor. | ||
Args: | ||
nx_skeleton: NetworkX skeleton whose edges will be annotated in place. | ||
counter: NodeCounter to use to tally node labels / counts / probabilities. | ||
remain_sources: A Sequence of Nodes; if given, we will start the cut | ||
search as far away from these sources as possible. This makes it likely | ||
that the orientation of cuts will have leaving direction away from | ||
remain_sources. | ||
Raises: | ||
ValueError: if input nx_skeleton is not a single connected component. | ||
""" | ||
if nx.number_connected_components(nx_skeleton) != 1: | ||
raise ValueError( | ||
'Skeleton consistency only works with single connected component.') | ||
self._nx_skeleton = nx_skeleton | ||
self._counter = counter | ||
self._total_label_counts = self._counter.total_counts() | ||
self._remain_sources = remain_sources | ||
|
||
def _consistency(self, class_label_counts: np.ndarray) -> float: | ||
if class_label_counts.size == 0: | ||
return 0.0 | ||
return float(class_label_counts.max()) | ||
|
||
def init_consistency(self): | ||
"""Get the initial global consistency of nx_skeleton.""" | ||
return self._consistency(self._total_label_counts) | ||
|
||
def _find_leaf_nodes(self) -> list[Node]: | ||
"""Find leaf nodes as BFS from any remain_sources. | ||
Returns: | ||
List of leaf nodes, in order visited by BFS starting from remain_sources. | ||
This order is important so that the cut search can be biased to generate | ||
cuts with orientation such that the leaving direction points away from | ||
remain_sources. | ||
""" | ||
nx_skeleton = self._nx_skeleton | ||
to_visit = collections.deque(self._remain_sources) | ||
if not to_visit: # No remain_sources; start anywhere. | ||
to_visit.append(next(iter(nx_skeleton))) | ||
|
||
visited = set() | ||
leaf_nodes = [] | ||
while to_visit: | ||
node = to_visit.popleft() | ||
visited.add(node) | ||
if nx_skeleton.degree(node) == 1: | ||
leaf_nodes.append(node) | ||
unvisited_neighbors = set(nx_skeleton.neighbors(node)) - visited | ||
to_visit.extend(unvisited_neighbors) | ||
return leaf_nodes | ||
|
||
def _normalize_edge(self, edge: Edge) -> Edge: | ||
n0, n1 = edge | ||
return (n0, n1) if n0 < n1 else (n1, n0) | ||
|
||
def best_consistency_cut(self, | ||
filter_func=None) -> tuple[Optional[Edge], float]: | ||
"""Moves from leaf nodes in, annotating edges and tracking best cut. | ||
Annotates 'leaving_counts' and 'leaving_direction_node' in each edge as it | ||
moves in. The leaving_counts are then used to compute the consistency of | ||
the 'leaving' branch and the 'remaining' branch, and thus the global post- | ||
cut consistency. | ||
The algorithm starts at leaf nodes and moves in until it reaches a branch | ||
point. We cannot move in from branch points until all adjacent edges but | ||
one are visited. The leaf nodes are used in reverse order, and traversal is | ||
by DFS; this makes it likely that leaving directions are oriented away from | ||
any remain_sources. | ||
Args: | ||
filter_func: Optional filter function accepting leaving_counts, | ||
cut_consistency, and init_consistency. If given, cuts for which | ||
filter function returns False will not be considered for best cut. | ||
Returns: | ||
(best_cut, best_cut_consistency) | ||
best_cut: (node0, node1) edge identifying the best consistency cut, or | ||
None if no good cut is found. | ||
best_cut_consistency: float giving the consistency of the best cut, or | ||
init_consistency if no good cut is found. | ||
Raises: | ||
networkx.HasACycle: if the algorithm detects a cycle due to failure to | ||
visit all edges. | ||
ValueError: if the algorithm fails to visit all edges without detecting | ||
a cycle. This can occur if the graph has been previously traversed and | ||
leaving_counts are already marked on edges. | ||
""" | ||
init_consistency = self.init_consistency() | ||
best_cut = None, init_consistency | ||
start_nodes = self._find_leaf_nodes() | ||
if not start_nodes: | ||
return best_cut | ||
|
||
edges_visited = 0 | ||
while start_nodes: | ||
start_node = start_nodes.pop() | ||
edges = self._nx_skeleton.edges(nbunch=start_node, data='leaving_counts') | ||
unvisited_edges = [edge for edge in edges if edge[2] is None] | ||
if not unvisited_edges: | ||
continue # This node is done. | ||
if len(unvisited_edges) > 1: | ||
# If we have more than one unvisited edge out, then we can't move in | ||
# from here at this time. Wait for this node to be added back into | ||
# start_nodes later. | ||
continue | ||
edge_to_visit = unvisited_edges[0][:2] | ||
|
||
# Check whether cutting here is the new best. | ||
summed_counts = np.zeros_like(self._total_label_counts) | ||
for edge in edges: | ||
if edge[2] is not None: | ||
summed_counts += edge[2] | ||
self._counter.add_node(start_node, summed_counts) | ||
remain_counts = self._total_label_counts - summed_counts | ||
leave_consistency = self._consistency(summed_counts) | ||
remain_consistency = self._consistency(remain_counts) | ||
consistency = leave_consistency + remain_consistency | ||
if consistency > best_cut[1]: | ||
if filter_func is None or filter_func( | ||
leaving_counts=summed_counts, cut_consistency=consistency, | ||
init_consistency=init_consistency): | ||
best_cut = self._normalize_edge(edge_to_visit), consistency | ||
|
||
# Mark this one visited and add the next node. | ||
data = self._nx_skeleton.edges[edge_to_visit] | ||
data['leaving_counts'] = summed_counts | ||
data['leaving_direction_node'] = start_node # Not used here, but useful. | ||
n0, n1 = edge_to_visit | ||
start_nodes.append(n0 if n0 != start_node else n1) | ||
edges_visited += 1 | ||
|
||
# If there are still unvisited edges, we might have gotten stuck at a cycle, | ||
# which prevents unvisted_edges == 1 above. | ||
if edges_visited < self._nx_skeleton.number_of_edges(): | ||
if nx.cycle_basis(self._nx_skeleton): | ||
raise nx.HasACycle('Failed to visit all edges; cycle detected.') | ||
raise ValueError('Failed to visit all edges; graph previously traversed?') | ||
return best_cut |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for connectomics.segmentation.consistency.""" | ||
|
||
from absl.testing import absltest | ||
from connectomics.segmentation import consistency | ||
import networkx as nx | ||
import numpy as np | ||
|
||
|
||
class CountersTest(absltest.TestCase): | ||
|
||
def test_index_counter(self): | ||
path10 = nx.generators.path_graph(10) | ||
for i in range(10): | ||
path10.nodes[i]['class_label'] = i // 2 | ||
index_counter = consistency.IndexCounter(path10, 'class_label') | ||
total_counts = index_counter.total_counts() | ||
np.testing.assert_equal(total_counts, np.ones(5) * 2) | ||
self.assertEqual(total_counts.dtype, np.int64) | ||
|
||
counts = np.zeros(5, dtype=np.int64) | ||
expected = counts.copy() | ||
expected[2] = 1 | ||
index_counter.add_node(5, counts) | ||
np.testing.assert_equal(counts, expected) | ||
|
||
def test_vector_counter(self): | ||
path10 = nx.generators.path_graph(10) | ||
for i in range(10): | ||
path10.nodes[i]['class_probabilities'] = [0.9, 0.1] | ||
vector_counter = consistency.VectorCounter(path10, 'class_probabilities') | ||
total_counts = vector_counter.total_counts() | ||
np.testing.assert_almost_equal(total_counts, np.array([9.0, 1.0])) | ||
|
||
|
||
class ConsistencyTest(absltest.TestCase): | ||
|
||
def test_class_label_best_consistency_cut(self): | ||
path10 = nx.generators.path_graph(10) | ||
for i in range(5): | ||
path10.nodes[i]['class_label'] = 1 | ||
for i in range(5, 10): | ||
path10.nodes[i]['class_label'] = 2 | ||
# Use remain_sources=[9] to cause cut search to start from node 0. | ||
csc = consistency.CentripetalSkeletonConsistency( | ||
path10, consistency.IndexCounter(path10, 'class_label'), | ||
remain_sources=[9]) | ||
self.assertEqual(csc.init_consistency(), 5.0) | ||
best_cut, best_cut_consistency = csc.best_consistency_cut() | ||
|
||
self.assertCountEqual(best_cut, (4, 5)) | ||
self.assertEqual(best_cut_consistency, 10.0) | ||
np.testing.assert_equal(path10.edges[(1, 2)]['leaving_counts'], | ||
np.array([0, 2, 0])) | ||
self.assertEqual(path10.edges[(1, 2)]['leaving_direction_node'], 1) | ||
|
||
def test_class_probability_best_consistency_cut(self): | ||
path10 = nx.generators.path_graph(10) | ||
for i in range(5): | ||
path10.nodes[i]['class_probability'] = 0.9, 0.1 | ||
for i in range(5, 10): | ||
path10.nodes[i]['class_probability'] = 0.1, 0.9 | ||
csc = consistency.CentripetalSkeletonConsistency( | ||
path10, consistency.VectorCounter(path10, 'class_probability')) | ||
self.assertAlmostEqual(csc.init_consistency(), 5.0) | ||
best_cut, best_cut_consistency = csc.best_consistency_cut() | ||
self.assertCountEqual(best_cut, (4, 5)) | ||
self.assertAlmostEqual(best_cut_consistency, 9.0) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |