Skip to content

Commit

Permalink
Add SSF method (#727)
Browse files Browse the repository at this point in the history
Add SSF method for PEFT
---------

Co-authored-by: Anwai Archit <[email protected]>

---------

Co-authored-by: Carolin <[email protected]>
Co-authored-by: Carolin Teuber <[email protected]>
Co-authored-by: Constantin Pape <[email protected]>
  • Loading branch information
4 people authored Dec 14, 2024
1 parent 0a62171 commit 98c4b07
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 12 deletions.
64 changes: 58 additions & 6 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,54 @@ def forward(self, x):
return qkv


class ScaleShiftLayer(nn.Module):
def __init__(self, layer, dim):
super().__init__()
self.layer = layer
self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
layer = self

def forward(self, x):
x = self.layer(x)
assert self.scale.shape == self.shift.shape
if x.shape[-1] == self.scale.shape[0]:
return x * self.scale + self.shift
elif x.shape[1] == self.scale.shape[0]:
return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
else:
raise ValueError('Input tensors do not match the shape of the scale factors.')


class SSFSurgery(nn.Module):
"""Operates on all layers in the transformer block for adding learnable scale and shift parameters.
Args:
rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
block: The chosen attention blocks for implementing ssf.
dim: The input dimensions determining the shape of scale and shift parameters.
"""
def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.block = block

# If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers.
block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])

# If we get the embedding block, add one ScaleShiftLayer
elif hasattr(block, "patch_embed"):
block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

def forward(self, x):
return x


class SelectiveSurgery(nn.Module):
"""Base class for selectively allowing gradient updates for certain parameters.
"""
Expand Down Expand Up @@ -254,8 +302,10 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, AdaptFormer]), (
"Invalid PEFT module")

assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
"Invalid PEFT module"
)

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand All @@ -269,17 +319,19 @@ def __init__(
for param in model.image_encoder.parameters():
param.requires_grad = False

# Add scale and shift parameters to the patch embedding layers.
if issubclass(self.peft_module, SSFSurgery):
self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

if issubclass(self.peft_module, SelectiveSurgery):
peft_block = self.peft_module(block=blk)
self.peft_blocks.append(self.peft_module(block=blk))
else:
peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)

self.peft_blocks.append(peft_block)
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))

self.peft_blocks = nn.ModuleList(self.peft_blocks)

Expand Down
13 changes: 7 additions & 6 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,6 @@ def set_description(self, desc, **kwargs):
self._signals.pbar_description.emit(desc)


def _count_parameters(model_parameters):
params = sum(p.numel() for p in model_parameters if p.requires_grad)
params = params / 1e6
print(f"The number of trainable parameters for the provided model is {round(params, 2)}M")


@contextmanager
def _filter_warnings(ignore_warnings):
if ignore_warnings:
Expand All @@ -163,6 +157,12 @@ def _filter_warnings(ignore_warnings):
yield


def _count_parameters(model_parameters):
params = sum(p.numel() for p in model_parameters if p.requires_grad)
params = params / 1e6
print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")


def train_sam(
name: str,
model_type: str,
Expand Down Expand Up @@ -249,6 +249,7 @@ def train_sam(
peft_kwargs=peft_kwargs,
**model_kwargs
)

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

Expand Down
15 changes: 15 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,27 @@ def test_bias_layer_peft_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_ssf_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_adaptformer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5)


shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
Expand Down

0 comments on commit 98c4b07

Please sign in to comment.