diff --git a/.github/scripts/.gitignore b/.github/scripts/.gitignore
new file mode 100644
index 0000000000..672e477d8d
--- /dev/null
+++ b/.github/scripts/.gitignore
@@ -0,0 +1 @@
+piper_phonemize.html
diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py
new file mode 100755
index 0000000000..3784d5fa58
--- /dev/null
+++ b/.github/scripts/generate-piper-phonemize-page.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+
+
+def main():
+ prefix = (
+ "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
+ )
+ files = [
+ "piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ ]
+ with open("piper_phonemize.html", "w") as f:
+ for file in files:
+ url = prefix + file
+ f.write(f'{file}
\n')
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh
index 7e9bd8a478..293ed66e53 100755
--- a/.github/scripts/librispeech/ASR/run.sh
+++ b/.github/scripts/librispeech/ASR/run.sh
@@ -15,9 +15,9 @@ function prepare_data() {
# cause OOM error for CI later.
mkdir -p download/lm
pushd download/lm
- wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
- wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
- wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
ls -lh
gunzip librispeech-lm-norm.txt.gz
diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh
new file mode 100755
index 0000000000..707361782f
--- /dev/null
+++ b/.github/scripts/ljspeech/TTS/run.sh
@@ -0,0 +1,157 @@
+#!/usr/bin/env bash
+
+set -ex
+
+python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
+python3 -m pip install espnet_tts_frontend
+python3 -m pip install numba
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/ljspeech/TTS
+
+sed -i.bak s/600/8/g ./prepare.sh
+sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
+sed -i.bak s/500/5/g ./prepare.sh
+git diff
+
+function prepare_data() {
+ # We have created a subset of the data for testing
+ #
+ mkdir download
+ pushd download
+ wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
+ tar xvf LJSpeech-1.1.tar.bz2
+ popd
+
+ ./prepare.sh
+ tree .
+}
+
+function train() {
+ pushd ./vits
+ sed -i.bak s/200/3/g ./train.py
+ git diff .
+ popd
+
+ for t in low medium high; do
+ ./vits/train.py \
+ --exp-dir vits/exp-$t \
+ --model-type $t \
+ --num-epochs 1 \
+ --save-every-n 1 \
+ --num-buckets 2 \
+ --tokens data/tokens.txt \
+ --max-duration 20
+
+ ls -lh vits/exp-$t
+ done
+}
+
+function infer() {
+ for t in low medium high; do
+ ./vits/infer.py \
+ --num-buckets 2 \
+ --model-type $t \
+ --epoch 1 \
+ --exp-dir ./vits/exp-$t \
+ --tokens data/tokens.txt \
+ --max-duration 20
+ done
+}
+
+function export_onnx() {
+ for t in low medium high; do
+ ./vits/export-onnx.py \
+ --model-type $t \
+ --epoch 1 \
+ --exp-dir ./vits/exp-$t \
+ --tokens data/tokens.txt
+
+ ls -lh vits/exp-$t/
+ done
+}
+
+function test_medium() {
+ git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
+
+ ./vits/export-onnx.py \
+ --model-type medium \
+ --epoch 820 \
+ --exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
+ --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
+
+ ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
+
+ ./vits/test_onnx.py \
+ --model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
+ --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
+ --output-filename /icefall/test-medium.wav
+
+ ls -lh /icefall/test-medium.wav
+
+ d=/icefall/vits-icefall-en_US-ljspeech-medium
+ mkdir $d
+ cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
+ cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
+
+ rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
+
+ pushd $d
+ wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+ rm espeak-ng-data.tar.bz2
+ cd ..
+ tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
+ rm -rf vits-icefall-en_US-ljspeech-medium
+ ls -lh *.tar.bz2
+ popd
+}
+
+function test_low() {
+ git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
+
+ ./vits/export-onnx.py \
+ --model-type low \
+ --epoch 1600 \
+ --exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
+ --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
+
+ ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
+
+ ./vits/test_onnx.py \
+ --model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
+ --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
+ --output-filename /icefall/test-low.wav
+
+ ls -lh /icefall/test-low.wav
+
+ d=/icefall/vits-icefall-en_US-ljspeech-low
+ mkdir $d
+ cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
+ cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
+
+ rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
+
+ pushd $d
+ wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+ rm espeak-ng-data.tar.bz2
+ cd ..
+ tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
+ rm -rf vits-icefall-en_US-ljspeech-low
+ ls -lh *.tar.bz2
+ popd
+}
+
+prepare_data
+train
+infer
+export_onnx
+rm -rf vits/exp-{low,medium,high}
+test_medium
+test_low
diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
index d7fe2c9643..c622476f2d 100644
--- a/.github/workflows/build-doc.yml
+++ b/.github/workflows/build-doc.yml
@@ -56,11 +56,14 @@ jobs:
- name: Build doc
shell: bash
run: |
+ .github/scripts/generate-piper-phonemize-page.py
cd docs
python3 -m pip install -r ./requirements.txt
make html
touch build/html/.nojekyll
+ cp -v ../piper_phonemize.html ./build/html/
+
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml
new file mode 100644
index 0000000000..25402275b4
--- /dev/null
+++ b/.github/workflows/ljspeech.yml
@@ -0,0 +1,102 @@
+name: ljspeech
+
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+concurrency:
+ group: ljspeech-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
+
+ ljspeech:
+ needs: generate_build_matrix
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ ls -lh
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
+
+ - name: Run tests
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
+
+ .github/scripts/ljspeech/TTS/run.sh
+
+ - name: display files
+ shell: bash
+ run: |
+ ls -lh
+
+ - uses: actions/upload-artifact@v4
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
+ with:
+ name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
+ path: ./*.wav
+
+ - uses: actions/upload-artifact@v4
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
+ with:
+ name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
+ path: ./*.wav
+
+ - name: Release exported onnx models
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
+ uses: svenstaro/upload-release-action@v2
+ with:
+ file_glob: true
+ overwrite: true
+ file: vits-icefall-*.tar.bz2
+ repo_name: k2-fsa/sherpa-onnx
+ repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
+ tag: tts-models
+
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
index 323d0adfc8..9499a3aea2 100644
--- a/docs/source/recipes/TTS/ljspeech/vits.rst
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -13,6 +13,14 @@ with the `LJSpeech `_ dataset.
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+Install extra dependencies
+--------------------------
+
+.. code-block:: bash
+
+ pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
+ pip install numba espnet_tts_frontend
+
Data preparation
----------------
@@ -56,7 +64,8 @@ Training
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
- --tokens data/tokens.txt
+ --tokens data/tokens.txt \
+ --model-type high \
--max-duration 500
.. note::
@@ -64,6 +73,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.
+.. warning::
+
+ If you want a model that runs faster on CPU, please use ``--model-type low``
+ or ``--model-type medium``.
+
.. note::
The training can take a long time (usually a couple of days).
@@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models
-------------
-Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
-``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
+``vits-epoch-*.onnx``.
.. code-block:: bash
@@ -120,4 +134,68 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- - ``_
+ - ``--model-type=high``: ``_
+ - ``--model-type=medium``: ``_
+ - ``--model-type=low``: ``_
+
+Usage in sherpa-onnx
+--------------------
+
+The following describes how to test the exported ONNX model in `sherpa-onnx`_.
+
+.. hint::
+
+ `sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
+ Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
+
+ We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
+ Please refer to ``_
+ for more documentation.
+
+Install sherpa-onnx
+^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ pip install sherpa-onnx
+
+To check that you have installed `sherpa-onnx`_ successfully, please run:
+
+.. code-block:: bash
+
+ which sherpa-onnx-offline-tts
+ sherpa-onnx-offline-tts --help
+
+Download lexicon files
+^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ cd /tmp
+ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+
+Run sherpa-onnx
+^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ cd egs/ljspeech/TTS
+
+ sherpa-onnx-offline-tts \
+ --vits-model=vits/exp/vits-epoch-1000.onnx \
+ --vits-tokens=data/tokens.txt \
+ --vits-data-dir=/tmp/espeak-ng-data \
+ --num-threads=1 \
+ --output-filename=./high.wav \
+ "Ask not what your country can do for you; ask what you can do for your country."
+
+.. hint::
+
+ You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
+ as it is generating.
+
+You should get a file ``high.wav`` after running the above command.
+
+Congratulations! You have successfully trained and exported a text-to-speech
+model and run it with `sherpa-onnx`_.
diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md
index b547191625..d088072a71 100644
--- a/egs/aishell/ASR/README.md
+++ b/egs/aishell/ASR/README.md
@@ -19,7 +19,9 @@ The following table lists the differences among them.
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
-| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
+| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
+| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
+
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index b7be89bc8e..13be69534f 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 11: Train RNN LM model"
+ log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 1 \
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index 301ab01117..996a1da2d4 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -15,7 +15,7 @@ perturb_speed=true
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
-# https://openslr.org/62/
+# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz
diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh
index 1098840f85..15c20692da 100755
--- a/egs/alimeeting/ASR_v2/prepare.sh
+++ b/egs/alimeeting/ASR_v2/prepare.sh
@@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
-# https://openslr.org/62/
+# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
index da8e620349..91220bd112 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
@@ -232,7 +232,7 @@ def train_dataloaders(
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
index e23c9b1b7e..4985f3f4c3 100644
--- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py
+++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
@@ -232,7 +232,7 @@ def train_dataloaders(
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md
index 80be5a3155..7b112c12c8 100644
--- a/egs/ljspeech/TTS/README.md
+++ b/egs/ljspeech/TTS/README.md
@@ -1,10 +1,10 @@
# Introduction
-This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
-A transcription is provided for each clip.
+This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
+A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
-The texts were published between 1884 and 1964, and are in the public domain.
+The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
@@ -35,4 +35,69 @@ To inference, use:
--exp-dir vits/exp \
--epoch 1000 \
--tokens data/tokens.txt
-```
\ No newline at end of file
+```
+
+## Quality vs speed
+
+If you feel that the trained model is slow at runtime, you can specify the
+argument `--model-type` during training. Possible values are:
+
+ - `low`, means **low** quality. The resulting model is very small in file size
+ and runs very fast. The following is a wave file generatd by a `low` quality model
+
+ https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
+
+ The text is `Ask not what your country can do for you; ask what you can do for your country.`
+
+ The exported onnx model has a file size of ``26.8 MB`` (float32).
+
+ - `medium`, means **medium** quality.
+ The following is a wave file generatd by a `medium` quality model
+
+ https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac
+
+ The text is `Ask not what your country can do for you; ask what you can do for your country.`
+
+ The exported onnx model has a file size of ``70.9 MB`` (float32).
+
+ - `high`, means **high** quality. This is the default value.
+
+ The following is a wave file generatd by a `high` quality model
+
+ https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc
+
+ The text is `Ask not what your country can do for you; ask what you can do for your country.`
+
+ The exported onnx model has a file size of ``113 MB`` (float32).
+
+
+A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
+
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1601 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --model-type low \
+ --max-duration 800
+```
+
+A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
+
+```bash
+export CUDA_VISIBLE_DEVICES=4,5,6,7
+./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp-medium \
+ --model-type medium \
+ --max-duration 500
+
+# (Note it is killed after `epoch-820.pt`)
+```
diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
index 08fe7430ef..4ba88604ce 100755
--- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
+++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
@@ -23,7 +23,11 @@
import logging
from pathlib import Path
-import tacotron_cleaner.cleaners
+try:
+ import tacotron_cleaner.cleaners
+except ModuleNotFoundError as ex:
+ raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n")
+
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
index cbf27bd423..9ed0f93fde 100755
--- a/egs/ljspeech/TTS/prepare.sh
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -28,7 +28,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
- python setup.py build_ext --inplace
+ python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
@@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
- # to $dl_dir/LJSpeech
+ # to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
@@ -82,8 +82,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
- # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
- # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
+ # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index 58b1663684..0740757c06 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -25,9 +25,8 @@
--exp-dir vits/exp \
--tokens data/tokens.txt
-It will generate two files inside vits/exp:
+It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- - vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
@@ -40,7 +39,6 @@
import onnx
import torch
import torch.nn as nn
-from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
@@ -75,6 +73,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
+ parser.add_argument(
+ "--model-type",
+ type=str,
+ default="high",
+ choices=["low", "medium", "high"],
+ help="""If not empty, valid values are: low, medium, high.
+ It controls the model size. low -> runs faster.
+ """,
+ )
+
return parser
@@ -136,7 +144,7 @@ def forward(
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
- audio, _, _ = self.model.inference(
+ audio, _, _ = self.model.generator.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
@@ -198,6 +206,11 @@ def export_model_onnx(
},
)
+ if model.model.spks is None:
+ num_speakers = 1
+ else:
+ num_speakers = model.model.spks
+
meta_data = {
"model_type": "vits",
"version": "1",
@@ -206,8 +219,8 @@ def export_model_onnx(
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
- "n_speakers": 1,
- "sample_rate": 22050, # Must match the real sample rate
+ "n_speakers": num_speakers,
+ "sample_rate": model.model.sampling_rate, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@@ -233,14 +246,13 @@ def main():
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
- model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
- logging.info(f"generator parameters: {num_param}")
+ logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}"
@@ -256,18 +268,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")
- # Generate int8 quantization models
- # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
-
- logging.info("Generate int8 quantization models")
-
- model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
- quantize_dynamic(
- model_input=model_filename,
- model_output=model_filename_int8,
- weight_type=QuantType.QUInt8,
- )
-
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
index 66c8cedb19..b9add9e828 100644
--- a/egs/ljspeech/TTS/vits/generator.py
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -189,7 +189,7 @@ def __init__(
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
- assert global_channels > 0
+ assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
index 9e7c71c6dc..7be76e3151 100755
--- a/egs/ljspeech/TTS/vits/infer.py
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -72,6 +72,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
+ parser.add_argument(
+ "--model-type",
+ type=str,
+ default="high",
+ choices=["low", "medium", "high"],
+ help="""If not empty, valid values are: low, medium, high.
+ It controls the model size. low -> runs faster.
+ """,
+ )
+
return parser
@@ -94,6 +104,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""
+
# Background worker save audios to disk.
def _save_worker(
batch_size: int,
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
index 2b35654f51..5dc3641e59 100644
--- a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
+++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
@@ -10,7 +10,11 @@
import numpy as np
import torch
-from numba import njit, prange
+
+try:
+ from numba import njit, prange
+except ModuleNotFoundError as ex:
+ raise RuntimeError(f"{ex}/nPlease run\n pip install numba")
try:
from .core import maximum_path_c
diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py
new file mode 100755
index 0000000000..1de10f012b
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/test_model.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tokenizer import Tokenizer
+from train import get_model, get_params
+from vits import VITS
+
+
+def test_model_type(model_type):
+ tokens = "./data/tokens.txt"
+
+ params = get_params()
+
+ tokenizer = Tokenizer(tokens)
+ params.blank_id = tokenizer.pad_id
+ params.vocab_size = tokenizer.vocab_size
+ params.model_type = model_type
+
+ model = get_model(params)
+ generator = model.generator
+
+ num_param = sum([p.numel() for p in generator.parameters()])
+ print(
+ f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
+ )
+
+
+def main():
+ test_model_type("high") # 35.63 M
+ test_model_type("low") # 7.55 M
+ test_model_type("medium") # 23.61 M
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
index 4f46e8e6c5..b3805fadb3 100755
--- a/egs/ljspeech/TTS/vits/test_onnx.py
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -54,6 +54,20 @@ def get_parser():
help="""Path to vocabulary.""",
)
+ parser.add_argument(
+ "--text",
+ type=str,
+ default="Ask not what your country can do for you; ask what you can do for your country.",
+ help="Text to generate speech for",
+ )
+
+ parser.add_argument(
+ "--output-filename",
+ type=str,
+ default="test_onnx.wav",
+ help="Filename to save the generated wave file.",
+ )
+
return parser
@@ -61,7 +75,7 @@ class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
- session_opts.intra_op_num_threads = 4
+ session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
@@ -72,6 +86,9 @@ def __init__(self, model_filename: str):
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+ metadata = self.model.get_modelmeta().custom_metadata_map
+ self.sample_rate = int(metadata["sample_rate"])
+
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
@@ -101,13 +118,14 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten
def main():
args = get_parser().parse_args()
+ logging.info(vars(args))
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
- text = "I went there to see the land, the people and how their system works, end quote."
+ text = args.text
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
@@ -115,8 +133,9 @@ def main():
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
- torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
- logging.info("Saved to test_onnx.wav")
+ output_filename = args.output_filename
+ torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
+ logging.info(f"Saved to {output_filename}")
if __name__ == "__main__":
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
index fcbae7103f..9b21ed9cb0 100644
--- a/egs/ljspeech/TTS/vits/text_encoder.py
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -92,9 +92,9 @@ def forward(
x_lengths (Tensor): Length tensor (B,).
Returns:
- Tensor: Encoded hidden representation (B, attention_dim, T_text).
- Tensor: Projected mean tensor (B, attention_dim, T_text).
- Tensor: Projected scale tensor (B, attention_dim, T_text).
+ Tensor: Encoded hidden representation (B, embed_dim, T_text).
+ Tensor: Projected mean tensor (B, embed_dim, T_text).
+ Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
@@ -108,6 +108,7 @@ def forward(
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
+ # Note: attention_dim == embed_dim
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
index 9a5a9090ec..3c9046adde 100644
--- a/egs/ljspeech/TTS/vits/tokenizer.py
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -18,7 +18,15 @@
from typing import Dict, List
import tacotron_cleaner.cleaners
-from piper_phonemize import phonemize_espeak
+
+try:
+ from piper_phonemize import phonemize_espeak
+except Exception as ex:
+ raise RuntimeError(
+ f"{ex}\nPlease run\n"
+ "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
+ )
+
from utils import intersperse
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
index 6589b75ff6..34b943765a 100755
--- a/egs/ljspeech/TTS/vits/train.py
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -153,6 +153,16 @@ def get_parser():
help="Whether to use half precision training.",
)
+ parser.add_argument(
+ "--model-type",
+ type=str,
+ default="high",
+ choices=["low", "medium", "high"],
+ help="""If not empty, valid values are: low, medium, high.
+ It controls the model size. low -> runs faster.
+ """,
+ )
+
return parser
@@ -189,15 +199,6 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used
in computing features.
-
- - subsampling_factor: The subsampling factor for the model.
-
- - encoder_dim: Hidden dim for multi-head attention model.
-
- - num_decoder_layers: Number of decoder layer of transformer decoder.
-
- - warm_step: The warmup period that dictates the decay of the
- scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
@@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
+ model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
@@ -363,7 +365,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
- # used to summary the stats over iterations in one epoch
+ # used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
index 8ff868bc8b..e1a9c7b3ca 100644
--- a/egs/ljspeech/TTS/vits/tts_datamodule.py
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -255,6 +255,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
@@ -294,6 +295,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
index b4f0c21e6d..0b9575cbde 100644
--- a/egs/ljspeech/TTS/vits/vits.py
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -5,6 +5,7 @@
"""VITS module for GAN-TTS task."""
+import copy
from typing import Any, Dict, Optional, Tuple
import torch
@@ -38,6 +39,36 @@
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
+LOW_CONFIG = {
+ "hidden_channels": 96,
+ "decoder_upsample_scales": (8, 8, 4),
+ "decoder_channels": 256,
+ "decoder_upsample_kernel_sizes": (16, 16, 8),
+ "decoder_resblock_kernel_sizes": (3, 5, 7),
+ "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
+ "text_encoder_cnn_module_kernel": 3,
+}
+
+MEDIUM_CONFIG = {
+ "hidden_channels": 192,
+ "decoder_upsample_scales": (8, 8, 4),
+ "decoder_channels": 256,
+ "decoder_upsample_kernel_sizes": (16, 16, 8),
+ "decoder_resblock_kernel_sizes": (3, 5, 7),
+ "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
+ "text_encoder_cnn_module_kernel": 3,
+}
+
+HIGH_CONFIG = {
+ "hidden_channels": 192,
+ "decoder_upsample_scales": (8, 8, 2, 2),
+ "decoder_channels": 512,
+ "decoder_upsample_kernel_sizes": (16, 16, 4, 4),
+ "decoder_resblock_kernel_sizes": (3, 7, 11),
+ "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ "text_encoder_cnn_module_kernel": 5,
+}
+
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
@@ -49,6 +80,7 @@ def __init__(
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
+ model_type: str = "",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
@@ -155,12 +187,13 @@ def __init__(
"""Initialize VITS module.
Args:
- idim (int): Input vocabrary size.
+ idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
+ model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
@@ -181,6 +214,24 @@ def __init__(
"""
super().__init__()
+ generator_params = copy.deepcopy(generator_params)
+ discriminator_params = copy.deepcopy(discriminator_params)
+ generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
+ discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
+ feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
+ mel_loss_params = copy.deepcopy(mel_loss_params)
+
+ if model_type != "":
+ assert model_type in ("low", "medium", "high"), model_type
+ if model_type == "low":
+ generator_params.update(LOW_CONFIG)
+ elif model_type == "medium":
+ generator_params.update(MEDIUM_CONFIG)
+ elif model_type == "high":
+ generator_params.update(HIGH_CONFIG)
+ else:
+ raise ValueError(f"Unknown model_type: ${model_type}")
+
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
diff --git a/egs/mdcc/ASR/README.md b/egs/mdcc/ASR/README.md
new file mode 100644
index 0000000000..112845b734
--- /dev/null
+++ b/egs/mdcc/ASR/README.md
@@ -0,0 +1,19 @@
+# Introduction
+
+Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with
+transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy,
+politics, education, culture, lifestyle and family domains, covering a wide range of topics.
+
+Manuscript can be found at: https://arxiv.org/abs/2201.02419
+
+# Transducers
+
+
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|-----------------------------|
+| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
+
+The decoder is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/mdcc/ASR/RESULTS.md b/egs/mdcc/ASR/RESULTS.md
new file mode 100644
index 0000000000..ff7ddc9579
--- /dev/null
+++ b/egs/mdcc/ASR/RESULTS.md
@@ -0,0 +1,41 @@
+## Results
+
+#### Zipformer
+
+See
+
+[./zipformer](./zipformer)
+
+##### normal-scaled model, number of model parameters: 74470867, i.e., 74.47 M
+
+| | test | valid | comment |
+|------------------------|------|-------|-----------------------------------------|
+| greedy search | 7.45 | 7.51 | --epoch 45 --avg 35 |
+| modified beam search | 6.68 | 6.73 | --epoch 45 --avg 35 |
+| fast beam search | 7.22 | 7.28 | --epoch 45 --avg 35 |
+
+The training command:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer/train.py \
+ --world-size 4 \
+ --start-epoch 1 \
+ --num-epochs 50 \
+ --use-fp16 1 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 1000
+```
+
+The decoding command:
+
+```
+ ./zipformer/decode.py \
+ --epoch 45 \
+ --avg 35 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search # modified_beam_search
+```
+
+The pretrained model is available at: https://huggingface.co/zrjin/icefall-asr-mdcc-zipformer-2024-03-11/
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/compile_hlg.py b/egs/mdcc/ASR/local/compile_hlg.py
new file mode 120000
index 0000000000..471aa7fb40
--- /dev/null
+++ b/egs/mdcc/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/compile_hlg_using_openfst.py b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py
new file mode 120000
index 0000000000..d34edd7f30
--- /dev/null
+++ b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg_using_openfst.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/compile_lg.py b/egs/mdcc/ASR/local/compile_lg.py
new file mode 120000
index 0000000000..462d6d3fb9
--- /dev/null
+++ b/egs/mdcc/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/compute_fbank_mdcc.py b/egs/mdcc/ASR/local/compute_fbank_mdcc.py
new file mode 100755
index 0000000000..647b211270
--- /dev/null
+++ b/egs/mdcc/ASR/local/compute_fbank_mdcc.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aishell dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor, str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_mdcc(
+ num_mel_bins: int = 80,
+ perturb_speed: bool = False,
+ whisper_fbank: bool = False,
+ output_dir: str = "data/fbank",
+):
+ src_dir = Path("data/manifests")
+ output_dir = Path(output_dir)
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train",
+ "valid",
+ "test",
+ )
+ prefix = "mdcc"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition and perturb_speed:
+ logging.info("Doing speed perturb")
+ cut_set = (
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
+ )
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory. Default: data/fbank.",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_mdcc(
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
+ output_dir=args.output_dir,
+ )
diff --git a/egs/mdcc/ASR/local/display_manifest_statistics.py b/egs/mdcc/ASR/local/display_manifest_statistics.py
new file mode 100755
index 0000000000..27cf8c9439
--- /dev/null
+++ b/egs/mdcc/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/fbank/mdcc_cuts_train.jsonl.gz"
+ path = "./data/fbank/mdcc_cuts_valid.jsonl.gz"
+ path = "./data/fbank/mdcc_cuts_test.jsonl.gz"
+
+ cuts = load_manifest_lazy(path)
+ cuts.describe(full=True)
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+data/fbank/mdcc_cuts_train.jsonl.gz (with speed perturbation)
+_________________________________________
+_ Cuts count: _ 195360
+_________________________________________
+_ Total duration (hh:mm:ss) _ 173:44:59
+_________________________________________
+_ mean _ 3.2
+_________________________________________
+_ std _ 2.1
+_________________________________________
+_ min _ 0.2
+_________________________________________
+_ 25% _ 1.8
+_________________________________________
+_ 50% _ 2.7
+_________________________________________
+_ 75% _ 4.0
+_________________________________________
+_ 99% _ 11.0 _
+_________________________________________
+_ 99.5% _ 12.4 _
+_________________________________________
+_ 99.9% _ 14.8 _
+_________________________________________
+_ max _ 16.7 _
+_________________________________________
+_ Recordings available: _ 195360 _
+_________________________________________
+_ Features available: _ 195360 _
+_________________________________________
+_ Supervisions available: _ 195360 _
+_________________________________________
+
+data/fbank/mdcc_cuts_valid.jsonl.gz
+________________________________________
+_ Cuts count: _ 5663 _
+________________________________________
+_ Total duration (hh:mm:ss) _ 05:03:12 _
+________________________________________
+_ mean _ 3.2 _
+________________________________________
+_ std _ 2.0 _
+________________________________________
+_ min _ 0.3 _
+________________________________________
+_ 25% _ 1.8 _
+________________________________________
+_ 50% _ 2.7 _
+________________________________________
+_ 75% _ 4.0 _
+________________________________________
+_ 99% _ 10.9 _
+________________________________________
+_ 99.5% _ 12.3 _
+________________________________________
+_ 99.9% _ 14.4 _
+________________________________________
+_ max _ 14.8 _
+________________________________________
+_ Recordings available: _ 5663 _
+________________________________________
+_ Features available: _ 5663 _
+________________________________________
+_ Supervisions available: _ 5663 _
+________________________________________
+
+data/fbank/mdcc_cuts_test.jsonl.gz
+________________________________________
+_ Cuts count: _ 12492 _
+________________________________________
+_ Total duration (hh:mm:ss) _ 11:00:31 _
+________________________________________
+_ mean _ 3.2 _
+________________________________________
+_ std _ 2.0 _
+________________________________________
+_ min _ 0.2 _
+________________________________________
+_ 25% _ 1.8 _
+________________________________________
+_ 50% _ 2.7 _
+________________________________________
+_ 75% _ 4.0 _
+________________________________________
+_ 99% _ 10.5 _
+________________________________________
+_ 99.5% _ 12.1 _
+________________________________________
+_ 99.9% _ 14.0 _
+________________________________________
+_ max _ 14.8 _
+________________________________________
+_ Recordings available: _ 12492 _
+________________________________________
+_ Features available: _ 12492 _
+________________________________________
+_ Supervisions available: _ 12492 _
+________________________________________
+
+"""
diff --git a/egs/mdcc/ASR/local/prepare_char.py b/egs/mdcc/ASR/local/prepare_char.py
new file mode 120000
index 0000000000..42743b5449
--- /dev/null
+++ b/egs/mdcc/ASR/local/prepare_char.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/prepare_char_lm_training_data.py b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py
new file mode 120000
index 0000000000..2374cafddb
--- /dev/null
+++ b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_char_lm_training_data.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/prepare_lang.py b/egs/mdcc/ASR/local/prepare_lang.py
new file mode 120000
index 0000000000..bee8d5f036
--- /dev/null
+++ b/egs/mdcc/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/prepare_lang_fst.py b/egs/mdcc/ASR/local/prepare_lang_fst.py
new file mode 120000
index 0000000000..c5787c5340
--- /dev/null
+++ b/egs/mdcc/ASR/local/prepare_lang_fst.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_fst.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/local/preprocess_mdcc.py b/egs/mdcc/ASR/local/preprocess_mdcc.py
new file mode 100755
index 0000000000..cd0dc7de82
--- /dev/null
+++ b/egs/mdcc/ASR/local/preprocess_mdcc.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes a text file "data/lang_char/text" as input, the file consist of
+lines each containing a transcript, applies text norm and generates the following
+files in the directory "data/lang_char":
+ - text_norm
+ - words.txt
+ - words_no_ids.txt
+ - text_words_segmentation
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List
+
+import pycantonese
+from tqdm.auto import tqdm
+
+from icefall.utils import is_cjk
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Prepare char lexicon",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ "-i",
+ default="data/lang_char/text",
+ type=str,
+ help="The input text file",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default="data/lang_char",
+ type=str,
+ help="The output directory",
+ )
+ return parser
+
+
+def get_norm_lines(lines: List[str]) -> List[str]:
+ def _text_norm(text: str) -> str:
+ # to cope with the protocol for transcription:
+ # When taking notes, the annotators adhere to the following guidelines:
+ # 1) If the audio contains pure music, the annotators mark the label
+ # "(music)" in the file name of its transcript. 2) If the utterance
+ # contains one or several sentences with background music or noise, the
+ # annotators mark the label "(music)" before each sentence in the transcript.
+ # 3) The annotators use {} symbols to enclose words they are uncertain
+ # about, for example, {梁佳佳},我是{}人.
+
+ # here we manually fix some errors in the transcript
+
+ return (
+ text.strip()
+ .replace("(music)", "")
+ .replace("(music", "")
+ .replace("{", "")
+ .replace("}", "")
+ .replace("BB所以就指腹為親喇", "BB 所以就指腹為親喇")
+ .upper()
+ )
+
+ return [_text_norm(line) for line in lines]
+
+
+def get_word_segments(lines: List[str]) -> List[str]:
+ # the current pycantonese segmenter does not handle the case when the input
+ # is code switching, so we need to handle it separately
+
+ new_lines = []
+
+ for line in tqdm(lines, desc="Segmenting lines"):
+ try:
+ # code switching
+ if len(line.strip().split(" ")) > 1:
+ segments = []
+ for segment in line.strip().split(" "):
+ if segment.strip() == "":
+ continue
+ try:
+ if not is_cjk(segment[0]): # en segment
+ segments.append(segment)
+ else: # zh segment
+ segments.extend(pycantonese.segment(segment))
+ except Exception as e:
+ logging.error(f"Failed to process segment: {segment}")
+ raise e
+ new_lines.append(" ".join(segments) + "\n")
+ # not code switching
+ else:
+ new_lines.append(" ".join(pycantonese.segment(line)) + "\n")
+ except Exception as e:
+ logging.error(f"Failed to process line: {line}")
+ raise e
+ return new_lines
+
+
+def get_words(lines: List[str]) -> List[str]:
+ words = set()
+ for line in tqdm(lines, desc="Getting words"):
+ words.update(line.strip().split(" "))
+ return list(words)
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = Path(args.input_file)
+ output_dir = Path(args.output_dir)
+
+ assert output_dir.is_dir(), f"{output_dir} does not exist"
+ assert input_file.is_file(), f"{input_file} does not exist"
+
+ lines = input_file.read_text(encoding="utf-8").strip().split("\n")
+
+ norm_lines = get_norm_lines(lines)
+ with open(output_dir / "text_norm", "w+", encoding="utf-8") as f:
+ f.writelines([line + "\n" for line in norm_lines])
+
+ text_words_segments = get_word_segments(norm_lines)
+ with open(output_dir / "text_words_segmentation", "w+", encoding="utf-8") as f:
+ f.writelines(text_words_segments)
+
+ words = get_words(text_words_segments)[1:] # remove "\n" from words
+ with open(output_dir / "words_no_ids.txt", "w+", encoding="utf-8") as f:
+ f.writelines([word + "\n" for word in sorted(words)])
+
+ words = (
+ ["", "!SIL", "", ""]
+ + sorted(words)
+ + ["#0", "", "<\s>"]
+ )
+
+ with open(output_dir / "words.txt", "w+", encoding="utf-8") as f:
+ f.writelines([f"{word} {i}\n" for i, word in enumerate(words)])
diff --git a/egs/mdcc/ASR/local/text2segments.py b/egs/mdcc/ASR/local/text2segments.py
new file mode 100755
index 0000000000..8ce7ab7e58
--- /dev/null
+++ b/egs/mdcc/ASR/local/text2segments.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+# 2022 Xiaomi Corp. (authors: Weiji Zhuang)
+# 2024 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input "text", which refers to the transcript file for
+MDCC:
+ - text
+and generates the output file text_word_segmentation which is implemented
+with word segmenting:
+ - text_words_segmentation
+"""
+
+import argparse
+from typing import List
+
+import pycantonese
+from tqdm.auto import tqdm
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Cantonese Word Segmentation for text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ "-i",
+ default="data/lang_char/text",
+ type=str,
+ help="the input text file for MDCC",
+ )
+ parser.add_argument(
+ "--output-file",
+ "-o",
+ default="data/lang_char/text_words_segmentation",
+ type=str,
+ help="the text implemented with words segmenting for MDCC",
+ )
+
+ return parser
+
+
+def get_word_segments(lines: List[str]) -> List[str]:
+ return [
+ " ".join(pycantonese.segment(line)) + "\n"
+ for line in tqdm(lines, desc="Segmenting lines")
+ ]
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ with open(input_file, "r", encoding="utf-8") as fr:
+ lines = fr.readlines()
+
+ new_lines = get_word_segments(lines)
+
+ with open(output_file, "w", encoding="utf-8") as fw:
+ fw.writelines(new_lines)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mdcc/ASR/local/text2token.py b/egs/mdcc/ASR/local/text2token.py
new file mode 120000
index 0000000000..81e459d69e
--- /dev/null
+++ b/egs/mdcc/ASR/local/text2token.py
@@ -0,0 +1 @@
+../../../aidatatang_200zh/ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/prepare.sh b/egs/mdcc/ASR/prepare.sh
new file mode 100755
index 0000000000..f4d9bc47e1
--- /dev/null
+++ b/egs/mdcc/ASR/prepare.sh
@@ -0,0 +1,308 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+perturb_speed=true
+
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/mdcc
+# |-- README.md
+# |-- audio/
+# |-- clip_info_rthk.csv
+# |-- cnt_asr_metadata_full.csv
+# |-- cnt_asr_test_metadata.csv
+# |-- cnt_asr_train_metadata.csv
+# |-- cnt_asr_valid_metadata.csv
+# |-- data_statistic.py
+# |-- length
+# |-- podcast_447_2021.csv
+# |-- test.txt
+# |-- transcription/
+# `-- words_length
+# You can download them from:
+# https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view?usp=drive_link
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/mdcc,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/mdcc $dl_dir/mdcc
+ #
+ # The directory structure is
+ # mdcc/
+ # |-- README.md
+ # |-- audio/
+ # |-- clip_info_rthk.csv
+ # |-- cnt_asr_metadata_full.csv
+ # |-- cnt_asr_test_metadata.csv
+ # |-- cnt_asr_train_metadata.csv
+ # |-- cnt_asr_valid_metadata.csv
+ # |-- data_statistic.py
+ # |-- length
+ # |-- podcast_447_2021.csv
+ # |-- test.txt
+ # |-- transcription/
+ # `-- words_length
+
+ if [ ! -d $dl_dir/mdcc/audio ]; then
+ lhotse download mdcc $dl_dir
+
+ # this will download and unzip dataset.zip to $dl_dir/
+
+ mv $dl_dir/dataset $dl_dir/mdcc
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/musan
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare MDCC manifest"
+ # We assume that you have downloaded the MDCC corpus
+ # to $dl_dir/mdcc
+ if [ ! -f data/manifests/.mdcc_manifests.done ]; then
+ log "Might take 40 minutes to traverse the directory."
+ mkdir -p data/manifests
+ lhotse prepare mdcc $dl_dir/mdcc data/manifests
+ touch data/manifests/.mdcc_manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Compute fbank for MDCC"
+ if [ ! -f data/fbank/.mdcc.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_mdcc.py --perturb-speed ${perturb_speed}
+ touch data/fbank/.mdcc.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
+fi
+
+lang_char_dir=data/lang_char
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Prepare char based lang"
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ # 2. chmod +x ./jq
+ # 3. cp jq /usr/bin
+ if [ ! -f $lang_char_dir/text ]; then
+ gunzip -c data/manifests/mdcc_supervisions_train.jsonl.gz \
+ |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
+ > $lang_char_dir/train_text
+
+ cat $lang_char_dir/train_text > $lang_char_dir/text
+
+ gunzip -c data/manifests/mdcc_supervisions_test.jsonl.gz \
+ |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
+ > $lang_char_dir/valid_text
+
+ cat $lang_char_dir/valid_text >> $lang_char_dir/text
+
+ gunzip -c data/manifests/mdcc_supervisions_valid.jsonl.gz \
+ |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
+ > $lang_char_dir/test_text
+
+ cat $lang_char_dir/test_text >> $lang_char_dir/text
+ fi
+
+ if [ ! -f $lang_char_dir/text_words_segmentation ]; then
+ ./local/preprocess_mdcc.py --input-file $lang_char_dir/text \
+ --output-dir $lang_char_dir
+
+ mv $lang_char_dir/text $lang_char_dir/_text
+ cp $lang_char_dir/text_words_segmentation $lang_char_dir/text
+ fi
+
+ if [ ! -f $lang_char_dir/tokens.txt ]; then
+ ./local/prepare_char.py --lang-dir $lang_char_dir
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare G"
+
+ mkdir -p data/lm
+
+ # Train LM on transcripts
+ if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
+ python3 ./shared/make_kn_lm.py \
+ -ngram-order 3 \
+ -text $lang_char_dir/text_words_segmentation \
+ -lm data/lm/3-gram.unpruned.arpa
+ fi
+
+ # We assume you have installed kaldilm, if not, please install
+ # it using: pip install kaldilm
+ if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_char_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
+ fi
+
+ if [ ! -f $lang_char_dir/HLG.fst ]; then
+ ./local/prepare_lang_fst.py \
+ --lang-dir $lang_char_dir \
+ --ngram-G ./data/lm/G_3_gram_char.fst.txt
+ fi
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compile LG & HLG"
+
+ ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
+ ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Generate LM training data"
+
+ log "Processing char based data"
+ out_dir=data/lm_training_char
+ mkdir -p $out_dir $dl_dir/lm
+
+ if [ ! -f $dl_dir/lm/mdcc-train-word.txt ]; then
+ ./local/text2segments.py --input-file $lang_char_dir/train_text \
+ --output-file $dl_dir/lm/mdcc-train-word.txt
+ fi
+
+ # training words
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/mdcc-train-word.txt \
+ --lm-archive $out_dir/lm_data.pt
+
+ # valid words
+ if [ ! -f $dl_dir/lm/mdcc-valid-word.txt ]; then
+ ./local/text2segments.py --input-file $lang_char_dir/valid_text \
+ --output-file $dl_dir/lm/mdcc-valid-word.txt
+ fi
+
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/mdcc-valid-word.txt \
+ --lm-archive $out_dir/lm_data_valid.pt
+
+ # test words
+ if [ ! -f $dl_dir/lm/mdcc-test-word.txt ]; then
+ ./local/text2segments.py --input-file $lang_char_dir/test_text \
+ --output-file $dl_dir/lm/mdcc-test-word.txt
+ fi
+
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/mdcc-test-word.txt \
+ --lm-archive $out_dir/lm_data_test.pt
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Sort LM training data"
+ # Sort LM training data by sentence length in descending order
+ # for ease of training.
+ #
+ # Sentence length equals to the number of tokens
+ # in a sentence.
+
+ out_dir=data/lm_training_char
+ mkdir -p $out_dir
+ ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data.pt \
+ --out-lm-data $out_dir/sorted_lm_data.pt \
+ --out-statistics $out_dir/statistics.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data_valid.pt \
+ --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+ --out-statistics $out_dir/statistics-valid.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data_test.pt \
+ --out-lm-data $out_dir/sorted_lm_data-test.pt \
+ --out-statistics $out_dir/statistics-test.txt
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 12: Train RNN LM model"
+ python ../../../icefall/rnn_lm/train.py \
+ --start-epoch 0 \
+ --world-size 1 \
+ --num-epochs 20 \
+ --use-fp16 0 \
+ --embedding-dim 512 \
+ --hidden-dim 512 \
+ --num-layers 2 \
+ --batch-size 400 \
+ --exp-dir rnnlm_char/exp \
+ --lm-data $out_dir/sorted_lm_data.pt \
+ --lm-data-valid $out_dir/sorted_lm_data-valid.pt \
+ --vocab-size 4336 \
+ --master-port 12345
+fi
diff --git a/egs/mdcc/ASR/shared b/egs/mdcc/ASR/shared
new file mode 120000
index 0000000000..4c5e91438c
--- /dev/null
+++ b/egs/mdcc/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/__init__.py b/egs/mdcc/ASR/zipformer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/egs/mdcc/ASR/zipformer/asr_datamodule.py b/egs/mdcc/ASR/zipformer/asr_datamodule.py
new file mode 100644
index 0000000000..1f49b65206
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1,382 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+# Copyright 2024 Xiaomi Corporation (Author: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class MdccAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ def train_dataloaders(
+ self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures()
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ cuts_train = load_manifest_lazy(
+ self.args.manifest_dir / "mdcc_cuts_train.jsonl.gz"
+ )
+ return cuts_train
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get valid cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_valid.jsonl.gz")
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_test.jsonl.gz")
diff --git a/egs/mdcc/ASR/zipformer/beam_search.py b/egs/mdcc/ASR/zipformer/beam_search.py
new file mode 120000
index 0000000000..e24eca39f2
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/decode.py b/egs/mdcc/ASR/zipformer/decode.py
new file mode 100755
index 0000000000..ce104baf72
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/decode.py
@@ -0,0 +1,813 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Mingshuang Luo,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search (trivial_graph)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(4) fast beam search (LG)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import MdccAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_char",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ sentence = "".join([lexicon.word_table[i] for i in hyp])
+ hyps.append(list(sentence))
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search_" + key: hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key += f"_beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ilme_scale_{params.ilme_scale}"
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}_" + key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list("".join(text.split())) for text in texts]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ MdccAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest_oracle",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"_ilme_scale_{params.ilme_scale}"
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if "LG" in params.decoding_method:
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ mdcc = MdccAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ valid_cuts = mdcc.valid_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
+ valid_dl = mdcc.valid_dataloaders(valid_cuts)
+
+ test_cuts = mdcc.test_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ test_dl = mdcc.test_dataloaders(test_cuts)
+
+ test_sets = ["valid", "test"]
+ test_dls = [valid_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mdcc/ASR/zipformer/decode_stream.py b/egs/mdcc/ASR/zipformer/decode_stream.py
new file mode 120000
index 0000000000..b8d8ddfc4c
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decode_stream.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/decoder.py b/egs/mdcc/ASR/zipformer/decoder.py
new file mode 120000
index 0000000000..5a8018680d
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/encoder_interface.py b/egs/mdcc/ASR/zipformer/encoder_interface.py
new file mode 120000
index 0000000000..c2eaca6712
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/export-onnx-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py
new file mode 120000
index 0000000000..f9d7563520
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-ctc.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py
new file mode 120000
index 0000000000..652346001e
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py
new file mode 120000
index 0000000000..2962eb7847
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/export-onnx.py b/egs/mdcc/ASR/zipformer/export-onnx.py
new file mode 120000
index 0000000000..70a15683c2
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/export.py b/egs/mdcc/ASR/zipformer/export.py
new file mode 120000
index 0000000000..dfc1bec080
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/joiner.py b/egs/mdcc/ASR/zipformer/joiner.py
new file mode 120000
index 0000000000..5b8a36332e
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/model.py b/egs/mdcc/ASR/zipformer/model.py
new file mode 120000
index 0000000000..cd7e07d72b
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/onnx_check.py b/egs/mdcc/ASR/zipformer/onnx_check.py
new file mode 120000
index 0000000000..f3dd420046
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_check.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/onnx_decode.py b/egs/mdcc/ASR/zipformer/onnx_decode.py
new file mode 100755
index 0000000000..1ed4a9fa11
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/onnx_decode.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX exported models and uses them to decode the test sets.
+"""
+
+import argparse
+import logging
+import time
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import MdccAsrDataModule
+from lhotse.cut import Cut
+from onnx_pretrained import OnnxModel, greedy_search
+
+from icefall.utils import setup_logger, store_transcripts, write_error_stats
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="Valid values are greedy_search and modified_beam_search",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ model: OnnxModel, token_table: k2.SymbolTable, batch: dict
+) -> List[List[str]]:
+ """Decode one batch and return the result.
+ Currently it only greedy_search is supported.
+
+ Args:
+ model:
+ The neural model.
+ token_table:
+ Mapping ids to tokens.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
+
+ encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
+
+ hyps = greedy_search(
+ model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
+ )
+
+ hyps = [[token_table[h] for h in hyp] for hyp in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ model:
+ The neural model.
+ token_table:
+ Mapping ids to tokens.
+
+ Returns:
+ - A list of tuples. Each tuple contains three elements:
+ - cut_id,
+ - reference transcript,
+ - predicted result.
+ - The total duration (in seconds) of the dataset.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 10
+ total_duration = 0
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
+
+ hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = list(ref_text)
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+ return results, total_duration
+
+
+def save_results(
+ res_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+):
+ recog_path = res_dir / f"recogs-{test_set_name}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = res_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("WER", file=f)
+ print(wer, file=f)
+
+ s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ MdccAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ assert (
+ args.decoding_method == "greedy_search"
+ ), "Only supports greedy_search currently."
+ res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
+
+ setup_logger(f"{res_dir}/log-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+
+ token_table = k2.SymbolTable.from_file(args.tokens)
+ assert token_table[0] == ""
+
+ logging.info(vars(args))
+
+ logging.info("About to create model")
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+
+ mdcc = MdccAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ valid_cuts = mdcc.valid_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
+ valid_dl = mdcc.valid_dataloaders(valid_cuts)
+
+ test_cuts = mdcc.test_net_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ test_dl = mdcc.test_dataloaders(test_cuts)
+
+ test_sets = ["valid", "test"]
+ test_dl = [valid_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ start_time = time.time()
+ results, total_duration = decode_dataset(
+ dl=test_dl, model=model, token_table=token_table
+ )
+ end_time = time.time()
+ elapsed_seconds = end_time - start_time
+ rtf = elapsed_seconds / total_duration
+
+ logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
+ logging.info(f"Wave duration: {total_duration:.3f} s")
+ logging.info(
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+ )
+
+ save_results(res_dir=res_dir, test_set_name=test_set, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mdcc/ASR/zipformer/optim.py b/egs/mdcc/ASR/zipformer/optim.py
new file mode 120000
index 0000000000..5eaa3cffd4
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/scaling.py b/egs/mdcc/ASR/zipformer/scaling.py
new file mode 120000
index 0000000000..6f398f431d
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/scaling_converter.py b/egs/mdcc/ASR/zipformer/scaling_converter.py
new file mode 120000
index 0000000000..b0ecee05e1
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/streaming_beam_search.py b/egs/mdcc/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 0000000000..b1ed545579
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/streaming_decode.py b/egs/mdcc/ASR/zipformer/streaming_decode.py
new file mode 100755
index 0000000000..dadb0b55fd
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/streaming_decode.py
@@ -0,0 +1,881 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 256 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import MdccAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="Path to the lang dir(containing lexicon, tokens, etc.)",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ streams=decode_streams,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ lexicon:
+ The Lexicon.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ list(decode_streams[i].ground_truth.strip()),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ key = f"greedy_search_{key}"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}_{key}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}_{key}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ MdccAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ mdcc = MdccAsrDataModule(args)
+
+ valid_cuts = mdcc.valid_cuts()
+ test_cuts = mdcc.test_cuts()
+
+ test_sets = ["valid", "test"]
+ test_cuts = [valid_cuts, test_cuts]
+
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mdcc/ASR/zipformer/subsampling.py b/egs/mdcc/ASR/zipformer/subsampling.py
new file mode 120000
index 0000000000..01ae9002c6
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py
new file mode 100755
index 0000000000..2fae668444
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/train.py
@@ -0,0 +1,1345 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 50 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 350
+
+# For mix precision training:
+
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 50 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MdccAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="""Maximum left-contexts for causal training, measured in frames which will
+ be converted to a number of chunks. If splitting into chunks,
+ chunk left-context frames will be chosen randomly from this list; else not relevant.""",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="""The prune range for rnnt loss, it means how many symbols(context)
+ we are using to compute the loss""",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="""The scale to smooth the loss with lm
+ (output of prediction network) part.""",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="""The scale to smooth the loss with am (output of encoder network) part.""",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="""To get pruning ranges, we will calculate a simple version
+ loss(joiner is just addition), this simple loss also uses for
+ training (as a regularization item). We will scale the simple loss
+ with this parameter before adding to the final loss.""",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000,
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(max(params.encoder_dim.split(","))),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, _ = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ mdcc = MdccAsrDataModule(args)
+
+ train_cuts = mdcc.train_cuts()
+ valid_cuts = mdcc.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 15 seconds
+ #
+ # Caution: There is a reason to select 15.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = mdcc.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+ valid_dl = mdcc.valid_dataloaders(valid_cuts)
+
+ if False and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ graph_compiler:
+ The compiler to encode texts to ids.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ texts = supervisions["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ MdccAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.lang_dir = Path(args.lang_dir)
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mdcc/ASR/zipformer/zipformer.py b/egs/mdcc/ASR/zipformer/zipformer.py
new file mode 120000
index 0000000000..23011dda71
--- /dev/null
+++ b/egs/mdcc/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
index 662ae01c51..489b38e657 100644
--- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
+++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
@@ -216,7 +216,7 @@ def train_dataloaders(
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
diff --git a/requirements.txt b/requirements.txt
index e64afd1eec..6bafa6aca3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-kaldifst
+kaldifst>1.7.0
kaldilm
kaldialign
num2words
@@ -14,4 +14,7 @@ onnxruntime==1.16.3
# style check session:
black==22.3.0
isort==5.10.1
-flake8==5.0.4
\ No newline at end of file
+flake8==5.0.4
+
+# cantonese word segment support
+pycantonese==3.4.0
\ No newline at end of file