From 98c4b07ef9da4c84af886dfa227019897ba7f1e7 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sat, 14 Dec 2024 17:49:12 +0100 Subject: [PATCH] Add SSF method (#727) Add SSF method for PEFT --------- Co-authored-by: Anwai Archit --------- Co-authored-by: Carolin Co-authored-by: Carolin Teuber <115626873+caroteu@users.noreply.github.com> Co-authored-by: Constantin Pape --- micro_sam/models/peft_sam.py | 64 ++++++++++++++++++++++++++++--- micro_sam/training/training.py | 13 ++++--- test/test_models/test_peft_sam.py | 15 ++++++++ 3 files changed, 80 insertions(+), 12 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index dfe13fee..65d95bf3 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -108,6 +108,54 @@ def forward(self, x): return qkv +class ScaleShiftLayer(nn.Module): + def __init__(self, layer, dim): + super().__init__() + self.layer = layer + self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) + self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) + layer = self + + def forward(self, x): + x = self.layer(x) + assert self.scale.shape == self.shift.shape + if x.shape[-1] == self.scale.shape[0]: + return x * self.scale + self.shift + elif x.shape[1] == self.scale.shape[0]: + return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) + else: + raise ValueError('Input tensors do not match the shape of the scale factors.') + + +class SSFSurgery(nn.Module): + """Operates on all layers in the transformer block for adding learnable scale and shift parameters. + + Args: + rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. + block: The chosen attention blocks for implementing ssf. + dim: The input dimensions determining the shape of scale and shift parameters. + """ + def __init__(self, rank: int, block: nn.Module): + super().__init__() + self.block = block + + # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. + if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. + block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) + block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) + block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) + block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) + block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) + block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) + + # If we get the embedding block, add one ScaleShiftLayer + elif hasattr(block, "patch_embed"): + block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) + + def forward(self, x): + return x + + class SelectiveSurgery(nn.Module): """Base class for selectively allowing gradient updates for certain parameters. """ @@ -254,8 +302,10 @@ def __init__( super().__init__() assert rank > 0 - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, AdaptFormer]), ( - "Invalid PEFT module") + + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( + "Invalid PEFT module" + ) if attention_layers_to_update: self.peft_layers = attention_layers_to_update @@ -269,17 +319,19 @@ def __init__( for param in model.image_encoder.parameters(): param.requires_grad = False + # Add scale and shift parameters to the patch embedding layers. + if issubclass(self.peft_module, SSFSurgery): + self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) + for t_layer_i, blk in enumerate(model.image_encoder.blocks): # If we only want specific layers with PEFT instead of all if t_layer_i not in self.peft_layers: continue if issubclass(self.peft_module, SelectiveSurgery): - peft_block = self.peft_module(block=blk) + self.peft_blocks.append(self.peft_module(block=blk)) else: - peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) - - self.peft_blocks.append(peft_block) + self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) self.peft_blocks = nn.ModuleList(self.peft_blocks) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index e2037aab..79485b8d 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -146,12 +146,6 @@ def set_description(self, desc, **kwargs): self._signals.pbar_description.emit(desc) -def _count_parameters(model_parameters): - params = sum(p.numel() for p in model_parameters if p.requires_grad) - params = params / 1e6 - print(f"The number of trainable parameters for the provided model is {round(params, 2)}M") - - @contextmanager def _filter_warnings(ignore_warnings): if ignore_warnings: @@ -163,6 +157,12 @@ def _filter_warnings(ignore_warnings): yield +def _count_parameters(model_parameters): + params = sum(p.numel() for p in model_parameters if p.requires_grad) + params = params / 1e6 + print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") + + def train_sam( name: str, model_type: str, @@ -249,6 +249,7 @@ def train_sam( peft_kwargs=peft_kwargs, **model_kwargs ) + # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 28d2b950..f480f314 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -78,12 +78,27 @@ def test_bias_layer_peft_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_ssf_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + def test_adaptformer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5) + shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024) with torch.no_grad():