From 925284352b59a27641be1b6099db5602efb90ac8 Mon Sep 17 00:00:00 2001 From: Carolin Date: Tue, 15 Oct 2024 11:09:32 +0200 Subject: [PATCH 1/3] fixes in the ssf implementation --- micro_sam/models/peft_sam.py | 43 ++++++++++++---------------------- micro_sam/training/training.py | 2 +- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 2881b171..e9cd774d 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -112,10 +112,10 @@ def __init__(self, layer, dim): self.layer = layer self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) + layer = self def forward(self, x): x = self.layer(x) - assert self.scale.shape == self.shift.shape if x.shape[-1] == self.scale.shape[0]: return x * self.scale + self.shift @@ -133,37 +133,25 @@ class SSFSurgery(nn.Module): block: The chosen attention blocks for implementing ssf. dim: The input dimensions determining the shape of scale and shift parameters. """ - def __init__(self, rank: int, block: nn.Module, dim: Optional[int] = None): + def __init__(self, rank: int, block: nn.Module): super().__init__() self.block = block # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. - self.scale_shift_layers = nn.ModuleList(self.add_scale_shift_layers_to_block(block)) - else: # This is an individual layer after which we apply scale and shift. - if dim is None: - raise ValueError("'dim' must be provided for the scale and shift parameters.") - self.scale_shift_layers = nn.ModuleList([self.create_scale_shift_layer(layer=block, dim=dim)]) - - def add_scale_shift_layers_to_block(self, block): - peft_blocks = [ - ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features), - ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features), - ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.in_features), - ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.in_features), - ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]), - ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]), - ] - return nn.ModuleList(peft_blocks) - - def create_scale_shift_layer(self, layer, dim): - return ScaleShiftLayer(layer=layer, dim=dim) + block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) + block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) + block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) + block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) + block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) + block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) - def forward(self, x): - for layer in self.scale_shift_layers: - x = layer(x) + # If we get the embedding block, add one ScaleShiftLayer + elif hasattr(block, "patch_embed"): + block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) - return self.block(x) + def forward(self, x): + return x class SelectiveSurgery(nn.Module): @@ -269,8 +257,7 @@ def __init__( if issubclass(self.peft_module, SSFSurgery): self.peft_blocks.append( self.peft_module( - rank=rank, block=model.image_encoder.patch_embed.proj, - dim=model.image_encoder.patch_embed.proj.out_channels + rank=rank, block=model.image_encoder.patch_embed ) ) @@ -289,4 +276,4 @@ def __init__( self.sam = model def forward(self, batched_input, multimask_output): - return self.sam(batched_input, multimask_output) + return self.sam(batched_input, multimask_output) \ No newline at end of file diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 8b911a26..ee68a6d8 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -238,7 +238,7 @@ def train_sam( # The number of trainable parameters for the provided model is 4.06456 (~4.06M) # peft: ssf - # + # The number of trainable parameters for the provided model is 4.267312 (~4.27M) breakpoint() From 42bc6ae489b8143d598808c2c2408b0b9eedfcd1 Mon Sep 17 00:00:00 2001 From: Carolin Date: Tue, 15 Oct 2024 13:17:10 +0200 Subject: [PATCH 2/3] deactivate breakpoint --- micro_sam/training/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index ee68a6d8..c9e88d68 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -240,7 +240,7 @@ def train_sam( # peft: ssf # The number of trainable parameters for the provided model is 4.267312 (~4.27M) - breakpoint() + # breakpoint() # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) From 36ffc05e21828b96948055dd866e8a69a19ffc8f Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 15 Oct 2024 22:49:43 +0200 Subject: [PATCH 3/3] Remove debug scripts --- micro_sam/models/peft_sam.py | 10 +++------- micro_sam/training/training.py | 13 ------------- micro_sam/training/util.py | 4 ---- 3 files changed, 3 insertions(+), 24 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index e9cd774d..19ac3a0b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -231,7 +231,7 @@ def __init__( self, model: Sam, rank: int, - peft_module: nn.Module = SSFSurgery, + peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Union[List[int]] = None, **module_kwargs ): @@ -255,11 +255,7 @@ def __init__( # Add scale and shift parameters to the patch embedding layers. if issubclass(self.peft_module, SSFSurgery): - self.peft_blocks.append( - self.peft_module( - rank=rank, block=model.image_encoder.patch_embed - ) - ) + self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) for t_layer_i, blk in enumerate(model.image_encoder.blocks): # If we only want specific layers with PEFT instead of all @@ -276,4 +272,4 @@ def __init__( self.sam = model def forward(self, batched_input, multimask_output): - return self.sam(batched_input, multimask_output) \ No newline at end of file + return self.sam(batched_input, multimask_output) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index c9e88d68..841cc4b7 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -229,19 +229,6 @@ def train_sam( **model_kwargs ) - _count_parameters(model.parameters()) - - # full - # The number of trainable parameters for the provided model is 93.735472 (~93.74M) - - # freeze image encoder - # The number of trainable parameters for the provided model is 4.06456 (~4.06M) - - # peft: ssf - # The number of trainable parameters for the provided model is 4.267312 (~4.27M) - - # breakpoint() - # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index eefd7b4a..7ecf41cd 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -87,10 +87,6 @@ def get_trainable_sam_model( sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam - for k, v in sam.named_parameters(): - if k.startswith("image_encoder") and v.requires_grad: - print(k) - # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network