From c3a81b789bc30e36d9b3d8db05e0e8d7f37b19a4 Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 8 Feb 2024 17:43:40 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- dacapo/experiments/run.py | 21 ++++++++++++------- dacapo/experiments/tasks/hot_distance_task.py | 3 ++- .../tasks/hot_distance_task_config.py | 3 ++- .../tasks/losses/hot_distance_loss.py | 17 +++++++++------ dacapo/utils/balance_weights.py | 4 ++-- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 1609892c8..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__file__) + class Run: name: str train_until: int @@ -58,28 +59,34 @@ def __init__(self, run_config): return try: from ..store import create_config_store + start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) except Exception as e: - logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) raise e - + # preloaded weights from previous run if run_config.task_config.name == starter_config.task_config.name: self.start = Start(run_config.start_config) else: # Match labels between old and new head - if hasattr(run_config.task_config,"channels"): + if hasattr(run_config.task_config, "channels"): # Map old head and new head old_head = starter_config.task_config.channels new_head = run_config.task_config.channels - self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) else: logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config,remove_head=True) + self.start = Start(run_config.start_config, remove_head=True) self.start.initialize_weights(self.model) - @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index 7f1e4dd96..ef0d03229 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -4,6 +4,7 @@ from .predictors import HotDistancePredictor from .task import Task + class HotDistanceTask(Task): """This is just a Hot Distance Task that combine Binary and distance prediction.""" @@ -21,4 +22,4 @@ def __init__(self, task_config): clip_distance=task_config.clip_distance, tol_distance=task_config.tol_distance, channels=task_config.channels, - ) \ No newline at end of file + ) diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index aab2b01d6..951226476 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -5,6 +5,7 @@ from typing import List + class HotDistanceTaskConfig(TaskConfig): """This is a Hot Distance task config used for generating and evaluating signed distance transforms as a way of generating @@ -43,4 +44,4 @@ class HotDistanceTaskConfig(TaskConfig): "object boundary cannot be known. This is anywhere that the distance to crop boundary " "is less than the distance to object boundary." }, - ) \ No newline at end of file + ) diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 77f34fd08..2e99ab5e1 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -1,6 +1,7 @@ from .loss import Loss import torch + # HotDistance is used for predicting hot and distance maps at the same time. # The first half of the channels are the hot maps, the second half are the distance maps. # The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. @@ -10,15 +11,19 @@ def compute(self, prediction, target, weight): target_hot, target_distance = self.split(target) prediction_hot, prediction_distance = self.split(prediction) weight_hot, weight_distance = self.split(weight) - return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance) - + return self.hot_loss( + prediction_hot, target_hot, weight_hot + ) + self.distance_loss(prediction_distance, target_distance, weight_distance) + def hot_loss(self, prediction, target, weight): return torch.nn.BCELoss().forward(prediction * weight, target * weight) - + def distance_loss(self, prediction, target, weight): return torch.nn.MSELoss().forward(prediction * weight, target * weight) - + def split(self, x): - assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + assert ( + x.shape[0] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." mid = x.shape[0] // 2 - return x[:mid], x[-mid:] \ No newline at end of file + return x[:mid], x[-mid:] diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 949bde0c4..5cd5ee597 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -75,11 +75,11 @@ def balance_weights( scale_slab *= np.take(w, labels_slab) if cross_class: - # get maximum error scale using first dimension + # get maximum error scale using first dimension shape = error_scale.shape error_scale = np.max(error_scale, axis=0) error_scale = np.broadcast_to(error_scale, shape) - + # set error_scale to 0 in masked-out areas for mask in masks: error_scale = error_scale * mask