From 82e2743d80959e63ff836edf3ca7d1e4e910a734 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Sat, 14 Oct 2023 21:26:43 -0400 Subject: [PATCH 01/14] fix use coord --- dacapo/experiments/datasplits/datasets/arrays/zarr_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index cadfcb6cd..42030e701 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -54,7 +54,7 @@ def axes(self): f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", ) - return ["t", "z", "y", "x"][-self.dims : :] + return ["c", "z", "y", "x"][-self.dims : :] @property def dims(self) -> int: From 8f648cd4e683bdc89c99cc824b32f5dd9bf0fa43 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Sat, 14 Oct 2023 21:31:00 -0400 Subject: [PATCH 02/14] fix use coord --- dacapo/experiments/datasplits/datasets/arrays/dvid_array.py | 2 +- dacapo/experiments/datasplits/datasets/arrays/numpy_array.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py index beaa474d1..e08ffe562 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py @@ -41,7 +41,7 @@ def attrs(self): @property def axes(self): - return ["t", "z", "y", "x"][-self.dims :] + return ["c", "z", "y", "x"][-self.dims :] @property def dims(self) -> int: diff --git a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py index 7101d737e..5f2bc0483 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array): ((["b", "c"] if len(array.data.shape) == instance.dims + 2 else [])) + (["c"] if len(array.data.shape) == instance.dims + 1 else []) + [ - "t", + "c", "z", "y", "x", From d95cf7aacf558167fc3b6b97e26556ed1be557d1 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Sat, 14 Oct 2023 21:31:19 -0400 Subject: [PATCH 03/14] weight cross class --- dacapo/utils/balance_weights.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index f5adcffca..949bde0c4 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -12,6 +12,7 @@ def balance_weights( clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: Optional[List[Dict[int, Tuple[int, int]]]] = None, + cross_class: bool = True, ): if moving_counts is None: moving_counts = [] @@ -29,10 +30,6 @@ def balance_weights( # initialize error scale with 1s error_scale = np.ones(label_data.shape, dtype=np.float32) - # set error_scale to 0 in masked-out areas - for mask in masks: - error_scale = error_scale * mask - if slab is None: slab = error_scale.shape else: @@ -77,4 +74,14 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) + if cross_class: + # get maximum error scale using first dimension + shape = error_scale.shape + error_scale = np.max(error_scale, axis=0) + error_scale = np.broadcast_to(error_scale, shape) + + # set error_scale to 0 in masked-out areas + for mask in masks: + error_scale = error_scale * mask + return error_scale, moving_counts From a8884a1f4d026fdfbe6dc96c8a27b7d7304696df Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 30 Oct 2023 17:53:16 -0400 Subject: [PATCH 04/14] start head matching --- dacapo/experiments/run.py | 34 ++++++++++++++++----- dacapo/experiments/starts/start.py | 48 ++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 129f947ab..1609892c8 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -6,9 +6,10 @@ from .validation_scores import ValidationScores from .starts import Start from .model import Model - +import logging import torch +logger = logging.getLogger(__file__) class Run: name: str @@ -53,14 +54,31 @@ def __init__(self, run_config): self.task.parameters, self.datasplit.validate, self.task.evaluation_scores ) + if run_config.start_config is None: + return + try: + from ..store import create_config_store + start_config_store = create_config_store() + starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + except Exception as e: + logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + raise e + # preloaded weights from previous run - self.start = ( - Start(run_config.start_config) - if run_config.start_config is not None - else None - ) - if self.start is not None: - self.start.initialize_weights(self.model) + if run_config.task_config.name == starter_config.task_config.name: + self.start = Start(run_config.start_config) + else: + # Match labels between old and new head + if hasattr(run_config.task_config,"channels"): + # Map old head and new head + old_head = starter_config.task_config.channels + new_head = run_config.task_config.channels + self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + else: + logger.warning("Not implemented channel match for this task") + self.start = Start(run_config.start_config,remove_head=True) + self.start.initialize_weights(self.model) + @staticmethod def get_validation_scores(run_config) -> ValidationScores: diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index a5b68069c..6d812fbfc 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,21 +3,63 @@ logger = logging.getLogger(__file__) + # self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] + # self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] + +def match_heads(model, weights, old_head, new_head ): + # match the heads + for label in new_head: + if label in old_head: + logger.warning(f"matching head for {label}") + # find the index of the label in the old_head + old_index = old_head.index(label) + # find the index of the label in the new_head + new_index = new_head.index(label) + # get the weight and bias of the old head + for key in ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]: + if key in model.state_dict().keys(): + n_val = weights.model[key][old_index] + model.state_dict()[key][new_index] = n_val + logger.warning(f"matched head for {label}") + return model class Start(ABC): - def __init__(self, start_config): + def __init__(self, start_config,remove_head = False, old_head= None, new_head = None): self.run = start_config.run self.criterion = start_config.criterion + self.remove_head = remove_head + self.old_head = old_head + self.new_head = new_head def initialize_weights(self, model): from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) + logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") - # load the model weights (taken from torch load_state_dict source) try: - model.load_state_dict(weights.model) + if self.old_head and self.new_head: + logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.info(f"old head: {self.old_head}") + logger.info(f"new head: {self.new_head}") + model = match_heads(model, weights, self.old_head, self.new_head) + logger.warning(f"matched heads from run {self.run}, criterion: {self.criterion}") + self.remove_head = True + if self.remove_head: + logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + weights.model.pop("prediction_head.weight", None) + weights.model.pop("prediction_head.bias", None) + weights.model.pop("chain.1.weight", None) + weights.model.pop("chain.1.bias", None) + logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") + model.load_state_dict(weights.model, strict=False) + logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + else: + model.load_state_dict(weights.model) except RuntimeError as e: logger.warning(e) + + + From 0a6a171c565017f67c04c374eb4916bb3beacf26 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 14 Nov 2023 16:45:40 -0500 Subject: [PATCH 05/14] head only train --- .../experiments/trainers/gunpowder_trainer.py | 44 ++++++++++++++----- .../trainers/gunpowder_trainer_config.py | 10 +++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index efec630f0..18902aa4e 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -42,10 +42,24 @@ def __init__(self, trainer_config): self.mask_integral_downsample_factor = 4 self.clip_raw = trainer_config.clip_raw + # Testing out if calculating multiple times and multiplying is necessary + self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset + self.finetune_head_only = trainer_config.finetune_head_only + self.scheduler = None def create_optimizer(self, model): - optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) + if self.finetune_head_only: + logger.warning("Finetuning head only") + parameters = [] + for key in model.state_dict().keys(): + if "prediction_head" in key: + parameters.append(model.state_dict()[key]) + else: + model.state_dict()[key].requires_grad = False + else: + parameters = model.parameters() + optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, @@ -146,13 +160,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) + if self.add_predictor_nodes_to_dataset: + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -162,11 +177,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key, + weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key, mask_key=mask_key, ) - pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) + if self.add_predictor_nodes_to_dataset: + pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) # Trainer attributes: if self.num_data_fetchers > 1: @@ -209,6 +225,11 @@ def iterate(self, num_iterations, model, optimizer, device): t_start_fetch = time.time() logger.info("Starting iteration!") + if self.finetune_head_only: + logger.warning("Finetuning head only") + for key in model.state_dict().keys(): + if "prediction_head" not in key: + model.state_dict()[key].requires_grad = False for iteration in range(self.iteration, self.iteration + num_iterations): raw, gt, target, weight, mask = self.next() @@ -227,6 +248,7 @@ def iterate(self, num_iterations, model, optimizer, device): torch.as_tensor(target[target.roi]).to(device).float(), torch.as_tensor(weight[weight.roi]).to(device).float(), ) + loss.backward() optimizer.step() @@ -337,4 +359,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def can_train(self, datasets) -> bool: - return all([dataset.gt is not None for dataset in datasets]) + return all([dataset.gt is not None for dataset in datasets]) \ No newline at end of file diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index ae4243059..17cf411ce 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -29,3 +29,13 @@ class GunpowderTrainerConfig(TrainerConfig): ) min_masked: Optional[float] = attr.ib(default=0.15) clip_raw: bool = attr.ib(default=True) + + add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( + default=True, + metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"} + ) + + finetune_head_only: Optional[bool] = attr.ib( + default=False, + metadata={"help_text": "Whether to fine-tune head only or all layers"} + ) \ No newline at end of file From 8933e76f2b58f19181ca387a858b787f3d428e2c Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 13:03:10 -0500 Subject: [PATCH 06/14] fix starter --- dacapo/experiments/starts/start.py | 67 ++++++++++++------- .../experiments/trainers/gunpowder_trainer.py | 17 ++--- dacapo/utils/balance_weights.py | 17 ++--- 3 files changed, 53 insertions(+), 48 deletions(-) diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index 6d812fbfc..f43ab6403 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -5,23 +5,30 @@ # self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] # self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] +head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"] -def match_heads(model, weights, old_head, new_head ): +# Hack +# if label is mito_peroxisome or peroxisome then change it to mito +mitos = ["mito_proxisome","peroxisome"] + +def match_heads(model, head_weights, old_head, new_head ): # match the heads for label in new_head: - if label in old_head: + old_label = label + if label in mitos: + old_label = "mito" + if old_label in old_head: logger.warning(f"matching head for {label}") # find the index of the label in the old_head - old_index = old_head.index(label) + old_index = old_head.index(old_label) # find the index of the label in the new_head new_index = new_head.index(label) # get the weight and bias of the old head - for key in ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]: + for key in head_keys: if key in model.state_dict().keys(): - n_val = weights.model[key][old_index] + n_val = head_weights[key][old_index] model.state_dict()[key][new_index] = n_val - logger.warning(f"matched head for {label}") - return model + logger.warning(f"matched head for {label} with {old_label}") class Start(ABC): def __init__(self, start_config,remove_head = False, old_head= None, new_head = None): @@ -37,29 +44,41 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") + logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}") try: if self.old_head and self.new_head: - logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") - logger.info(f"old head: {self.old_head}") - logger.info(f"new head: {self.new_head}") - model = match_heads(model, weights, self.old_head, self.new_head) - logger.warning(f"matched heads from run {self.run}, criterion: {self.criterion}") - self.remove_head = True - if self.remove_head: - logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") - weights.model.pop("prediction_head.weight", None) - weights.model.pop("prediction_head.bias", None) - weights.model.pop("chain.1.weight", None) - weights.model.pop("chain.1.bias", None) - logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") - model.load_state_dict(weights.model, strict=False) - logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + try: + self.load_model_using_head_matching(model, weights) + except RuntimeError as e: + logger.error(f"ERROR starter matching head: {e}") + self.load_model_using_head_removal(model, weights) + elif self.remove_head: + self.load_model_using_head_removal(model, weights) else: model.load_state_dict(weights.model) except RuntimeError as e: - logger.warning(e) + logger.warning(f"ERROR starter: {e}") + + def load_model_using_head_removal(self, model, weights): + logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + for key in head_keys: + weights.model.pop(key, None) + logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") + model.load_state_dict(weights.model, strict=False) + logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + + def load_model_using_head_matching(self, model, weights): + logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.warning(f"old head: {self.old_head}") + logger.warning(f"new head: {self.new_head}") + head_weights = {} + for key in head_keys: + head_weights[key] = weights.model[key] + for key in head_keys: + weights.model.pop(key, None) + model.load_state_dict(weights.model, strict=False) + model = match_heads(model, head_weights, self.old_head, self.new_head) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 18902aa4e..8a4bf8a2f 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -52,11 +52,11 @@ def create_optimizer(self, model): if self.finetune_head_only: logger.warning("Finetuning head only") parameters = [] - for key in model.state_dict().keys(): - if "prediction_head" in key: - parameters.append(model.state_dict()[key]) + for name, param in model.named_parameters(): + if "prediction_head" in name: + parameters.append(param) else: - model.state_dict()[key].requires_grad = False + param.requires_grad = False else: parameters = model.parameters() optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters) @@ -224,20 +224,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): def iterate(self, num_iterations, model, optimizer, device): t_start_fetch = time.time() - logger.info("Starting iteration!") - if self.finetune_head_only: - logger.warning("Finetuning head only") - for key in model.state_dict().keys(): - if "prediction_head" not in key: - model.state_dict()[key].requires_grad = False - for iteration in range(self.iteration, self.iteration + num_iterations): raw, gt, target, weight, mask = self.next() logger.debug( f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for param in model.parameters(): + for param in model.parameters(): # TODO: get parameters from optimizer instead param.grad = None t_start_prediction = time.time() diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 949bde0c4..96fbc80e8 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -12,7 +12,6 @@ def balance_weights( clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: Optional[List[Dict[int, Tuple[int, int]]]] = None, - cross_class: bool = True, ): if moving_counts is None: moving_counts = [] @@ -30,6 +29,10 @@ def balance_weights( # initialize error scale with 1s error_scale = np.ones(label_data.shape, dtype=np.float32) + # set error_scale to 0 in masked-out areas + for mask in masks: + error_scale = error_scale * mask + if slab is None: slab = error_scale.shape else: @@ -74,14 +77,4 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) - if cross_class: - # get maximum error scale using first dimension - shape = error_scale.shape - error_scale = np.max(error_scale, axis=0) - error_scale = np.broadcast_to(error_scale, shape) - - # set error_scale to 0 in masked-out areas - for mask in masks: - error_scale = error_scale * mask - - return error_scale, moving_counts + return error_scale, moving_counts \ No newline at end of file From a9e452d8d799d9d4b564c1dbe3ac244b35b316cd Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 13:07:28 -0500 Subject: [PATCH 07/14] fix bugs cpu bugs and more informative logs --- .../datasplits/datasets/arrays/concat_array.py | 6 +++++- dacapo/train.py | 14 ++++++++++---- dacapo/validate.py | 7 +++++-- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 122526b14..3090c17ee 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -5,7 +5,9 @@ import numpy as np from typing import Dict, Any +import logging +logger = logging.getLogger(__file__) class ConcatArray(Array): """This is a wrapper around other `source_arrays` that concatenates @@ -93,6 +95,7 @@ def num_channels(self): return len(self.channels) def __getitem__(self, roi: Roi) -> np.ndarray: + logger.info(f"Concat Array: Get Item {self.name} {roi}") default = ( np.zeros_like(self.source_array[roi]) if self.default_array is None @@ -116,5 +119,6 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - raise Exception(f"{concatenated.shape}, shapes") + logger.info(f"Concatenated array has only one channel: {self.name} {concatenated.shape}") + # raise Exception(f"{concatenated.shape}, shapes") return concatenated diff --git a/dacapo/train.py b/dacapo/train.py index 9203c1be3..c940b8889 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -12,10 +12,11 @@ logger = logging.getLogger(__name__) -def train(run_name: str, compute_context: ComputeContext = LocalTorch()): +def train(run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda = False): """Train a run""" if compute_context.train(run_name): + logger.error("Run %s is already being trained", run_name) # if compute context runs train in some other process # we are done here. return @@ -96,10 +97,15 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - raise RuntimeError( + weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}." ) + # raise RuntimeError( + # f"Found weights for iteration {latest_weights_iteration}, but " + # f"run {run.name} was only trained until {trained_until}." + # ) # start/resume training @@ -157,7 +163,7 @@ def train_run( run.model.eval() # free up optimizer memory to allow larger validation blocks - run.model = run.model.to(torch.device("cpu")) + # run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) weights_store.store_weights(run, iteration_stats.iteration + 1) @@ -172,7 +178,7 @@ def train_run( stats_store.store_training_stats(run.name, run.training_stats) # make sure to move optimizer back to the correct device - run.move_optimizer(compute_context.device) + run.move_optimizer(compute_context.device) run.model.train() weights_store.store_weights(run, run.training_stats.trained_until()) diff --git a/dacapo/validate.py b/dacapo/validate.py index 25b7463e1..3458aadf7 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -79,6 +79,7 @@ def validate_run( evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: + logger.warning("Validating on dataset %s", validation_dataset.name) assert ( validation_dataset.gt is not None ), "We do not yet support validating on datasets without ground truth" @@ -98,7 +99,7 @@ def validate_run( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - logger.info("Copying validation inputs!") + logger.warning("Copying validation inputs!") input_voxel_size = validation_dataset.raw.voxel_size output_voxel_size = run.model.scale(input_voxel_size) input_shape = run.model.eval_input_shape @@ -136,11 +137,12 @@ def validate_run( ) input_gt[output_roi] = validation_dataset.gt[output_roi] else: - logger.info("validation inputs already copied!") + logger.warning("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset ) + logger.warning("Predicting on dataset %s", validation_dataset.name) predict( run.model, validation_dataset.raw, @@ -148,6 +150,7 @@ def validate_run( compute_context=compute_context, output_roi=validation_dataset.gt.roi, ) + logger.warning("Predicted on dataset %s", validation_dataset.name) post_processor.set_prediction(prediction_array_identifier) From a8477354a3e7520f97d0409defc728af45e98c8e Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 13:08:10 -0500 Subject: [PATCH 08/14] extra conv used for head only --- dacapo/experiments/tasks/distance_task.py | 1 + .../experiments/tasks/distance_task_config.py | 7 +++ .../tasks/predictors/distance_predictor.py | 50 +++++++++++++++---- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/dacapo/experiments/tasks/distance_task.py b/dacapo/experiments/tasks/distance_task.py index cdb82e95c..2092d70d6 100644 --- a/dacapo/experiments/tasks/distance_task.py +++ b/dacapo/experiments/tasks/distance_task.py @@ -15,6 +15,7 @@ def __init__(self, task_config): channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, + extra_conv=task_config.extra_conv, ) self.loss = MSELoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/distance_task_config.py b/dacapo/experiments/tasks/distance_task_config.py index 130cf1c20..b4eb73e3f 100644 --- a/dacapo/experiments/tasks/distance_task_config.py +++ b/dacapo/experiments/tasks/distance_task_config.py @@ -46,3 +46,10 @@ class DistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) + + extra_conv: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether or not to add an extra conv layer before the head" + }, + ) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 70c2bde4a..98aa2fa20 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,7 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -36,20 +36,52 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo self.max_distance = 1 * scale_factor self.epsilon = 5e-2 self.threshold = 0.8 + self.extra_conv = extra_conv + self.extra_conv_dims =len(self.channels) *2 @property def embedding_dims(self): return len(self.channels) def create_model(self, architecture): - if architecture.dims == 2: - head = torch.nn.Conv2d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) - elif architecture.dims == 3: - head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) + if self.extra_conv: + if architecture.dims == 2: + head = torch.nn.Sequential( + torch.nn.Conv2d( + architecture.num_out_channels, + self.extra_conv_dims, + kernel_size=3, + padding=1, + ), + torch.nn.Conv2d( + self.extra_conv_dims, + self.embedding_dims, + kernel_size=1, + ), + ) + elif architecture.dims == 3: + head = torch.nn.Sequential( + torch.nn.Conv3d( + architecture.num_out_channels, + self.extra_conv_dims, + kernel_size=3, + padding=1, + ), + torch.nn.Conv3d( + self.extra_conv_dims, + self.embedding_dims, + kernel_size=1, + ), + ) + else: + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) return Model(architecture, head) From 20b540425f35f25b5fbf46247c51cd0fdc2c4fd2 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 15:15:43 -0500 Subject: [PATCH 09/14] attention block --- .../architectures/attention_unet.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 dacapo/experiments/architectures/attention_unet.py diff --git a/dacapo/experiments/architectures/attention_unet.py b/dacapo/experiments/architectures/attention_unet.py new file mode 100644 index 000000000..f9c7f767f --- /dev/null +++ b/dacapo/experiments/architectures/attention_unet.py @@ -0,0 +1,71 @@ + +import torch +import torch.nn as nn +from .cnnectome_unet import ConvPass,Downsample,Upsample + +class AttentionBlockModule(nn.Module): + def __init__(self, F_g, F_l, F_int, dims): + """Attention Block Module:: + + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- + + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. + + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. + + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. + + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. + + """ + + + super(AttentionBlockModule, self).__init__() + self.dims = dims + self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + print("kernel_sizes:",self.kernel_sizes) + + self.W_g = ConvPass(F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same") + + self.W_x = nn.Sequential( + ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same"), + Downsample((2,)*self.dims) + ) + + self.psi = ConvPass(F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid",padding="same") + + up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] + + self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, g, x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + psi = self.up(psi) + return x * psi \ No newline at end of file From 2c17e176d79650bcf92d81092cd7975b11ed637c Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 16:04:21 -0500 Subject: [PATCH 10/14] create CNNectomeUNetModule using attention --- .../architectures/attention_unet.py | 71 -------------- .../architectures/cnnectome_unet.py | 98 ++++++++++++++++++- 2 files changed, 94 insertions(+), 75 deletions(-) delete mode 100644 dacapo/experiments/architectures/attention_unet.py diff --git a/dacapo/experiments/architectures/attention_unet.py b/dacapo/experiments/architectures/attention_unet.py deleted file mode 100644 index f9c7f767f..000000000 --- a/dacapo/experiments/architectures/attention_unet.py +++ /dev/null @@ -1,71 +0,0 @@ - -import torch -import torch.nn as nn -from .cnnectome_unet import ConvPass,Downsample,Upsample - -class AttentionBlockModule(nn.Module): - def __init__(self, F_g, F_l, F_int, dims): - """Attention Block Module:: - - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). - - [g] --> W_g --\ /--> psi --> * --> [output] - \ / - [x] --> W_x --> [+] --> relu -- - - Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights - - Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. - - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. - - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. - - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. - - """ - - - super(AttentionBlockModule, self).__init__() - self.dims = dims - self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] - print("kernel_sizes:",self.kernel_sizes) - - self.W_g = ConvPass(F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same") - - self.W_x = nn.Sequential( - ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same"), - Downsample((2,)*self.dims) - ) - - self.psi = ConvPass(F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid",padding="same") - - up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] - - self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) - - self.relu = nn.ReLU(inplace=True) - - def forward(self, g, x): - g1 = self.W_g(g) - x1 = self.W_x(x) - psi = self.relu(g1 + x1) - psi = self.psi(psi) - psi = self.up(psi) - return x * psi \ No newline at end of file diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 01a261d09..8f3e74dfe 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -125,6 +125,7 @@ def __init__( padding="valid", upsample_channel_contraction=False, activation_on_upsample=False, + use_attention=False, ): """Create a U-Net:: @@ -244,6 +245,7 @@ def __init__( ) self.dims = len(downsample_factors[0]) + self.use_attention = use_attention # default arguments @@ -317,6 +319,17 @@ def __init__( ] ) + if self.use_attention: + self.attention = nn.ModuleList( + [ + AttentionBlockModule( + F_g=num_fmaps * fmap_inc_factor ** (level ), + F_l=num_fmaps * fmap_inc_factor ** (level ), + F_int=num_fmaps * fmap_inc_factor ** (level - 1), + dims=self.dims, + )for level in range(1,self.num_levels) + ]) + # right convolutional passes self.r_conv = nn.ModuleList( [ @@ -359,10 +372,16 @@ def rec_forward(self, level, f_in): # nested levels gs_out = self.rec_forward(level - 1, g_in) - # up, concat, and crop - fs_right = [ - self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) - ] + if self.use_attention: + f_left_attented = [self.attention[i-1](gs_out[h],f_left) for h in range(self.num_heads)] + fs_right = [ + self.r_up[h][i](gs_out[h], f_left_attented[h]) + for h in range(self.num_heads) + ] + else: # up, concat, and crop + fs_right = [ + self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) + ] # convolve fs_out = [self.r_conv[h][i](fs_right[h]) for h in range(self.num_heads)] @@ -580,3 +599,74 @@ def forward(self, g_out, f_left=None): return torch.cat([f_cropped, g_cropped], dim=1) else: return g_cropped + + + +class AttentionBlockModule(nn.Module): + def __init__(self, F_g, F_l, F_int, dims): + """Attention Block Module:: + + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- + + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. + + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. + + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. + + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. + + """ + + super(AttentionBlockModule, self).__init__() + self.dims = dims + self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + print("kernel_sizes:", self.kernel_sizes) + + self.W_g = ConvPass( + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same") + + self.W_x = nn.Sequential( + ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, + activation=None, padding="same"), + Downsample((2,)*self.dims) + ) + + self.psi = ConvPass( + F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same") + + up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] + + self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, g, x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + psi = self.up(psi) + return x * psi From e2a29749c57bcbff3a334ea90a8ee3792846027d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 19:16:37 -0500 Subject: [PATCH 11/14] unet using attention --- .../architectures/cnnectome_unet.py | 32 +++++++++++++++++-- .../architectures/cnnectome_unet_config.py | 6 ++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 8f3e74dfe..32cbe1744 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -25,6 +25,7 @@ def __init__(self, architecture_config): self.upsample_factors = ( self.upsample_factors if self.upsample_factors is not None else [] ) + self.use_attention = architecture_config.use_attention self.unet = self.module() @@ -64,6 +65,7 @@ def module(self): activation_on_upsample=True, upsample_channel_contraction=[False] + [True] * (len(downsample_factors) - 1), + use_attention=self.use_attention, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -323,9 +325,9 @@ def __init__( self.attention = nn.ModuleList( [ AttentionBlockModule( - F_g=num_fmaps * fmap_inc_factor ** (level ), - F_l=num_fmaps * fmap_inc_factor ** (level ), - F_int=num_fmaps * fmap_inc_factor ** (level - 1), + F_g=num_fmaps * fmap_inc_factor ** (self.num_levels - level ), + F_l=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), + F_int=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), dims=self.dims, )for level in range(1,self.num_levels) ]) @@ -663,9 +665,33 @@ def __init__(self, F_g, F_l, F_int, dims): self.relu = nn.ReLU(inplace=True) + def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): + """ + Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. + + Args: + smaller_tensor (Tensor): The tensor to be padded. + larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. + + Returns: + Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + """ + padding = [] + for i in range(2, 2 + self.dims): + diff = larger_tensor.size(i) - smaller_tensor.size(i) + padding.extend([diff // 2, diff - diff // 2]) + + # Reverse padding to match the 'pad' function's expectation + padding = padding[::-1] + + # Apply symmetric padding + return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0) + + def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) + g1 = self.calculate_and_apply_padding(g1, x1) psi = self.relu(g1 + x1) psi = self.psi(psi) psi = self.up(psi) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 5a40cca6d..c0e9e5b9d 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -82,3 +82,9 @@ class CNNectomeUNetConfig(ArchitectureConfig): default="valid", metadata={"help_text": "The padding to use in convolution operations."}, ) + use_attention: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D." + }, + ) From c45e93c84257f04bef07f2e959945442e5d104ff Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 21:36:36 -0500 Subject: [PATCH 12/14] fix fmap calculation for attention --- .../architectures/cnnectome_unet.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 32cbe1744..798620e04 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -320,17 +320,31 @@ def __init__( for _ in range(num_heads) ] ) - +# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out if self.use_attention: self.attention = nn.ModuleList( + [ + nn.ModuleList( [ AttentionBlockModule( - F_g=num_fmaps * fmap_inc_factor ** (self.num_levels - level ), - F_l=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), - F_int=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), + F_g=num_fmaps * fmap_inc_factor ** (level + 1), + F_l=num_fmaps + * fmap_inc_factor + ** level, + F_int=num_fmaps + * fmap_inc_factor + ** (level + (1 - upsample_channel_contraction[level])) + if num_fmaps_out is None or level != 0 + else num_fmaps_out, dims=self.dims, - )for level in range(1,self.num_levels) - ]) + upsample_factor=downsample_factors[level], + ) + for level in range(self.num_levels - 1) + ] + ) + for _ in range(num_heads) + ] + ) # right convolutional passes self.r_conv = nn.ModuleList( @@ -375,7 +389,7 @@ def rec_forward(self, level, f_in): gs_out = self.rec_forward(level - 1, g_in) if self.use_attention: - f_left_attented = [self.attention[i-1](gs_out[h],f_left) for h in range(self.num_heads)] + f_left_attented = [self.attention[h][i](gs_out[h],f_left) for h in range(self.num_heads)] fs_right = [ self.r_up[h][i](gs_out[h], f_left_attented[h]) for h in range(self.num_heads) @@ -605,7 +619,7 @@ def forward(self, g_out, f_left=None): class AttentionBlockModule(nn.Module): - def __init__(self, F_g, F_l, F_int, dims): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): """Attention Block Module:: The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). @@ -645,7 +659,10 @@ def __init__(self, F_g, F_l, F_int, dims): super(AttentionBlockModule, self).__init__() self.dims = dims self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] - print("kernel_sizes:", self.kernel_sizes) + if upsample_factor is not None: + self.upsample_factor = upsample_factor + else: + self.upsample_factor = (2,)*self.dims self.W_g = ConvPass( F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same") @@ -653,7 +670,7 @@ def __init__(self, F_g, F_l, F_int, dims): self.W_x = nn.Sequential( ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"), - Downsample((2,)*self.dims) + Downsample(upsample_factor) ) self.psi = ConvPass( @@ -661,7 +678,7 @@ def __init__(self, F_g, F_l, F_int, dims): up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] - self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) + self.up = nn.Upsample(scale_factor=upsample_factor, mode=up_mode, align_corners=True) self.relu = nn.ReLU(inplace=True) From a9764c8757f81c2b69fd8205ecdd867c7644b2f3 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 9 Feb 2024 14:39:22 +0000 Subject: [PATCH 13/14] :art: Format Python code with psf/black --- .../architectures/cnnectome_unet.py | 128 ++++++++++-------- .../datasets/arrays/concat_array.py | 3 +- dacapo/experiments/run.py | 21 ++- dacapo/experiments/starts/start.py | 38 ++++-- .../tasks/predictors/distance_predictor.py | 10 +- .../experiments/trainers/gunpowder_trainer.py | 14 +- .../trainers/gunpowder_trainer_config.py | 8 +- dacapo/train.py | 6 +- dacapo/utils/balance_weights.py | 2 +- dacapo/validate.py | 2 +- 10 files changed, 140 insertions(+), 92 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 798620e04..ddf847456 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -320,31 +320,29 @@ def __init__( for _ in range(num_heads) ] ) -# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out + # if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out if self.use_attention: self.attention = nn.ModuleList( - [ - nn.ModuleList( [ - AttentionBlockModule( - F_g=num_fmaps * fmap_inc_factor ** (level + 1), - F_l=num_fmaps - * fmap_inc_factor - ** level, - F_int=num_fmaps - * fmap_inc_factor - ** (level + (1 - upsample_channel_contraction[level])) - if num_fmaps_out is None or level != 0 - else num_fmaps_out, - dims=self.dims, - upsample_factor=downsample_factors[level], + nn.ModuleList( + [ + AttentionBlockModule( + F_g=num_fmaps * fmap_inc_factor ** (level + 1), + F_l=num_fmaps * fmap_inc_factor**level, + F_int=num_fmaps + * fmap_inc_factor + ** (level + (1 - upsample_channel_contraction[level])) + if num_fmaps_out is None or level != 0 + else num_fmaps_out, + dims=self.dims, + upsample_factor=downsample_factors[level], + ) + for level in range(self.num_levels - 1) + ] ) - for level in range(self.num_levels - 1) + for _ in range(num_heads) ] ) - for _ in range(num_heads) - ] - ) # right convolutional passes self.r_conv = nn.ModuleList( @@ -389,12 +387,15 @@ def rec_forward(self, level, f_in): gs_out = self.rec_forward(level - 1, g_in) if self.use_attention: - f_left_attented = [self.attention[h][i](gs_out[h],f_left) for h in range(self.num_heads)] + f_left_attented = [ + self.attention[h][i](gs_out[h], f_left) + for h in range(self.num_heads) + ] fs_right = [ self.r_up[h][i](gs_out[h], f_left_attented[h]) for h in range(self.num_heads) ] - else: # up, concat, and crop + else: # up, concat, and crop fs_right = [ self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) ] @@ -617,44 +618,43 @@ def forward(self, g_out, f_left=None): return g_cropped - class AttentionBlockModule(nn.Module): def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): """Attention Block Module:: - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). - [g] --> W_g --\ /--> psi --> * --> [output] - \ / - [x] --> W_x --> [+] --> relu -- + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- - Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights - Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. - """ + """ super(AttentionBlockModule, self).__init__() self.dims = dims @@ -662,23 +662,36 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): if upsample_factor is not None: self.upsample_factor = upsample_factor else: - self.upsample_factor = (2,)*self.dims + self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same") + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same" + ) self.W_x = nn.Sequential( - ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, - activation=None, padding="same"), - Downsample(upsample_factor) + ConvPass( + F_l, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + ), + Downsample(upsample_factor), ) self.psi = ConvPass( - F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same") + F_int, + 1, + kernel_sizes=self.kernel_sizes, + activation="Sigmoid", + padding="same", + ) - up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] + up_mode = {2: "bilinear", 3: "trilinear"}[self.dims] - self.up = nn.Upsample(scale_factor=upsample_factor, mode=up_mode, align_corners=True) + self.up = nn.Upsample( + scale_factor=upsample_factor, mode=up_mode, align_corners=True + ) self.relu = nn.ReLU(inplace=True) @@ -702,8 +715,7 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): padding = padding[::-1] # Apply symmetric padding - return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0) - + return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0) def forward(self, g, x): g1 = self.W_g(g) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index df01129d8..71976393e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__file__) + class ConcatArray(Array): """This is a wrapper around other `source_arrays` that concatenates them along the channel dimension.""" @@ -119,7 +120,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - logger.info( + logger.info( f"Concatenated array has only one channel: {self.name} {concatenated.shape}" ) return concatenated diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 1609892c8..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__file__) + class Run: name: str train_until: int @@ -58,28 +59,34 @@ def __init__(self, run_config): return try: from ..store import create_config_store + start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) except Exception as e: - logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) raise e - + # preloaded weights from previous run if run_config.task_config.name == starter_config.task_config.name: self.start = Start(run_config.start_config) else: # Match labels between old and new head - if hasattr(run_config.task_config,"channels"): + if hasattr(run_config.task_config, "channels"): # Map old head and new head old_head = starter_config.task_config.channels new_head = run_config.task_config.channels - self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) else: logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config,remove_head=True) + self.start = Start(run_config.start_config, remove_head=True) self.start.initialize_weights(self.model) - @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index f43ab6403..c64436294 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,15 +3,21 @@ logger = logging.getLogger(__file__) - # self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] - # self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] -head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"] +# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] +# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] +head_keys = [ + "prediction_head.weight", + "prediction_head.bias", + "chain.1.weight", + "chain.1.bias", +] # Hack # if label is mito_peroxisome or peroxisome then change it to mito -mitos = ["mito_proxisome","peroxisome"] +mitos = ["mito_proxisome", "peroxisome"] -def match_heads(model, head_weights, old_head, new_head ): + +def match_heads(model, head_weights, old_head, new_head): # match the heads for label in new_head: old_label = label @@ -30,8 +36,9 @@ def match_heads(model, head_weights, old_head, new_head ): model.state_dict()[key][new_index] = n_val logger.warning(f"matched head for {label} with {old_label}") + class Start(ABC): - def __init__(self, start_config,remove_head = False, old_head= None, new_head = None): + def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): self.run = start_config.run self.criterion = start_config.criterion self.remove_head = remove_head @@ -44,7 +51,9 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loading weights from run {self.run}, criterion: {self.criterion}" + ) try: if self.old_head and self.new_head: @@ -61,15 +70,21 @@ def initialize_weights(self, model): logger.warning(f"ERROR starter: {e}") def load_model_using_head_removal(self, model, weights): - logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"removing head from run {self.run}, criterion: {self.criterion}" + ) for key in head_keys: weights.model.pop(key, None) logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") model.load_state_dict(weights.model, strict=False) - logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" + ) def load_model_using_head_matching(self, model, weights): - logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"matching heads from run {self.run}, criterion: {self.criterion}" + ) logger.warning(f"old head: {self.old_head}") logger.warning(f"new head: {self.new_head}") head_weights = {} @@ -79,6 +94,3 @@ def load_model_using_head_matching(self, model, weights): weights.model.pop(key, None) model.load_state_dict(weights.model, strict=False) model = match_heads(model, head_weights, self.old_head, self.new_head) - - - diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 98aa2fa20..ca762fc3e 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,13 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + extra_conv: bool, + ): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -37,7 +43,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo self.epsilon = 5e-2 self.threshold = 0.8 self.extra_conv = extra_conv - self.extra_conv_dims =len(self.channels) *2 + self.extra_conv_dims = len(self.channels) * 2 @property def embedding_dims(self): diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 8a4bf8a2f..09ffd2230 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -43,7 +43,9 @@ def __init__(self, trainer_config): self.clip_raw = trainer_config.clip_raw # Testing out if calculating multiple times and multiplying is necessary - self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset + self.add_predictor_nodes_to_dataset = ( + trainer_config.add_predictor_nodes_to_dataset + ) self.finetune_head_only = trainer_config.finetune_head_only self.scheduler = None @@ -177,7 +179,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key, + weights_key=datasets_weight_key + if self.add_predictor_nodes_to_dataset + else weight_key, mask_key=mask_key, ) @@ -230,7 +234,9 @@ def iterate(self, num_iterations, model, optimizer, device): f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for param in model.parameters(): # TODO: get parameters from optimizer instead + for ( + param + ) in model.parameters(): # TODO: get parameters from optimizer instead param.grad = None t_start_prediction = time.time() @@ -352,4 +358,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def can_train(self, datasets) -> bool: - return all([dataset.gt is not None for dataset in datasets]) \ No newline at end of file + return all([dataset.gt is not None for dataset in datasets]) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 17cf411ce..5ed63eee8 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -32,10 +32,12 @@ class GunpowderTrainerConfig(TrainerConfig): add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( default=True, - metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"} + metadata={ + "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" + }, ) finetune_head_only: Optional[bool] = attr.ib( default=False, - metadata={"help_text": "Whether to fine-tune head only or all layers"} - ) \ No newline at end of file + metadata={"help_text": "Whether to fine-tune head only or all layers"}, + ) diff --git a/dacapo/train.py b/dacapo/train.py index e84d33613..5665e043c 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) -def train(run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda = False): +def train( + run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda=False +): """Train a run""" if compute_context.train(run_name): @@ -187,7 +189,7 @@ def train_run( ) # make sure to move optimizer back to the correct device - run.move_optimizer(compute_context.device) + run.move_optimizer(compute_context.device) run.model.train() logger.info("Trained until %d, finished.", trained_until) diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 96fbc80e8..f5adcffca 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -77,4 +77,4 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) - return error_scale, moving_counts \ No newline at end of file + return error_scale, moving_counts diff --git a/dacapo/validate.py b/dacapo/validate.py index bce02a92e..fca055baf 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -143,7 +143,7 @@ def validate_run( run.name, iteration, validation_dataset ) logger.info("Predicting on dataset %s", validation_dataset.name) - + predict( run.model, validation_dataset.raw, From a16448c66d1045a289a5eeb37e10b98105533ece Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 14 Feb 2024 09:52:43 -0500 Subject: [PATCH 14/14] remove extra irrelevant stuff --- .../datasets/arrays/concat_array.py | 1 - dacapo/experiments/run.py | 41 ++------ dacapo/experiments/starts/start.py | 97 ++++--------------- dacapo/experiments/tasks/distance_task.py | 1 - .../experiments/tasks/distance_task_config.py | 7 -- .../tasks/predictors/distance_predictor.py | 56 ++--------- .../experiments/trainers/gunpowder_trainer.py | 20 +--- .../trainers/gunpowder_trainer_config.py | 5 - dacapo/train.py | 10 +- dacapo/validate.py | 6 +- 10 files changed, 43 insertions(+), 201 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 71976393e..1475c7b97 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -96,7 +96,6 @@ def num_channels(self): return len(self.channels) def __getitem__(self, roi: Roi) -> np.ndarray: - logger.info(f"Concat Array: Get Item {self.name} {roi}") default = ( np.zeros_like(self.source_array[roi]) if self.default_array is None diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 9ea496758..129f947ab 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -6,10 +6,8 @@ from .validation_scores import ValidationScores from .starts import Start from .model import Model -import logging -import torch -logger = logging.getLogger(__file__) +import torch class Run: @@ -55,37 +53,14 @@ def __init__(self, run_config): self.task.parameters, self.datasplit.validate, self.task.evaluation_scores ) - if run_config.start_config is None: - return - try: - from ..store import create_config_store - - start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config( - run_config.start_config.run - ) - except Exception as e: - logger.error( - f"could not load start config: {e} Should be added to the database config store RUN" - ) - raise e - # preloaded weights from previous run - if run_config.task_config.name == starter_config.task_config.name: - self.start = Start(run_config.start_config) - else: - # Match labels between old and new head - if hasattr(run_config.task_config, "channels"): - # Map old head and new head - old_head = starter_config.task_config.channels - new_head = run_config.task_config.channels - self.start = Start( - run_config.start_config, old_head=old_head, new_head=new_head - ) - else: - logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config, remove_head=True) - self.start.initialize_weights(self.model) + self.start = ( + Start(run_config.start_config) + if run_config.start_config is not None + else None + ) + if self.start is not None: + self.start.initialize_weights(self.model) @staticmethod def get_validation_scores(run_config) -> ValidationScores: diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index c64436294..da7badbf9 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,94 +3,33 @@ logger = logging.getLogger(__file__) -# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] -# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] -head_keys = [ - "prediction_head.weight", - "prediction_head.bias", - "chain.1.weight", - "chain.1.bias", -] - -# Hack -# if label is mito_peroxisome or peroxisome then change it to mito -mitos = ["mito_proxisome", "peroxisome"] - - -def match_heads(model, head_weights, old_head, new_head): - # match the heads - for label in new_head: - old_label = label - if label in mitos: - old_label = "mito" - if old_label in old_head: - logger.warning(f"matching head for {label}") - # find the index of the label in the old_head - old_index = old_head.index(old_label) - # find the index of the label in the new_head - new_index = new_head.index(label) - # get the weight and bias of the old head - for key in head_keys: - if key in model.state_dict().keys(): - n_val = head_weights[key][old_index] - model.state_dict()[key][new_index] = n_val - logger.warning(f"matched head for {label} with {old_label}") - class Start(ABC): - def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): + def __init__(self, start_config): self.run = start_config.run self.criterion = start_config.criterion - self.remove_head = remove_head - self.old_head = old_head - self.new_head = new_head def initialize_weights(self, model): from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - - logger.warning( - f"loading weights from run {self.run}, criterion: {self.criterion}" - ) - + logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") + # load the model weights (taken from torch load_state_dict source) try: - if self.old_head and self.new_head: - try: - self.load_model_using_head_matching(model, weights) - except RuntimeError as e: - logger.error(f"ERROR starter matching head: {e}") - self.load_model_using_head_removal(model, weights) - elif self.remove_head: - self.load_model_using_head_removal(model, weights) - else: - model.load_state_dict(weights.model) + model.load_state_dict(weights.model) except RuntimeError as e: - logger.warning(f"ERROR starter: {e}") - - def load_model_using_head_removal(self, model, weights): - logger.warning( - f"removing head from run {self.run}, criterion: {self.criterion}" - ) - for key in head_keys: - weights.model.pop(key, None) - logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") - model.load_state_dict(weights.model, strict=False) - logger.warning( - f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" - ) - - def load_model_using_head_matching(self, model, weights): - logger.warning( - f"matching heads from run {self.run}, criterion: {self.criterion}" - ) - logger.warning(f"old head: {self.old_head}") - logger.warning(f"new head: {self.new_head}") - head_weights = {} - for key in head_keys: - head_weights[key] = weights.model[key] - for key in head_keys: - weights.model.pop(key, None) - model.load_state_dict(weights.model, strict=False) - model = match_heads(model, head_weights, self.old_head, self.new_head) + logger.warning(e) + # if the model is not the same, we can try to load the weights + # of the common layers + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in weights.model.items() + if k in model_dict and v.size() == model_dict[k].size() + } + model_dict.update( + pretrained_dict + ) # update only the existing and matching layers + model.load_state_dict(model_dict) + logger.warning(f"loaded only common layers from weights") diff --git a/dacapo/experiments/tasks/distance_task.py b/dacapo/experiments/tasks/distance_task.py index 2092d70d6..cdb82e95c 100644 --- a/dacapo/experiments/tasks/distance_task.py +++ b/dacapo/experiments/tasks/distance_task.py @@ -15,7 +15,6 @@ def __init__(self, task_config): channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, - extra_conv=task_config.extra_conv, ) self.loss = MSELoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/distance_task_config.py b/dacapo/experiments/tasks/distance_task_config.py index b4eb73e3f..130cf1c20 100644 --- a/dacapo/experiments/tasks/distance_task_config.py +++ b/dacapo/experiments/tasks/distance_task_config.py @@ -46,10 +46,3 @@ class DistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) - - extra_conv: bool = attr.ib( - default=False, - metadata={ - "help_text": "Whether or not to add an extra conv layer before the head" - }, - ) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index ca762fc3e..70c2bde4a 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,13 +27,7 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__( - self, - channels: List[str], - scale_factor: float, - mask_distances: bool, - extra_conv: bool, - ): + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -42,52 +36,20 @@ def __init__( self.max_distance = 1 * scale_factor self.epsilon = 5e-2 self.threshold = 0.8 - self.extra_conv = extra_conv - self.extra_conv_dims = len(self.channels) * 2 @property def embedding_dims(self): return len(self.channels) def create_model(self, architecture): - if self.extra_conv: - if architecture.dims == 2: - head = torch.nn.Sequential( - torch.nn.Conv2d( - architecture.num_out_channels, - self.extra_conv_dims, - kernel_size=3, - padding=1, - ), - torch.nn.Conv2d( - self.extra_conv_dims, - self.embedding_dims, - kernel_size=1, - ), - ) - elif architecture.dims == 3: - head = torch.nn.Sequential( - torch.nn.Conv3d( - architecture.num_out_channels, - self.extra_conv_dims, - kernel_size=3, - padding=1, - ), - torch.nn.Conv3d( - self.extra_conv_dims, - self.embedding_dims, - kernel_size=1, - ), - ) - else: - if architecture.dims == 2: - head = torch.nn.Conv2d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) - elif architecture.dims == 3: - head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) return Model(architecture, head) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 09ffd2230..f5d8fcd52 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -46,22 +46,11 @@ def __init__(self, trainer_config): self.add_predictor_nodes_to_dataset = ( trainer_config.add_predictor_nodes_to_dataset ) - self.finetune_head_only = trainer_config.finetune_head_only self.scheduler = None def create_optimizer(self, model): - if self.finetune_head_only: - logger.warning("Finetuning head only") - parameters = [] - for name, param in model.named_parameters(): - if "prediction_head" in name: - parameters.append(param) - else: - param.requires_grad = False - else: - parameters = model.parameters() - optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters) + optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, @@ -228,15 +217,15 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): def iterate(self, num_iterations, model, optimizer, device): t_start_fetch = time.time() + logger.info("Starting iteration!") + for iteration in range(self.iteration, self.iteration + num_iterations): raw, gt, target, weight, mask = self.next() logger.debug( f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for ( - param - ) in model.parameters(): # TODO: get parameters from optimizer instead + for param in model.parameters(): param.grad = None t_start_prediction = time.time() @@ -247,7 +236,6 @@ def iterate(self, num_iterations, model, optimizer, device): torch.as_tensor(target[target.roi]).to(device).float(), torch.as_tensor(weight[weight.roi]).to(device).float(), ) - loss.backward() optimizer.step() diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 5ed63eee8..539e3c5e1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -36,8 +36,3 @@ class GunpowderTrainerConfig(TrainerConfig): "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" }, ) - - finetune_head_only: Optional[bool] = attr.ib( - default=False, - metadata={"help_text": "Whether to fine-tune head only or all layers"}, - ) diff --git a/dacapo/train.py b/dacapo/train.py index 5665e043c..7beb096b4 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -12,9 +12,7 @@ logger = logging.getLogger(__name__) -def train( - run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda=False -): +def train(run_name: str, compute_context: ComputeContext = LocalTorch()): """Train a run""" if compute_context.train(run_name): @@ -104,10 +102,6 @@ def train_run( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " ) - # raise RuntimeError( - # f"Found weights for iteration {latest_weights_iteration}, but " - # f"run {run.name} was only trained until {trained_until}." - # ) # start/resume training @@ -167,7 +161,7 @@ def train_run( run.model.eval() # free up optimizer memory to allow larger validation blocks - # run.model = run.model.to(torch.device("cpu")) + run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) stats_store.store_training_stats(run.name, run.training_stats) diff --git a/dacapo/validate.py b/dacapo/validate.py index fca055baf..a1cf9da7d 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -79,7 +79,6 @@ def validate_run( evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: - logger.warning("Validating on dataset %s", validation_dataset.name) assert ( validation_dataset.gt is not None ), "We do not yet support validating on datasets without ground truth" @@ -99,7 +98,7 @@ def validate_run( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - logger.warning("Copying validation inputs!") + logger.info("Copying validation inputs!") input_voxel_size = validation_dataset.raw.voxel_size output_voxel_size = run.model.scale(input_voxel_size) input_shape = run.model.eval_input_shape @@ -137,13 +136,12 @@ def validate_run( ) input_gt[output_roi] = validation_dataset.gt[output_roi] else: - logger.warning("validation inputs already copied!") + logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset ) logger.info("Predicting on dataset %s", validation_dataset.name) - predict( run.model, validation_dataset.raw,