Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge main to jeff's dev branch #26

Merged
merged 6 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run_name", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-ic",
Expand Down
21 changes: 14 additions & 7 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__file__)


class Run:
name: str
train_until: int
Expand Down Expand Up @@ -58,28 +59,34 @@ def __init__(self, run_config):
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

Expand All @@ -21,4 +22,4 @@ def __init__(self, task_config):
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
)
3 changes: 2 additions & 1 deletion dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List


class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
Expand Down Expand Up @@ -43,4 +44,4 @@ class HotDistanceTaskConfig(TaskConfig):
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
)
17 changes: 11 additions & 6 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
Expand All @@ -10,15 +11,19 @@ def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance)

return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
return torch.nn.BCELoss().forward(prediction * weight, target * weight)

def distance_loss(self, prediction, target, weight):
return torch.nn.MSELoss().forward(prediction * weight, target * weight)

def split(self, x):
assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
assert (
x.shape[0] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[0] // 2
return x[:mid], x[-mid:]
return x[:mid], x[-mid:]
2 changes: 1 addition & 1 deletion dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: Optional[np.dtype] = np.uint8,
output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions
overwrite: bool = False,
):
# get the model's input and output size
Expand Down
4 changes: 2 additions & 2 deletions dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def balance_weights(
scale_slab *= np.take(w, labels_slab)

if cross_class:
# get maximum error scale using first dimension
# get maximum error scale using first dimension
shape = error_scale.shape
error_scale = np.max(error_scale, axis=0)
error_scale = np.broadcast_to(error_scale, shape)

# set error_scale to 0 in masked-out areas
for mask in masks:
error_scale = error_scale * mask
Expand Down
Loading