Skip to content

Commit

Permalink
Update export-onnx.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Jan 26, 2024
1 parent d922766 commit c606ef5
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
2. Export the model to ONNX
./pruned_transducer_stateless2/export-onnx.py \
--lang-dir $repo/data/lang_char \
--tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
Expand All @@ -48,6 +48,7 @@
from pathlib import Path
from typing import Dict, Tuple

import k2
import onnx
import torch
import torch.nn as nn
Expand All @@ -57,14 +58,8 @@
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import num_tokens, setup_logger, str2bool


def get_parser():
Expand Down Expand Up @@ -110,10 +105,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -397,9 +392,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down

0 comments on commit c606ef5

Please sign in to comment.