Skip to content

Commit

Permalink
Merge pull request #26 from janelia-cellmap/merge_fix
Browse files Browse the repository at this point in the history
merge main to jeff's dev branch
  • Loading branch information
mzouink authored Feb 8, 2024
2 parents 673484c + 75eaff4 commit 290e57f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run_name", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-ic",
Expand Down
21 changes: 14 additions & 7 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__file__)


class Run:
name: str
train_until: int
Expand Down Expand Up @@ -58,28 +59,34 @@ def __init__(self, run_config):
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

Expand All @@ -21,4 +22,4 @@ def __init__(self, task_config):
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
)
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List


class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
Expand Down Expand Up @@ -43,4 +44,4 @@ class HotDistanceTaskConfig(TaskConfig):
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
)
17 changes: 11 additions & 6 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
Expand All @@ -10,15 +11,19 @@ def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance)

return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
return torch.nn.BCELoss().forward(prediction * weight, target * weight)

def distance_loss(self, prediction, target, weight):
return torch.nn.MSELoss().forward(prediction * weight, target * weight)

def split(self, x):
assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
assert (
x.shape[0] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[0] // 2
return x[:mid], x[-mid:]
return x[:mid], x[-mid:]
2 changes: 1 addition & 1 deletion dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: Optional[np.dtype] = np.uint8,
output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions
overwrite: bool = False,
):
# get the model's input and output size
Expand Down
4 changes: 2 additions & 2 deletions dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def balance_weights(
scale_slab *= np.take(w, labels_slab)

if cross_class:
# get maximum error scale using first dimension
# get maximum error scale using first dimension
shape = error_scale.shape
error_scale = np.max(error_scale, axis=0)
error_scale = np.broadcast_to(error_scale, shape)

# set error_scale to 0 in masked-out areas
for mask in masks:
error_scale = error_scale * mask
Expand Down

0 comments on commit 290e57f

Please sign in to comment.