From a9764c8757f81c2b69fd8205ecdd867c7644b2f3 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 9 Feb 2024 14:39:22 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- .../architectures/cnnectome_unet.py | 128 ++++++++++-------- .../datasets/arrays/concat_array.py | 3 +- dacapo/experiments/run.py | 21 ++- dacapo/experiments/starts/start.py | 38 ++++-- .../tasks/predictors/distance_predictor.py | 10 +- .../experiments/trainers/gunpowder_trainer.py | 14 +- .../trainers/gunpowder_trainer_config.py | 8 +- dacapo/train.py | 6 +- dacapo/utils/balance_weights.py | 2 +- dacapo/validate.py | 2 +- 10 files changed, 140 insertions(+), 92 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 798620e04..ddf847456 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -320,31 +320,29 @@ def __init__( for _ in range(num_heads) ] ) -# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out + # if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out if self.use_attention: self.attention = nn.ModuleList( - [ - nn.ModuleList( [ - AttentionBlockModule( - F_g=num_fmaps * fmap_inc_factor ** (level + 1), - F_l=num_fmaps - * fmap_inc_factor - ** level, - F_int=num_fmaps - * fmap_inc_factor - ** (level + (1 - upsample_channel_contraction[level])) - if num_fmaps_out is None or level != 0 - else num_fmaps_out, - dims=self.dims, - upsample_factor=downsample_factors[level], + nn.ModuleList( + [ + AttentionBlockModule( + F_g=num_fmaps * fmap_inc_factor ** (level + 1), + F_l=num_fmaps * fmap_inc_factor**level, + F_int=num_fmaps + * fmap_inc_factor + ** (level + (1 - upsample_channel_contraction[level])) + if num_fmaps_out is None or level != 0 + else num_fmaps_out, + dims=self.dims, + upsample_factor=downsample_factors[level], + ) + for level in range(self.num_levels - 1) + ] ) - for level in range(self.num_levels - 1) + for _ in range(num_heads) ] ) - for _ in range(num_heads) - ] - ) # right convolutional passes self.r_conv = nn.ModuleList( @@ -389,12 +387,15 @@ def rec_forward(self, level, f_in): gs_out = self.rec_forward(level - 1, g_in) if self.use_attention: - f_left_attented = [self.attention[h][i](gs_out[h],f_left) for h in range(self.num_heads)] + f_left_attented = [ + self.attention[h][i](gs_out[h], f_left) + for h in range(self.num_heads) + ] fs_right = [ self.r_up[h][i](gs_out[h], f_left_attented[h]) for h in range(self.num_heads) ] - else: # up, concat, and crop + else: # up, concat, and crop fs_right = [ self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) ] @@ -617,44 +618,43 @@ def forward(self, g_out, f_left=None): return g_cropped - class AttentionBlockModule(nn.Module): def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): """Attention Block Module:: - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). - [g] --> W_g --\ /--> psi --> * --> [output] - \ / - [x] --> W_x --> [+] --> relu -- + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- - Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights - Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. - """ + """ super(AttentionBlockModule, self).__init__() self.dims = dims @@ -662,23 +662,36 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): if upsample_factor is not None: self.upsample_factor = upsample_factor else: - self.upsample_factor = (2,)*self.dims + self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same") + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same" + ) self.W_x = nn.Sequential( - ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, - activation=None, padding="same"), - Downsample(upsample_factor) + ConvPass( + F_l, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + ), + Downsample(upsample_factor), ) self.psi = ConvPass( - F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same") + F_int, + 1, + kernel_sizes=self.kernel_sizes, + activation="Sigmoid", + padding="same", + ) - up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] + up_mode = {2: "bilinear", 3: "trilinear"}[self.dims] - self.up = nn.Upsample(scale_factor=upsample_factor, mode=up_mode, align_corners=True) + self.up = nn.Upsample( + scale_factor=upsample_factor, mode=up_mode, align_corners=True + ) self.relu = nn.ReLU(inplace=True) @@ -702,8 +715,7 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): padding = padding[::-1] # Apply symmetric padding - return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0) - + return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0) def forward(self, g, x): g1 = self.W_g(g) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index df01129d8..71976393e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__file__) + class ConcatArray(Array): """This is a wrapper around other `source_arrays` that concatenates them along the channel dimension.""" @@ -119,7 +120,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - logger.info( + logger.info( f"Concatenated array has only one channel: {self.name} {concatenated.shape}" ) return concatenated diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 1609892c8..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__file__) + class Run: name: str train_until: int @@ -58,28 +59,34 @@ def __init__(self, run_config): return try: from ..store import create_config_store + start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config(run_config.start_config.run) + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) except Exception as e: - logger.error(f"could not load start config: {e} Should be added to the database config store RUN") + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) raise e - + # preloaded weights from previous run if run_config.task_config.name == starter_config.task_config.name: self.start = Start(run_config.start_config) else: # Match labels between old and new head - if hasattr(run_config.task_config,"channels"): + if hasattr(run_config.task_config, "channels"): # Map old head and new head old_head = starter_config.task_config.channels new_head = run_config.task_config.channels - self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head) + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) else: logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config,remove_head=True) + self.start = Start(run_config.start_config, remove_head=True) self.start.initialize_weights(self.model) - @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index f43ab6403..c64436294 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,15 +3,21 @@ logger = logging.getLogger(__file__) - # self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] - # self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] -head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"] +# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] +# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] +head_keys = [ + "prediction_head.weight", + "prediction_head.bias", + "chain.1.weight", + "chain.1.bias", +] # Hack # if label is mito_peroxisome or peroxisome then change it to mito -mitos = ["mito_proxisome","peroxisome"] +mitos = ["mito_proxisome", "peroxisome"] -def match_heads(model, head_weights, old_head, new_head ): + +def match_heads(model, head_weights, old_head, new_head): # match the heads for label in new_head: old_label = label @@ -30,8 +36,9 @@ def match_heads(model, head_weights, old_head, new_head ): model.state_dict()[key][new_index] = n_val logger.warning(f"matched head for {label} with {old_label}") + class Start(ABC): - def __init__(self, start_config,remove_head = False, old_head= None, new_head = None): + def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): self.run = start_config.run self.criterion = start_config.criterion self.remove_head = remove_head @@ -44,7 +51,9 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loading weights from run {self.run}, criterion: {self.criterion}" + ) try: if self.old_head and self.new_head: @@ -61,15 +70,21 @@ def initialize_weights(self, model): logger.warning(f"ERROR starter: {e}") def load_model_using_head_removal(self, model, weights): - logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"removing head from run {self.run}, criterion: {self.criterion}" + ) for key in head_keys: weights.model.pop(key, None) logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") model.load_state_dict(weights.model, strict=False) - logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" + ) def load_model_using_head_matching(self, model, weights): - logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.warning( + f"matching heads from run {self.run}, criterion: {self.criterion}" + ) logger.warning(f"old head: {self.old_head}") logger.warning(f"new head: {self.new_head}") head_weights = {} @@ -79,6 +94,3 @@ def load_model_using_head_matching(self, model, weights): weights.model.pop(key, None) model.load_state_dict(weights.model, strict=False) model = match_heads(model, head_weights, self.old_head, self.new_head) - - - diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 98aa2fa20..ca762fc3e 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,13 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + extra_conv: bool, + ): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -37,7 +43,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo self.epsilon = 5e-2 self.threshold = 0.8 self.extra_conv = extra_conv - self.extra_conv_dims =len(self.channels) *2 + self.extra_conv_dims = len(self.channels) * 2 @property def embedding_dims(self): diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 8a4bf8a2f..09ffd2230 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -43,7 +43,9 @@ def __init__(self, trainer_config): self.clip_raw = trainer_config.clip_raw # Testing out if calculating multiple times and multiplying is necessary - self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset + self.add_predictor_nodes_to_dataset = ( + trainer_config.add_predictor_nodes_to_dataset + ) self.finetune_head_only = trainer_config.finetune_head_only self.scheduler = None @@ -177,7 +179,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key, + weights_key=datasets_weight_key + if self.add_predictor_nodes_to_dataset + else weight_key, mask_key=mask_key, ) @@ -230,7 +234,9 @@ def iterate(self, num_iterations, model, optimizer, device): f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for param in model.parameters(): # TODO: get parameters from optimizer instead + for ( + param + ) in model.parameters(): # TODO: get parameters from optimizer instead param.grad = None t_start_prediction = time.time() @@ -352,4 +358,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def can_train(self, datasets) -> bool: - return all([dataset.gt is not None for dataset in datasets]) \ No newline at end of file + return all([dataset.gt is not None for dataset in datasets]) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 17cf411ce..5ed63eee8 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -32,10 +32,12 @@ class GunpowderTrainerConfig(TrainerConfig): add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( default=True, - metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"} + metadata={ + "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" + }, ) finetune_head_only: Optional[bool] = attr.ib( default=False, - metadata={"help_text": "Whether to fine-tune head only or all layers"} - ) \ No newline at end of file + metadata={"help_text": "Whether to fine-tune head only or all layers"}, + ) diff --git a/dacapo/train.py b/dacapo/train.py index e84d33613..5665e043c 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) -def train(run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda = False): +def train( + run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda=False +): """Train a run""" if compute_context.train(run_name): @@ -187,7 +189,7 @@ def train_run( ) # make sure to move optimizer back to the correct device - run.move_optimizer(compute_context.device) + run.move_optimizer(compute_context.device) run.model.train() logger.info("Trained until %d, finished.", trained_until) diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index 96fbc80e8..f5adcffca 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -77,4 +77,4 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) - return error_scale, moving_counts \ No newline at end of file + return error_scale, moving_counts diff --git a/dacapo/validate.py b/dacapo/validate.py index bce02a92e..fca055baf 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -143,7 +143,7 @@ def validate_run( run.name, iteration, validation_dataset ) logger.info("Predicting on dataset %s", validation_dataset.name) - + predict( run.model, validation_dataset.raw,