-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from shivammehta25/dev
Merging dev to main | adding ONNX support
- Loading branch information
Showing
8 changed files
with
424 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.0.3 | ||
0.0.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import argparse | ||
import random | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from lightning import LightningModule | ||
|
||
from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder | ||
|
||
DEFAULT_OPSET = 15 | ||
|
||
SEED = 1234 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
torch.manual_seed(SEED) | ||
torch.cuda.manual_seed(SEED) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
|
||
class MatchaWithVocoder(LightningModule): | ||
def __init__(self, matcha, vocoder): | ||
super().__init__() | ||
self.matcha = matcha | ||
self.vocoder = vocoder | ||
|
||
def forward(self, x, x_lengths, scales, spks=None): | ||
mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) | ||
wavs = self.vocoder(mel).clamp(-1, 1) | ||
lengths = mel_lengths * 256 | ||
return wavs.squeeze(1), lengths | ||
|
||
|
||
def get_exportable_module(matcha, vocoder, n_timesteps): | ||
""" | ||
Return an appropriate `LighteningModule` and output-node names | ||
based on whether the vocoder is embedded in the final graph | ||
""" | ||
|
||
def onnx_forward_func(x, x_lengths, scales, spks=None): | ||
""" | ||
Custom forward function for accepting | ||
scaler parameters as tensors | ||
""" | ||
# Extract scaler parameters from tensors | ||
temperature = scales[0] | ||
length_scale = scales[1] | ||
output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) | ||
return output["mel"], output["mel_lengths"] | ||
|
||
# Monkey-patch Matcha's forward function | ||
matcha.forward = onnx_forward_func | ||
|
||
if vocoder is None: | ||
model, output_names = matcha, ["mel", "mel_lengths"] | ||
else: | ||
model = MatchaWithVocoder(matcha, vocoder) | ||
output_names = ["wav", "wav_lengths"] | ||
return model, output_names | ||
|
||
|
||
def get_inputs(is_multi_speaker): | ||
""" | ||
Create dummy inputs for tracing | ||
""" | ||
dummy_input_length = 50 | ||
x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) | ||
x_lengths = torch.LongTensor([dummy_input_length]) | ||
|
||
# Scales | ||
temperature = 0.667 | ||
length_scale = 1.0 | ||
scales = torch.Tensor([temperature, length_scale]) | ||
|
||
model_inputs = [x, x_lengths, scales] | ||
input_names = [ | ||
"x", | ||
"x_lengths", | ||
"scales", | ||
] | ||
|
||
if is_multi_speaker: | ||
spks = torch.LongTensor([1]) | ||
model_inputs.append(spks) | ||
input_names.append("spks") | ||
|
||
return tuple(model_inputs), input_names | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") | ||
|
||
parser.add_argument( | ||
"checkpoint_path", | ||
type=str, | ||
help="Path to the model checkpoint", | ||
) | ||
parser.add_argument("output", type=str, help="Path to output `.onnx` file") | ||
parser.add_argument( | ||
"--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" | ||
) | ||
parser.add_argument( | ||
"--vocoder-name", | ||
type=str, | ||
choices=list(VOCODER_URLS.keys()), | ||
default=None, | ||
help="Name of the vocoder to embed in the ONNX graph", | ||
) | ||
parser.add_argument( | ||
"--vocoder-checkpoint-path", | ||
type=str, | ||
default=None, | ||
help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", | ||
) | ||
parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") | ||
|
||
args = parser.parse_args() | ||
|
||
print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") | ||
print(f"Setting n_timesteps to {args.n_timesteps}") | ||
|
||
checkpoint_path = Path(args.checkpoint_path) | ||
matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") | ||
|
||
if args.vocoder_name or args.vocoder_checkpoint_path: | ||
assert ( | ||
args.vocoder_name and args.vocoder_checkpoint_path | ||
), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." | ||
vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") | ||
else: | ||
vocoder = None | ||
|
||
is_multi_speaker = matcha.n_spks > 1 | ||
|
||
dummy_input, input_names = get_inputs(is_multi_speaker) | ||
model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) | ||
|
||
# Set dynamic shape for inputs/outputs | ||
dynamic_axes = { | ||
"x": {0: "batch_size", 1: "time"}, | ||
"x_lengths": {0: "batch_size"}, | ||
} | ||
|
||
if vocoder is None: | ||
dynamic_axes.update( | ||
{ | ||
"mel": {0: "batch_size", 2: "time"}, | ||
"mel_lengths": {0: "batch_size"}, | ||
} | ||
) | ||
else: | ||
print("Embedding the vocoder in the ONNX graph") | ||
dynamic_axes.update( | ||
{ | ||
"wav": {0: "batch_size", 1: "time"}, | ||
"wav_lengths": {0: "batch_size"}, | ||
} | ||
) | ||
|
||
if is_multi_speaker: | ||
dynamic_axes["spks"] = {0: "batch_size"} | ||
|
||
# Create the output directory (if not exists) | ||
Path(args.output).parent.mkdir(parents=True, exist_ok=True) | ||
|
||
model.to_onnx( | ||
args.output, | ||
dummy_input, | ||
input_names=input_names, | ||
output_names=output_names, | ||
dynamic_axes=dynamic_axes, | ||
opset_version=args.opset, | ||
export_params=True, | ||
do_constant_folding=True, | ||
) | ||
print(f"[🍵] ONNX model exported to {args.output}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.