Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fade_in_out error when steam = True #524

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 35 additions & 84 deletions cosyvoice/bin/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,138 +15,88 @@

from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed

from hyperpyyaml import load_hyperpyyaml

from torch.distributed.elastic.multiprocessing.errors import record

from copy import deepcopy
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)

wrap_cuda_model, check_modify_and_save_config
)
from torch.distributed.elastic.multiprocessing.errors import record

def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=60,
type=int,
help='timeout (in seconds) of cosyvoice_join.')
parser = argparse.ArgumentParser(description='Training your network')
parser.add_argument('--train_engine', default='torch_ddp', choices=['torch_ddp', 'deepspeed'], help='Engine for parallelized training')
parser.add_argument('--model', required=True, help='Model to be trained')
parser.add_argument('--config', required=True, help='Config file')
parser.add_argument('--train_data', required=True, help='Training data file')
parser.add_argument('--cv_data', required=True, help='CV data file')
parser.add_argument('--checkpoint', help='Checkpoint model path')
parser.add_argument('--model_dir', required=True, help='Directory to save the model')
parser.add_argument('--tensorboard_dir', default='tensorboard', help='Tensorboard log directory')
parser.add_argument('--ddp.dist_backend', dest='dist_backend', default='nccl', choices=['nccl', 'gloo'], help='Distributed backend')
parser.add_argument('--num_workers', default=0, type=int, help='Number of subprocess workers for reading')
parser.add_argument('--prefetch', default=100, type=int, help='Prefetch number')
parser.add_argument('--pin_memory', action='store_true', default=False, help='Use pinned memory buffers for reading')
parser.add_argument('--deepspeed.save_states', dest='save_states', default='model_only', choices=['model_only', 'model+optimizer'], help='Save model/optimizer states')
parser.add_argument('--timeout', default=60, type=int, help='Timeout (in seconds) for cosyvoice_join')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args

return parser.parse_args()

@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
gan = True if args.model == 'hifigan' else False
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')

gan = True if args.model == 'hifigan' else False
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
if gan: override_dict.pop('hift')

with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:

if gan:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))

# Init env for ddp
init_distributed(args)

# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan)

# Do some sanity checks and save config to arsg.model_dir
train_dataset, cv_dataset, train_data_loader, cv_data_loader = init_dataset_and_dataloader(args, configs, gan)
configs = check_modify_and_save_config(args, configs)

# Tensorboard summary
writer = init_summarywriter(args)

# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
if os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))

# Dispatch model from cpu to gpu
if args.checkpoint and os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)

model = wrap_cuda_model(args, model)

# Get optimizer & scheduler
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)

# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
save_model(model, 'init', info_dict)

# Get executor
executor = Executor(gan=gan)

# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
if gan is True:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, group_join)

if gan:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)

dist.destroy_process_group(group_join)


if __name__ == '__main__':
main()
14 changes: 4 additions & 10 deletions cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def encode(
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
Expand All @@ -105,8 +104,7 @@ def forward(
embedding = batch['embedding'].to(device)

# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)

# 1. encode text_token
Expand All @@ -126,8 +124,7 @@ def forward(
speech_token = self.speech_embedding(speech_token)

# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)

# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
Expand Down Expand Up @@ -197,10 +194,7 @@ def inference(
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
Expand Down
1 change: 1 addition & 0 deletions cosyvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone() #clone function will do
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
Expand Down