diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index febbccf6..7677f2e5 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -23,10 +23,11 @@ class LoRASurgery(nn.Module): rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing lora. """ - def __init__(self, rank: int, block: nn.Module): + def __init__(self, rank: int, block: nn.Module, alpha: float): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features + self.alpha = alpha self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) @@ -45,8 +46,8 @@ def reset_parameters(self): def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C - new_q = self.w_b_linear_q(self.w_a_linear_q(x)) - new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + new_q = self.alpha / self.rank * self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.alpha / self.rank * self.w_b_linear_v(self.w_a_linear_v(x)) qkv[:, :, :, :self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v return qkv