Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSF method #727

Merged
merged 6 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import List, Union, Optional

import torch
import torch.nn as nn

from segment_anything.modeling import Sam
Expand Down Expand Up @@ -107,6 +108,54 @@ def forward(self, x):
return qkv


class ScaleShiftLayer(nn.Module):
def __init__(self, layer, dim):
super().__init__()
self.layer = layer
self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
layer = self

def forward(self, x):
x = self.layer(x)
assert self.scale.shape == self.shift.shape
if x.shape[-1] == self.scale.shape[0]:
return x * self.scale + self.shift
elif x.shape[1] == self.scale.shape[0]:
return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
else:
raise ValueError('Input tensors do not match the shape of the scale factors.')


class SSFSurgery(nn.Module):
"""Operates on all layers in the transformer block for adding learnable scale and shift parameters.

Args:
rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
block: The chosen attention blocks for implementing ssf.
dim: The input dimensions determining the shape of scale and shift parameters.
"""
def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.block = block

# If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers.
block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])

# If we get the embedding block, add one ScaleShiftLayer
elif hasattr(block, "patch_embed"):
block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

def forward(self, x):
return x


class SelectiveSurgery(nn.Module):
"""Base class for selectively allowing gradient updates for certain parameters.
"""
Expand Down Expand Up @@ -191,7 +240,8 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery]), (
"Invalid PEFT module")

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand All @@ -205,17 +255,19 @@ def __init__(
for param in model.image_encoder.parameters():
param.requires_grad = False

# Add scale and shift parameters to the patch embedding layers.
if issubclass(self.peft_module, SSFSurgery):
self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

if issubclass(self.peft_module, SelectiveSurgery):
peft_block = self.peft_module(block=blk)
self.peft_blocks.append(self.peft_module(block=blk))
else:
peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)

self.peft_blocks.append(peft_block)
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))

self.peft_blocks = nn.ModuleList(self.peft_blocks)

Expand Down
13 changes: 7 additions & 6 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,6 @@ def set_description(self, desc, **kwargs):
self._signals.pbar_description.emit(desc)


def _count_parameters(model_parameters):
params = sum(p.numel() for p in model_parameters if p.requires_grad)
params = params / 1e6
print(f"The number of trainable parameters for the provided model is {round(params, 2)}M")


@contextmanager
def _filter_warnings(ignore_warnings):
if ignore_warnings:
Expand All @@ -163,6 +157,12 @@ def _filter_warnings(ignore_warnings):
yield


def _count_parameters(model_parameters):
params = sum(p.numel() for p in model_parameters if p.requires_grad)
params = params / 1e6
print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")


def train_sam(
name: str,
model_type: str,
Expand Down Expand Up @@ -249,6 +249,7 @@ def train_sam(
peft_kwargs=peft_kwargs,
**model_kwargs
)

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

Expand Down
14 changes: 14 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def test_bias_layer_peft_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_ssf_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
Loading