Skip to content

Commit

Permalink
Update TTS export-onnx.py scripts for handling variable token counts (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ahazned authored Dec 25, 2023
1 parent c855a58 commit ddd7131
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def forward(
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
Expand All @@ -165,10 +166,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
Expand Down Expand Up @@ -244,6 +247,7 @@ def main():
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
Expand Down
6 changes: 5 additions & 1 deletion egs/vctk/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def forward(
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
Expand All @@ -175,10 +176,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
Expand Down Expand Up @@ -261,6 +264,7 @@ def main():
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
Expand Down

0 comments on commit ddd7131

Please sign in to comment.