Skip to content

Commit

Permalink
Fix SSF implementation (#737)
Browse files Browse the repository at this point in the history
Fixes SSF implementation - scale and shift parameters now seem to be working as expected
---------

Co-authored-by: Anwai Archit <[email protected]>
  • Loading branch information
caroteu and anwai98 authored Oct 15, 2024
1 parent e9176ce commit 26774e1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 49 deletions.
47 changes: 15 additions & 32 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def __init__(self, layer, dim):
self.layer = layer
self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
layer = self

def forward(self, x):
x = self.layer(x)

assert self.scale.shape == self.shift.shape
if x.shape[-1] == self.scale.shape[0]:
return x * self.scale + self.shift
Expand All @@ -133,37 +133,25 @@ class SSFSurgery(nn.Module):
block: The chosen attention blocks for implementing ssf.
dim: The input dimensions determining the shape of scale and shift parameters.
"""
def __init__(self, rank: int, block: nn.Module, dim: Optional[int] = None):
def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.block = block

# If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers.
self.scale_shift_layers = nn.ModuleList(self.add_scale_shift_layers_to_block(block))
else: # This is an individual layer after which we apply scale and shift.
if dim is None:
raise ValueError("'dim' must be provided for the scale and shift parameters.")
self.scale_shift_layers = nn.ModuleList([self.create_scale_shift_layer(layer=block, dim=dim)])

def add_scale_shift_layers_to_block(self, block):
peft_blocks = [
ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features),
ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features),
ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.in_features),
ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.in_features),
ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]),
ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]),
]
return nn.ModuleList(peft_blocks)

def create_scale_shift_layer(self, layer, dim):
return ScaleShiftLayer(layer=layer, dim=dim)
block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])

def forward(self, x):
for layer in self.scale_shift_layers:
x = layer(x)
# If we get the embedding block, add one ScaleShiftLayer
elif hasattr(block, "patch_embed"):
block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

return self.block(x)
def forward(self, x):
return x


class SelectiveSurgery(nn.Module):
Expand Down Expand Up @@ -243,7 +231,7 @@ def __init__(
self,
model: Sam,
rank: int,
peft_module: nn.Module = SSFSurgery,
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None,
**module_kwargs
):
Expand All @@ -267,12 +255,7 @@ def __init__(

# Add scale and shift parameters to the patch embedding layers.
if issubclass(self.peft_module, SSFSurgery):
self.peft_blocks.append(
self.peft_module(
rank=rank, block=model.image_encoder.patch_embed.proj,
dim=model.image_encoder.patch_embed.proj.out_channels
)
)
self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
Expand Down
13 changes: 0 additions & 13 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,6 @@ def train_sam(
**model_kwargs
)

_count_parameters(model.parameters())

# full
# The number of trainable parameters for the provided model is 93.735472 (~93.74M)

# freeze image encoder
# The number of trainable parameters for the provided model is 4.06456 (~4.06M)

# peft: ssf
#

breakpoint()

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

Expand Down
4 changes: 0 additions & 4 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ def get_trainable_sam_model(

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

for k, v in sam.named_parameters():
if k.startswith("image_encoder") and v.requires_grad:
print(k)

# freeze components of the model if freeze was passed
# ideally we would want to add components in such a way that:
# - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network
Expand Down

0 comments on commit 26774e1

Please sign in to comment.