From cbbaa234ceb37ccfdb677dcaa385bc0d86aa41d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90?= <110042431+zly-idleness@users.noreply.github.com> Date: Thu, 29 Aug 2024 18:40:41 +0800 Subject: [PATCH] Update dvae.py change gamma to weight --- ChatTTS/model/dvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 5602071eb..1bf18e258 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -36,7 +36,7 @@ def __init__( ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) - self.gamma = ( + self.weight = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None @@ -55,8 +55,8 @@ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: del y y = self.pwconv2(x) del x - if self.gamma is not None: - y *= self.gamma + if self.weight is not None: + y *= self.weight y.transpose_(1, 2) # (B, T, C) -> (B, C, T) x = y + residual