diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index fdfca7f6c7..8971c05989 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1755,8 +1755,6 @@ def __init__( initial_scale: float = 1.0, ): super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from self.l = ScaledLinear_lora( in_features=in_channels, out_features=out_channels, @@ -1767,17 +1765,16 @@ def __init__( bias=bias, ) - self.activation = activation + if activation == "SwooshL": + self.activation = SwooshL() + elif activation == "SwooshR": + self.activation = SwooshR() + else: + assert False, activation self.dropout = Dropout3(dropout_p, dropout_shared_dim) def forward(self, x: Tensor): - if self.activation == "SwooshL": - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return self.dropout(self.l(x)) + return self.l(self.dropout(self.activation(x))) def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: