Skip to content

Commit

Permalink
Add decoding with H and HL for LibriSpeech
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 26, 2023
1 parent 2d70677 commit 9384d63
Show file tree
Hide file tree
Showing 14 changed files with 674 additions and 116 deletions.
43 changes: 43 additions & 0 deletions .github/scripts/run-pre-trained-conformer-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,46 @@ log "HLG decoding"
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac

log "CTC decoding on CPU with kaldi decoders using OpenFst"

log "Exporting model with torchscript"

pushd $repo/exp
ln -s pretrained.pt epoch-99.pt
popd

./conformer_ctc/export.py \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--jit 1

ls -lh $repo/exp


log "Generating H.fst, HL.fst"

./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500
ls -lh $repo/data/lang_bpe_500

log "Decoding with H on CPU with OpenFst"

./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model $repo/exp/cpu_jit.pt \
--H $repo/data/lang_bpe_500/H.fst \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac

log "Decoding with HL on CPU with OpenFst"

./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model $repo/exp/cpu_jit.pt \
--HL $repo/data/lang_bpe_500/HL.fst \
--words $repo/data/lang_bpe_500/words.txt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac
2 changes: 1 addition & 1 deletion .github/workflows/run-pretrained-conformer-ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ concurrency:

jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc'
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down
221 changes: 221 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This file shows how to use a torchscript model for decoding with H
on CPU using OpenFST and decoders from kaldi.
Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./cpu_jit.pt \
--H ./data/lang_bpe_500/H.fst \
--tokens ./data/lang_bpe_500/tokens.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./tdnn/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""

import argparse
import logging
import math
from typing import Dict, List

import kaldi_hmm_gmm
import kaldifeat
import kaldifst
import torch
import torchaudio
from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)

parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)

parser.add_argument("--H", type=str, required=True, help="Path to H.fst")

parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)

return parser


def read_tokens(tokens_txt: str) -> Dict[int, str]:
id2token = dict()
with open(tokens_txt, encoding="utf-8") as f:
for line in f:
token, idx = line.strip().split()
id2token[int(idx)] = token

return id2token


def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)

# We use only the first channel
ans.append(wave[0].contiguous())
return ans


def decode(
filename: str,
nnet_output: torch.Tensor,
H: kaldifst,
id2token: Dict[int, str],
) -> List[str]:
logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output)

decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts)
decoder.decode(decodable)

if not decoder.reached_final():
print(f"failed to decode {filename}")
return ""

ok, best_path = decoder.get_best_path()

(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return ""

# tokens are incremented during graph construction
# so they need to be decremented
hyps = [id2token[i - 1] for i in osymbols_out]
# hyps = "".join(hyps).split("▁")
hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁

return hyps


@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()

device = torch.device("cpu")

logging.info(f"device: {device}")

logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)

logging.info(f"Loading H from {args.H}")
H = kaldifst.StdVectorFst.read(args.H)

sample_rate = 16000

logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80

fbank = kaldifeat.Fbank(opts)

logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=sample_rate
)
waves = [w.to(device) for w in waves]

logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.shape[0] for f in features]
feature_lengths = torch.tensor(feature_lengths)

supervisions = dict()
supervisions["sequence_idx"] = torch.arange(len(features))
supervisions["start_frame"] = torch.zeros(len(features))
supervisions["num_frames"] = feature_lengths

features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))

nnet_output, _, _ = model(features, supervisions)
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2

id2token = read_tokens(args.tokens)

hyps = []
for i in range(nnet_output.shape[0]):
hyp = decode(
filename=args.sound_files[i],
nnet_output=nnet_output[i, : feature_lengths[i]],
H=H,
id2token=id2token,
)
hyps.append(hyp)

s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)

logging.info("Decoding Done")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
main()
Loading

0 comments on commit 9384d63

Please sign in to comment.