Skip to content

Commit

Permalink
Merge branch 'k2-fsa:master' into gigaspeech_streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
wgb14 authored Aug 22, 2022
2 parents 865fd21 + c010118 commit 5ea4d94
Show file tree
Hide file tree
Showing 136 changed files with 11,221 additions and 391 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203,
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc2/*py: E501,
egs/librispeech/ASR/RESULTS.md: E999,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,80 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
popd

log "Test exporting to ONNX format"

./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1

log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1

./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit-trace 1

ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt

log "Decode with ONNX models"

./pruned_transducer_stateless3/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx

./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx

./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

log "Decode with models exported by torch.jit.trace()"

./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

log "Decode with models exported by torch.jit.script()"

./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav


for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ on:

jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -

# install lhotse
RUN pip install torchaudio==0.7.2
RUN pip install git+https://github.com/lhotse-speech/lhotse
#RUN pip install lhotse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

hyps_dict = decode_one_batch(
params=params,
Expand All @@ -379,8 +380,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))

results[name].extend(this_batch)

Expand All @@ -405,6 +406,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")

Expand Down Expand Up @@ -528,6 +530,8 @@ def main():
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset

# we need cut ids to display recognition results.
args.return_cuts = True
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)

dev = "dev"
Expand Down
52 changes: 52 additions & 0 deletions egs/aishell/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,58 @@ We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
to use the pre-trained model for non-streaming ASR. See
<https://k2-fsa.github.io/sherpa/offline_asr/conformer/aishell.html>


#### Pruned transducer stateless 2

See https://github.com/k2-fsa/icefall/pull/536

[./pruned_transducer_stateless2](./pruned_transducer_stateless2)

It uses pruned RNN-T.

| | test | dev | comment |
| -------------------- | ---- | ---- | -------------------------------------- |
| greedy search | 5.20 | 4.78 | --epoch 72 --avg 14 --max-duration 200 |
| modified beam search | 5.07 | 4.63 | --epoch 72 --avg 14 --max-duration 200 |
| fast beam search | 5.13 | 4.70 | --epoch 72 --avg 14 --max-duration 200 |

Training command is:

```bash
./prepare.sh

export CUDA_VISIBLE_DEVICES="0,1"

./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--max-duration 200 \
```

The tensorboard log is available at
https://tensorboard.dev/experiment/QI3PVzrGRrebxpbWUPwmkA/

The decoding command is:
```bash
for m in greedy_search modified_beam_search fast_beam_search ; do
./pruned_transducer_stateless2/decode.py \
--epoch 72 \
--avg 14 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 200 \
--decoding-method $m

done
```

Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/teapoly/icefall-aishell-pruned-transducer-stateless2-2022-08-18>


#### 2022-03-01

[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
Expand Down
12 changes: 9 additions & 3 deletions egs/aishell/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

hyps_dict = decode_one_batch(
params=params,
Expand All @@ -389,9 +390,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))

results[lm_scale].extend(this_batch)

Expand Down Expand Up @@ -419,6 +420,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
Expand All @@ -429,7 +431,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
Expand Down Expand Up @@ -537,6 +541,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
Expand Down
12 changes: 9 additions & 3 deletions egs/aishell/ASR/conformer_mmi/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

hyps_dict = decode_one_batch(
params=params,
Expand All @@ -401,9 +402,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))

results[lm_scale].extend(this_batch)

Expand Down Expand Up @@ -431,6 +432,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
Expand All @@ -441,7 +443,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
Expand Down Expand Up @@ -556,6 +560,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
Expand Down
1 change: 1 addition & 0 deletions egs/aishell/ASR/pruned_transducer_stateless2/conformer.py
Loading

0 comments on commit 5ea4d94

Please sign in to comment.