diff --git a/connectomics/common/counters.py b/connectomics/common/counters.py index 9136ebb..5d9e015 100644 --- a/connectomics/common/counters.py +++ b/connectomics/common/counters.py @@ -14,7 +14,9 @@ # limitations under the License. """Provides counters for monitoring processing.""" +import contextlib import threading +import time from typing import Iterable @@ -63,6 +65,13 @@ def get_counter(self, name: str) -> ThreadsafeCounter: self._counters[name] = counter return counter + @contextlib.contextmanager + def timer_counter(self, name: str): + """Counts execution time in ms.""" + start = time.time() + yield + self.get_counter(name + '-ms').inc(int((time.time() - start) * 1e3)) + def get_nonzero(self) -> Iterable[tuple[str, ThreadsafeCounter]]: """Yields name, counter tuples for any counters with value > 0.""" with self._lock: diff --git a/connectomics/common/graph_utils.py b/connectomics/common/graph_utils.py new file mode 100644 index 0000000..7db0f13 --- /dev/null +++ b/connectomics/common/graph_utils.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for manipulating graphs.""" + +import itertools +import networkx as nx + + +def _rindex(seq, item): + """Returns the 1st position of `item` in `seq`, searching from the end.""" + for i, v in enumerate(reversed(seq)): + if item == v: + return len(seq) - i - 1 + raise ValueError('%r not in list' % item) + + +# In contrast to networkx/algorithms/components/biconnected, this uses rindex +# for ~20x faster execution. In the networkx implementation, computing BCCs +# is 55x slower than computing APs for a 256^3 segmentation subvolume RAG, +# due to repeated slow linear searches over a potentially large edge +# stack. Also returns both APs and BBCs from a single pass. +def biconnected_dfs( + g: nx.Graph, + start_points=None) -> tuple[list[frozenset[int]], frozenset[int]]: + """Returns the biconnected components and articulation points of a graph.""" + visited = set() + aps = set() + bccs = [] + + if start_points is None: + start_points = [] + + for start in itertools.chain(start_points, g): + if start in visited: + continue + discovery = {start: 0} # time of first discovery of node during search + low = {start: 0} + root_children = 0 + visited.add(start) + edge_stack = [] + stack = [(start, start, iter(g[start]))] + + while stack: + grandparent, parent, children = stack[-1] + try: + child = next(children) + if grandparent == child: + continue + if child in visited: + if discovery[child] <= discovery[parent]: # back edge + low[parent] = min(low[parent], discovery[child]) + # Record edge, but don't follow. + edge_stack.append((parent, child)) + else: + low[child] = discovery[child] = len(discovery) + visited.add(child) + stack.append((parent, child, iter(g[child]))) + edge_stack.append((parent, child)) + except StopIteration: + stack.pop() + if len(stack) > 1: + if low[parent] >= discovery[grandparent]: + ind = _rindex(edge_stack, (grandparent, parent)) + bccs.append( + frozenset(itertools.chain.from_iterable(edge_stack[ind:]))) + edge_stack = edge_stack[:ind] + aps.add(grandparent) + low[grandparent] = min(low[parent], low[grandparent]) + elif stack: # length 1 so grandparent is root + root_children += 1 + ind = _rindex(edge_stack, (grandparent, parent)) + bccs.append( + frozenset(itertools.chain.from_iterable(edge_stack[ind:]))) + + # Root node is articulation point if it has more than 1 child. + if root_children > 1: + aps.add(start) + + return bccs, frozenset(aps) diff --git a/connectomics/segmentation/labels.py b/connectomics/segmentation/labels.py index 647886d..e435e39 100644 --- a/connectomics/segmentation/labels.py +++ b/connectomics/segmentation/labels.py @@ -15,9 +15,10 @@ """Routines for manipulating numpy arrays of segmentation data.""" import collections -from typing import Iterable, Optional, Sequence +from typing import AbstractSet, Iterable, Optional, Sequence import edt +import networkx as nx import numpy as np import skimage.measure import skimage.morphology @@ -233,3 +234,98 @@ def split_disconnected_components(labels: np.ndarray, connectivity=1): fixed_labels[...] += 1 fixed_labels[labels == 0] = 0 return np.cast[labels.dtype](fixed_labels) + + +def get_border_ids(vol3d: np.ndarray, inplane: bool = False) -> set[int]: + """Finds ids of objects adjacent to the border of a 3d subvolume.""" + ret = (set(np.unique(vol3d[:, 0, :])) # + | set(np.unique(vol3d[:, -1, :])) # + | set(np.unique(vol3d[:, :, 0])) # + | set(np.unique(vol3d[:, :, -1]))) + if not inplane: + ret |= set(np.unique(vol3d[0, :, :])) | set(np.unique(vol3d[-1, :, :])) + return ret + + +def merge_internal_objects(bcc: list[AbstractSet[int]], aps: AbstractSet[int], + todo_bcc_idx: Iterable[int]) -> dict[int, int]: + """Merges objects that are completely internal to other objects. + + Takes as input biconnected components (BCCs) and articulation points (APs) + of a region adjacency graph (RAG) representing a segmentation. + + Args: + bcc: list of sets of nodes of BCCs of the RAG + aps: set of APs of the RAG + todo_bcc_idx: indices in `bcc` for components that should be considered for + merging + + Returns: + map from BCC index to new label for the BCC + """ + ap_to_bcc_idx = {} # AP -> indices of BCCs they are a part of + for ap in aps: + ap_to_bcc_idx[ap] = {i for i, cc in enumerate(bcc) if ap in cc} + + ap_merge_forest = nx.DiGraph() + to_merge = [] + + while True: + start_len = len(to_merge) + remaining_bccs = [] + for cc_i in todo_bcc_idx: + cc = bcc[cc_i] + cc_aps = set(cc & aps) + + if len(cc_aps) == 1: + # Direct merge of the BCC into the only AP that is part of it. + to_merge.append(cc_i) + cc_ap = cc_aps.pop() + ap_to_bcc_idx[cc_ap].remove(cc_i) + elif len([cc_ap for cc_ap in cc_aps if len(ap_to_bcc_idx[cc_ap]) > 1 + ]) == 1: + # Merge into an AP that is the only remaining AP that is part of + # more than the current BCC. + to_merge.append(cc_i) + target = None + for cc_ap in cc_aps: + if len(ap_to_bcc_idx[cc_ap]) > 1: + target = cc_ap + ap_to_bcc_idx[cc_ap].remove(cc_i) + + assert target is not None + for cc_ap in cc_aps: + if cc_ap == target: + continue + ap_merge_forest.add_edge(target, cc_ap) + else: + # The current BCC cannot be merged in this iteration because it + # still contains multiple APs that are part of more than 1 BCC. + remaining_bccs.append(cc_i) + + todo_bcc_idx = remaining_bccs + + # Terminate if no merges were applied in the last iteration. + if len(to_merge) == start_len: + break + + # Build the AP relabel map by exploring the AP merge forest starting + # from the target labels (roots). + ap_relabel = {} + roots = [n for n, deg in ap_merge_forest.in_degree if deg == 0] + for root in roots: + for n in nx.dfs_preorder_nodes(ap_merge_forest, source=root): + ap_relabel[n] = root + + bcc_relabel = {} + for bcc_i in to_merge: + cc = bcc[bcc_i] + adjacent_aps = cc & aps + + targets = set([ap_relabel.get(ap, ap) for ap in adjacent_aps]) + assert len(targets) == 1 + target = targets.pop() + + bcc_relabel[bcc_i] = target + + return bcc_relabel diff --git a/connectomics/segmentation/rag.py b/connectomics/segmentation/rag.py index e0cf6de..22acdec 100644 --- a/connectomics/segmentation/rag.py +++ b/connectomics/segmentation/rag.py @@ -19,6 +19,43 @@ from scipy import spatial +def from_subvolume(vol3d: np.ndarray) -> nx.Graph: + """Returns the RAG for a 3d subvolume. + + Uses 6-connectvity to find neighbors. Only works for segmentations + with IDs that fit in a uint32. + + Args: + vol3d: 3d ndarray with the segmentation + + Returns: + the corresponding RAG + """ + assert np.max(vol3d) < 2**32 + + # Looks for neighboring segments assuming 6-connectivity. + seg_nbor_pairs = set() + for dim in 0, 1, 2: + sel_offset = [slice(None)] * 3 + sel_offset[dim] = np.s_[:-1] + + sel_base = [slice(None)] * 3 + sel_base[dim] = np.s_[1:] + + a = vol3d[tuple(sel_offset)].ravel() + b = vol3d[tuple(sel_base)].ravel() + x = a | (b << 32) + x = x[a != b] + unique_joint_labels = np.unique(x) + + seg_nbor_pairs |= set( + zip(unique_joint_labels & 0xFFFFFFFF, unique_joint_labels >> 32)) + + g = nx.Graph() + g.add_edges_from(seg_nbor_pairs) + return g + + def from_set(kdts: dict[int, spatial._ckdtree.cKDTree]) -> nx.Graph: """Builds a RAG for a set of segments relying on their spatial proximity. diff --git a/connectomics/segmentation/rag_test.py b/connectomics/segmentation/rag_test.py index 7d43875..4ed9652 100644 --- a/connectomics/segmentation/rag_test.py +++ b/connectomics/segmentation/rag_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from connectomics.segmentation import rag import networkx as nx +import numpy as np from scipy import spatial @@ -71,6 +72,19 @@ def test_from_set_skeletons(self): self.assertEqual(g.edges[2, 3]['idx'][2], 3) self.assertEqual(g.edges[2, 3]['idx'][3], 2) + def test_from_subvolume(self): + seg = np.zeros((10, 10, 2), dtype=np.uint64) + seg[2:, :, 0] = 1 + seg[1:, 3:4, 1] = 3 + seg[1:, 5:6, 1] = 2 + seg[2:, 7:, 1] = 3 + + result = rag.from_subvolume(seg) + expected = nx.Graph() + expected.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (0, 3)]) + + self.assertTrue(nx.is_isomorphic(result, expected)) + if __name__ == '__main__': absltest.main() diff --git a/connectomics/volume/processor/segmentation.py b/connectomics/volume/processor/segmentation.py new file mode 100644 index 0000000..72a6654 --- /dev/null +++ b/connectomics/volume/processor/segmentation.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Subvolume processors for segmentation.""" + +import collections +from connectomics.common import graph_utils +from connectomics.segmentation import labels +from connectomics.segmentation import rag +from connectomics.volume import subvolume +from connectomics.volume import subvolume_processor +import numpy as np + + +class MergeInternalObjects(subvolume_processor.SubvolumeProcessor): + """Merges internal objects into the containing ones. + + An object A is considered internal to a containing object B if all + paths from A to the 'exterior' pass through B. The exterior is + defined as the set of objects that touch the border of a subvolume. + + If the segmentation is represented as a region adjacency graph (RAG), + with connected components of the background considered as labeled + segmments, the above definition implies that the containing object + is an articulation point (AP) in the RAG. + """ + + crop_at_borders = False + + def __init__(self, ignore_bg=False): + """Constructor. + + Args: + ignore_bg: if True, ignores the background component in the RAG + construction; this causes objects only adjacent to one other + ("containing") object and the background component to be considered + "internal" and thus eligible for merging; use with caution. + """ + super().__init__() + self.ignore_bg = ignore_bg + + def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: + box = subvol.bbox + input_ndarray = subvol.data + with subvolume_processor.timer_counter('segmentation-prep'): + seg3d = input_ndarray[0, ...] + if self.ignore_bg: + no_bg = seg3d.copy() + else: + no_bg = seg3d + 1 + no_bg_ccs = labels.split_disconnected_components(no_bg) + border_ids = labels.get_border_ids(no_bg_ccs) + + cc_ids, indices = np.unique(no_bg_ccs, return_index=True) + orig_ids = seg3d.ravel()[indices] + cc_to_orig = dict(zip(cc_ids, orig_ids)) + + g = rag.from_subvolume(no_bg_ccs) + if self.ignore_bg and 0 in g: + g.remove_node(0) + border_ids.discard(0) + + if not g: # Should only occur on empty input. + return self.crop_box_and_data(box, input_ndarray) + + with subvolume_processor.timer_counter('ap-graph'): + bcc, aps = graph_utils.biconnected_dfs(g, start_points=border_ids) + + with subvolume_processor.timer_counter('define-relabel'): + # Any biconnected component (BCC) of the graph containing objects which + # are not target (labeled) APs and which touch the border, cannot be + # collapsed in the merging process. + no_zero_aps = frozenset(n for n in aps if cc_to_orig[n] != 0) + border_bcc_idx = { + i for i, cc in enumerate(bcc) if (cc - no_zero_aps) & border_ids + } + mergeable_bcc_idx = set(range(len(bcc))) - border_bcc_idx + + # Find BBCs that can be collapsed and their new labels. + bcc_relabel = labels.merge_internal_objects(bcc, aps, mergeable_bcc_idx) + relabel = {} + for bcc_i, label in bcc_relabel.items(): + cc = bcc[bcc_i] + for n in cc: + relabel[n] = label + + if cc_to_orig[n] == 0: + subvolume_processor.counter('merged-segments-zero').inc() + else: + subvolume_processor.counter('merged-segments-nonzero').inc() + + with subvolume_processor.timer_counter('apply-relabel'): + if relabel: + for i, cc_id in enumerate(cc_ids): + if cc_id in relabel: + orig_ids[i] = cc_to_orig[relabel[cc_id]] + + # Map back to original ID space and perform mergers. + ret = labels.relabel(no_bg_ccs, cc_ids, orig_ids) + else: + ret = seg3d + + subvolume_processor.counter('subvolumes-done').inc() + return self.crop_box_and_data(box, ret[np.newaxis, ...]) + + +class FillHoles(subvolume_processor.SubvolumeProcessor): + """Fills holes in segments. + + A hole is a connected component of the background segment (0) + that touches exactly one non-background segment and does not + touch the border of the subvolume, both assuming 6-connectivity + along the canonical axes. + + Run with context ~ 2x largest expected hole diameter to avoid + edge effects. + """ + + crop_at_borders = False + + def __init__(self, min_neighbor_size=1000, inplane=False): + """Constructor. + + Args: + min_neighbor_size: minimum size (in voxels) of an object within the + current subvolume for it to be considered a neighboring segment; + settings this to small non-zero value allows filing of empty space + completely embedded in large segments when this space also contains + small (< specified size) labeled components + inplane: whether to treat the segmentation as 2d and fill holes within XY + planes + """ + super().__init__() + self._min_neighbor_size = min_neighbor_size + self._inplane = inplane + + def _fill_holes(self, seg3d): + """Fills holes in a segmentation subvolumes.""" + + no_bg = seg3d + 1 + no_bg_ccs = labels.split_disconnected_components(no_bg) + sizes = dict(zip(*np.unique(no_bg_ccs, return_counts=True))) + border = labels.get_border_ids(no_bg_ccs, inplane=self._inplane) + + # Any connected component that used to be background that does not touch + # the border of the volume is potentially a hole to be filled. + hole_labels = set(np.unique(no_bg_ccs[no_bg == 1])) - border + subvolume_processor.counter('potential-holes').inc(len(hole_labels)) + hole_mask = np.isin(no_bg_ccs, list(hole_labels)) + + # (a, b) pairs where 'a' is background and 'b' is labeled. + # Looks for neighboring segments assuming 6-connectivity. + seg_nbor_pairs = set() + for dim in 0, 1, 2: + sel_offset = [slice(None)] * 3 + sel_offset[dim] = np.s_[:-1] + + sel_base = [slice(None)] * 3 + sel_base[dim] = np.s_[1:] + + sel_offset = tuple(sel_offset) + sel_base = tuple(sel_base) + + # Right neighbor; 'b' is to the right of 'a'. + right_bg = hole_mask[sel_offset] + seg_nbor_pairs |= set(zip(no_bg_ccs[sel_offset][right_bg].ravel(), + no_bg_ccs[sel_base][right_bg].ravel())) + + # Left neighbor, 'b' is to the left of 'a'. + left_bg = hole_mask[sel_base] + seg_nbor_pairs |= set(zip(no_bg_ccs[sel_base][left_bg].ravel(), + no_bg_ccs[sel_offset][left_bg].ravel())) + + cc_ids, indices = np.unique(no_bg_ccs, return_index=True) + orig_ids = seg3d.ravel()[indices] + cc_to_orig = dict(zip(cc_ids, orig_ids)) + + # Maps connected components of the background region to adjacent + # segments. + bg_to_nbors = collections.defaultdict(set) + for a, b in seg_nbor_pairs: + if sizes[b] >= self._min_neighbor_size: + bg_to_nbors[a].add(cc_to_orig[b]) + + # Build a relabel map mapping hole IDs to the IDs of the segments + # containing them. + relabel = {} + for bg, nbors in bg_to_nbors.items(): + nbors.discard(0) + # If there is more than 1 neighboring labeled component, this is + # not a hole. + if len(nbors) != 1: + continue + relabel[bg] = nbors.pop() + subvolume_processor.counter('holes-filled').inc() + + if not relabel: + return seg3d + + for i, cc_id in enumerate(cc_ids): + if cc_id in relabel: + assert orig_ids[i] == 0 + orig_ids[i] = relabel[cc_id] + + # Fill holes and map IDs back to the original ID space. + return labels.relabel(no_bg_ccs, cc_ids, orig_ids) + + def process(self, subvol: subvolume.Subvolume) -> subvolume.SubvolumeOrMany: + box = subvol.bbox + input_ndarray = subvol.data + seg3d = input_ndarray[0, ...] + if self._inplane: + ret = np.zeros_like(seg3d) + for z in range(seg3d.shape[0]): + ret[z : z + 1, ...] = self._fill_holes(seg3d[z : z + 1, ...]) + else: + ret = self._fill_holes(seg3d) + + return self.crop_box_and_data(box, ret[np.newaxis, ...]) diff --git a/connectomics/volume/subvolume_processor.py b/connectomics/volume/subvolume_processor.py index 6f7de3f..569b231 100644 --- a/connectomics/volume/subvolume_processor.py +++ b/connectomics/volume/subvolume_processor.py @@ -22,6 +22,7 @@ from connectomics.common import array from connectomics.common import bounding_box +from connectomics.common import counters from connectomics.common import file from connectomics.volume import descriptor from connectomics.volume import subvolume @@ -36,6 +37,10 @@ XyzTuple = array.Tuple3i SubvolumeOrMany = Union[Subvolume, List[Subvolume]] +COUNTER_STORE = counters.ThreadsafeCounterStore() +counter = COUNTER_STORE.get_counter +timer_counter = COUNTER_STORE.timer_counter + @dataclasses.dataclass class SubvolumeProcessorConfig(dataclasses_json.DataClassJsonMixin):