From 9c917a0366607b5e80e5182935f6fd47d989f881 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 15 Nov 2024 15:01:57 +0800 Subject: [PATCH] Add vocos vocoder --- egs/ljspeech/TTS/vocos/backbone.py | 127 +++ egs/ljspeech/TTS/vocos/discriminators.py | 296 +++++++ egs/ljspeech/TTS/vocos/heads.py | 178 ++++ egs/ljspeech/TTS/vocos/loss.py | 133 +++ egs/ljspeech/TTS/vocos/models.py | 51 ++ egs/ljspeech/TTS/vocos/modules.py | 213 +++++ egs/ljspeech/TTS/vocos/spectral_ops.py | 230 +++++ egs/ljspeech/TTS/vocos/train.py | 1003 ++++++++++++++++++++++ egs/ljspeech/TTS/vocos/tts_datamodule.py | 372 ++++++++ egs/ljspeech/TTS/vocos/utils.py | 205 +++++ 10 files changed, 2808 insertions(+) create mode 100644 egs/ljspeech/TTS/vocos/backbone.py create mode 100644 egs/ljspeech/TTS/vocos/discriminators.py create mode 100644 egs/ljspeech/TTS/vocos/heads.py create mode 100644 egs/ljspeech/TTS/vocos/loss.py create mode 100644 egs/ljspeech/TTS/vocos/models.py create mode 100644 egs/ljspeech/TTS/vocos/modules.py create mode 100644 egs/ljspeech/TTS/vocos/spectral_ops.py create mode 100755 egs/ljspeech/TTS/vocos/train.py create mode 100644 egs/ljspeech/TTS/vocos/tts_datamodule.py create mode 100644 egs/ljspeech/TTS/vocos/utils.py diff --git a/egs/ljspeech/TTS/vocos/backbone.py b/egs/ljspeech/TTS/vocos/backbone.py new file mode 100644 index 0000000000..168c8847b1 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/backbone.py @@ -0,0 +1,127 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get("bandwidth_id", None) + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, + input_channels, + dim, + num_blocks, + layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm( + nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) + ) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ + ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) + for _ in range(num_blocks) + ] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/egs/ljspeech/TTS/vocos/discriminators.py b/egs/ljspeech/TTS/vocos/discriminators.py new file mode 100644 index 0000000000..6b013e392a --- /dev/null +++ b/egs/ljspeech/TTS/vocos/discriminators.py @@ -0,0 +1,296 @@ +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm +from torchaudio.transforms import Spectrogram + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__( + self, + periods: Tuple[int, ...] = (2, 3, 5, 7, 11), + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods] + ) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + bandwidth_id: Optional[torch.Tensor] = None, + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm( + Conv2d( + in_channels, + 32, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + (1, 1), + padding=(kernel_size // 2, 0), + ) + ), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=1024 + ) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorR(window_length=w, num_embeddings=num_embeddings) + for w in fft_sizes + ] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ( + (0.0, 0.1), + (0.1, 0.25), + (0.25, 0.5), + (0.5, 0.75), + (0.75, 1.0), + ), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, + hop_length=int(window_length * hop_factor), + win_length=window_length, + power=None, + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1)) + ), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=channels + ) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm( + nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)) + ) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = rearrange(x, "b f t c -> b c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/egs/ljspeech/TTS/vocos/heads.py b/egs/ljspeech/TTS/vocos/heads.py new file mode 100644 index 0000000000..ed4d623a8d --- /dev/null +++ b/egs/ljspeech/TTS/vocos/heads.py @@ -0,0 +1,178 @@ +from typing import Optional + +import torch +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from spectral_ops import IMDCT, ISTFT +from modules import symexp + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2 + ) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip( + x, min=-1e2, max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + clip_audio: bool = False, + ): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip( + max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/egs/ljspeech/TTS/vocos/loss.py b/egs/ljspeech/TTS/vocos/loss.py new file mode 100644 index 0000000000..c89d818349 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/loss.py @@ -0,0 +1,133 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from modules import safe_log + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, + sample_rate: int = 24000, + n_fft: int = 1024, + hop_length: int = 256, + n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward( + self, disc_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = torch.zeros( + 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype + ) + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, + disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = torch.zeros( + 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype + ) + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward( + self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]] + ) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss diff --git a/egs/ljspeech/TTS/vocos/models.py b/egs/ljspeech/TTS/vocos/models.py new file mode 100644 index 0000000000..5dbadbad85 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/models.py @@ -0,0 +1,51 @@ +import logging +import torch +from backbone import VocosBackbone +from heads import ISTFTHead +from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator +from loss import ( + DiscriminatorLoss, + GeneratorLoss, + FeatureMatchingLoss, + MelSpecReconstructionLoss, +) + + +class Vocos(torch.nn.Module): + def __init__( + self, + dim: int = 512, + n_fft: int = 1024, + hop_length: int = 256, + feature_dim: int = 80, + intermediate_dim: int = 1536, + num_layers: int = 8, + padding: str = "same", + sample_rate: int = 22050, + ): + super(Vocos, self).__init__() + self.backbone = VocosBackbone( + input_channels=feature_dim, + dim=dim, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + ) + self.head = ISTFTHead( + dim=dim, + n_fft=n_fft, + hop_length=hop_length, + padding=padding, + ) + + self.mpd = MultiPeriodDiscriminator() + self.mrd = MultiResolutionDiscriminator() + + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.feat_matching_loss = FeatureMatchingLoss() + self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) + + def forward(self, features: torch.Tensor): + x = self.backbone(features) + audio_output = self.head(x) + return audio_output diff --git a/egs/ljspeech/TTS/vocos/modules.py b/egs/ljspeech/TTS/vocos/modules.py new file mode 100644 index 0000000000..af1d6db16e --- /dev/null +++ b/egs/ljspeech/TTS/vocos/modules.py @@ -0,0 +1,213 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/egs/ljspeech/TTS/vocos/spectral_ops.py b/egs/ljspeech/TTS/vocos/spectral_ops.py new file mode 100644 index 0000000000..c0ad35ab31 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/spectral_ops.py @@ -0,0 +1,230 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + ) + elif self.padding == "same": + # return torch.istft( + # spec, + # self.n_fft, + # self.hop_length, + # self.win_length, + # self.window, + # center=False, + # ) + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, :] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze() + + # Normalize + norm_indexes = window_envelope > 1e-11 + + y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes] + # assert (window_envelope > 1e-11).all() + # y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad( + audio, (self.frame_len // 2, self.frame_len // 2) + ) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad( + audio, (self.frame_len // 4, self.frame_len // 4) + ) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft( + x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1 + )[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft( + Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1 + ) + y = ( + torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) + * np.sqrt(N) + * np.sqrt(2) + ) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio diff --git a/egs/ljspeech/TTS/vocos/train.py b/egs/ljspeech/TTS/vocos/train.py new file mode 100755 index 0000000000..e2092096e4 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/train.py @@ -0,0 +1,1003 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union +import itertools +import json +import copy +import math +import os +import random +import time + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule + +from torch.optim.lr_scheduler import ExponentialLR, LRScheduler +from torch.optim import Optimizer + +from utils import ( + load_checkpoint, + save_checkpoint, + plot_spectrogram, + get_cosine_schedule_with_warmup, +) + +from icefall import diagnostics +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs, +) +from models import Vocos +from lhotse import Fbank, FbankConfig + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-layers", + type=int, + default=8, + help="Number of ConvNeXt layers.", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=512, + help="Hidden dim of ConvNeXt module.", + ) + + parser.add_argument( + "--intermediate-dim", + type=int, + default=1536, + help="Intermediate dim of ConvNeXt module.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=100, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vocos/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--learning-rate", type=float, default=0.0005, help="The learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=500, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--mrd-loss-scale", + type=float, + default=0.1, + help="The scale of MultiResolutionDiscriminator loss.", + ) + + parser.add_argument( + "--mel-loss-scale", + type=float, + default=45, + help="The scale of melspectrogram loss.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 500, + "feature_dim": 80, + "segment_size": 16384, + "adam_b1": 0.8, + "adam_b2": 0.9, + "warmup_steps": 0, + "max_steps": 2000000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_model(params: AttributeDict) -> nn.Module: + device = params.device + model = Vocos( + feature_dim=params.feature_dim, + dim=params.hidden_dim, + n_fft=params.frame_length, + hop_length=params.frame_shift, + intermediate_dim=params.intermediate_dim, + num_layers=params.num_layers, + sample_rate=params.sampling_rate, + ).to(device) + + num_param_head = sum([p.numel() for p in model.head.parameters()]) + logging.info(f"Number of Head parameters : {num_param_head}") + num_param_bone = sum([p.numel() for p in model.backbone.parameters()]) + logging.info(f"Number of Generator parameters : {num_param_bone}") + num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) + logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") + num_param_mrd = sum([p.numel() for p in model.mrd.parameters()]) + logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}") + logging.info( + f"Number of model parameters : {num_param_head + num_param_bone + num_param_mpd + num_param_mrd}" + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def compute_generator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios_hat = model(features) # (B, T) + + mel_loss = model.melspec_loss(audios_hat, audios) + + _, gen_score_mpd, fmap_rs_mpd, fmap_gs_mpd = model.mpd(y=audios, y_hat=audios_hat) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = model.mrd(y=audios, y_hat=audios_hat) + + loss_gen_mpd, list_loss_gen_mpd = model.gen_loss(disc_outputs=gen_score_mpd) + loss_gen_mrd, list_loss_gen_mrd = model.gen_loss(disc_outputs=gen_score_mrd) + + loss_gen_mpd = loss_gen_mpd / len(list_loss_gen_mpd) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + + loss_fm_mpd = model.feat_matching_loss( + fmap_r=fmap_rs_mpd, fmap_g=fmap_gs_mpd + ) / len(fmap_rs_mpd) + loss_fm_mrd = model.feat_matching_loss( + fmap_r=fmap_gs_mrd, fmap_g=fmap_gs_mrd + ) / len(fmap_rs_mrd) + + loss_gen_all = ( + loss_gen_mpd + + params.mrd_loss_scale * loss_gen_mrd + + loss_fm_mpd + + params.mrd_loss_scale * loss_fm_mrd + + params.mel_loss_scale * mel_loss + ) + + assert loss_gen_all.requires_grad == True + + info = MetricsTracker() + info["frames"] = 1 + info["loss_gen"] = loss_gen_all.detach().cpu().item() + info["loss_mel"] = mel_loss.detach().cpu().item() + info["loss_feature_mpd"] = loss_fm_mpd.detach().cpu().item() + info["loss_feature_mrd"] = loss_fm_mrd.detach().cpu().item() + info["loss_gen_mrd"] = loss_gen_mrd.detach().cpu().item() + info["loss_gen_mpd"] = loss_gen_mpd.detach().cpu().item() + + return loss_gen_all, info + + +def compute_discriminator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + with torch.no_grad(): + audios_hat = model(features) # (B, 1, T) + + real_score_mpd, gen_score_mpd, _, _ = model.mpd(y=audios, y_hat=audios_hat) + real_score_mrd, gen_score_mrd, _, _ = model.mrd(y=audios, y_hat=audios_hat) + loss_mpd, loss_mpd_real, loss_mpd_gen = model.disc_loss( + disc_real_outputs=real_score_mpd, disc_generated_outputs=gen_score_mpd + ) + loss_mrd, loss_mrd_real, loss_mrd_gen = model.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mpd /= len(loss_mpd_real) + loss_mrd /= len(loss_mrd_real) + + loss_disc_all = loss_mpd + params.mrd_loss_scale * loss_mrd + + info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + info["frames"] = 1 + info["loss_disc"] = loss_disc_all.detach().cpu().item() + info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() + info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() + + for i in range(len(loss_mpd_real)): + info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i] + for i in range(len(loss_mrd_real)): + info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i] + + return loss_disc_all, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: ExponentialLR, + scheduler_d: ExponentialLR, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = batch["features_lens"].size(0) + + features = batch["features"].to(device) # (B, T, F) + features_lens = batch["features_lens"].to(device) + audios = batch["audio"].to(device) + + segment_frames = ( + params.segment_size - params.frame_length + ) // params.frame_shift + 1 + + # segment_frames = ( + # params.segment_size + params.frame_shift // 2 + # ) // params.frame_shift + + start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) + + features = features[:, start_p : start_p + segment_frames, :].permute( + 0, 2, 1 + ) # (B, F, T) + + audios = audios[ + :, + start_p * params.frame_shift : start_p * params.frame_shift + + params.segment_size, + ] # (B, T) + + try: + optimizer_d.zero_grad() + + loss_disc, loss_disc_info = compute_discriminator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_disc.backward() + optimizer_d.step() + + optimizer_g.zero_grad() + loss_gen, loss_gen_info = compute_generator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_gen.backward() + optimizer_g.step() + + loss_info = loss_gen_info + loss_disc_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info + + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, " + f"cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_gen", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_disc", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + # if ( + # params.batch_idx_train % params.valid_interval == 0 + # and not params.print_diagnostics + # ): + if True: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + tb_writer=tb_writer, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + scheduler_g.step() + scheduler_d.step() + loss_value = tot_loss["loss_gen"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, + tb_writer: Optional[SummaryWriter] = None, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + torch.cuda.empty_cache() + model = model.module if isinstance(model, DDP) else model + device = next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + infer_time = 0 + audio_time = 0 + for batch_idx, batch in enumerate(valid_dl): + features = batch["features"] # (B, T, F) + features_lens = batch["features_lens"] + + audio_time += torch.sum(features_lens) + + x = features.permute(0, 2, 1) # (B, F, T) + y = batch["audio"].to(device) # (B, T) + + start = time.time() + y_g_hat = model(x.to(device)) # (B, T) + infer_time += time.time() - start + + if y_g_hat.size(1) > y.size(1): + y = torch.cat( + [ + y, + torch.zeros( + (y.size(0), y_g_hat.size(1) - y.size(1)), device=device + ), + ], + dim=1, + ) + else: + y = y[:, 0 : y_g_hat.size(1)] + + loss_mel_error = model.melspec_loss(y_g_hat, y) + + loss_info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + loss_info["frames"] = 1 + loss_info["loss_mel_error"] = loss_mel_error.item() + + tot_loss = tot_loss + loss_info + + if batch_idx <= 5 and rank == 0 and tb_writer is not None: + if params.batch_idx_train == params.valid_interval: + tb_writer.add_audio( + "gt/y_{}".format(batch_idx), + y[0], + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "generated/y_hat_{}".format(batch_idx), + y_g_hat[0], + params.batch_idx_train, + params.sampling_rate, + ) + + logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["loss_mel_error"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + torch.autograd.set_detect_anomaly(True) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + params.device = device + logging.info(params) + logging.info("About to create model") + + model = get_model(params) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model = model.to(device) + head = model.head + backbone = model.backbone + mrd = model.mrd + mpd = model.mpd + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + itertools.chain(head.parameters(), backbone.parameters()), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + optimizer_d = torch.optim.AdamW( + itertools.chain(mrd.parameters(), mpd.parameters()), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + + scheduler_g = get_cosine_schedule_with_warmup( + optimizer_g, + num_warmup_steps=params.warmup_steps, + num_training_steps=params.max_steps, + ) + scheduler_d = get_cosine_schedule_with_warmup( + optimizer_d, + num_warmup_steps=params.warmup_steps, + num_training_steps=params.max_steps, + ) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading generator optimizer state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading discriminator optimizer state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading generator scheduler state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading discriminator scheduler state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if params.batch_idx_train % params.save_every_n == 0: + filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/ljspeech/TTS/vocos/tts_datamodule.py b/egs/ljspeech/TTS/vocos/tts_datamodule.py new file mode 100644 index 0000000000..44b8052096 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/tts_datamodule.py @@ -0,0 +1,372 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampleing rate of ljspeech dataset", + ) + + group.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + group.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def train_cuts_finetune(self) -> CutSet: + logging.info("About to get train cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz" + ) + + @lru_cache() + def valid_cuts_finetune(self) -> CutSet: + logging.info("About to get validation cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/vocos/utils.py b/egs/ljspeech/TTS/vocos/utils.py new file mode 100644 index 0000000000..c8132e208d --- /dev/null +++ b/egs/ljspeech/TTS/vocos/utils.py @@ -0,0 +1,205 @@ +import glob +import os +import logging +import matplotlib +import math +import torch +import torch.nn as nn +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from torch.nn.utils import weight_norm +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler +from lhotse.dataset.sampling.base import CutSampler +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def load_checkpoint( + filename: Path, + model: nn.Module, + model_avg: Optional[nn.Module] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + strict: bool = False, +) -> Dict[str, Any]: + logging.info(f"Loading checkpoint from {filename}") + checkpoint = torch.load(filename, map_location="cpu") + + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + + checkpoint.pop("model") + + if model_avg is not None and "model_avg" in checkpoint: + logging.info("Loading averaged model") + model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) + checkpoint.pop("model_avg") + + def load(name, obj): + s = checkpoint.get(name, None) + if obj and s: + obj.load_state_dict(s) + checkpoint.pop(name) + + load("optimizer_g", optimizer_g) + load("optimizer_d", optimizer_d) + load("scheduler_g", scheduler_g) + load("scheduler_d", scheduler_d) + load("grad_scaler", scaler) + load("sampler", sampler) + + return checkpoint + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer: + The optimizer to be saved. We only save its `state_dict()`. + scheduler: + The scheduler to be saved. We only save its `state_dict()`. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch)