Skip to content

Commit

Permalink
added scaling factor to lorasurgery
Browse files Browse the repository at this point in the history
  • Loading branch information
caroteu committed Oct 17, 2024
1 parent 896ea00 commit b273ace
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
"""
def __init__(self, rank: int, block: nn.Module):
def __init__(self, rank: int, block: nn.Module, alpha: float):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
self.alpha = alpha

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
Expand All @@ -45,8 +46,8 @@ def reset_parameters(self):

def forward(self, x):
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
new_q = self.alpha / self.rank * self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.alpha / self.rank * self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv
Expand Down

0 comments on commit b273ace

Please sign in to comment.