From c88016c5bba37c9f481956cce7c74074f8eee022 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 7 Feb 2024 21:22:29 +0000 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=8E=A8=20Black=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../datasets/arrays/concat_array.py | 5 ++- dacapo/experiments/run.py | 21 ++++++---- dacapo/experiments/starts/start.py | 38 ++++++++++++------- .../tasks/predictors/distance_predictor.py | 10 ++++- .../experiments/trainers/gunpowder_trainer.py | 14 +++++-- .../trainers/gunpowder_trainer_config.py | 8 ++-- dacapo/train.py | 2 +- dacapo/utils/balance_weights.py | 2 +- docs/source/conf.py | 23 +++++------ 9 files changed, 80 insertions(+), 43 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 59a350d49..1475c7b97 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__file__) + class ConcatArray(Array): """This is a wrapper around other `source_arrays` that concatenates them along the channel dimension.""" @@ -118,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - logger.info(f"Concatenated array has only one channel: {self.name} {concatenated.shape}") + logger.info( + f"Concatenated array has only one channel: {self.name} {concatenated.shape}" + ) return concatenated diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 1609892c8..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__file__) + class Run: name: str train_until: int @@ -58,28 +59,34 @@ def __init__(self, run_config): return try: from ..store import create_config_store + start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) except Exception as e: - logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) raise e - + # preloaded weights from previous run if run_config.task_config.name == starter_config.task_config.name: self.start = Start(run_config.start_config) else: # Match labels between old and new head - if hasattr(run_config.task_config,"channels"): + if hasattr(run_config.task_config, "channels"): # Map old head and new head old_head = starter_config.task_config.channels new_head = run_config.task_config.channels - self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) else: logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config,remove_head=True) + self.start = Start(run_config.start_config, remove_head=True) self.start.initialize_weights(self.model) - @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index f43ab6403..c64436294 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,15 +3,21 @@ logger = logging.getLogger(__file__) - # self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] - # self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] -head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"] +# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] +# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] +head_keys = [ + "prediction_head.weight", + "prediction_head.bias", + "chain.1.weight", + "chain.1.bias", +] # Hack # if label is mito_peroxisome or peroxisome then change it to mito -mitos = ["mito_proxisome","peroxisome"] +mitos = ["mito_proxisome", "peroxisome"] -def match_heads(model, head_weights, old_head, new_head ): + +def match_heads(model, head_weights, old_head, new_head): # match the heads for label in new_head: old_label = label @@ -30,8 +36,9 @@ def match_heads(model, head_weights, old_head, new_head ): model.state_dict()[key][new_index] = n_val logger.warning(f"matched head for {label} with {old_label}") + class Start(ABC): - def __init__(self, start_config,remove_head = False, old_head= None, new_head = None): + def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): self.run = start_config.run self.criterion = start_config.criterion self.remove_head = remove_head @@ -44,7 +51,9 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loading weights from run {self.run}, criterion: {self.criterion}" + ) try: if self.old_head and self.new_head: @@ -61,15 +70,21 @@ def initialize_weights(self, model): logger.warning(f"ERROR starter: {e}") def load_model_using_head_removal(self, model, weights): - logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"removing head from run {self.run}, criterion: {self.criterion}" + ) for key in head_keys: weights.model.pop(key, None) logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") model.load_state_dict(weights.model, strict=False) - logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" + ) def load_model_using_head_matching(self, model, weights): - logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"matching heads from run {self.run}, criterion: {self.criterion}" + ) logger.warning(f"old head: {self.old_head}") logger.warning(f"new head: {self.new_head}") head_weights = {} @@ -79,6 +94,3 @@ def load_model_using_head_matching(self, model, weights): weights.model.pop(key, None) model.load_state_dict(weights.model, strict=False) model = match_heads(model, head_weights, self.old_head, self.new_head) - - - diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 98aa2fa20..ca762fc3e 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,13 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + extra_conv: bool, + ): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -37,7 +43,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo self.epsilon = 5e-2 self.threshold = 0.8 self.extra_conv = extra_conv - self.extra_conv_dims =len(self.channels) *2 + self.extra_conv_dims = len(self.channels) * 2 @property def embedding_dims(self): diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 8a4bf8a2f..09ffd2230 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -43,7 +43,9 @@ def __init__(self, trainer_config): self.clip_raw = trainer_config.clip_raw # Testing out if calculating multiple times and multiplying is necessary - self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset + self.add_predictor_nodes_to_dataset = ( + trainer_config.add_predictor_nodes_to_dataset + ) self.finetune_head_only = trainer_config.finetune_head_only self.scheduler = None @@ -177,7 +179,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key, + weights_key=datasets_weight_key + if self.add_predictor_nodes_to_dataset + else weight_key, mask_key=mask_key, ) @@ -230,7 +234,9 @@ def iterate(self, num_iterations, model, optimizer, device): f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for param in model.parameters(): # TODO: get parameters from optimizer instead + for ( + param + ) in model.parameters(): # TODO: get parameters from optimizer instead param.grad = None t_start_prediction = time.time() @@ -352,4 +358,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def can_train(self, datasets) -> bool: - return all([dataset.gt is not None for dataset in datasets]) \ No newline at end of file + return all([dataset.gt is not None for dataset in datasets]) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 17cf411ce..5ed63eee8 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -32,10 +32,12 @@ class GunpowderTrainerConfig(TrainerConfig): add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( default=True, - metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"} + metadata={ + "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" + }, ) finetune_head_only: Optional[bool] = attr.ib( default=False, - metadata={"help_text": "Whether to fine-tune head only or all layers"} - ) \ No newline at end of file + metadata={"help_text": "Whether to fine-tune head only or all layers"}, + ) diff --git a/dacapo/train.py b/dacapo/train.py index 6f8894d7b..86473ee36 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -178,7 +178,7 @@ def train_run( stats_store.store_training_stats(run.name, run.training_stats) # make sure to move optimizer back to the correct device - run.move_optimizer(compute_context.device) + run.move_optimizer(compute_context.device) run.model.train() weights_store.store_weights(run, run.training_stats.trained_until()) diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 96fbc80e8..f5adcffca 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -77,4 +77,4 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) - return error_scale, moving_counts \ No newline at end of file + return error_scale, moving_counts diff --git a/docs/source/conf.py b/docs/source/conf.py index cd5823612..7df2f563b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'DaCapo' -copyright = '2022, William Patton, David Ackerman, Jan Funke' -author = 'William Patton, David Ackerman, Jan Funke' +project = "DaCapo" +copyright = "2022, William Patton, David Ackerman, Jan Funke" +author = "William Patton, David Ackerman, Jan Funke" # -- General configuration --------------------------------------------------- @@ -27,15 +28,15 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_autodoc_typehints'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -43,12 +44,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_material' +html_theme = "sphinx_material" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/custom.css', -] \ No newline at end of file + "css/custom.css", +]