Skip to content

Commit

Permalink
fix finetuning of pos embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Nov 21, 2024
1 parent 6245511 commit 33ee65f
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
use_lora: bool = False,
alpha: float = 16.0,
rank: int = 16,
tune_pos_embed: bool = False,
pos_embed: bool = False,
):
super().__init__()

Expand All @@ -170,7 +170,7 @@ def __init__(
"use_lora": use_lora,
"rank": rank,
"alpha": alpha,
"tune_pos_embed": tune_pos_embed,
"pos_embed": pos_embed,
}

if lift == "stack":
Expand All @@ -194,6 +194,13 @@ def __init__(
model_name, get_classifier_f, use_lora=use_lora, alpha=alpha, rank=rank
)
# self.classifier = get_classifier_f(n_features=num_features)
if "zeros" in self.kwargs["pos_embed"]:
self.dino.encoder.pos_embed[:, 1:, :] = torch.zeros_like(
self.dino.encoder.pos_embed[:, 1:, :]
)
self.dino.encoder.pos_embed.requires_grad = (
True if "fine-tune" in self.kwargs["pos_embed"] else False
)
self.forward = self.dino_forward

else:
Expand Down Expand Up @@ -246,7 +253,6 @@ def eval(self):
self.dino.eval()
else:
self.dino.decoder.eval()
self.dino.encoder.pos_embed.requires_grad = False
else:
self.classifier.eval()

Expand All @@ -259,7 +265,7 @@ def train(self):
else:
self.dino.decoder.train()
# Finetune learnable pos_embedding
self.dino.encoder.pos_embed.requires_grad = self.kwargs["tune_pos_embed"]

else:
self.classifier.train()

Expand Down Expand Up @@ -389,7 +395,7 @@ def main(
log_path: Path = "runs",
epochs: int = 120,
flip_augmentations: bool = False,
tune_pos_embed: bool = False,
pos_embed: str = "pre-trained",
):
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -424,7 +430,7 @@ def main(
alpha=alpha,
lift=lift,
flip_augmentations=flip_augmentations,
tune_pos_embed=tune_pos_embed,
pos_embed=pos_embed,
)

writer = get_tensorboard_logger(logging_dir)
Expand All @@ -439,7 +445,7 @@ def main(
use_lora=use_lora,
alpha=alpha,
rank=rank,
tune_pos_embed=tune_pos_embed,
pos_embed=pos_embed,
)

# noinspection PyArgumentList
Expand Down Expand Up @@ -518,7 +524,7 @@ def main(
"rank": rank,
"alpha": alpha,
"dropout_p": dropout_p,
"tune_pos_embed": tune_pos_embed,
"pos_embed": pos_embed,
},
)

Expand Down Expand Up @@ -636,7 +642,6 @@ def train_step(
enumerate(train_dataloader), desc="Training", total=len(train_dataloader)
):
global_step += 1

data, labels = prepare_data_f(data, labels)
smoothed_label = smoothing_fn(labels)
data = augmentation_fn(data)
Expand Down Expand Up @@ -895,9 +900,11 @@ def get_argparser():
)

parser.add_argument(
"--tune_pos_embed",
action="store_true",
help="Whether to fine-tune the positional embedding if applicable",
"--pos_embed",
type=str,
default="pre-trained",
choices=["pre-trained", "fine-tune", "zeros", "zeros-fine-tune"],
help="How to handle positional embeddings",
)

parser.add_argument("--rank", type=int, default=16, help="rank of LoRA")
Expand Down

0 comments on commit 33ee65f

Please sign in to comment.