This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathtwist_generation.py
100 lines (76 loc) · 2.6 KB
/
twist_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torchaudio
import torch
from speech_lm import generate_with_offset, build_speech_lm
from textless.vocoders.hifigan.vocoder import CodeHiFiGANVocoder
from textless.data.speech_encoder import SpeechEncoder
def run_full_generation(hubert_encoder, twist_model, hifi_vocoder, speech_prompt):
input_ids = hubert_encoder(speech_prompt)['units'].unsqueeze(0)
generated_ids = generate_with_offset(twist_model, input_ids)
full_generation = hifi_vocoder(generated_ids, dur_prediction = True)
return full_generation
def main(args):
dense_model, quantizer_model, vocab = "mhubert-base-25hz", "kmeans", 500
# Load speech encoder and vocoder
encoder = SpeechEncoder.by_name(
dense_model_name = dense_model,
quantizer_model_name = quantizer_model,
vocab_size = vocab,
deduplicate=True,
need_f0=False,
add_bos_eos=False,
).eval()
vocoder = CodeHiFiGANVocoder.by_name(
dense_model_name = dense_model,
quantizer_model_name = quantizer_model,
vocab_size = vocab
).eval()
if torch.cuda.is_available():
encoder = encoder.cuda()
vocoder = vocoder.cuda()
# Load twist model
twist_model = build_speech_lm(args.twist_model_name)
audio, sample_rate = torchaudio.load(args.input_file)
if audio.ndim == 2:
audio = audio.mean(0)
if args.prompt_duration_sec:
prompt = int(args.prompt_duration_sec * sample_rate)
audio = audio[:prompt]
generated_audio = run_full_generation(encoder, twist_model, vocoder, audio)
torchaudio.save(
args.output_file,
generated_audio.cpu().unsqueeze(0),
16000,
)
def cli_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
required=True,
help="Input filepath",
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path where generated metadata is saved"
)
parser.add_argument(
"--twist_model_name",
type=str,
default="TWIST-350M",
choices=["TWIST-350M", "TWIST-1.3B", "TWIST-7B"],
help="Name of TWIST model",
)
parser.add_argument(
"--prompt_duration_sec",
type=float,
default=3.0,
help="Cutting prompts to a maximum duration",
)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
cli_main()