Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 11, 2024
1 parent c810a0e commit 5f99dd4
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List


@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
Expand Down
16 changes: 9 additions & 7 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ def compute(self, prediction, target, weight):
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCEWithLogitsLoss(reduction='none')
return torch.mean(loss(prediction , target) * weight)
loss = torch.nn.BCEWithLogitsLoss(reduction="none")
return torch.mean(loss(prediction, target) * weight)

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
assert (
x.shape[1] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x,mid,dim=1)
return torch.split(x, mid, dim=1)
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
8 changes: 6 additions & 2 deletions dacapo/experiments/tasks/predictors/hot_distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi]],
moving_counts=None if moving_class_counts is None else moving_class_counts[: self.classes],
moving_counts=None
if moving_class_counts is None
else moving_class_counts[: self.classes],
)

if self.mask_distances:
Expand All @@ -95,7 +97,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi], distance_mask],
moving_counts=None if moving_class_counts is None else moving_class_counts[-self.classes :],
moving_counts=None
if moving_class_counts is None
else moving_class_counts[-self.classes :],
)

weights = np.concatenate((one_hot_weights, distance_weights))
Expand Down
2 changes: 1 addition & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train_run(

weights_store.retrieve_weights(run, iteration=trained_until)

elif latest_weights_iteration > trained_until:
elif latest_weights_iteration > trained_until:
logger.warn(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}. "
Expand Down

0 comments on commit 5f99dd4

Please sign in to comment.