Skip to content

Commit

Permalink
Add AdaptFormer (#741)
Browse files Browse the repository at this point in the history
Add AdaptFormer for PEFT Finetuning

---------

Co-authored-by: Anwai Archit <[email protected]>
Co-authored-by: Anwai Archit <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2024
1 parent 001878b commit 4e88e1c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
66 changes: 65 additions & 1 deletion micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import List, Union, Optional

import torch
import torch.nn as nn

from segment_anything.modeling import Sam
Expand Down Expand Up @@ -143,6 +144,68 @@ def forward(self, x):
return x


class AdaptFormer(nn.Module):
"""Adds AdaptFormer Module in place of the MLP Layers
Args:
rank: The rank is not used in this class but kept here for consistency.
block: The chosen encoder block for implementing AdaptFormer.
alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
dropout: The dropout rate for the dropout layer between down and up projection layer.
projection_size: The size of the projection layer.
"""
def __init__(
self,
rank: int,
block: nn.Module,
alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp.
dropout: Optional[float] = None, # Does not have an obvious advantage.
projection_size: int = 64, # Stable choice from our preliminary exp.
):
super().__init__()

self.mlp_proj = block.mlp
self.n_embd = block.mlp.lin1.in_features

if alpha == 'learnable_scalar':
self.alpha = nn.Parameter(torch.ones(1))
else:
self.alpha = alpha

self.projection_size = projection_size
self.dropout = dropout

self.down_proj = nn.Linear(self.n_embd, self.projection_size)
self.non_linear_func = nn.ReLU()
self.up_proj = nn.Linear(self.projection_size, self.n_embd)

block.mlp = self

if self.dropout is not None:
self.dropout_layer = nn.Dropout(self.dropout)

nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
nn.init.zeros_(self.up_proj.weight)
nn.init.zeros_(self.down_proj.bias)
nn.init.zeros_(self.up_proj.bias)

def forward(self, x):
residual = x
mlp_output = self.mlp_proj(x)

down = self.down_proj(x)
down = self.non_linear_func(down)

if self.dropout is not None:
down = self.dropout_layer(down)

up = self.up_proj(down)
up = up * self.alpha
output = up + residual + mlp_output

return output


class AttentionSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates for parameters in attention layers.
"""
Expand Down Expand Up @@ -191,7 +254,8 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, AdaptFormer]), (
"Invalid PEFT module")

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand Down
14 changes: 14 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def test_bias_layer_peft_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_adaptformer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4e88e1c

Please sign in to comment.