From d015421ca2ac7768a3ba4347e71a083b3a8e3e5e Mon Sep 17 00:00:00 2001 From: LVeefkind Date: Thu, 26 Sep 2024 11:29:09 +0200 Subject: [PATCH] Add Dino with LoRA --- neural_networks/dino_model.py | 140 +++++++++++++++++++++++ neural_networks/pre_processing_for_ml.py | 7 +- neural_networks/train_nn.py | 102 +++++++++++++---- 3 files changed, 220 insertions(+), 29 deletions(-) create mode 100644 neural_networks/dino_model.py diff --git a/neural_networks/dino_model.py b/neural_networks/dino_model.py new file mode 100644 index 00000000..bd10b05a --- /dev/null +++ b/neural_networks/dino_model.py @@ -0,0 +1,140 @@ +""" +From https://github.com/RobvanGastel/dinov2-finetune +""" + +from dino_finetune import LoRA, LinearClassifier, FPNDecoder +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DINOV2FeatureExtractor(nn.Module): + def __init__( + self, + encoder, + decoder, + r: int = 3, + use_lora: bool = False, + ): + """The DINOv2 encoder-decoder model for finetuning to downstream tasks. + + Args: + encoder (nn.Module): The ViT encoder model loaded with the DINOv2 model weights. + r (int, optional): The rank parameter of the LoRA weights. Defaults to 3. + emb_dim (int, optional): The embedding dimension of the encoder. Defaults to 1024. + n_classes (int, optional): The number of classes to output. Defaults to 1000. + use_lora (bool, optional): Determines whether to use LoRA. Defaults to False. + """ + super().__init__() + assert r > 0 + + self.use_lora = use_lora + + # Number of previous layers to use as input + self.inter_layers = 4 + + self.encoder = encoder + for param in self.encoder.parameters(): + param.requires_grad = False + + # Decoder + # Patch size is given by (490/14)**2 = 35 * 35 + + self.decoder = decoder + # Add LoRA layers to the encoder + if self.use_lora: + self.lora_layers = list(range(len(self.encoder.blocks))) + self.w_a = [] + self.w_b = [] + + for i, block in enumerate(self.encoder.blocks): + if i not in self.lora_layers: + continue + w_qkv_linear = block.attn.qkv + dim = w_qkv_linear.in_features + + w_a_linear_q, w_b_linear_q = self._create_lora_layer(dim, r) + w_a_linear_v, w_b_linear_v = self._create_lora_layer(dim, r) + + self.w_a.extend([w_a_linear_q, w_a_linear_v]) + self.w_b.extend([w_b_linear_q, w_b_linear_v]) + + block.attn.qkv = LoRA( + w_qkv_linear, + w_a_linear_q, + w_b_linear_q, + w_a_linear_v, + w_b_linear_v, + ) + self._reset_lora_parameters() + + def _create_lora_layer(self, dim: int, r: int): + w_a = nn.Linear(dim, r, bias=False) + w_b = nn.Linear(r, dim, bias=False) + return w_a, w_b + + def _reset_lora_parameters(self) -> None: + for w_a in self.w_a: + nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5)) + for w_b in self.w_b: + nn.init.zeros_(w_b.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + feature = self.encoder.forward(x) + + # get the patch embeddings - so we exclude the CLS token + # patch_embeddings = feature["x_norm_patchtokens"] + logits = self.decoder(feature) + + # logits = F.interpolate( + # logits, + # size=x.shape[2:], + # mode="bilinear", + # align_corners=False, + # ) + return logits + + def save_parameters(self, filename: str) -> None: + """Save the LoRA weights and decoder weights to a .pt file + + Args: + filename (str): Filename of the weights + """ + w_a, w_b = {}, {} + if self.use_lora: + w_a = {f"w_a_{i:03d}": self.w_a[i].weight for i in range(len(self.w_a))} + w_b = {f"w_b_{i:03d}": self.w_b[i].weight for i in range(len(self.w_a))} + + decoder_weights = self.decoder.state_dict() + torch.save({**w_a, **w_b, **decoder_weights}, filename) + + def load_parameters(self, filename: str) -> None: + """Load the LoRA and decoder weights from a file + + Args: + filename (str): File name of the weights + """ + state_dict = torch.load(filename) + + # Load the LoRA parameters + if self.use_lora: + for i, w_A_linear in enumerate(self.w_a): + saved_key = f"w_a_{i:03d}" + saved_tensor = state_dict[saved_key] + w_A_linear.weight = nn.Parameter(saved_tensor) + + for i, w_B_linear in enumerate(self.w_b): + saved_key = f"w_b_{i:03d}" + saved_tensor = state_dict[saved_key] + w_B_linear.weight = nn.Parameter(saved_tensor) + + # Load decoder parameters + decoder_head_dict = self.decoder.state_dict() + decoder_head_keys = [k for k in decoder_head_dict.keys()] + decoder_state_dict = {k: state_dict[k] for k in decoder_head_keys} + + self.decoder.load_state_dict(decoder_state_dict) diff --git a/neural_networks/pre_processing_for_ml.py b/neural_networks/pre_processing_for_ml.py index 67b5f6d6..451a2e25 100644 --- a/neural_networks/pre_processing_for_ml.py +++ b/neural_networks/pre_processing_for_ml.py @@ -281,12 +281,11 @@ def make_histogram(root_dir): if __name__ == "__main__": - root = f"/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data" - # transform_data(root) + root = f"public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data" + transform_data(root) # make_histogram(root) - dataset = FitsDataset(root, mode="val") - dataset.compute_statistics(normalize=1) + # dataset.compute_statistics(normalize=1) # dataset = FitsDataset(root, mode='train', normalize=1) # sources = dataset.sources diff --git a/neural_networks/train_nn.py b/neural_networks/train_nn.py index 3c9a8e73..14627fea 100644 --- a/neural_networks/train_nn.py +++ b/neural_networks/train_nn.py @@ -2,6 +2,7 @@ import os from functools import partial, lru_cache from pathlib import Path +import warnings import torch import torcheval.metrics.functional as tef @@ -18,6 +19,7 @@ import random from pre_processing_for_ml import FitsDataset +from dino_model import DINOV2FeatureExtractor PROFILE = False SEED = None @@ -38,36 +40,34 @@ def init_vit(model_name): hidden_dim = backbone.heads[0].in_features del backbone.heads - backbone.eval() return backbone, hidden_dim -def init_dino(model_name): +def init_dino_old(model_name): backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") for param in backbone.parameters(): param.requires_grad_(False) backbone.cls_token.requires_grad_(True) hidden_dim = backbone.cls_token.shape[-1] - backbone.eval() return backbone, hidden_dim -def init_first_conv(conv): - kernel_size = conv.kernel_size - stride = conv.stride - padding = conv.padding - bias = conv.bias - out_channels = conv.out_channels - return nn.Conv2d( - in_channels=1, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=bias, +def init_dino(model_name, get_classifier_f, use_lora): + + backbone = torch.hub.load("facebookresearch/dinov2", model_name) + hidden_dim = backbone.cls_token.shape[-1] + classifier = get_classifier_f(n_features=hidden_dim) + + dino_lora = DINOV2FeatureExtractor( + encoder=backbone, + decoder=classifier, + r=3, + use_lora=use_lora, ) + return dino_lora, hidden_dim + def init_cnn(name: str, lift="stack"): # use partial to prevent loading all models at once @@ -84,7 +84,6 @@ def init_cnn(name: str, lift="stack"): feature_extractor = nn.Sequential(*list(backbone.children())[:-1]) for param in feature_extractor.parameters(): param.requires_grad_(False) - feature_extractor.eval() if lift == "reinit_first": if name in ("resnet50", "resnet152", "resnext50_32x4d", "resnext101_64x4d"): @@ -106,6 +105,22 @@ def init_cnn(name: str, lift="stack"): return feature_extractor, num_out_features +def init_first_conv(conv): + kernel_size = conv.kernel_size + stride = conv.stride + padding = conv.padding + bias = conv.bias + out_channels = conv.out_channels + return nn.Conv2d( + in_channels=1, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def get_classifier(dropout_p: float, n_features: int, num_target_classes: int): assert 0 <= dropout_p <= 1 @@ -143,7 +158,8 @@ def __init__( model_name: str = "resnet50", dropout_p: float = 0.25, use_compile: bool = True, - lift="stack", + lift: str = "stack", + use_lora: bool = False, ): super().__init__() @@ -175,13 +191,16 @@ def __init__( self.forward = self.vit_forward - elif model_name == "dino_v2": - self.feature_extractor, num_features = init_dino(model_name) - self.classifier = get_classifier_f(n_features=num_features) + elif "dinov2" in model_name: + self.dino, num_features = init_dino( + model_name, get_classifier_f, use_lora=use_lora + ) + # self.classifier = get_classifier_f(n_features=num_features) self.forward = self.dino_forward else: self.feature_extractor, num_features = init_cnn(name=model_name, lift=lift) + self.feature_extractor.eval() self.classifier = get_classifier_f(n_features=num_features) @@ -205,7 +224,8 @@ def vit_forward(self, x): return x def dino_forward(self, x): - return self.cnn_forward(x) + x = self.lift(x) + return self.dino(x) def step(self, inputs, targets, ratio=1): logits = self(inputs).flatten() @@ -223,12 +243,16 @@ def step(self, inputs, targets, ratio=1): def eval(self): if self.kwargs["model_name"] == "vit_l_16": self.vit.heads.eval() + elif "dinov2" in self.kwargs["model_name"]: + self.dino.eval() else: self.classifier.eval() def train(self): if self.kwargs["model_name"] == "vit_l_16": self.vit.heads.train() + elif "dinov2" in self.kwargs["model_name"]: + self.dino.train() else: self.classifier.train() @@ -352,6 +376,7 @@ def main( label_smoothing: float, stochastic_smoothing: bool, lift: str, + use_lora: bool, ): torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True @@ -378,6 +403,9 @@ def main( normalize=normalize, dropout_p=dropout_p, use_compile=use_compile, + label_smoothing=label_smoothing, + stochastic_smoothing=stochastic_smoothing, + use_lora=use_lora, ) writer = get_tensorboard_logger(logging_dir) @@ -385,7 +413,11 @@ def main( device = torch.device("cuda") model: nn.Module = ImagenetTransferLearning( - model_name=model_name, dropout_p=dropout_p, use_compile=use_compile, lift=lift + model_name=model_name, + dropout_p=dropout_p, + use_compile=use_compile, + lift=lift, + use_lora=use_lora, ) # noinspection PyArgumentList @@ -695,7 +727,14 @@ def get_argparser(): "resnext101_64x4d", "efficientnet_v2_l", "vit_l_16", - "dino_v2", + "dinov2_vits14", + "dinov2_vitb14", + "dinov2_vitl14", + "dinov2_vitg14", + "dinov2_vits14_reg", + "dinov2_vitb14_reg", + "dinov2_vitl14_reg", + "dinov2_vitg14_reg", ], ) parser.add_argument( @@ -749,6 +788,12 @@ def get_argparser(): help="How to lift single channel to 3 channels. Stacking stacks the single channel thrice. Conv adds a 1x1 convolution layer before the model. reinit_first re-initialises the first layer if applicable.", ) + parser.add_argument( + "--use_lora", + action="store_true", + help="Whether to use LoRA if applicable.", + ) + return parser.parse_args() @@ -766,10 +811,17 @@ def sanity_check_args(parsed_args): if parsed_args.model_name == "vit_l_16" and parsed_args.resize != 512: print("Setting resize to 512 since vit_16_l is being used") parsed_args.resize = 512 - if parsed_args.model_name == "dino_v2" and parsed_args.resize == 0: + if "dinov2" in parsed_args.model_name and parsed_args.resize == 0: resize = 504 print(f"\n#######\nSetting resize to {resize} \n######\n") parsed_args.resize = resize + + if parsed_args.use_lora and not "dino_v2" in parsed_args.model_name: + warnings.warn( + "Warning: LoRA is only supported for Dino V2 models. Ignoring setting....\n", + UserWarning, + ) + assert parsed_args.resize % 14 == 0 or parsed_args.model_name != "dino_v2" return parsed_args