-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'rhoadesj/dev' into merge_fix
- Loading branch information
Showing
31 changed files
with
977 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from .evaluators import BinarySegmentationEvaluator | ||
from .losses import HotDistanceLoss | ||
from .post_processors import ThresholdPostProcessor | ||
from .predictors import HotDistancePredictor | ||
from .task import Task | ||
|
||
class HotDistanceTask(Task): | ||
"""This is just a Hot Distance Task that combine Binary and distance prediction.""" | ||
|
||
def __init__(self, task_config): | ||
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" | ||
|
||
self.predictor = HotDistancePredictor( | ||
channels=task_config.channels, | ||
scale_factor=task_config.scale_factor, | ||
mask_distances=task_config.mask_distances, | ||
) | ||
self.loss = HotDistanceLoss() | ||
self.post_processor = ThresholdPostProcessor() | ||
self.evaluator = BinarySegmentationEvaluator( | ||
clip_distance=task_config.clip_distance, | ||
tol_distance=task_config.tol_distance, | ||
channels=task_config.channels, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import attr | ||
|
||
from .hot_distance_task import HotDistanceTask | ||
from .task_config import TaskConfig | ||
|
||
from typing import List | ||
|
||
class HotDistanceTaskConfig(TaskConfig): | ||
"""This is a Hot Distance task config used for generating and | ||
evaluating signed distance transforms as a way of generating | ||
segmentations. | ||
The advantage of generating distance transforms over regular | ||
affinities is you can get a denser signal, i.e. 1 misclassified | ||
pixel in an affinity prediction could merge 2 otherwise very | ||
distinct objects, this cannot happen with distances. | ||
""" | ||
|
||
task_type = HotDistanceTask | ||
|
||
channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) | ||
clip_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Maximum distance to consider for false positive/negatives." | ||
}, | ||
) | ||
tol_distance: float = attr.ib( | ||
metadata={ | ||
"help_text": "Tolerance distance for counting false positives/negatives" | ||
}, | ||
) | ||
scale_factor: float = attr.ib( | ||
default=1, | ||
metadata={ | ||
"help_text": "The amount by which to scale distances before applying " | ||
"a tanh normalization." | ||
}, | ||
) | ||
mask_distances: bool = attr.ib( | ||
default=False, | ||
metadata={ | ||
"help_text": "Whether or not to mask out regions where the true distance to " | ||
"object boundary cannot be known. This is anywhere that the distance to crop boundary " | ||
"is less than the distance to object boundary." | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.