Skip to content

Commit

Permalink
Merge pull request #27 from janelia-cellmap/actions/black
Browse files Browse the repository at this point in the history
Format Python code with psf/black push
  • Loading branch information
mzouink authored Feb 8, 2024
2 parents f9f85d3 + c3a81b7 commit 75eaff4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
21 changes: 14 additions & 7 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__file__)


class Run:
name: str
train_until: int
Expand Down Expand Up @@ -58,28 +59,34 @@ def __init__(self, run_config):
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

Expand All @@ -21,4 +22,4 @@ def __init__(self, task_config):
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
)
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List


class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
Expand Down Expand Up @@ -43,4 +44,4 @@ class HotDistanceTaskConfig(TaskConfig):
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
)
17 changes: 11 additions & 6 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
Expand All @@ -10,15 +11,19 @@ def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance)

return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
return torch.nn.BCELoss().forward(prediction * weight, target * weight)

def distance_loss(self, prediction, target, weight):
return torch.nn.MSELoss().forward(prediction * weight, target * weight)

def split(self, x):
assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
assert (
x.shape[0] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[0] // 2
return x[:mid], x[-mid:]
return x[:mid], x[-mid:]
4 changes: 2 additions & 2 deletions dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def balance_weights(
scale_slab *= np.take(w, labels_slab)

if cross_class:
# get maximum error scale using first dimension
# get maximum error scale using first dimension
shape = error_scale.shape
error_scale = np.max(error_scale, axis=0)
error_scale = np.broadcast_to(error_scale, shape)

# set error_scale to 0 in masked-out areas
for mask in masks:
error_scale = error_scale * mask
Expand Down

0 comments on commit 75eaff4

Please sign in to comment.