diff --git a/egs/libritts/TTS/vocos/discriminators.py b/egs/libritts/TTS/vocos/discriminators.py index 238b974239..6f6a8b1ad2 100644 --- a/egs/libritts/TTS/vocos/discriminators.py +++ b/egs/libritts/TTS/vocos/discriminators.py @@ -3,7 +3,7 @@ import torch from torch import nn from torch.nn import Conv2d -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm from torchaudio.transforms import Spectrogram diff --git a/egs/libritts/TTS/vocos/export-onnx.py b/egs/libritts/TTS/vocos/export-onnx.py new file mode 100755 index 0000000000..18e58c1d9d --- /dev/null +++ b/egs/libritts/TTS/vocos/export-onnx.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" \ + --fp16 True +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxconverter_common import float16 +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="The sampleing rate of libritts dataset", + ) + + parser.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + parser.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 13, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + input_tensor = torch.rand((2, 80, 100), dtype=torch.float32) + + torch.onnx.export( + model, + (input_tensor,), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "features", + ], + output_names=["audio"], + dynamic_axes={ + "features": {0: "N", 2: "F"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "Vocos", + "version": "1", + "model_author": "k2-fsa", + "comment": "ConvNext Vocos", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + params.device = device + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + vocos = model.generator + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting model") + model_filename = params.exp_dir / f"vocos-{suffix}.onnx" + export_model_onnx( + vocos, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported vocos generator to {model_filename}") + + if params.fp16: + logging.info("Generate fp16 models") + + model = onnx.load(model_filename) + model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True) + model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx" + onnx.save(model_fp16, model_filename_fp16) + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vocos/export.py b/egs/libritts/TTS/vocos/export.py new file mode 100755 index 0000000000..f8ec255ee9 --- /dev/null +++ b/egs/libritts/TTS/vocos/export.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for libritts dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + + +./vocos/export.py \ + --exp-dir ./vocos/exp \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, +) + +from icefall.utils import str2bool +from utils import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="The sampleing rate of libritts dataset", + ) + + parser.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + parser.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vocos/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + params.device = device + logging.info(f"device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + model = model.generator + + if params.jit is True: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "generator.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vocos/generator.py b/egs/libritts/TTS/vocos/generator.py index 6e1dcdc4c7..fe51fa48d6 100644 --- a/egs/libritts/TTS/vocos/generator.py +++ b/egs/libritts/TTS/vocos/generator.py @@ -1,122 +1,154 @@ +import logging +from typing import Optional + +import numpy as np import torch from torch import nn - -from typing import Optional +from torch.autograd import Variable +from torch.nn import functional as F -class AdaLayerNorm(nn.Module): +def window_sumsquare( + window: torch.Tensor, + n_samples: int, + hop_length: int = 256, + win_length: int = 1024, +): """ - Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes - - Args: - num_embeddings (int): Number of embeddings. - embedding_dim (int): Dimension of the embeddings. + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_samples : int > 0 + The number of expected samples. + hop_length : int > 0 + The number of samples to advance between frames + win_length : + The length of the window function. + Returns + ------- + wss : torch.Tensor, The sum-squared envelope of the window function. """ - def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.dim = embedding_dim - self.scale = nn.Embedding( - num_embeddings=num_embeddings, embedding_dim=embedding_dim - ) - self.shift = nn.Embedding( - num_embeddings=num_embeddings, embedding_dim=embedding_dim - ) - torch.nn.init.ones_(self.scale.weight) - torch.nn.init.zeros_(self.shift.weight) - - def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: - scale = self.scale(cond_embedding_id) - shift = self.shift(cond_embedding_id) - x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) - x = x * scale + shift - return x + n_frames = (n_samples - win_length) // hop_length + 1 + output_size = (n_frames - 1) * hop_length + win_length + device = window.device + # Window envelope + window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, win_length), + stride=(1, hop_length), + ).squeeze() + window_envelope = torch.nn.functional.pad( + window_envelope, (0, n_samples - output_size) + ) + return window_envelope -class ISTFT(nn.Module): - """ - Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with - windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. - See issue: https://github.com/pytorch/pytorch/issues/62323 - Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. - Args: - n_fft (int): Size of Fourier transform. - hop_length (int): The distance between neighboring sliding window frames. - win_length (int): The size of window frame and STFT filter. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - """ +class ISTFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" def __init__( - self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" + self, + filter_length: int = 1024, + hop_length: int = 256, + win_length: int = 1024, + padding: str = "none", + window_type: str = "povey", + max_samples: int = 1440000, # 1440000 / 24000 = 60s ): - super().__init__() - if padding not in ["center", "same"]: - raise ValueError("Padding must be 'center' or 'same'.") - self.padding = padding - self.n_fft = n_fft + super(ISTFT, self).__init__() + self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length - window = torch.hann_window(win_length) - self.register_buffer("window", window) - - def forward(self, spec: torch.Tensor) -> torch.Tensor: - """ - Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. - - Args: - spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, - N is the number of frequency bins, and T is the number of time frames. + self.padding = padding + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) - Returns: - Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. - """ - if self.padding == "center": - # Fallback to pytorch native implementation - return torch.istft( - spec, - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - ) - elif self.padding == "same": - pad = (self.win_length - self.hop_length) // 2 - else: - raise ValueError("Padding must be 'center' or 'same'.") + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) - assert spec.dim() == 3, "Expected a 3D tensor as input" - B, N, T = spec.shape + assert filter_length >= win_length + # Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse + assert window_type in [ + "hanning", + "povey", + ], f"Only 'hanning' and 'povey' windows are supported, given {window_type}." + fft_window = torch.hann_window(win_length, periodic=False) + if window_type == "povey": + fft_window = fft_window.pow(0.85) + + if filter_length > win_length: + pad_size = (filter_length - win_length) // 2 + fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size)) + + window_sum = window_sumsquare( + window=fft_window, + n_samples=max_samples, + hop_length=hop_length, + win_length=filter_length, + ) - # Inverse FFT - ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") - ifft = ifft * self.window[None, :, None] + inverse_basis *= fft_window - # Overlap and Add - output_size = (T - 1) * self.hop_length + self.win_length - y = torch.nn.functional.fold( - ifft, - output_size=(1, output_size), - kernel_size=(1, self.win_length), - stride=(1, self.hop_length), - )[:, 0, 0, :] + self.register_buffer("inverse_basis", inverse_basis.float()) + self.register_buffer("fft_window", fft_window) + self.register_buffer("window_sum", window_sum) + self.tiny = torch.finfo(torch.float16).tiny - # Window envelope - window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) - window_envelope = torch.nn.functional.fold( - window_sq, - output_size=(1, output_size), - kernel_size=(1, self.win_length), - stride=(1, self.hop_length), - ).squeeze() + def forward(self, magnitude, phase): + magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + inverse_transform = F.conv_transpose1d( + magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + inverse_transform = inverse_transform.squeeze(1) + + window_sum = self.window_sum + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + if self.window_sum.size(-1) < inverse_transform.size(-1): + logging.warning( + f"The precomputed `window_sumsquare` is too small, recomputing, " + f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}" + ) + window_sum = window_sumsquare( + window=self.fft_window, + n_samples=inverse_transform.size(-1), + win_length=self.filter_length, + hop_length=self.hop_length, + ) + window_sum = window_sum[: inverse_transform.size(-1)] + approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze() - # Normalize - norm_indexes = window_envelope > 1e-11 - y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes] + inverse_transform[:, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] - return y + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + assert self.padding in ["none", "same", "center"] + if self.padding == "center": + pad_len = self.filter_length // 2 + elif self.padding == "same": + pad_len = (self.filter_length - self.hop_length) // 2 + else: + return inverse_transform + return inverse_transform[:, pad_len:-pad_len] class ConvNeXtBlock(nn.Module): @@ -127,8 +159,6 @@ class ConvNeXtBlock(nn.Module): intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. - adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. - None means non-conditional LayerNorm. Defaults to None. """ def __init__( @@ -136,20 +166,14 @@ def __init__( dim: int, intermediate_dim: int, layer_scale_init_value: Optional[float] = None, - adanorm_num_embeddings: Optional[int] = None, ): super().__init__() self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv - self.adanorm = adanorm_num_embeddings is not None - if adanorm_num_embeddings: - self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, intermediate_dim - ) # pointwise/1x1 convs, implemented with linear layers + self.norm = nn.LayerNorm(dim, eps=1e-6) + # pointwise/1x1 convs, implemented with linear layers + self.pwconv1 = nn.Linear(dim, intermediate_dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( @@ -159,16 +183,13 @@ def __init__( ) def forward( - self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + self, + x: torch.Tensor, ) -> torch.Tensor: residual = x x = self.dwconv(x) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) - if self.adanorm: - assert cond_embedding_id is not None - x = self.norm(x, cond_embedding_id) - else: - x = self.norm(x) + x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) @@ -189,28 +210,22 @@ def __init__( hop_length: int = 256, intermediate_dim: int = 1536, num_layers: int = 8, - padding: str = "same", - layer_scale_init_value: Optional[float] = None, - adanorm_num_embeddings: Optional[int] = None, + padding: str = "none", + max_samples: int = 1440000, # 1440000 / 24000 = 60s ): super(Generator, self).__init__() self.feature_dim = feature_dim self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3) - self.adanorm = adanorm_num_embeddings is not None - if adanorm_num_embeddings: - self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) + self.norm = nn.LayerNorm(dim, eps=1e-6) - layer_scale_init_value = layer_scale_init_value or 1 / num_layers + layer_scale_init_value = 1 / num_layers self.convnext = nn.ModuleList( [ ConvNeXtBlock( dim=dim, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, - adanorm_num_embeddings=adanorm_num_embeddings, ) for _ in range(num_layers) ] @@ -221,7 +236,11 @@ def __init__( self.out_proj = torch.nn.Linear(dim, n_fft + 2) self.istft = ISTFT( - n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + filter_length=n_fft, + hop_length=hop_length, + win_length=n_fft, + padding=padding, + max_samples=max_samples, ) def _init_weights(self, m): @@ -229,29 +248,17 @@ def _init_weights(self, m): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - bandwidth_id = kwargs.get("bandwidth_id", None) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embed(x) - if self.adanorm: - assert bandwidth_id is not None - x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) - else: - x = self.norm(x.transpose(1, 2)) - + x = self.norm(x.transpose(1, 2)) x = x.transpose(1, 2) for conv_block in self.convnext: - x = conv_block(x, cond_embedding_id=bandwidth_id) - + x = conv_block(x) x = self.final_layer_norm(x.transpose(1, 2)) - x = self.out_proj(x).transpose(1, 2) - mag, p = x.chunk(2, dim=1) + mag, phase = x.chunk(2, dim=1) mag = torch.exp(mag) - mag = torch.clip( - mag, max=1e2 - ) # safeguard to prevent excessively large magnitudes - x = torch.cos(p) - y = torch.sin(p) - S = mag * (x + 1j * y) - audio = self.istft(S) + # safeguard to prevent excessively large magnitudes + mag = torch.clip(mag, max=1e2) + audio = self.istft(mag, phase) return audio diff --git a/egs/libritts/TTS/vocos/infer.py b/egs/libritts/TTS/vocos/infer.py old mode 100644 new mode 100755 index 70e0aa6f00..60bdcdcb56 --- a/egs/libritts/TTS/vocos/infer.py +++ b/egs/libritts/TTS/vocos/infer.py @@ -20,6 +20,7 @@ import json import logging import math +import time import os from functools import partial from pathlib import Path @@ -29,7 +30,7 @@ from lhotse.utils import fix_random_seed from scipy.io.wavfile import write from train import add_model_arguments, get_model, get_params -from tts_datamodule import LJSpeechTtsDataModule +from tts_datamodule import LibriTTSDataModule from icefall.checkpoint import ( average_checkpoints, @@ -89,7 +90,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="flow_match/exp", + default="vocos/exp", help="The experiment dir", ) @@ -128,22 +129,31 @@ def decode_one_batch( cut_ids = [cut.id for cut in batch["cut"]] + infer_time = 0 + audio_time = 0 + features = batch["features"] # (B, T, F) utt_durations = batch["features_lens"] x = features.permute(0, 2, 1) # (B, F, T) + audio_time += torch.sum(utt_durations) + + start = time.time() + audios = model(x.to(device)) # (B, T) + infer_time += time.time() - start + wav_dir = f"{params.res_dir}/{params.suffix}" os.makedirs(wav_dir, exist_ok=True) for i in range(audios.shape[0]): - audio = audios[i][ - : int(utt_durations[i] * params.frame_shift_ms / 1000 * 22050) - ] + audio = audios[i][: int(utt_durations[i] * 256)] audio = audio.cpu().squeeze().numpy() - write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio) + write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio) + + print(f"RTF : {infer_time / (audio_time * (256/24000))}") def decode_dataset( @@ -173,7 +183,7 @@ def decode_dataset( with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f: for batch_idx, batch in enumerate(dl): - texts = batch["text"] + # texts = batch["text"] cut_ids = [cut.id for cut in batch["cut"]] decode_one_batch( @@ -182,12 +192,12 @@ def decode_dataset( batch=batch, ) - assert len(texts) == len(cut_ids), (len(texts), len(cut_ids)) + # assert len(texts) == len(cut_ids), (len(texts), len(cut_ids)) - for i in range(len(texts)): - f.write(f"{cut_ids[i]}\t{texts[i]}\n") + # for i in range(len(texts)): + # f.write(f"{cut_ids[i]}\t{texts[i]}\n") - num_cuts += len(texts) + # num_cuts += len(texts) if batch_idx % 50 == 0: batch_str = f"{batch_idx}/{num_batches}" @@ -200,7 +210,7 @@ def decode_dataset( @torch.no_grad() def main(): parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) + LibriTTSDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -318,11 +328,11 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - ljspeech = LJSpeechTtsDataModule(args) + libritts = LibriTTSDataModule(args) - test_cuts = ljspeech.test_cuts() + test_cuts = libritts.test_clean_cuts() - test_dl = ljspeech.test_dataloaders(test_cuts) + test_dl = libritts.test_dataloaders(test_cuts) test_sets = ["test"] test_dls = [test_dl] diff --git a/egs/libritts/TTS/vocos/model.py b/egs/libritts/TTS/vocos/model.py index 30c906ef95..3738a1ded0 100644 --- a/egs/libritts/TTS/vocos/model.py +++ b/egs/libritts/TTS/vocos/model.py @@ -19,8 +19,9 @@ def __init__( hop_length: int = 256, intermediate_dim: int = 1536, num_layers: int = 8, - padding: str = "same", + padding: str = "none", sample_rate: int = 24000, + max_seconds: int = 60, ): super(Vocos, self).__init__() self.generator = Generator( @@ -31,6 +32,7 @@ def __init__( num_layers=num_layers, intermediate_dim=intermediate_dim, padding=padding, + max_samples=int(sample_rate * max_seconds), ) self.mpd = MultiPeriodDiscriminator() diff --git a/egs/libritts/TTS/vocos/onnx_pretrained.py b/egs/libritts/TTS/vocos/onnx_pretrained.py new file mode 100755 index 0000000000..68e49c90c5 --- /dev/null +++ b/egs/libritts/TTS/vocos/onnx_pretrained.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import List, Tuple + +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from lhotse import Fbank, FbankConfig +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="The sampleing rate of libritts dataset", + ) + + parser.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + parser.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + parser.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + parser.add_argument( + "--output-dir", + type=str, + default="generated_audios", + help="The generated will be written to.", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + def run_model( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + output_dir = Path(args.model_filename).parent / args.output_dir + output_dir.mkdir(exist_ok=True) + args.output_dir = output_dir + logging.info(vars(args)) + + model = OnnxModel(model_filename=args.model_filename) + + config = FbankConfig( + sampling_rate=args.sampling_rate, + frame_length=args.frame_length / args.sampling_rate, # (in second), + frame_shift=args.frame_shift / args.sampling_rate, # (in second) + use_fft_mag=args.use_fft_mag, + ) + fbank = Fbank(config) + + logging.info(f"Reading sound files: {args.sound_files}") + + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=args.sampling_rate + ) + wave_lengths = [w.size(0) for w in waves] + waves = pad_sequence(waves, batch_first=True, padding_value=0) + + logging.info(f"waves : {waves.shape}") + + features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate) + + if features.dim() == 2: + features = features.unsqueeze(0) + + features = features.permute(0, 2, 1) + + logging.info(f"features : {features.shape}") + + logging.info("Generating started") + + # model forward + audios = model.run_model(features) + + for i, filename in enumerate(args.sound_files): + audio = audios[i : i + 1, 0 : wave_lengths[i]] + ofilename = args.output_dir / filename.split("/")[-1] + logging.info(f"Writting audio : {ofilename}") + torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate) + + logging.info("Generating Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vocos/pretrained.py b/egs/libritts/TTS/vocos/pretrained.py new file mode 100755 index 0000000000..1f1c42d183 --- /dev/null +++ b/egs/libritts/TTS/vocos/pretrained.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: +""" + + +import argparse +import logging +import math +from pathlib import Path +from typing import List + +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params +from lhotse import Fbank, FbankConfig + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="The sampleing rate of libritts dataset", + ) + + parser.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + parser.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + parser.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + parser.add_argument( + "--output-dir", + type=str, + default="generated_audios", + help="The generated will be written to.", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + output_dir = Path(params.checkpoint).parent / params.output_dir + output_dir.mkdir(exist_ok=True) + params.output_dir = output_dir + + logging.info(f"{params}") + + logging.info("Creating model") + model = get_model(params) + + model = model.generator + + checkpoint = torch.load(params.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + + config = FbankConfig( + sampling_rate=params.sampling_rate, + frame_length=params.frame_length / params.sampling_rate, # (in second), + frame_shift=params.frame_shift / params.sampling_rate, # (in second) + use_fft_mag=params.use_fft_mag, + ) + fbank = Fbank(config) + + logging.info(f"Reading sound files: {params.sound_files}") + + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sampling_rate + ) + wave_lengths = [w.size(0) for w in waves] + waves = pad_sequence(waves, batch_first=True, padding_value=0) + + features = ( + fbank.extract_batch(waves, sampling_rate=params.sampling_rate) + .permute(0, 2, 1) + .to(device) + ) + + logging.info("Generating started") + + # model forward + audios = model(features) + + for i, filename in enumerate(params.sound_files): + audio = audios[i : i + 1, 0 : wave_lengths[i]] + ofilename = params.output_dir / filename.split("/")[-1] + logging.info(f"Writting audio : {ofilename}") + torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate) + + logging.info("Generating Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vocos/train.py b/egs/libritts/TTS/vocos/train.py index c00afecdb7..f36e7ad711 100755 --- a/egs/libritts/TTS/vocos/train.py +++ b/egs/libritts/TTS/vocos/train.py @@ -52,9 +52,11 @@ save_checkpoint, plot_spectrogram, get_cosine_schedule_with_warmup, + save_checkpoint_with_global_batch_idx, ) from icefall import diagnostics +from icefall.checkpoint import remove_checkpoints, update_averaged_model from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks @@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Intermediate dim of ConvNeXt module.", ) + parser.add_argument( + "--max-seconds", + type=int, + default=60, + help=""" + The length of the precomputed normalization window sum square + (required by istft). This argument is only for onnx export, it determines + the max length of the audio that be properly normalized. + Note, you can generate audios longer than this value with the exported onnx model, + the part longer than this value will not be normalized yet. + The larger this value is the bigger the exported onnx model will be. + """, + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -203,6 +219,16 @@ def get_parser(): """, ) + parser.add_argument( + "--keep-last-epoch-k", + type=int, + default=50, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `epoch-xxx.pt`. + """, + ) + parser.add_argument( "--average-period", type=int, @@ -290,8 +316,8 @@ def get_params() -> AttributeDict: "valid_interval": 500, "feature_dim": 80, "segment_size": 16384, - "adam_b1": 0.8, - "adam_b2": 0.9, + "adam_b1": 0.9, + "adam_b2": 0.99, "warmup_steps": 0, "max_steps": 2000000, "env_info": get_env_info(), @@ -311,6 +337,7 @@ def get_model(params: AttributeDict) -> nn.Module: intermediate_dim=params.intermediate_dim, num_layers=params.num_layers, sample_rate=params.sampling_rate, + max_seconds=params.max_seconds, ).to(device) num_param_gen = sum([p.numel() for p in model.generator.parameters()]) @@ -479,11 +506,6 @@ def compute_discriminator_loss( info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() - for i in range(len(loss_mpd_real)): - info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i] - for i in range(len(loss_mrd_real)): - info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i] - return loss_disc_all, info @@ -497,6 +519,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -542,6 +565,7 @@ def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, + model_avg=model_avg, params=params, optimizer_g=optimizer_g, optimizer_d=optimizer_d, @@ -588,6 +612,7 @@ def save_bad_model(suffix: str = ""): loss_disc.backward() optimizer_d.step() + scheduler_d.step() optimizer_g.zero_grad() loss_gen, loss_gen_info = compute_generator_loss( @@ -599,6 +624,7 @@ def save_bad_model(suffix: str = ""): loss_gen.backward() optimizer_g.step() + scheduler_g.step() loss_info = loss_gen_info + loss_disc_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info @@ -611,6 +637,39 @@ def save_bad_model(suffix: str = ""): if params.print_diagnostics and batch_idx == 5: return + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if params.batch_idx_train % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different @@ -641,8 +700,8 @@ def save_bad_model(suffix: str = ""): f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, " - f"cur_lr_d: {cur_lr_d:.2e}, " + f"cur_lr_g: {cur_lr_g:.4e}, " + f"cur_lr_d: {cur_lr_d:.4e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -685,8 +744,6 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - scheduler_g.step() - scheduler_d.step() loss_value = tot_loss["loss_gen"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -766,7 +823,7 @@ def compute_validation_loss( params.sampling_rate, ) - logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") + logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}") if world_size > 1: tot_loss.reduce(device) @@ -811,15 +868,22 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) - logging.info(f"Device: {device}") params.device = device logging.info(params) - logging.info("About to create model") + logging.info("About to create model") model = get_model(params) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model = model.to(device) generator = model.generator @@ -915,6 +979,7 @@ def remove_short_and_long_utt(c: Cut): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, @@ -936,6 +1001,7 @@ def remove_short_and_long_utt(c: Cut): filename=filename, params=params, model=model, + model_avg=model_avg, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, @@ -945,28 +1011,20 @@ def remove_short_and_long_utt(c: Cut): rank=rank, ) - if params.batch_idx_train % params.save_every_n == 0: - filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_epoch_k, + prefix="epoch", + rank=rank, + ) logging.info("Done!") diff --git a/egs/libritts/TTS/vocos/utils.py b/egs/libritts/TTS/vocos/utils.py index c0fb107331..3984f0fd03 100644 --- a/egs/libritts/TTS/vocos/utils.py +++ b/egs/libritts/TTS/vocos/utils.py @@ -34,6 +34,69 @@ def plot_spectrogram(spectrogram): return fig +def save_checkpoint_with_global_batch_idx( + out_dir: Path, + global_batch_idx: int, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +): + """Save training info after processing given number of batches. + + Args: + out_dir: + The directory to save the checkpoint. + global_batch_idx: + The number of batches processed so far from the very start of the + training. The saved checkpoint will have the following filename: + + f'out_dir / checkpoint-{global_batch_idx}.pt' + model: + The neural network model whose `state_dict` will be saved in the + checkpoint. + model_avg: + The stored model averaged from the start of training. + params: + A dict of training configurations to be saved. + optimizer: + The optimizer used in the training. Its `state_dict` will be saved. + scheduler: + The learning rate scheduler used in the training. Its `state_dict` will + be saved. + scaler: + The scaler used for mix precision training. Its `state_dict` will + be saved. + sampler: + The sampler used in the training dataset. + rank: + The rank ID used in DDP training of the current node. Set it to 0 + if DDP is not used. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = out_dir / f"checkpoint-{global_batch_idx}.pt" + save_checkpoint( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer_g=optimizer_g, + scheduler_g=scheduler_g, + optimizer_d=optimizer_d, + scheduler_d=scheduler_d, + scaler=scaler, + sampler=sampler, + rank=rank, + ) + + def load_checkpoint( filename: Path, model: nn.Module, diff --git a/egs/ljspeech/TTS/local/evaluate_fsd.py b/egs/ljspeech/TTS/local/evaluate_fsd.py new file mode 100644 index 0000000000..f0e94b314a --- /dev/null +++ b/egs/ljspeech/TTS/local/evaluate_fsd.py @@ -0,0 +1,287 @@ +""" +Calculate Frechet Speech Distance betweeen two speech directories. +Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py +""" +import argparse +import logging +import os +from multiprocessing.dummy import Pool as ThreadPool + +import librosa +import numpy as np +import soundfile as sf +import torch +from scipy import linalg +from tqdm import tqdm +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model + +logging.basicConfig(level=logging.INFO) + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--real-path", type=str, help="path of the real speech directory" + ) + parser.add_argument( + "--eval-path", type=str, help="path of the evaluated speech directory" + ) + parser.add_argument( + "--model-path", + type=str, + default="model/huggingface/wav2vec2_base", + help="path of the wav2vec 2.0 model directory", + ) + parser.add_argument( + "--real-embds-path", + type=str, + default=None, + help="path of the real embedding directory", + ) + parser.add_argument( + "--eval-embds-path", + type=str, + default=None, + help="path of the evaluated embedding directory", + ) + return parser + + +class FrechetSpeechDistance: + def __init__( + self, + model_path="resources/wav2vec2_base", + pca_dim=128, + speech_load_worker=8, + ): + """ + Initialize FSD + """ + self.sample_rate = 16000 + self.channels = 1 + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + logging.info("[Frechet Speech Distance] Using device: {}".format(self.device)) + self.speech_load_worker = speech_load_worker + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path) + self.model = Wav2Vec2Model.from_pretrained(model_path) + self.model.to(self.device) + self.model.eval() + self.pca_dim = pca_dim + + def load_speech_files(self, dir, dtype="float32"): + def _load_speech_task(fname, sample_rate, channels, dtype="float32"): + if dtype not in ["float64", "float32", "int32", "int16"]: + raise ValueError(f"dtype not supported: {dtype}") + + wav_data, sr = sf.read(fname, dtype=dtype) + # For integer type PCM input, convert to [-1.0, +1.0] + if dtype == "int16": + wav_data = wav_data / 32768.0 + elif dtype == "int32": + wav_data = wav_data / float(2**31) + + # Convert to mono + assert channels in [1, 2], "channels must be 1 or 2" + if len(wav_data.shape) > channels: + wav_data = np.mean(wav_data, axis=1) + + if sr != sample_rate: + wav_data = ( + librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate), + ) + + return wav_data + + task_results = [] + + pool = ThreadPool(self.speech_load_worker) + + logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir)) + for fname in os.listdir(dir): + res = pool.apply_async( + _load_speech_task, + args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype), + ) + task_results.append(res) + pool.close() + pool.join() + + return [k.get() for k in task_results] + + def get_embeddings(self, x): + """ + Get embeddings + Params: + -- x : a list of np.ndarray speech samples + -- sr : sampling rate. + """ + embd_lst = [] + try: + for speech in tqdm(x): + input_features = self.feature_extractor( + speech, sampling_rate=self.sample_rate, return_tensors="pt" + ).input_values.to(self.device) + with torch.no_grad(): + embd = self.model(input_features).last_hidden_state.mean(1) + + if embd.device != torch.device("cpu"): + embd = embd.cpu() + + if torch.is_tensor(embd): + embd = embd.detach().numpy() + + embd_lst.append(embd) + except Exception as e: + print( + "[Frechet Speech Distance] get_embeddings throw an exception: {}".format( + str(e) + ) + ) + + return np.concatenate(embd_lst, axis=0) + + def calculate_embd_statistics(self, embd_lst): + if isinstance(embd_lst, list): + embd_lst = np.array(embd_lst) + mu = np.mean(embd_lst, axis=0) + sigma = np.cov(embd_lst, rowvar=False) + return mu, sigma + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """ + Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py + + Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert ( + sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + logging.info(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm( + (sigma1 + offset).dot(sigma2 + offset).astype(complex) + ) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + def score( + self, + real_path, + eval_path, + real_embds_path=None, + eval_embds_path=None, + dtype="float32", + ): + """ + Computes the Frechet Speech Distance (FSD) between two directories of speech files. + + Parameters: + - real_path (str): Path to the directory containing real speech files. + - eval_path (str): Path to the directory containing evaluation speech files. + - real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved. + - eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved. + - dtype (str, optional): Data type for loading speech. Default is "float32". + + Returns: + - float: The Frechet Speech Distance (FSD) score between the two directories of speech files. + """ + # Load or compute real embeddings + if real_embds_path is not None and os.path.exists(real_embds_path): + logging.info( + f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..." + ) + embds_real = np.load(real_embds_path) + else: + speech_real = self.load_speech_files(real_path, dtype=dtype) + embds_real = self.get_embeddings(speech_real) + if real_embds_path: + os.makedirs(os.path.dirname(real_embds_path), exist_ok=True) + np.save(real_embds_path, embds_real) + + # Load or compute eval embeddings + if eval_embds_path is not None and os.path.exists(eval_embds_path): + logging.info( + f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..." + ) + embds_eval = np.load(eval_embds_path) + else: + speech_eval = self.load_speech_files(eval_path, dtype=dtype) + embds_eval = self.get_embeddings(speech_eval) + if eval_embds_path: + os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True) + np.save(eval_embds_path, embds_eval) + + # Check if embeddings are empty + if len(embds_real) == 0: + logging.info("[Frechet Speech Distance] real set dir is empty, exiting...") + return -10.46 + if len(embds_eval) == 0: + logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...") + return -1 + + # Compute statistics and FSD score + mu_real, sigma_real = self.calculate_embd_statistics(embds_real) + mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) + + fsd_score = self.calculate_frechet_distance( + mu_real, sigma_real, mu_eval, sigma_eval + ) + + return fsd_score + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + FSD = FrechetSpeechDistance(model_path=args.model_path) + score = FSD.score( + args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path + ) + logging.info(f"FSD score: {score:.2f}") diff --git a/egs/ljspeech/TTS/local/evaluate_wer_whisper.py b/egs/ljspeech/TTS/local/evaluate_wer_whisper.py new file mode 100644 index 0000000000..baf719bd7c --- /dev/null +++ b/egs/ljspeech/TTS/local/evaluate_wer_whisper.py @@ -0,0 +1,139 @@ +""" +Calculate WER with Whisper model +""" +import argparse +import logging +import os +import re +from pathlib import Path +from typing import List, Tuple + +import librosa +import soundfile as sf +import torch +from num2words import num2words +from tqdm import tqdm +from transformers import pipeline + +from icefall.utils import store_transcripts, write_error_stats + +logging.basicConfig(level=logging.INFO) + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--wav-path", type=str, help="path of the speech directory") + parser.add_argument("--decode-path", type=str, help="path of the speech directory") + parser.add_argument( + "--model-path", + type=str, + default="model/huggingface/whisper_medium", + help="path of the huggingface whisper model", + ) + parser.add_argument( + "--transcript-path", + type=str, + default="data/transcript/test.tsv", + help="path of the transcript tsv file", + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="decoding batch size" + ) + parser.add_argument( + "--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu" + ) + return parser + + +def post_process(text: str): + def convert_numbers(match): + return num2words(match.group()) + + text = re.sub(r"\b\d{1,2}\b", convert_numbers, text) + text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower()) + text = re.sub(r"\s+", " ", text) + return text + + +def save_results( + res_dir: str, + results: List[Tuple[str, List[str], List[str]]], +): + if not os.path.exists(res_dir): + os.makedirs(res_dir) + recog_path = os.path.join(res_dir, "recogs.txt") + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + errs_filename = os.path.join(res_dir, "errs.txt") + with open(errs_filename, "w") as f: + _ = write_error_stats(f, "test", results, enable_log=True) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +class SpeechEvalDataset(torch.utils.data.Dataset): + def __init__(self, wav_path: str, transcript_path: str): + super().__init__() + self.audio_name = [] + self.audio_paths = [] + self.transcripts = [] + with Path(transcript_path).open("r", encoding="utf8") as f: + meta = [item.split("\t") for item in f.read().rstrip().split("\n")] + for item in meta: + self.audio_name.append(item[0]) + self.audio_paths.append(Path(wav_path, item[0] + ".wav")) + self.transcripts.append(item[1]) + + def __len__(self): + return len(self.audio_paths) + + def __getitem__(self, index: int): + audio, sampling_rate = sf.read(self.audio_paths[index]) + item = { + "array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000), + "sampling_rate": 16000, + "reference": self.transcripts[index], + "audio_name": self.audio_name[index], + } + return item + + +def main(args): + + batch_size = args.batch_size + + pipe = pipeline( + "automatic-speech-recognition", + model=args.model_path, + device=args.device, + tokenizer=args.model_path, + ) + + dataset = SpeechEvalDataset(args.wav_path, args.transcript_path) + + results = [] + bar = tqdm( + pipe( + dataset, + generate_kwargs={"language": "english", "task": "transcribe"}, + batch_size=batch_size, + ), + total=len(dataset), + ) + for out in bar: + results.append( + ( + out["audio_name"][0], + post_process(out["reference"][0].strip()).split(), + post_process(out["text"].strip()).split(), + ) + ) + save_results(args.decode_path, results) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/egs/ljspeech/TTS/vocos/export-onnx.py b/egs/ljspeech/TTS/vocos/export-onnx.py new file mode 120000 index 0000000000..47a0dadd7e --- /dev/null +++ b/egs/ljspeech/TTS/vocos/export-onnx.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/export-onnx.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/export.py b/egs/ljspeech/TTS/vocos/export.py new file mode 120000 index 0000000000..28b8a06fde --- /dev/null +++ b/egs/ljspeech/TTS/vocos/export.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/export.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/infer.py b/egs/ljspeech/TTS/vocos/infer.py new file mode 100755 index 0000000000..4b11b3721d --- /dev/null +++ b/egs/ljspeech/TTS/vocos/infer.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import math +import os +from functools import partial +from pathlib import Path + +import torch +import torch.nn as nn +from lhotse.utils import fix_random_seed +from scipy.io.wavfile import write +from train import add_model_arguments, get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=100, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="flow_match/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--generate-dir", + type=str, + default="generated_wavs", + help="Path name of the generated wavs", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +): + """ + Args: + params: + It's the return value of :func:`get_params`. + model: + The text-to-feature neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + + cut_ids = [cut.id for cut in batch["cut"]] + + features = batch["features"] # (B, T, F) + utt_durations = batch["features_lens"] + + x = features.permute(0, 2, 1) # (B, F, T) + + audios = model(x.to(device)) # (B, T) + + wav_dir = f"{params.res_dir}/{params.suffix}" + os.makedirs(wav_dir, exist_ok=True) + + for i in range(audios.shape[0]): + audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024] + audio = audio.cpu().squeeze().numpy() + write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio) + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + test_set: str, +): + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The text-to-feature neural model. + test_set: + The name of the test_set + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f: + for batch_idx, batch in enumerate(dl): + texts = batch["text"] + cut_ids = [cut.id for cut in batch["cut"]] + + decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + assert len(texts) == len(cut_ids), (len(texts), len(cut_ids)) + + for i in range(len(texts)): + f.write(f"{cut_ids[i]}\t{texts[i]}\n") + + num_cuts += len(texts) + + if batch_idx % 50 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / params.generate_dir + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + + logging.info(params) + fix_random_seed(666) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model = model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + + test_dl = ljspeech.test_dataloaders(test_cuts) + + test_sets = ["test"] + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + decode_dataset( + dl=test_dl, + params=params, + model=model, + test_set=test_set, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vocos/onnx_pretrained.py b/egs/ljspeech/TTS/vocos/onnx_pretrained.py new file mode 120000 index 0000000000..08bf4579cb --- /dev/null +++ b/egs/ljspeech/TTS/vocos/onnx_pretrained.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/onnx_pretrained.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/pretrained.py b/egs/ljspeech/TTS/vocos/pretrained.py new file mode 120000 index 0000000000..37802e4c84 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/pretrained.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/pretrained.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/train.py b/egs/ljspeech/TTS/vocos/train.py index 51ec024efb..ca61d03054 100755 --- a/egs/ljspeech/TTS/vocos/train.py +++ b/egs/ljspeech/TTS/vocos/train.py @@ -52,9 +52,11 @@ save_checkpoint, plot_spectrogram, get_cosine_schedule_with_warmup, + save_checkpoint_with_global_batch_idx, ) from icefall import diagnostics +from icefall.checkpoint import remove_checkpoints, update_averaged_model from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks @@ -65,7 +67,7 @@ str2bool, get_parameter_groups_with_lrs, ) -from models import Vocos +from model import Vocos from lhotse import Fbank, FbankConfig @@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Intermediate dim of ConvNeXt module.", ) + parser.add_argument( + "--max-seconds", + type=int, + default=60, + help=""" + The length of the precomputed normalization window sum square + (required by istft). This argument is only for onnx export, it determines + the max length of the audio that be properly normalized. + Note, you can generate audios longer than this value with the exported onnx model, + the part longer than this value will not be normalized yet. + The larger this value is the bigger the exported onnx model will be. + """, + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -203,6 +219,16 @@ def get_parser(): """, ) + parser.add_argument( + "--keep-last-epoch-k", + type=int, + default=50, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `epoch-xxx.pt`. + """, + ) + parser.add_argument( "--average-period", type=int, @@ -290,8 +316,8 @@ def get_params() -> AttributeDict: "valid_interval": 500, "feature_dim": 80, "segment_size": 16384, - "adam_b1": 0.8, - "adam_b2": 0.9, + "adam_b1": 0.9, + "adam_b2": 0.99, "warmup_steps": 0, "max_steps": 2000000, "env_info": get_env_info(), @@ -311,18 +337,17 @@ def get_model(params: AttributeDict) -> nn.Module: intermediate_dim=params.intermediate_dim, num_layers=params.num_layers, sample_rate=params.sampling_rate, + max_seconds=params.max_seconds, ).to(device) - num_param_head = sum([p.numel() for p in model.head.parameters()]) - logging.info(f"Number of Head parameters : {num_param_head}") - num_param_bone = sum([p.numel() for p in model.backbone.parameters()]) - logging.info(f"Number of Generator parameters : {num_param_bone}") + num_param_gen = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of Generator parameters : {num_param_gen}") num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") num_param_mrd = sum([p.numel() for p in model.mrd.parameters()]) logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}") logging.info( - f"Number of model parameters : {num_param_head + num_param_bone + num_param_mpd + num_param_mrd}" + f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}" ) return model @@ -481,11 +506,6 @@ def compute_discriminator_loss( info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() - for i in range(len(loss_mpd_real)): - info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i] - for i in range(len(loss_mrd_real)): - info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i] - return loss_disc_all, info @@ -499,6 +519,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -544,6 +565,7 @@ def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, + model_avg=model_avg, params=params, optimizer_g=optimizer_g, optimizer_d=optimizer_d, @@ -566,10 +588,6 @@ def save_bad_model(suffix: str = ""): params.segment_size - params.frame_length ) // params.frame_shift + 1 - # segment_frames = ( - # params.segment_size + params.frame_shift // 2 - # ) // params.frame_shift - start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) features = features[:, start_p : start_p + segment_frames, :].permute( @@ -594,6 +612,7 @@ def save_bad_model(suffix: str = ""): loss_disc.backward() optimizer_d.step() + scheduler_d.step() optimizer_g.zero_grad() loss_gen, loss_gen_info = compute_generator_loss( @@ -605,6 +624,7 @@ def save_bad_model(suffix: str = ""): loss_gen.backward() optimizer_g.step() + scheduler_g.step() loss_info = loss_gen_info + loss_disc_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info @@ -617,6 +637,39 @@ def save_bad_model(suffix: str = ""): if params.print_diagnostics and batch_idx == 5: return + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if params.batch_idx_train % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different @@ -647,8 +700,8 @@ def save_bad_model(suffix: str = ""): f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, " - f"cur_lr_d: {cur_lr_d:.2e}, " + f"cur_lr_g: {cur_lr_g:.4e}, " + f"cur_lr_d: {cur_lr_d:.4e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -668,11 +721,10 @@ def save_bad_model(suffix: str = ""): "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - # if ( - # params.batch_idx_train % params.valid_interval == 0 - # and not params.print_diagnostics - # ): - if True: + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -692,8 +744,6 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) - scheduler_g.step() - scheduler_d.step() loss_value = tot_loss["loss_gen"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -773,7 +823,7 @@ def compute_validation_loss( params.sampling_rate, ) - logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") + logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}") if world_size > 1: tot_loss.reduce(device) @@ -818,19 +868,25 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) - logging.info(f"Device: {device}") params.device = device logging.info(params) - logging.info("About to create model") + logging.info("About to create model") model = get_model(params) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model = model.to(device) - head = model.head - backbone = model.backbone + generator = model.generator mrd = model.mrd mpd = model.mpd if world_size > 1: @@ -838,7 +894,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer_g = torch.optim.AdamW( - itertools.chain(head.parameters(), backbone.parameters()), + generator.parameters(), params.learning_rate, betas=[params.adam_b1, params.adam_b2], ) @@ -923,6 +979,7 @@ def remove_short_and_long_utt(c: Cut): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, @@ -944,6 +1001,7 @@ def remove_short_and_long_utt(c: Cut): filename=filename, params=params, model=model, + model_avg=model_avg, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, @@ -953,28 +1011,20 @@ def remove_short_and_long_utt(c: Cut): rank=rank, ) - if params.batch_idx_train % params.save_every_n == 0: - filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_epoch_k, + prefix="epoch", + rank=rank, + ) logging.info("Done!") @@ -997,7 +1047,8 @@ def main(): run(rank=0, world_size=1, args=args) +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) main() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce13019..e935a5ad73 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -250,18 +250,22 @@ def save_checkpoint_with_global_batch_idx( ) -def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: +def find_checkpoints( + out_dir: Path, + iteration: int = 0, + prefix: str = "checkpoint", +) -> List[str]: """Find all available checkpoints in a directory. - The checkpoint filenames have the form: `checkpoint-xxx.pt` + The checkpoint filenames have the form: `{prefix}-xxx.pt` where xxx is a numerical value. Assume you have the following checkpoints in the folder `foo`: - - checkpoint-1.pt - - checkpoint-20.pt - - checkpoint-300.pt - - checkpoint-4000.pt + - {prefix}-1.pt + - {prefix}-20.pt + - {prefix}-300.pt + - {prefix}-4000.pt Case 1 (Return all checkpoints):: @@ -290,8 +294,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: Return a list of checkpoint filenames, sorted in descending order by the numerical value in the filename. """ - checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) - pattern = re.compile(r"checkpoint-([0-9]+).pt") + checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt")) + pattern = re.compile(rf"{prefix}-([0-9]+).pt") iter_checkpoints = [] for c in checkpoints: result = pattern.search(c) @@ -316,12 +320,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: def remove_checkpoints( out_dir: Path, topk: int, + prefix: str = "checkpoint", rank: int = 0, ): """Remove checkpoints from the given directory. - We assume that checkpoint filename has the form `checkpoint-xxx.pt` - where xxx is a number, representing the number of processed batches + We assume that checkpoint filename has the form `{prefix}-xxx.pt` + where xxx is a number, representing the number of processed batches/epochs when saving that checkpoint. We sort checkpoints by filename and keep only the `topk` checkpoints with the highest `xxx`. @@ -330,6 +335,8 @@ def remove_checkpoints( The directory containing checkpoints to be removed. topk: Number of checkpoints to keep. + prefix: + The prefix of the checkpoint filename, normally `epoch`, `checkpoint`. rank: If using DDP for training, it is the rank of the current node. Use 0 if no DDP is used for training. @@ -337,7 +344,7 @@ def remove_checkpoints( assert topk >= 1, topk if rank != 0: return - checkpoints = find_checkpoints(out_dir) + checkpoints = find_checkpoints(out_dir, prefix=prefix) if len(checkpoints) == 0: logging.warn(f"No checkpoints found in {out_dir}")