Skip to content

Commit

Permalink
Fix torchscript export to use tokens.txt instead of lang_dir (#1475)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 26, 2024
1 parent c401a26 commit 8d39f95
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 69 deletions.
25 changes: 16 additions & 9 deletions egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
Expand All @@ -20,7 +21,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 29 \
--avg 19
Expand All @@ -45,12 +46,13 @@
import logging
from pathlib import Path

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -85,10 +87,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -122,10 +124,14 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand All @@ -152,6 +158,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
24 changes: 14 additions & 10 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Expand All @@ -47,12 +47,13 @@
import logging
from pathlib import Path

import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -98,10 +99,10 @@ def get_parser():
)

parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -135,12 +136,14 @@ def main():

logging.info(f"device: {device}")

sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down Expand Up @@ -183,6 +186,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
1 change: 1 addition & 0 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
7 changes: 3 additions & 4 deletions egs/librispeech/ASR/lstm_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])

traced_model = torch.jit.trace(decoder_model, (y, need_pad))
# TODO(fangjun): Change the function name since we are actually using
# torch.jit.script instead of torch.jit.trace
traced_model = torch.jit.script(decoder_model)
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def main():

# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
Expand All @@ -45,7 +45,7 @@
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Expand Down Expand Up @@ -87,7 +87,7 @@
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
Expand All @@ -113,7 +113,7 @@
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
Expand Down
31 changes: 15 additions & 16 deletions egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--tokens ./data/lang_char/tokens.txt \
--epoch 30 \
--avg 24 \
--use-averaged-model True
Expand All @@ -50,8 +50,9 @@
import logging
from pathlib import Path

import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
Expand All @@ -60,8 +61,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -118,13 +118,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt.",
)

parser.add_argument(
Expand Down Expand Up @@ -160,13 +157,14 @@ def main():

logging.info(f"device: {device}")

bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)

lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>

logging.info(params)

Expand Down Expand Up @@ -256,6 +254,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
1 change: 1 addition & 0 deletions egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
23 changes: 11 additions & 12 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit 1
Expand All @@ -47,7 +47,7 @@
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit-trace 1
Expand All @@ -63,7 +63,7 @@
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2
Expand Down Expand Up @@ -91,14 +91,14 @@
import logging
from pathlib import Path

import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool


def get_parser():
Expand Down Expand Up @@ -133,10 +133,10 @@ def get_parser():
)

parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -313,10 +313,9 @@ def main():

logging.info(f"device: {device}")

lexicon = Lexicon(params.lang_dir)

params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down
Loading

0 comments on commit 8d39f95

Please sign in to comment.