Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 21, 2024
1 parent dc0106a commit 8da9acd
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 55 deletions.
10 changes: 10 additions & 0 deletions egs/libritts/TTS/local/prepare_tokens_libritts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
from tqdm.auto import tqdm


def remove_punc_to_upper(text: str) -> str:
text = text.replace("‘", "'")
text = text.replace("’", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s

def prepare_tokens_libritts():
output_dir = Path("data/spectrogram")
prefix = "libritts"
Expand Down Expand Up @@ -60,6 +68,8 @@ def prepare_tokens_libritts():
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)

new_cuts.append(cut)

new_cut_set = CutSet.from_cuts(new_cuts)
Expand Down
71 changes: 42 additions & 29 deletions egs/libritts/TTS/vits/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
Expand Down Expand Up @@ -331,16 +332,22 @@ def prepare_input(
batch: dict,
tokenizer: Tokenizer,
device: torch.device,
speaker_map: Dict[str, int],
speaker_map: KaldiReader,
):
"""Parse batch data"""

def parse_sids(batch: dict) -> List[str]:
return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]]

audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
.squeeze(1)
.to(device)
)

tokens = tokenizer.tokens_to_token_ids(
Expand All @@ -366,8 +373,9 @@ def train_one_epoch(
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
dev_dl: torch.utils.data.DataLoader,
train_speaker_map: KaldiReader,
dev_speaker_map: KaldiReader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
Expand Down Expand Up @@ -442,7 +450,7 @@ def save_bad_model(suffix: str = ""):
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, train_speaker_map)

loss_info = MetricsTracker()
loss_info["samples"] = batch_size
Expand All @@ -457,7 +465,7 @@ def save_bad_model(suffix: str = ""):
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
for k, v in stats_d.items():
Expand All @@ -476,7 +484,7 @@ def save_bad_model(suffix: str = ""):
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
Expand Down Expand Up @@ -583,8 +591,8 @@ def save_bad_model(suffix: str = ""):
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
speaker_map=speaker_map,
dev_dl=dev_dl,
dev_speaker_map=dev_speaker_map,
world_size=world_size,
)
model.train()
Expand Down Expand Up @@ -620,8 +628,8 @@ def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
dev_dl: torch.utils.data.DataLoader,
dev_speaker_map: KaldiReader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
Expand All @@ -634,7 +642,7 @@ def compute_validation_loss(
returned_sample = None

with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
for batch_idx, batch in enumerate(dev_dl):
batch_size = len(batch["tokens"])
(
audio,
Expand All @@ -644,7 +652,7 @@ def compute_validation_loss(
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, dev_speaker_map)

loss_info = MetricsTracker()
loss_info["samples"] = batch_size
Expand All @@ -657,7 +665,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
assert loss_d.requires_grad is False
Expand All @@ -672,7 +680,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
)
assert loss_g.requires_grad is False
Expand All @@ -687,7 +695,7 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()],
sids=speakers[0],
spembs=speakers[0],
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
Expand Down Expand Up @@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
speaker_map: Dict[str, int],
train_speaker_map: KaldiReader,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
Expand All @@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom(
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, train_speaker_map)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
Expand All @@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
optimizer_d.zero_grad()
Expand All @@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
)
optimizer_g.zero_grad()
Expand Down Expand Up @@ -820,9 +828,12 @@ def run(rank, world_size, args):

libritts = LibrittsTtsDataModule(args)

train_cuts = libritts.train_cuts()
speaker_map = libritts.speakers()
params.num_spks = len(speaker_map)
if params.full_libri:
train_cuts = libritts.train_all_shuf_cuts()
train_speaker_map = libritts.train_all_shuf_xvector()
else:
train_cuts = libritts.train_clean_460_cuts()
train_speaker_map = libritts.train_clean_460_xvector()

logging.info(params)

Expand Down Expand Up @@ -896,8 +907,9 @@ def remove_short_and_long_utt(c: Cut):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = libritts.train_dataloaders(train_cuts)

valid_cuts = libritts.valid_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts)
dev_clean_cuts = libritts.dev_clean_cuts()
dev_speaker_map = libritts.dev_clean_xvector()
dev_dl = libritts.dev_dataloaders(dev_clean_cuts)

if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
Expand All @@ -906,7 +918,7 @@ def remove_short_and_long_utt(c: Cut):
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
speaker_map=speaker_map,
train_speaker_map=train_speaker_map,
params=params,
)

Expand Down Expand Up @@ -935,8 +947,9 @@ def remove_short_and_long_utt(c: Cut):
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
speaker_map=speaker_map,
dev_dl=dev_dl,
train_speaker_map=train_speaker_map,
dev_speaker_map=dev_speaker_map,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
Expand Down
Loading

0 comments on commit 8da9acd

Please sign in to comment.