diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 36a9de27f0..f82f9dbe9b 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -149,6 +149,7 @@ def forward( def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -165,10 +166,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -244,6 +247,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 667ac284b3..80d1556261 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -159,6 +159,7 @@ def forward( def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -175,10 +176,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -261,6 +264,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}")