diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 7677f2e5..3359de01 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -23,16 +23,17 @@ class LoRASurgery(nn.Module): rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing lora. """ - def __init__(self, rank: int, block: nn.Module, alpha: float): + def __init__(self, rank: int, block: nn.Module, alpha: float = 1): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features self.alpha = alpha + self.rank = rank - self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) - self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) self.reset_parameters() @@ -124,7 +125,7 @@ def allow_gradient_update_for_parameters( Args: prefix: Matches the part of parameter name in front. suffix: Matches the part of parameter name at the end. - infix: Matches parts of parameter name occuring in between. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)):