Skip to content

Commit

Permalink
Add SubvolumeProcessors for filling holes and merging internal objects.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597614184
  • Loading branch information
chinasaur authored and copybara-github committed Jan 11, 2024
1 parent 58ef059 commit 19e9ea5
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 1 deletion.
9 changes: 9 additions & 0 deletions connectomics/common/counters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.
"""Provides counters for monitoring processing."""

import contextlib
import threading
import time
from typing import Iterable


Expand Down Expand Up @@ -63,6 +65,13 @@ def get_counter(self, name: str) -> ThreadsafeCounter:
self._counters[name] = counter
return counter

@contextlib.contextmanager
def timer_counter(self, name: str):
"""Counts execution time in ms."""
start = time.time()
yield
self.get_counter(name + '-ms').inc(int((time.time() - start) * 1e3))

def get_nonzero(self) -> Iterable[tuple[str, ThreadsafeCounter]]:
"""Yields name, counter tuples for any counters with value > 0."""
with self._lock:
Expand Down
91 changes: 91 additions & 0 deletions connectomics/common/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for manipulating graphs."""

import itertools
import networkx as nx


def _rindex(seq, item):
"""Returns the 1st position of `item` in `seq`, searching from the end."""
for i, v in enumerate(reversed(seq)):
if item == v:
return len(seq) - i - 1
raise ValueError('%r not in list' % item)


# In contrast to networkx/algorithms/components/biconnected, this uses rindex
# for ~20x faster execution. In the networkx implementation, computing BCCs
# is 55x slower than computing APs for a 256^3 segmentation subvolume RAG,
# due to repeated slow linear searches over a potentially large edge
# stack. Also returns both APs and BBCs from a single pass.
def biconnected_dfs(
g: nx.Graph,
start_points=None) -> tuple[list[frozenset[int]], frozenset[int]]:
"""Returns the biconnected components and articulation points of a graph."""
visited = set()
aps = set()
bccs = []

if start_points is None:
start_points = []

for start in itertools.chain(start_points, g):
if start in visited:
continue
discovery = {start: 0} # time of first discovery of node during search
low = {start: 0}
root_children = 0
visited.add(start)
edge_stack = []
stack = [(start, start, iter(g[start]))]

while stack:
grandparent, parent, children = stack[-1]
try:
child = next(children)
if grandparent == child:
continue
if child in visited:
if discovery[child] <= discovery[parent]: # back edge
low[parent] = min(low[parent], discovery[child])
# Record edge, but don't follow.
edge_stack.append((parent, child))
else:
low[child] = discovery[child] = len(discovery)
visited.add(child)
stack.append((parent, child, iter(g[child])))
edge_stack.append((parent, child))
except StopIteration:
stack.pop()
if len(stack) > 1:
if low[parent] >= discovery[grandparent]:
ind = _rindex(edge_stack, (grandparent, parent))
bccs.append(
frozenset(itertools.chain.from_iterable(edge_stack[ind:])))
edge_stack = edge_stack[:ind]
aps.add(grandparent)
low[grandparent] = min(low[parent], low[grandparent])
elif stack: # length 1 so grandparent is root
root_children += 1
ind = _rindex(edge_stack, (grandparent, parent))
bccs.append(
frozenset(itertools.chain.from_iterable(edge_stack[ind:])))

# Root node is articulation point if it has more than 1 child.
if root_children > 1:
aps.add(start)

return bccs, frozenset(aps)
98 changes: 97 additions & 1 deletion connectomics/segmentation/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""Routines for manipulating numpy arrays of segmentation data."""

import collections
from typing import Iterable, Optional, Sequence
from typing import AbstractSet, Iterable, Optional, Sequence

import edt
import networkx as nx
import numpy as np
import skimage.measure
import skimage.morphology
Expand Down Expand Up @@ -233,3 +234,98 @@ def split_disconnected_components(labels: np.ndarray, connectivity=1):
fixed_labels[...] += 1
fixed_labels[labels == 0] = 0
return np.cast[labels.dtype](fixed_labels)


def get_border_ids(vol3d: np.ndarray, inplane: bool = False) -> set[int]:
"""Finds ids of objects adjacent to the border of a 3d subvolume."""
ret = (set(np.unique(vol3d[:, 0, :])) #
| set(np.unique(vol3d[:, -1, :])) #
| set(np.unique(vol3d[:, :, 0])) #
| set(np.unique(vol3d[:, :, -1])))
if not inplane:
ret |= set(np.unique(vol3d[0, :, :])) | set(np.unique(vol3d[-1, :, :]))
return ret


def merge_internal_objects(bcc: list[AbstractSet[int]], aps: AbstractSet[int],
todo_bcc_idx: Iterable[int]) -> dict[int, int]:
"""Merges objects that are completely internal to other objects.
Takes as input biconnected components (BCCs) and articulation points (APs)
of a region adjacency graph (RAG) representing a segmentation.
Args:
bcc: list of sets of nodes of BCCs of the RAG
aps: set of APs of the RAG
todo_bcc_idx: indices in `bcc` for components that should be considered for
merging
Returns:
map from BCC index to new label for the BCC
"""
ap_to_bcc_idx = {} # AP -> indices of BCCs they are a part of
for ap in aps:
ap_to_bcc_idx[ap] = {i for i, cc in enumerate(bcc) if ap in cc}

ap_merge_forest = nx.DiGraph()
to_merge = []

while True:
start_len = len(to_merge)
remaining_bccs = []
for cc_i in todo_bcc_idx:
cc = bcc[cc_i]
cc_aps = set(cc & aps)

if len(cc_aps) == 1:
# Direct merge of the BCC into the only AP that is part of it.
to_merge.append(cc_i)
cc_ap = cc_aps.pop()
ap_to_bcc_idx[cc_ap].remove(cc_i)
elif len([cc_ap for cc_ap in cc_aps if len(ap_to_bcc_idx[cc_ap]) > 1
]) == 1:
# Merge into an AP that is the only remaining AP that is part of
# more than the current BCC.
to_merge.append(cc_i)
target = None
for cc_ap in cc_aps:
if len(ap_to_bcc_idx[cc_ap]) > 1:
target = cc_ap
ap_to_bcc_idx[cc_ap].remove(cc_i)

assert target is not None
for cc_ap in cc_aps:
if cc_ap == target:
continue
ap_merge_forest.add_edge(target, cc_ap)
else:
# The current BCC cannot be merged in this iteration because it
# still contains multiple APs that are part of more than 1 BCC.
remaining_bccs.append(cc_i)

todo_bcc_idx = remaining_bccs

# Terminate if no merges were applied in the last iteration.
if len(to_merge) == start_len:
break

# Build the AP relabel map by exploring the AP merge forest starting
# from the target labels (roots).
ap_relabel = {}
roots = [n for n, deg in ap_merge_forest.in_degree if deg == 0]
for root in roots:
for n in nx.dfs_preorder_nodes(ap_merge_forest, source=root):
ap_relabel[n] = root

bcc_relabel = {}
for bcc_i in to_merge:
cc = bcc[bcc_i]
adjacent_aps = cc & aps

targets = set([ap_relabel.get(ap, ap) for ap in adjacent_aps])
assert len(targets) == 1
target = targets.pop()

bcc_relabel[bcc_i] = target

return bcc_relabel
37 changes: 37 additions & 0 deletions connectomics/segmentation/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,43 @@
from scipy import spatial


def from_subvolume(vol3d: np.ndarray) -> nx.Graph:
"""Returns the RAG for a 3d subvolume.
Uses 6-connectvity to find neighbors. Only works for segmentations
with IDs that fit in a uint32.
Args:
vol3d: 3d ndarray with the segmentation
Returns:
the corresponding RAG
"""
assert np.max(vol3d) < 2**32

# Looks for neighboring segments assuming 6-connectivity.
seg_nbor_pairs = set()
for dim in 0, 1, 2:
sel_offset = [slice(None)] * 3
sel_offset[dim] = np.s_[:-1]

sel_base = [slice(None)] * 3
sel_base[dim] = np.s_[1:]

a = vol3d[tuple(sel_offset)].ravel()
b = vol3d[tuple(sel_base)].ravel()
x = a | (b << 32)
x = x[a != b]
unique_joint_labels = np.unique(x)

seg_nbor_pairs |= set(
zip(unique_joint_labels & 0xFFFFFFFF, unique_joint_labels >> 32))

g = nx.Graph()
g.add_edges_from(seg_nbor_pairs)
return g


def from_set(kdts: dict[int, spatial._ckdtree.cKDTree]) -> nx.Graph:
"""Builds a RAG for a set of segments relying on their spatial proximity.
Expand Down
14 changes: 14 additions & 0 deletions connectomics/segmentation/rag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl.testing import absltest
from connectomics.segmentation import rag
import networkx as nx
import numpy as np
from scipy import spatial


Expand Down Expand Up @@ -71,6 +72,19 @@ def test_from_set_skeletons(self):
self.assertEqual(g.edges[2, 3]['idx'][2], 3)
self.assertEqual(g.edges[2, 3]['idx'][3], 2)

def test_from_subvolume(self):
seg = np.zeros((10, 10, 2), dtype=np.uint64)
seg[2:, :, 0] = 1
seg[1:, 3:4, 1] = 3
seg[1:, 5:6, 1] = 2
seg[2:, 7:, 1] = 3

result = rag.from_subvolume(seg)
expected = nx.Graph()
expected.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (0, 3)])

self.assertTrue(nx.is_isomorphic(result, expected))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 19e9ea5

Please sign in to comment.