Skip to content

Commit

Permalink
added scaling factor in lora
Browse files Browse the repository at this point in the history
  • Loading branch information
caroteu committed Oct 28, 2024
1 parent b273ace commit 7a27ffb
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
"""
def __init__(self, rank: int, block: nn.Module, alpha: float):
def __init__(self, rank: int, block: nn.Module, alpha: float = 1):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
self.alpha = alpha
self.rank = rank

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)

self.reset_parameters()

Expand Down Expand Up @@ -124,7 +125,7 @@ def allow_gradient_update_for_parameters(
Args:
prefix: Matches the part of parameter name in front.
suffix: Matches the part of parameter name at the end.
infix: Matches parts of parameter name occuring in between.
infix: Matches parts of parameter name occuring in between.
"""
for k, v in self.block.named_parameters():
if prefix is not None and k.startswith(tuple(prefix)):
Expand Down

0 comments on commit 7a27ffb

Please sign in to comment.