Skip to content

Commit

Permalink
Hot distance (#89)
Browse files Browse the repository at this point in the history
 not ready to merge
  • Loading branch information
rhoadesScholar authored Feb 14, 2024
2 parents 36748eb + c9015ff commit c956257
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
25 changes: 25 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import HotDistanceLoss
from .post_processors import ThresholdPostProcessor
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

def __init__(self, task_config):
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`."""

self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
48 changes: 48 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import attr

from .hot_distance_task import HotDistanceTask
from .task_config import TaskConfig

from typing import List


@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.
The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = HotDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
mask_distances: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to mask out regions where the true distance to "
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mse_loss import MSELoss # noqa
from .loss import Loss # noqa
from .affinities_loss import AffinitiesLoss # noqa
from .hot_distance_loss import HotDistanceLoss # noqa
32 changes: 32 additions & 0 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
# Model should predict twice the number of channels as the target.
class HotDistanceLoss(Loss):
def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCEWithLogitsLoss(reduction="none")
return torch.mean(loss(prediction, target) * weight)

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert (
x.shape[1] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x, mid, dim=1)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
Loading

0 comments on commit c956257

Please sign in to comment.