Skip to content

Commit

Permalink
Add Dino with LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
LVeefkind committed Sep 26, 2024
1 parent b9371bf commit d015421
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 29 deletions.
140 changes: 140 additions & 0 deletions neural_networks/dino_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
From https://github.com/RobvanGastel/dinov2-finetune
"""

from dino_finetune import LoRA, LinearClassifier, FPNDecoder
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class DINOV2FeatureExtractor(nn.Module):
def __init__(
self,
encoder,
decoder,
r: int = 3,
use_lora: bool = False,
):
"""The DINOv2 encoder-decoder model for finetuning to downstream tasks.
Args:
encoder (nn.Module): The ViT encoder model loaded with the DINOv2 model weights.
r (int, optional): The rank parameter of the LoRA weights. Defaults to 3.
emb_dim (int, optional): The embedding dimension of the encoder. Defaults to 1024.
n_classes (int, optional): The number of classes to output. Defaults to 1000.
use_lora (bool, optional): Determines whether to use LoRA. Defaults to False.
"""
super().__init__()
assert r > 0

self.use_lora = use_lora

# Number of previous layers to use as input
self.inter_layers = 4

self.encoder = encoder
for param in self.encoder.parameters():
param.requires_grad = False

# Decoder
# Patch size is given by (490/14)**2 = 35 * 35

self.decoder = decoder
# Add LoRA layers to the encoder
if self.use_lora:
self.lora_layers = list(range(len(self.encoder.blocks)))
self.w_a = []
self.w_b = []

for i, block in enumerate(self.encoder.blocks):
if i not in self.lora_layers:
continue
w_qkv_linear = block.attn.qkv
dim = w_qkv_linear.in_features

w_a_linear_q, w_b_linear_q = self._create_lora_layer(dim, r)
w_a_linear_v, w_b_linear_v = self._create_lora_layer(dim, r)

self.w_a.extend([w_a_linear_q, w_a_linear_v])
self.w_b.extend([w_b_linear_q, w_b_linear_v])

block.attn.qkv = LoRA(
w_qkv_linear,
w_a_linear_q,
w_b_linear_q,
w_a_linear_v,
w_b_linear_v,
)
self._reset_lora_parameters()

def _create_lora_layer(self, dim: int, r: int):
w_a = nn.Linear(dim, r, bias=False)
w_b = nn.Linear(r, dim, bias=False)
return w_a, w_b

def _reset_lora_parameters(self) -> None:
for w_a in self.w_a:
nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5))
for w_b in self.w_b:
nn.init.zeros_(w_b.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:

feature = self.encoder.forward(x)

# get the patch embeddings - so we exclude the CLS token
# patch_embeddings = feature["x_norm_patchtokens"]
logits = self.decoder(feature)

# logits = F.interpolate(
# logits,
# size=x.shape[2:],
# mode="bilinear",
# align_corners=False,
# )
return logits

def save_parameters(self, filename: str) -> None:
"""Save the LoRA weights and decoder weights to a .pt file
Args:
filename (str): Filename of the weights
"""
w_a, w_b = {}, {}
if self.use_lora:
w_a = {f"w_a_{i:03d}": self.w_a[i].weight for i in range(len(self.w_a))}
w_b = {f"w_b_{i:03d}": self.w_b[i].weight for i in range(len(self.w_a))}

decoder_weights = self.decoder.state_dict()
torch.save({**w_a, **w_b, **decoder_weights}, filename)

def load_parameters(self, filename: str) -> None:
"""Load the LoRA and decoder weights from a file
Args:
filename (str): File name of the weights
"""
state_dict = torch.load(filename)

# Load the LoRA parameters
if self.use_lora:
for i, w_A_linear in enumerate(self.w_a):
saved_key = f"w_a_{i:03d}"
saved_tensor = state_dict[saved_key]
w_A_linear.weight = nn.Parameter(saved_tensor)

for i, w_B_linear in enumerate(self.w_b):
saved_key = f"w_b_{i:03d}"
saved_tensor = state_dict[saved_key]
w_B_linear.weight = nn.Parameter(saved_tensor)

# Load decoder parameters
decoder_head_dict = self.decoder.state_dict()
decoder_head_keys = [k for k in decoder_head_dict.keys()]
decoder_state_dict = {k: state_dict[k] for k in decoder_head_keys}

self.decoder.load_state_dict(decoder_state_dict)
7 changes: 3 additions & 4 deletions neural_networks/pre_processing_for_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,11 @@ def make_histogram(root_dir):


if __name__ == "__main__":
root = f"/dev/shm/scratch-shared/CORTEX/public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data"
# transform_data(root)
root = f"public.spider.surfsara.nl/project/lofarvwf/jdejong/CORTEX/calibrator_selection_robertjan/cnn_data"
transform_data(root)

# make_histogram(root)
dataset = FitsDataset(root, mode="val")
dataset.compute_statistics(normalize=1)
# dataset.compute_statistics(normalize=1)

# dataset = FitsDataset(root, mode='train', normalize=1)
# sources = dataset.sources
Expand Down
102 changes: 77 additions & 25 deletions neural_networks/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from functools import partial, lru_cache
from pathlib import Path
import warnings

import torch
import torcheval.metrics.functional as tef
Expand All @@ -18,6 +19,7 @@
import random

from pre_processing_for_ml import FitsDataset
from dino_model import DINOV2FeatureExtractor

PROFILE = False
SEED = None
Expand All @@ -38,36 +40,34 @@ def init_vit(model_name):
hidden_dim = backbone.heads[0].in_features

del backbone.heads
backbone.eval()
return backbone, hidden_dim


def init_dino(model_name):
def init_dino_old(model_name):
backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
for param in backbone.parameters():
param.requires_grad_(False)
backbone.cls_token.requires_grad_(True)

hidden_dim = backbone.cls_token.shape[-1]
backbone.eval()
return backbone, hidden_dim


def init_first_conv(conv):
kernel_size = conv.kernel_size
stride = conv.stride
padding = conv.padding
bias = conv.bias
out_channels = conv.out_channels
return nn.Conv2d(
in_channels=1,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
def init_dino(model_name, get_classifier_f, use_lora):

backbone = torch.hub.load("facebookresearch/dinov2", model_name)
hidden_dim = backbone.cls_token.shape[-1]
classifier = get_classifier_f(n_features=hidden_dim)

dino_lora = DINOV2FeatureExtractor(
encoder=backbone,
decoder=classifier,
r=3,
use_lora=use_lora,
)

return dino_lora, hidden_dim


def init_cnn(name: str, lift="stack"):
# use partial to prevent loading all models at once
Expand All @@ -84,7 +84,6 @@ def init_cnn(name: str, lift="stack"):
feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
for param in feature_extractor.parameters():
param.requires_grad_(False)
feature_extractor.eval()

if lift == "reinit_first":
if name in ("resnet50", "resnet152", "resnext50_32x4d", "resnext101_64x4d"):
Expand All @@ -106,6 +105,22 @@ def init_cnn(name: str, lift="stack"):
return feature_extractor, num_out_features


def init_first_conv(conv):
kernel_size = conv.kernel_size
stride = conv.stride
padding = conv.padding
bias = conv.bias
out_channels = conv.out_channels
return nn.Conv2d(
in_channels=1,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)


def get_classifier(dropout_p: float, n_features: int, num_target_classes: int):
assert 0 <= dropout_p <= 1

Expand Down Expand Up @@ -143,7 +158,8 @@ def __init__(
model_name: str = "resnet50",
dropout_p: float = 0.25,
use_compile: bool = True,
lift="stack",
lift: str = "stack",
use_lora: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -175,13 +191,16 @@ def __init__(

self.forward = self.vit_forward

elif model_name == "dino_v2":
self.feature_extractor, num_features = init_dino(model_name)
self.classifier = get_classifier_f(n_features=num_features)
elif "dinov2" in model_name:
self.dino, num_features = init_dino(
model_name, get_classifier_f, use_lora=use_lora
)
# self.classifier = get_classifier_f(n_features=num_features)
self.forward = self.dino_forward

else:
self.feature_extractor, num_features = init_cnn(name=model_name, lift=lift)
self.feature_extractor.eval()

self.classifier = get_classifier_f(n_features=num_features)

Expand All @@ -205,7 +224,8 @@ def vit_forward(self, x):
return x

def dino_forward(self, x):
return self.cnn_forward(x)
x = self.lift(x)
return self.dino(x)

def step(self, inputs, targets, ratio=1):
logits = self(inputs).flatten()
Expand All @@ -223,12 +243,16 @@ def step(self, inputs, targets, ratio=1):
def eval(self):
if self.kwargs["model_name"] == "vit_l_16":
self.vit.heads.eval()
elif "dinov2" in self.kwargs["model_name"]:
self.dino.eval()
else:
self.classifier.eval()

def train(self):
if self.kwargs["model_name"] == "vit_l_16":
self.vit.heads.train()
elif "dinov2" in self.kwargs["model_name"]:
self.dino.train()
else:
self.classifier.train()

Expand Down Expand Up @@ -352,6 +376,7 @@ def main(
label_smoothing: float,
stochastic_smoothing: bool,
lift: str,
use_lora: bool,
):
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
Expand All @@ -378,14 +403,21 @@ def main(
normalize=normalize,
dropout_p=dropout_p,
use_compile=use_compile,
label_smoothing=label_smoothing,
stochastic_smoothing=stochastic_smoothing,
use_lora=use_lora,
)

writer = get_tensorboard_logger(logging_dir)

device = torch.device("cuda")

model: nn.Module = ImagenetTransferLearning(
model_name=model_name, dropout_p=dropout_p, use_compile=use_compile, lift=lift
model_name=model_name,
dropout_p=dropout_p,
use_compile=use_compile,
lift=lift,
use_lora=use_lora,
)

# noinspection PyArgumentList
Expand Down Expand Up @@ -695,7 +727,14 @@ def get_argparser():
"resnext101_64x4d",
"efficientnet_v2_l",
"vit_l_16",
"dino_v2",
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
"dinov2_vits14_reg",
"dinov2_vitb14_reg",
"dinov2_vitl14_reg",
"dinov2_vitg14_reg",
],
)
parser.add_argument(
Expand Down Expand Up @@ -749,6 +788,12 @@ def get_argparser():
help="How to lift single channel to 3 channels. Stacking stacks the single channel thrice. Conv adds a 1x1 convolution layer before the model. reinit_first re-initialises the first layer if applicable.",
)

parser.add_argument(
"--use_lora",
action="store_true",
help="Whether to use LoRA if applicable.",
)

return parser.parse_args()


Expand All @@ -766,10 +811,17 @@ def sanity_check_args(parsed_args):
if parsed_args.model_name == "vit_l_16" and parsed_args.resize != 512:
print("Setting resize to 512 since vit_16_l is being used")
parsed_args.resize = 512
if parsed_args.model_name == "dino_v2" and parsed_args.resize == 0:
if "dinov2" in parsed_args.model_name and parsed_args.resize == 0:
resize = 504
print(f"\n#######\nSetting resize to {resize} \n######\n")
parsed_args.resize = resize

if parsed_args.use_lora and not "dino_v2" in parsed_args.model_name:
warnings.warn(
"Warning: LoRA is only supported for Dino V2 models. Ignoring setting....\n",
UserWarning,
)

assert parsed_args.resize % 14 == 0 or parsed_args.model_name != "dino_v2"

return parsed_args
Expand Down

0 comments on commit d015421

Please sign in to comment.