Skip to content

Commit

Permalink
Use piper_phonemize as text tokenizer in vctk TTS recipe (#1522)
Browse files Browse the repository at this point in the history
* to align with PR #1524
  • Loading branch information
JinZr authored Mar 18, 2024
1 parent 9b0eae3 commit eec12f0
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 135 deletions.
3 changes: 1 addition & 2 deletions egs/vctk/TTS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk

This recipe provides a VITS model trained on the VCTK dataset.

Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.

For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).

Expand All @@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350
Expand Down
104 changes: 0 additions & 104 deletions egs/vctk/TTS/local/prepare_token_file.py

This file was deleted.

1 change: 1 addition & 0 deletions egs/vctk/TTS/local/prepare_token_file.py
11 changes: 7 additions & 4 deletions egs/vctk/TTS/local/prepare_tokens_vctk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import logging
from pathlib import Path

import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm


Expand All @@ -37,17 +37,20 @@ def prepare_tokens_vctk():
partition = "all"

cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()

new_cuts = []
for cut in tqdm(cut_set):
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)

new_cut_set = CutSet.from_cuts(new_cuts)
Expand Down
20 changes: 14 additions & 6 deletions egs/vctk/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ fi

if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for VCTK"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
Expand Down Expand Up @@ -111,14 +118,15 @@ fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi

Expand Down
19 changes: 13 additions & 6 deletions egs/vctk/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
meta.value = str(value)

onnx.save(model, filename)

Expand Down Expand Up @@ -160,6 +161,7 @@ def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
n_speakers: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
Expand Down Expand Up @@ -212,10 +214,15 @@ def export_model_onnx(
)

meta_data = {
"model_type": "VITS",
"model_type": "vits",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
"comment": "icefall", # must be icefall for models from icefall
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": n_speakers,
"sample_rate": 22050, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")

Expand All @@ -231,8 +238,7 @@ def main():
params.update(vars(args))

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

with open(args.speakers) as f:
Expand Down Expand Up @@ -265,6 +271,7 @@ def main():
model,
model_filename,
params.vocab_size,
params.num_spks,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
Expand Down
9 changes: 5 additions & 4 deletions egs/vctk/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,16 @@ def _save_worker(
batch_size = len(batch["tokens"])

tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
Expand Down Expand Up @@ -214,8 +216,7 @@ def main():
device = torch.device("cuda", 0)

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

# we need cut ids to display recognition results.
Expand Down
7 changes: 5 additions & 2 deletions egs/vctk/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -122,7 +123,9 @@ def main():
model = OnnxModel(args.model_filename)

text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
Expand Down
12 changes: 7 additions & 5 deletions egs/vctk/TTS/vits/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -342,14 +343,16 @@ def prepare_input(
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
)

tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)

return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers

Expand Down Expand Up @@ -812,8 +815,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

vctk = VctkTtsDataModule(args)
Expand Down
5 changes: 3 additions & 2 deletions egs/vctk/TTS/vits/tts_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down

0 comments on commit eec12f0

Please sign in to comment.