Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot distance #89

Merged
merged 43 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
68258d7
feat: better tqdm, training reloading
rhoadesScholar Aug 26, 2023
9eee5a7
feat: numpy version requirement (1.22.3); watershed
rhoadesScholar Aug 29, 2023
703a31d
feat: improve training logging
rhoadesScholar Aug 30, 2023
b97f3d7
Merge branch 'master' into rhoadesj/dev
rhoadesScholar Aug 30, 2023
8a03a57
feat: stdout logging, version
rhoadesScholar Sep 1, 2023
9626922
feat: validation logging, keep best validation
rhoadesScholar Sep 5, 2023
8212b8b
WIP: apply.py
rhoadesScholar Sep 8, 2023
8326400
feat: keep actual best scoring dataset
rhoadesScholar Sep 11, 2023
663fe0f
feat: ready to debug apply.py
rhoadesScholar Sep 11, 2023
beeed60
feat!: add cli for applying models
rhoadesScholar Sep 12, 2023
8553b5a
feat!: add cli for applying models
rhoadesScholar Sep 12, 2023
273ecb7
feat: apply.py roi
rhoadesScholar Sep 12, 2023
2622b94
feat: postprocess parameter parsing
rhoadesScholar Sep 13, 2023
bd192e5
feat: overwrite option for predict & postprocess
rhoadesScholar Sep 13, 2023
6421bd1
bugfix!?: removed odd zarr creation
rhoadesScholar Sep 14, 2023
8198648
!: attempts to fix roi mismatch failed.
rhoadesScholar Sep 15, 2023
ed0d4ad
bugfix: prediction works with zarrs
rhoadesScholar Sep 15, 2023
ae23179
feat: backlogged updates and wip
rhoadesScholar Sep 20, 2023
53b3556
feat: zarr fix dimension detection, simple augment config kwargs, val…
rhoadesScholar Sep 21, 2023
8e6dfa3
bugfix: simple augment config
rhoadesScholar Sep 21, 2023
0869015
bugfix: simple augment config
rhoadesScholar Sep 21, 2023
922ba62
feat: black format, iou score
rhoadesScholar Sep 25, 2023
7d954f9
bugfix: remove array overspecification
rhoadesScholar Sep 25, 2023
ec1b0d8
hot distance loss function
mzouink Sep 25, 2023
be0f6db
feat: hotdistance predictor, model/target
rhoadesScholar Sep 25, 2023
da48612
hot distance task
mzouink Sep 25, 2023
a6714c0
Merge branch 'hot_distance' of github.com:janelia-cellmap/dacapo into…
rhoadesScholar Sep 25, 2023
9567ec4
feat: hot_distance_predictor target/weight
rhoadesScholar Sep 25, 2023
37de7e0
init show predictor
mzouink Sep 26, 2023
ee03505
fix hotdistance bugs
mzouink Sep 26, 2023
cd4077d
fix bce loss
mzouink Sep 27, 2023
448f766
feat: ⚡️ Incorporate hot_distance related changes from rhoadesj/dev
rhoadesScholar Feb 9, 2024
53b57b6
Merge branch 'hot_distance' into rhoadesj/hot_distance
rhoadesScholar Feb 9, 2024
70169e2
Merge branch 'rhoadesj/hot_distance' into actions/black
rhoadesScholar Feb 11, 2024
daa41b3
Merge pull request #47 from janelia-cellmap/actions/black
rhoadesScholar Feb 11, 2024
c810a0e
Revert GunpowderTrainer class and configuration to main
rhoadesScholar Feb 11, 2024
5f99dd4
:art: Format Python code with psf/black
rhoadesScholar Feb 11, 2024
62772b2
Merge pull request #39 from janelia-cellmap/rhoadesj/hot_distance
rhoadesScholar Feb 11, 2024
33cbb06
Merge branch 'hot_distance' into actions/black
rhoadesScholar Feb 11, 2024
0abfdc4
Merge pull request #52 from janelia-cellmap/actions/black
rhoadesScholar Feb 11, 2024
cfe42ee
Merge branch 'main' into hot_distance
mzouink Feb 14, 2024
06f2dc3
remove irrelevant stuff to hotdistance
mzouink Feb 14, 2024
c9015ff
Merge branch 'main' into hot_distance
rhoadesScholar Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
25 changes: 25 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import HotDistanceLoss
from .post_processors import ThresholdPostProcessor
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

def __init__(self, task_config):
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`."""

self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
48 changes: 48 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import attr

from .hot_distance_task import HotDistanceTask
from .task_config import TaskConfig

from typing import List


@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.

The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = HotDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
mask_distances: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to mask out regions where the true distance to "
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mse_loss import MSELoss # noqa
from .loss import Loss # noqa
from .affinities_loss import AffinitiesLoss # noqa
from .hot_distance_loss import HotDistanceLoss # noqa
32 changes: 32 additions & 0 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
# Model should predict twice the number of channels as the target.
class HotDistanceLoss(Loss):
def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCEWithLogitsLoss(reduction="none")
return torch.mean(loss(prediction, target) * weight)

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert (
x.shape[1] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x, mid, dim=1)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
Loading
Loading