Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 11, 2024
1 parent 3492d94 commit 6806810
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions egs/librispeech/ASR/zipformer_lora/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,8 +1755,6 @@ def __init__(
initial_scale: float = 1.0,
):
super().__init__()
# create a temporary module of nn.Linear that we'll steal the
# weights and bias from
self.l = ScaledLinear_lora(
in_features=in_channels,
out_features=out_channels,
Expand All @@ -1767,17 +1765,16 @@ def __init__(
bias=bias,
)

self.activation = activation
if activation == "SwooshL":
self.activation = SwooshL()
elif activation == "SwooshR":
self.activation = SwooshR()
else:
assert False, activation
self.dropout = Dropout3(dropout_p, dropout_shared_dim)

def forward(self, x: Tensor):
if self.activation == "SwooshL":
x = SwooshLForward(x)
elif self.activation == "SwooshR":
x = SwooshRForward(x)
else:
assert False, self.activation
return self.dropout(self.l(x))
return self.l(self.dropout(self.activation(x)))


def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
Expand Down

0 comments on commit 6806810

Please sign in to comment.