diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 7e6b62a83..0f745cee4 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -36,7 +36,7 @@ def __init__( ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) - self.gamma = ( + self.weight = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None @@ -55,8 +55,8 @@ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: del y y = self.pwconv2(x) del x - if self.gamma is not None: - y *= self.gamma + if self.weight is not None: + y *= self.weight y.transpose_(1, 2) # (B, T, C) -> (B, C, T) x = y + residual