-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
not ready to merge
- Loading branch information
Showing
7 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from .evaluators import BinarySegmentationEvaluator | ||
from .losses import HotDistanceLoss | ||
from .post_processors import ThresholdPostProcessor | ||
from .predictors import HotDistancePredictor | ||
from .task import Task | ||
|
||
|
||
class HotDistanceTask(Task): | ||
"""This is just a Hot Distance Task that combine Binary and distance prediction.""" | ||
|
||
def __init__(self, task_config): | ||
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" | ||
|
||
self.predictor = HotDistancePredictor( | ||
channels=task_config.channels, | ||
scale_factor=task_config.scale_factor, | ||
mask_distances=task_config.mask_distances, | ||
) | ||
self.loss = HotDistanceLoss() | ||
self.post_processor = ThresholdPostProcessor() | ||
self.evaluator = BinarySegmentationEvaluator( | ||
clip_distance=task_config.clip_distance, | ||
tol_distance=task_config.tol_distance, | ||
channels=task_config.channels, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import attr | ||
|
||
from .hot_distance_task import HotDistanceTask | ||
from .task_config import TaskConfig | ||
|
||
from typing import List | ||
|
||
|
||
@attr.s | ||
class HotDistanceTaskConfig(TaskConfig): | ||
"""This is a Hot Distance task config used for generating and | ||
evaluating signed distance transforms as a way of generating | ||
segmentations. | ||
The advantage of generating distance transforms over regular | ||
affinities is you can get a denser signal, i.e. 1 misclassified | ||
pixel in an affinity prediction could merge 2 otherwise very | ||
distinct objects, this cannot happen with distances. | ||
""" | ||
|
||
task_type = HotDistanceTask | ||
|
||
channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) | ||
clip_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Maximum distance to consider for false positive/negatives." | ||
}, | ||
) | ||
tol_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Tolerance distance for counting false positives/negatives" | ||
}, | ||
) | ||
scale_factor: float = attr.ib( | ||
default=1, | ||
metadata={ | ||
"help_text": "The amount by which to scale distances before applying " | ||
"a tanh normalization." | ||
}, | ||
) | ||
mask_distances: bool = attr.ib( | ||
default=False, | ||
metadata={ | ||
"help_text": "Whether or not to mask out regions where the true distance to " | ||
"object boundary cannot be known. This is anywhere that the distance to crop boundary " | ||
"is less than the distance to object boundary." | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from .loss import Loss | ||
import torch | ||
|
||
|
||
# HotDistance is used for predicting hot and distance maps at the same time. | ||
# The first half of the channels are the hot maps, the second half are the distance maps. | ||
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. | ||
# Model should predict twice the number of channels as the target. | ||
class HotDistanceLoss(Loss): | ||
def compute(self, prediction, target, weight): | ||
target_hot, target_distance = self.split(target) | ||
prediction_hot, prediction_distance = self.split(prediction) | ||
weight_hot, weight_distance = self.split(weight) | ||
return self.hot_loss( | ||
prediction_hot, target_hot, weight_hot | ||
) + self.distance_loss(prediction_distance, target_distance, weight_distance) | ||
|
||
def hot_loss(self, prediction, target, weight): | ||
loss = torch.nn.BCEWithLogitsLoss(reduction="none") | ||
return torch.mean(loss(prediction, target) * weight) | ||
|
||
def distance_loss(self, prediction, target, weight): | ||
loss = torch.nn.MSELoss() | ||
return loss(prediction * weight, target * weight) | ||
|
||
def split(self, x): | ||
# Shape[0] is the batch size and Shape[1] is the number of channels. | ||
assert ( | ||
x.shape[1] % 2 == 0 | ||
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." | ||
mid = x.shape[1] // 2 | ||
return torch.split(x, mid, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.