From bbec33d178fc794e8314f69bf62a9888a7c2628c Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 14 Feb 2024 11:25:17 -0500 Subject: [PATCH 1/2] inner distance target and task --- dacapo/experiments/tasks/__init__.py | 1 + .../experiments/tasks/inner_distance_task.py | 24 +++ .../tasks/inner_distance_task_config.py | 40 ++++ .../experiments/tasks/predictors/__init__.py | 1 + .../predictors/inner_distance_predictor.py | 191 ++++++++++++++++++ 5 files changed, 257 insertions(+) create mode 100644 dacapo/experiments/tasks/inner_distance_task.py create mode 100644 dacapo/experiments/tasks/inner_distance_task_config.py create mode 100644 dacapo/experiments/tasks/predictors/inner_distance_predictor.py diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 780f343d1..7c63fafdb 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,3 +5,4 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa +from .inner_distance_task_config import InnerDistanceTaskConfig, InnerDistanceTask # noqa \ No newline at end of file diff --git a/dacapo/experiments/tasks/inner_distance_task.py b/dacapo/experiments/tasks/inner_distance_task.py new file mode 100644 index 000000000..e02d82705 --- /dev/null +++ b/dacapo/experiments/tasks/inner_distance_task.py @@ -0,0 +1,24 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import MSELoss +from .post_processors import ThresholdPostProcessor +from .predictors import InnerDistancePredictor +from .task import Task + +# Goal is have a distance task but with distance inside the forground only +class InnerDistanceTask(Task): + """This is just a dummy task for testing.""" + + def __init__(self, task_config): + """Create a `DummyTask` from a `DummyTaskConfig`.""" + + self.predictor = InnerDistancePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + ) + self.loss = MSELoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) diff --git a/dacapo/experiments/tasks/inner_distance_task_config.py b/dacapo/experiments/tasks/inner_distance_task_config.py new file mode 100644 index 000000000..1a66cc47d --- /dev/null +++ b/dacapo/experiments/tasks/inner_distance_task_config.py @@ -0,0 +1,40 @@ +import attr + +from .inner_distance_task import InnerDistanceTask +from .task_config import TaskConfig + +from typing import List + + +@attr.s +class InnerDistanceTaskConfig(TaskConfig): + """This is a Distance task config used for generating and + evaluating signed distance transforms as a way of generating + segmentations. + + The advantage of generating distance transforms over regular + affinities is you can get a denser signal, i.e. 1 misclassified + pixel in an affinity prediction could merge 2 otherwise very + distinct objects, this cannot happen with distances. + """ + + task_type = InnerDistanceTask + + channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) + clip_distance: float = attr.ib( + metadata={ + "help_text": "Maximum distance to consider for false positive/negatives." + }, + ) + tol_distance: float = attr.ib( + metadata={ + "help_text": "Tolerance distance for counting false positives/negatives" + }, + ) + scale_factor: float = attr.ib( + default=1, + metadata={ + "help_text": "The amount by which to scale distances before applying " + "a tanh normalization." + }, + ) diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 76f82138d..7f9bd6285 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,3 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa +from .inner_distance_predictor import InnerDistancePredictor # noqa diff --git a/dacapo/experiments/tasks/predictors/inner_distance_predictor.py b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py new file mode 100644 index 000000000..f0a354c6d --- /dev/null +++ b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py @@ -0,0 +1,191 @@ +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +class InnerDistancePredictor(Predictor): + """ + Predict signed distances for a binary segmentation task. + + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__(self, channels: List[str], scale_factor: float): + self.channels = channels + self.norm = "tanh" + self.dt_scale_factor = scale_factor + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 + self.threshold = 0.8 + + @property + def embedding_dims(self): + return len(self.channels) + + def create_model(self, architecture): + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) + + return Model(architecture, head) + + def create_target(self, gt): + distances = self.process( + gt.data, gt.voxel_size, self.norm, self.dt_scale_factor + ) + return NumpyArray.from_np_array( + distances, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + + weights, moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts, + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + return DistanceArray(self.embedding_dims) + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + + return all_distances * labels + + def __find_boundaries(self, labels): + # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n + # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 + # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 + # bound.: 00000001000100000001000 2n - 1 + + logger.debug("computing boundaries for %s", labels.shape) + + dims = len(labels.shape) + in_shape = labels.shape + out_shape = tuple(2 * s - 1 for s in in_shape) + + boundaries = np.zeros(out_shape, dtype=bool) + + logger.debug("boundaries shape is %s", boundaries.shape) + + for d in range(dims): + logger.debug("processing dimension %d", d) + + shift_p = [slice(None)] * dims + shift_p[d] = slice(1, in_shape[d]) + + shift_n = [slice(None)] * dims + shift_n[d] = slice(0, in_shape[d] - 1) + + diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 + + logger.debug("diff shape is %s", diff.shape) + + target = [slice(None, None, 2)] * dims + target[d] = slice(1, out_shape[d], 2) + + logger.debug("target slices are %s", target) + + boundaries[tuple(target)] = diff + + return boundaries + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) From 9eb2e718950c964712c5fbc901fdef51783acadd Mon Sep 17 00:00:00 2001 From: mzouink Date: Wed, 14 Feb 2024 16:28:30 +0000 Subject: [PATCH 2/2] :art: Format Python code with psf/black --- dacapo/experiments/tasks/__init__.py | 5 ++++- dacapo/experiments/tasks/inner_distance_task.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 7c63fafdb..1951ba403 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,4 +5,7 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa -from .inner_distance_task_config import InnerDistanceTaskConfig, InnerDistanceTask # noqa \ No newline at end of file +from .inner_distance_task_config import ( + InnerDistanceTaskConfig, + InnerDistanceTask, +) # noqa diff --git a/dacapo/experiments/tasks/inner_distance_task.py b/dacapo/experiments/tasks/inner_distance_task.py index e02d82705..eeea236cc 100644 --- a/dacapo/experiments/tasks/inner_distance_task.py +++ b/dacapo/experiments/tasks/inner_distance_task.py @@ -4,6 +4,7 @@ from .predictors import InnerDistancePredictor from .task import Task + # Goal is have a distance task but with distance inside the forground only class InnerDistanceTask(Task): """This is just a dummy task for testing."""