From 4e88e1cae9ea57eb18328fb514219288aa6de470 Mon Sep 17 00:00:00 2001 From: Carolin Teuber <115626873+caroteu@users.noreply.github.com> Date: Sat, 14 Dec 2024 17:19:21 +0100 Subject: [PATCH] Add AdaptFormer (#741) Add AdaptFormer for PEFT Finetuning --------- Co-authored-by: Anwai Archit Co-authored-by: Anwai Archit <52396323+anwai98@users.noreply.github.com> --- micro_sam/models/peft_sam.py | 66 ++++++++++++++++++++++++++++++- test/test_models/test_peft_sam.py | 14 +++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index c1b5dcc6..dfe13fee 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,6 +1,7 @@ import math from typing import List, Union, Optional +import torch import torch.nn as nn from segment_anything.modeling import Sam @@ -143,6 +144,68 @@ def forward(self, x): return x +class AdaptFormer(nn.Module): + """Adds AdaptFormer Module in place of the MLP Layers + + Args: + rank: The rank is not used in this class but kept here for consistency. + block: The chosen encoder block for implementing AdaptFormer. + alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. + dropout: The dropout rate for the dropout layer between down and up projection layer. + projection_size: The size of the projection layer. + """ + def __init__( + self, + rank: int, + block: nn.Module, + alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. + dropout: Optional[float] = None, # Does not have an obvious advantage. + projection_size: int = 64, # Stable choice from our preliminary exp. + ): + super().__init__() + + self.mlp_proj = block.mlp + self.n_embd = block.mlp.lin1.in_features + + if alpha == 'learnable_scalar': + self.alpha = nn.Parameter(torch.ones(1)) + else: + self.alpha = alpha + + self.projection_size = projection_size + self.dropout = dropout + + self.down_proj = nn.Linear(self.n_embd, self.projection_size) + self.non_linear_func = nn.ReLU() + self.up_proj = nn.Linear(self.projection_size, self.n_embd) + + block.mlp = self + + if self.dropout is not None: + self.dropout_layer = nn.Dropout(self.dropout) + + nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) + nn.init.zeros_(self.up_proj.weight) + nn.init.zeros_(self.down_proj.bias) + nn.init.zeros_(self.up_proj.bias) + + def forward(self, x): + residual = x + mlp_output = self.mlp_proj(x) + + down = self.down_proj(x) + down = self.non_linear_func(down) + + if self.dropout is not None: + down = self.dropout_layer(down) + + up = self.up_proj(down) + up = up * self.alpha + output = up + residual + mlp_output + + return output + + class AttentionSurgery(SelectiveSurgery): """Child class for allowing gradient updates for parameters in attention layers. """ @@ -191,7 +254,8 @@ def __init__( super().__init__() assert rank > 0 - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, AdaptFormer]), ( + "Invalid PEFT module") if attention_layers_to_update: self.peft_layers = attention_layers_to_update diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 4461aa9b..28d2b950 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -78,6 +78,20 @@ def test_bias_layer_peft_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_adaptformer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + if __name__ == "__main__": unittest.main()