From 33ee65f1b6332fddda0f6239d2d192421318d23b Mon Sep 17 00:00:00 2001 From: LVeefkind Date: Thu, 21 Nov 2024 14:34:01 +0100 Subject: [PATCH] fix finetuning of pos embedding --- neural_networks/train_nn.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index df9db03..19cedd2 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -153,7 +153,7 @@ def __init__( use_lora: bool = False, alpha: float = 16.0, rank: int = 16, - tune_pos_embed: bool = False, + pos_embed: bool = False, ): super().__init__() @@ -170,7 +170,7 @@ def __init__( "use_lora": use_lora, "rank": rank, "alpha": alpha, - "tune_pos_embed": tune_pos_embed, + "pos_embed": pos_embed, } if lift == "stack": @@ -194,6 +194,13 @@ def __init__( model_name, get_classifier_f, use_lora=use_lora, alpha=alpha, rank=rank ) # self.classifier = get_classifier_f(n_features=num_features) + if "zeros" in self.kwargs["pos_embed"]: + self.dino.encoder.pos_embed[:, 1:, :] = torch.zeros_like( + self.dino.encoder.pos_embed[:, 1:, :] + ) + self.dino.encoder.pos_embed.requires_grad = ( + True if "fine-tune" in self.kwargs["pos_embed"] else False + ) self.forward = self.dino_forward else: @@ -246,7 +253,6 @@ def eval(self): self.dino.eval() else: self.dino.decoder.eval() - self.dino.encoder.pos_embed.requires_grad = False else: self.classifier.eval() @@ -259,7 +265,7 @@ def train(self): else: self.dino.decoder.train() # Finetune learnable pos_embedding - self.dino.encoder.pos_embed.requires_grad = self.kwargs["tune_pos_embed"] + else: self.classifier.train() @@ -389,7 +395,7 @@ def main( log_path: Path = "runs", epochs: int = 120, flip_augmentations: bool = False, - tune_pos_embed: bool = False, + pos_embed: str = "pre-trained", ): torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True @@ -424,7 +430,7 @@ def main( alpha=alpha, lift=lift, flip_augmentations=flip_augmentations, - tune_pos_embed=tune_pos_embed, + pos_embed=pos_embed, ) writer = get_tensorboard_logger(logging_dir) @@ -439,7 +445,7 @@ def main( use_lora=use_lora, alpha=alpha, rank=rank, - tune_pos_embed=tune_pos_embed, + pos_embed=pos_embed, ) # noinspection PyArgumentList @@ -518,7 +524,7 @@ def main( "rank": rank, "alpha": alpha, "dropout_p": dropout_p, - "tune_pos_embed": tune_pos_embed, + "pos_embed": pos_embed, }, ) @@ -636,7 +642,6 @@ def train_step( enumerate(train_dataloader), desc="Training", total=len(train_dataloader) ): global_step += 1 - data, labels = prepare_data_f(data, labels) smoothed_label = smoothing_fn(labels) data = augmentation_fn(data) @@ -895,9 +900,11 @@ def get_argparser(): ) parser.add_argument( - "--tune_pos_embed", - action="store_true", - help="Whether to fine-tune the positional embedding if applicable", + "--pos_embed", + type=str, + default="pre-trained", + choices=["pre-trained", "fine-tune", "zeros", "zeros-fine-tune"], + help="How to handle positional embeddings", ) parser.add_argument("--rank", type=int, default=16, help="rank of LoRA")