Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script to convert vits models #355

Merged
merged 12 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions .github/workflows/export-vits-ljspeech-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
name: export-vits-ljspeech-to-onnx

on:
push:
branches:
- master
paths:
- 'scripts/vits/**'
- '.github/workflows/export-vits-ljspeech-to-onnx.yaml'
pull_request:
paths:
- 'scripts/vits/**'
- '.github/workflows/export-vits-ljspeech-to-onnx.yaml'

workflow_dispatch:

concurrency:
group: export-vits-ljspeech-${{ github.ref }}
cancel-in-progress: true

jobs:
export-vits-ljspeech-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: vits ljspeech
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
torch: ["1.13.0"]

steps:
- uses: actions/checkout@v4

- name: Install dependencies
shell: bash
run: |
python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html numpy
python3 -m pip install onnxruntime onnx soundfile
python3 -m pip install scipy cython unidecode phonemizer

# required by phonemizer
# See https://bootphon.github.io/phonemizer/install.html
# To fix the following error: RuntimeError: espeak not installed on your system
#
sudo apt-get install festival espeak-ng mbrola


- name: export vits ljspeech
shell: bash
run: |
cd scripts/vits

echo "Downloading vits"
git clone https://github.com/jaywalnut310/vits
pushd vits/monotonic_align
python3 setup.py build
ls -lh build/
ls -lh build/lib*/
ls -lh build/lib*/*/

cp build/lib*/monotonic_align/core*.so .
sed -i.bak s/.monotonic_align.core/.core/g ./__init__.py
git diff
popd

export PYTHONPATH=$PWD/vits:$PYTHONPATH

echo "Download models"

wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/test.py

python3 ./export-onnx-ljs.py --config vits/configs/ljs_base.json --checkpoint ./pretrained_ljs.pth
python3 ./test.py
ls -lh *.wav

- uses: actions/upload-artifact@v3
with:
name: test-0.wav
path: scripts/vits/test-0.wav

- uses: actions/upload-artifact@v3
with:
name: test-1.wav
path: scripts/vits/test-1.wav

- uses: actions/upload-artifact@v3
with:
name: test-2.wav
path: scripts/vits/test-2.wav
1 change: 1 addition & 0 deletions scripts/vits/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tokens-ljs.txt
Empty file added scripts/vits/__init__.py
Empty file.
213 changes: 213 additions & 0 deletions scripts/vits/export-onnx-ljs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This script converts vits models trained using the LJ Speech dataset.

Usage:

(1) Download vits

cd /Users/fangjun/open-source
git clone https://github.com/jaywalnut310/vits

(2) Download pre-trained models from
https://huggingface.co/csukuangfj/vits-ljs/tree/main

wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth

(3) Run this file

./export-onnx-ljs.py \
--config ~/open-source//vits/configs/ljs_base.json \
--checkpoint ~/open-source/icefall-models/vits-ljs/pretrained_ljs.pth

It will generate the following two files:

$ ls -lh *.onnx
-rw-r--r-- 1 fangjun staff 36M Oct 10 20:48 vits-ljs.int8.onnx
-rw-r--r-- 1 fangjun staff 109M Oct 10 20:48 vits-ljs.onnx
"""
import sys

# Please change this line to point to the vits directory.
# You can download vits from
# https://github.com/jaywalnut310/vits
sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa

import argparse
from pathlib import Path
from typing import Dict, Any

import commons
import onnx
import torch
import utils
from models import SynthesizerTrn
from onnxruntime.quantization import QuantType, quantize_dynamic
from text import text_to_sequence
from text.symbols import symbols
from text.symbols import _punctuation


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="""Path to ljs_base.json.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/ljs_base.json
""",
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="""Path to the checkpoint file.
You can find it at
https://huggingface.co/csukuangfj/vits-ljs/resolve/main/pretrained_ljs.pth

""",
)

return parser.parse_args()


class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=None,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]


def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm


def check_args(args):
assert Path(args.config).is_file(), args.config
assert Path(args.checkpoint).is_file(), args.checkpoint


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.

Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


def generate_tokens():
with open("tokens-ljs.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbols):
f.write(f"{s} {i}\n")
print("Generated tokens-ljs.txt")


@torch.no_grad()
def main():
args = get_args()
check_args(args)

generate_tokens()

hps = utils.get_hparams_from_file(args.config)

net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
)
_ = net_g.eval()

_ = utils.load_checkpoint(args.checkpoint, net_g, None)

x = get_text("Liliana is the most beautiful assistant", hps)
x = x.unsqueeze(0)

x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)

model = OnnxModel(net_g)

opset_version = 13

filename = "vits-ljs.onnx"

torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale", "noise_scale_w"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": "ljspeech",
"language": "English",
"add_blank": int(hps.data.add_blank),
"sample_rate": hps.data.sampling_rate,
"punctuation": " ".join(list(_punctuation)),
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)

print("Generate int8 quantization models")

filename_int8 = "vits-ljs.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)

print(f"Saved to {filename} and {filename_int8}")


if __name__ == "__main__":
main()