diff --git a/.github/scripts/.gitignore b/.github/scripts/.gitignore new file mode 100644 index 0000000000..672e477d8d --- /dev/null +++ b/.github/scripts/.gitignore @@ -0,0 +1 @@ +piper_phonemize.html diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py new file mode 100755 index 0000000000..3784d5fa58 --- /dev/null +++ b/.github/scripts/generate-piper-phonemize-page.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + + +def main(): + prefix = ( + "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/" + ) + files = [ + "piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl", + "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + ] + with open("piper_phonemize.html", "w") as f: + for file in files: + url = prefix + file + f.write(f'{file}
\n') + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh index 7e9bd8a478..293ed66e53 100755 --- a/.github/scripts/librispeech/ASR/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -15,9 +15,9 @@ function prepare_data() { # cause OOM error for CI later. mkdir -p download/lm pushd download/lm - wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt - wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt - wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt ls -lh gunzip librispeech-lm-norm.txt.gz diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh new file mode 100755 index 0000000000..707361782f --- /dev/null +++ b/.github/scripts/ljspeech/TTS/run.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install espnet_tts_frontend +python3 -m pip install numba + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/ljspeech/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir download + pushd download + wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 + tar xvf LJSpeech-1.1.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./vits + sed -i.bak s/200/3/g ./train.py + git diff . + popd + + for t in low medium high; do + ./vits/train.py \ + --exp-dir vits/exp-$t \ + --model-type $t \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh vits/exp-$t + done +} + +function infer() { + for t in low medium high; do + ./vits/infer.py \ + --num-buckets 2 \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt \ + --max-duration 20 + done +} + +function export_onnx() { + for t in low medium high; do + ./vits/export-onnx.py \ + --model-type $t \ + --epoch 1 \ + --exp-dir ./vits/exp-$t \ + --tokens data/tokens.txt + + ls -lh vits/exp-$t/ + done +} + +function test_medium() { + git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12 + + ./vits/export-onnx.py \ + --model-type medium \ + --epoch 820 \ + --exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \ + --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt + + ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp + + ./vits/test_onnx.py \ + --model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \ + --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \ + --output-filename /icefall/test-medium.wav + + ls -lh /icefall/test-medium.wav + + d=/icefall/vits-icefall-en_US-ljspeech-medium + mkdir $d + cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/ + cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx + + rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12 + + pushd $d + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + cd .. + tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium + rm -rf vits-icefall-en_US-ljspeech-medium + ls -lh *.tar.bz2 + popd +} + +function test_low() { + git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12 + + ./vits/export-onnx.py \ + --model-type low \ + --epoch 1600 \ + --exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \ + --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt + + ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp + + ./vits/test_onnx.py \ + --model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \ + --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \ + --output-filename /icefall/test-low.wav + + ls -lh /icefall/test-low.wav + + d=/icefall/vits-icefall-en_US-ljspeech-low + mkdir $d + cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/ + cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx + + rm -rf icefall-tts-ljspeech-vits-low-2024-03-12 + + pushd $d + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + cd .. + tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low + rm -rf vits-icefall-en_US-ljspeech-low + ls -lh *.tar.bz2 + popd +} + +prepare_data +train +infer +export_onnx +rm -rf vits/exp-{low,medium,high} +test_medium +test_low diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index d7fe2c9643..c622476f2d 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -56,11 +56,14 @@ jobs: - name: Build doc shell: bash run: | + .github/scripts/generate-piper-phonemize-page.py cd docs python3 -m pip install -r ./requirements.txt make html touch build/html/.nojekyll + cp -v ../piper_phonemize.html ./build/html/ + - name: Deploy uses: peaceiris/actions-gh-pages@v3 with: diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml new file mode 100644 index 0000000000..25402275b4 --- /dev/null +++ b/.github/workflows/ljspeech.yml @@ -0,0 +1,102 @@ +name: ljspeech + +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: ljspeech-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + + ljspeech: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + ls -lh + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/ljspeech/TTS/run.sh + + - name: display files + shell: bash + run: | + ls -lh + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} + path: ./*.wav + + - uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + with: + name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} + path: ./*.wav + + - name: Release exported onnx models + if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: vits-icefall-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: tts-models + diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 323d0adfc8..9499a3aea2 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -13,6 +13,14 @@ with the `LJSpeech `_ dataset. The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ +Install extra dependencies +-------------------------- + +.. code-block:: bash + + pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html + pip install numba espnet_tts_frontend + Data preparation ---------------- @@ -56,7 +64,8 @@ Training --start-epoch 1 \ --use-fp16 1 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ + --model-type high \ --max-duration 500 .. note:: @@ -64,6 +73,11 @@ Training You can adjust the hyper-parameters to control the size of the VITS model and the training configurations. For more details, please run ``./vits/train.py --help``. +.. warning:: + + If you want a model that runs faster on CPU, please use ``--model-type low`` + or ``--model-type medium``. + .. note:: The training can take a long time (usually a couple of days). @@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir Export models ------------- -Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: -``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. +Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``: +``vits-epoch-*.onnx``. .. code-block:: bash @@ -120,4 +134,68 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following link: - - ``_ + - ``--model-type=high``: ``_ + - ``--model-type=medium``: ``_ + - ``--model-type=low``: ``_ + +Usage in sherpa-onnx +-------------------- + +The following describes how to test the exported ONNX model in `sherpa-onnx`_. + +.. hint:: + + `sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python, + Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS. + + We only describe how to use pre-built binaries from `sherpa-onnx`_ below. + Please refer to ``_ + for more documentation. + +Install sherpa-onnx +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + pip install sherpa-onnx + +To check that you have installed `sherpa-onnx`_ successfully, please run: + +.. code-block:: bash + + which sherpa-onnx-offline-tts + sherpa-onnx-offline-tts --help + +Download lexicon files +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + cd /tmp + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + +Run sherpa-onnx +^^^^^^^^^^^^^^^ + +.. code-block:: bash + + cd egs/ljspeech/TTS + + sherpa-onnx-offline-tts \ + --vits-model=vits/exp/vits-epoch-1000.onnx \ + --vits-tokens=data/tokens.txt \ + --vits-data-dir=/tmp/espeak-ng-data \ + --num-threads=1 \ + --output-filename=./high.wav \ + "Ask not what your country can do for you; ask what you can do for your country." + +.. hint:: + + You can also use ``sherpa-onnx-offline-tts-play`` to play the audio + as it is generating. + +You should get a file ``high.wav`` after running the above command. + +Congratulations! You have successfully trained and exported a text-to-speech +model and run it with `sherpa-onnx`_. diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index b547191625..d088072a71 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -19,7 +19,9 @@ The following table lists the differences among them. | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | | `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| -| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 | +| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 | +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | + The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index b7be89bc8e..13be69534f 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then fi if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 11: Train RNN LM model" + log "Stage 12: Train RNN LM model" python ../../../icefall/rnn_lm/train.py \ --start-epoch 0 \ --world-size 1 \ diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 301ab01117..996a1da2d4 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -15,7 +15,7 @@ perturb_speed=true # # - $dl_dir/alimeeting # This directory contains the following files downloaded from -# https://openslr.org/62/ +# https://openslr.org/119/ # # - Train_Ali_far.tar.gz # - Train_Ali_near.tar.gz diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 1098840f85..15c20692da 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting # # - $dl_dir/alimeeting # This directory contains the following files downloaded from -# https://openslr.org/62/ +# https://openslr.org/119/ # # - Train_Ali_far.tar.gz # - Train_Ali_near.tar.gz diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index da8e620349..91220bd112 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -232,7 +232,7 @@ def train_dataloaders( logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index e23c9b1b7e..4985f3f4c3 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -232,7 +232,7 @@ def train_dataloaders( logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 80be5a3155..7b112c12c8 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -1,10 +1,10 @@ # Introduction -This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. -A transcription is provided for each clip. +This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. +A transcription is provided for each clip. Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. -The texts were published between 1884 and 1964, and are in the public domain. +The texts were published between 1884 and 1964, and are in the public domain. The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). @@ -35,4 +35,69 @@ To inference, use: --exp-dir vits/exp \ --epoch 1000 \ --tokens data/tokens.txt -``` \ No newline at end of file +``` + +## Quality vs speed + +If you feel that the trained model is slow at runtime, you can specify the +argument `--model-type` during training. Possible values are: + + - `low`, means **low** quality. The resulting model is very small in file size + and runs very fast. The following is a wave file generatd by a `low` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633 + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``26.8 MB`` (float32). + + - `medium`, means **medium** quality. + The following is a wave file generatd by a `medium` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``70.9 MB`` (float32). + + - `high`, means **high** quality. This is the default value. + + The following is a wave file generatd by a `high` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``113 MB`` (float32). + + +A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at + + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1601 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --model-type low \ + --max-duration 800 +``` + +A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at + +```bash +export CUDA_VISIBLE_DEVICES=4,5,6,7 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp-medium \ + --model-type medium \ + --max-duration 500 + +# (Note it is killed after `epoch-820.pt`) +``` diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index 08fe7430ef..4ba88604ce 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -23,7 +23,11 @@ import logging from pathlib import Path -import tacotron_cleaner.cleaners +try: + import tacotron_cleaner.cleaners +except ModuleNotFoundError as ex: + raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") + from lhotse import CutSet, load_manifest from piper_phonemize import phonemize_espeak diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index cbf27bd423..9ed0f93fde 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -28,7 +28,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: build monotonic_align lib" if [ ! -d vits/monotonic_align/build ]; then cd vits/monotonic_align - python setup.py build_ext --inplace + python3 setup.py build_ext --inplace cd ../../ else log "monotonic_align lib already built" @@ -54,7 +54,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare LJSpeech manifest" # We assume that you have downloaded the LJSpeech corpus - # to $dl_dir/LJSpeech + # to $dl_dir/LJSpeech-1.1 mkdir -p data/manifests if [ ! -e data/manifests/.ljspeech.done ]; then lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests @@ -82,8 +82,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LJSpeech" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, - # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then ./local/prepare_tokens_ljspeech.py diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 58b1663684..0740757c06 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -25,9 +25,8 @@ --exp-dir vits/exp \ --tokens data/tokens.txt -It will generate two files inside vits/exp: +It will generate one file inside vits/exp: - vits-epoch-1000.onnx - - vits-epoch-1000.int8.onnx (quantizated model) See ./test_onnx.py for how to use the exported ONNX models. """ @@ -40,7 +39,6 @@ import onnx import torch import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic from tokenizer import Tokenizer from train import get_model, get_params @@ -75,6 +73,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -136,7 +144,7 @@ def forward( Return a tuple containing: - audio, generated wavform tensor, (B, T_wav) """ - audio, _, _ = self.model.inference( + audio, _, _ = self.model.generator.inference( text=tokens, text_lengths=tokens_lens, noise_scale=noise_scale, @@ -198,6 +206,11 @@ def export_model_onnx( }, ) + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + meta_data = { "model_type": "vits", "version": "1", @@ -206,8 +219,8 @@ def export_model_onnx( "language": "English", "voice": "en-us", # Choose your language appropriately "has_espeak": 1, - "n_speakers": 1, - "sample_rate": 22050, # Must match the real sample rate + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -233,14 +246,13 @@ def main(): load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model = model.generator model.to("cpu") model.eval() model = OnnxModel(model=model) num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}") + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") suffix = f"epoch-{params.epoch}" @@ -256,18 +268,6 @@ def main(): ) logging.info(f"Exported generator to {model_filename}") - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - weight_type=QuantType.QUInt8, - ) - if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index 66c8cedb19..b9add9e828 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -189,7 +189,7 @@ def __init__( self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.spks = None if spks is not None and spks > 1: - assert global_channels > 0 + assert global_channels > 0, global_channels self.spks = spks self.global_emb = torch.nn.Embedding(spks, global_channels) self.spk_embed_dim = None diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 9e7c71c6dc..7be76e3151 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -72,6 +72,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -94,6 +104,7 @@ def infer_dataset( tokenizer: Used to convert text to phonemes. """ + # Background worker save audios to disk. def _save_worker( batch_size: int, diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py index 2b35654f51..5dc3641e59 100644 --- a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py @@ -10,7 +10,11 @@ import numpy as np import torch -from numba import njit, prange + +try: + from numba import njit, prange +except ModuleNotFoundError as ex: + raise RuntimeError(f"{ex}/nPlease run\n pip install numba") try: from .core import maximum_path_c diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py new file mode 100755 index 0000000000..1de10f012b --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tokenizer import Tokenizer +from train import get_model, get_params +from vits import VITS + + +def test_model_type(model_type): + tokens = "./data/tokens.txt" + + params = get_params() + + tokenizer = Tokenizer(tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_type = model_type + + model = get_model(params) + generator = model.generator + + num_param = sum([p.numel() for p in generator.parameters()]) + print( + f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M" + ) + + +def main(): + test_model_type("high") # 35.63 M + test_model_type("low") # 7.55 M + test_model_type("medium") # 23.61 M + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 4f46e8e6c5..b3805fadb3 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -54,6 +54,20 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--text", + type=str, + default="Ask not what your country can do for you; ask what you can do for your country.", + help="Text to generate speech for", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="test_onnx.wav", + help="Filename to save the generated wave file.", + ) + return parser @@ -61,7 +75,7 @@ class OnnxModel: def __init__(self, model_filename: str): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 + session_opts.intra_op_num_threads = 1 self.session_opts = session_opts @@ -72,6 +86,9 @@ def __init__(self, model_filename: str): ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: """ Args: @@ -101,13 +118,14 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten def main(): args = get_parser().parse_args() + logging.info(vars(args)) tokenizer = Tokenizer(args.tokens) logging.info("About to create onnx model") model = OnnxModel(args.model_filename) - text = "I went there to see the land, the people and how their system works, end quote." + text = args.text tokens = tokenizer.texts_to_token_ids( [text], intersperse_blank=True, add_sos=True, add_eos=True ) @@ -115,8 +133,9 @@ def main(): tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") + output_filename = args.output_filename + torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) + logging.info(f"Saved to {output_filename}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index fcbae7103f..9b21ed9cb0 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -92,9 +92,9 @@ def forward( x_lengths (Tensor): Length tensor (B,). Returns: - Tensor: Encoded hidden representation (B, attention_dim, T_text). - Tensor: Projected mean tensor (B, attention_dim, T_text). - Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Encoded hidden representation (B, embed_dim, T_text). + Tensor: Projected mean tensor (B, embed_dim, T_text). + Tensor: Projected scale tensor (B, embed_dim, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text). """ @@ -108,6 +108,7 @@ def forward( # encoder assume the channel last (B, T_text, embed_dim) x = self.encoder(x, key_padding_mask=pad_mask) + # Note: attention_dim == embed_dim # convert the channel first (B, embed_dim, T_text) x = x.transpose(1, 2) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 9a5a9090ec..3c9046adde 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -18,7 +18,15 @@ from typing import Dict, List import tacotron_cleaner.cleaners -from piper_phonemize import phonemize_espeak + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + from utils import intersperse diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 6589b75ff6..34b943765a 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -153,6 +153,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -189,15 +199,6 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. """ params = AttributeDict( { @@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, + model_type=params.model_type, mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, @@ -363,7 +365,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations in one epoch + # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 8ff868bc8b..e1a9c7b3ca 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -255,6 +255,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, shuffle=False, ) logging.info("About to create valid dataloader") @@ -294,6 +295,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: test_sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, shuffle=False, ) logging.info("About to create test dataloader") diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index b4f0c21e6d..0b9575cbde 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -5,6 +5,7 @@ """VITS module for GAN-TTS task.""" +import copy from typing import Any, Dict, Optional, Tuple import torch @@ -38,6 +39,36 @@ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA } +LOW_CONFIG = { + "hidden_channels": 96, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +MEDIUM_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +HIGH_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 2, 2), + "decoder_channels": 512, + "decoder_upsample_kernel_sizes": (16, 16, 4, 4), + "decoder_resblock_kernel_sizes": (3, 7, 11), + "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + "text_encoder_cnn_module_kernel": 5, +} + class VITS(nn.Module): """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" @@ -49,6 +80,7 @@ def __init__( feature_dim: int = 513, sampling_rate: int = 22050, generator_type: str = "vits_generator", + model_type: str = "", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, @@ -155,12 +187,13 @@ def __init__( """Initialize VITS module. Args: - idim (int): Input vocabrary size. + idim (int): Input vocabulary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. + model_type (str): If not empty, must be one of: low, medium, high generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. @@ -181,6 +214,24 @@ def __init__( """ super().__init__() + generator_params = copy.deepcopy(generator_params) + discriminator_params = copy.deepcopy(discriminator_params) + generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params) + discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params) + feat_match_loss_params = copy.deepcopy(feat_match_loss_params) + mel_loss_params = copy.deepcopy(mel_loss_params) + + if model_type != "": + assert model_type in ("low", "medium", "high"), model_type + if model_type == "low": + generator_params.update(LOW_CONFIG) + elif model_type == "medium": + generator_params.update(MEDIUM_CONFIG) + elif model_type == "high": + generator_params.update(HIGH_CONFIG) + else: + raise ValueError(f"Unknown model_type: ${model_type}") + # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator": diff --git a/egs/mdcc/ASR/README.md b/egs/mdcc/ASR/README.md new file mode 100644 index 0000000000..112845b734 --- /dev/null +++ b/egs/mdcc/ASR/README.md @@ -0,0 +1,19 @@ +# Introduction + +Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with +transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy, +politics, education, culture, lifestyle and family domains, covering a wide range of topics. + +Manuscript can be found at: https://arxiv.org/abs/2201.02419 + +# Transducers + + + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|-----------------------------| +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 | + +The decoder is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/mdcc/ASR/RESULTS.md b/egs/mdcc/ASR/RESULTS.md new file mode 100644 index 0000000000..ff7ddc9579 --- /dev/null +++ b/egs/mdcc/ASR/RESULTS.md @@ -0,0 +1,41 @@ +## Results + +#### Zipformer + +See + +[./zipformer](./zipformer) + +##### normal-scaled model, number of model parameters: 74470867, i.e., 74.47 M + +| | test | valid | comment | +|------------------------|------|-------|-----------------------------------------| +| greedy search | 7.45 | 7.51 | --epoch 45 --avg 35 | +| modified beam search | 6.68 | 6.73 | --epoch 45 --avg 35 | +| fast beam search | 7.22 | 7.28 | --epoch 45 --avg 35 | + +The training command: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 50 \ + --use-fp16 1 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 +``` + +The decoding command: + +``` + ./zipformer/decode.py \ + --epoch 45 \ + --avg 35 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search # modified_beam_search +``` + +The pretrained model is available at: https://huggingface.co/zrjin/icefall-asr-mdcc-zipformer-2024-03-11/ \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg.py b/egs/mdcc/ASR/local/compile_hlg.py new file mode 120000 index 0000000000..471aa7fb40 --- /dev/null +++ b/egs/mdcc/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_hlg_using_openfst.py b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py new file mode 120000 index 0000000000..d34edd7f30 --- /dev/null +++ b/egs/mdcc/ASR/local/compile_hlg_using_openfst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg_using_openfst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compile_lg.py b/egs/mdcc/ASR/local/compile_lg.py new file mode 120000 index 0000000000..462d6d3fb9 --- /dev/null +++ b/egs/mdcc/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/compute_fbank_mdcc.py b/egs/mdcc/ASR/local/compute_fbank_mdcc.py new file mode 100755 index 0000000000..647b211270 --- /dev/null +++ b/egs/mdcc/ASR/local/compute_fbank_mdcc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the aishell dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_mdcc( + num_mel_bins: int = 80, + perturb_speed: bool = False, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", +): + src_dir = Path("data/manifests") + output_dir = Path(output_dir) + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ( + "train", + "valid", + "test", + ) + prefix = "mdcc" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and perturb_speed: + logging.info("Doing speed perturb") + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_mdcc( + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, + ) diff --git a/egs/mdcc/ASR/local/display_manifest_statistics.py b/egs/mdcc/ASR/local/display_manifest_statistics.py new file mode 100755 index 0000000000..27cf8c9439 --- /dev/null +++ b/egs/mdcc/ASR/local/display_manifest_statistics.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/fbank/mdcc_cuts_train.jsonl.gz" + path = "./data/fbank/mdcc_cuts_valid.jsonl.gz" + path = "./data/fbank/mdcc_cuts_test.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe(full=True) + + +if __name__ == "__main__": + main() + +""" +data/fbank/mdcc_cuts_train.jsonl.gz (with speed perturbation) +_________________________________________ +_ Cuts count: _ 195360 +_________________________________________ +_ Total duration (hh:mm:ss) _ 173:44:59 +_________________________________________ +_ mean _ 3.2 +_________________________________________ +_ std _ 2.1 +_________________________________________ +_ min _ 0.2 +_________________________________________ +_ 25% _ 1.8 +_________________________________________ +_ 50% _ 2.7 +_________________________________________ +_ 75% _ 4.0 +_________________________________________ +_ 99% _ 11.0 _ +_________________________________________ +_ 99.5% _ 12.4 _ +_________________________________________ +_ 99.9% _ 14.8 _ +_________________________________________ +_ max _ 16.7 _ +_________________________________________ +_ Recordings available: _ 195360 _ +_________________________________________ +_ Features available: _ 195360 _ +_________________________________________ +_ Supervisions available: _ 195360 _ +_________________________________________ + +data/fbank/mdcc_cuts_valid.jsonl.gz +________________________________________ +_ Cuts count: _ 5663 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 05:03:12 _ +________________________________________ +_ mean _ 3.2 _ +________________________________________ +_ std _ 2.0 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 2.7 _ +________________________________________ +_ 75% _ 4.0 _ +________________________________________ +_ 99% _ 10.9 _ +________________________________________ +_ 99.5% _ 12.3 _ +________________________________________ +_ 99.9% _ 14.4 _ +________________________________________ +_ max _ 14.8 _ +________________________________________ +_ Recordings available: _ 5663 _ +________________________________________ +_ Features available: _ 5663 _ +________________________________________ +_ Supervisions available: _ 5663 _ +________________________________________ + +data/fbank/mdcc_cuts_test.jsonl.gz +________________________________________ +_ Cuts count: _ 12492 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 11:00:31 _ +________________________________________ +_ mean _ 3.2 _ +________________________________________ +_ std _ 2.0 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 2.7 _ +________________________________________ +_ 75% _ 4.0 _ +________________________________________ +_ 99% _ 10.5 _ +________________________________________ +_ 99.5% _ 12.1 _ +________________________________________ +_ 99.9% _ 14.0 _ +________________________________________ +_ max _ 14.8 _ +________________________________________ +_ Recordings available: _ 12492 _ +________________________________________ +_ Features available: _ 12492 _ +________________________________________ +_ Supervisions available: _ 12492 _ +________________________________________ + +""" diff --git a/egs/mdcc/ASR/local/prepare_char.py b/egs/mdcc/ASR/local/prepare_char.py new file mode 120000 index 0000000000..42743b5449 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_char_lm_training_data.py b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py new file mode 120000 index 0000000000..2374cafddb --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char_lm_training_data.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang.py b/egs/mdcc/ASR/local/prepare_lang.py new file mode 120000 index 0000000000..bee8d5f036 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/prepare_lang_fst.py b/egs/mdcc/ASR/local/prepare_lang_fst.py new file mode 120000 index 0000000000..c5787c5340 --- /dev/null +++ b/egs/mdcc/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/mdcc/ASR/local/preprocess_mdcc.py b/egs/mdcc/ASR/local/preprocess_mdcc.py new file mode 100755 index 0000000000..cd0dc7de82 --- /dev/null +++ b/egs/mdcc/ASR/local/preprocess_mdcc.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a text file "data/lang_char/text" as input, the file consist of +lines each containing a transcript, applies text norm and generates the following +files in the directory "data/lang_char": + - text_norm + - words.txt + - words_no_ids.txt + - text_words_segmentation +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +import pycantonese +from tqdm.auto import tqdm + +from icefall.utils import is_cjk + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare char lexicon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + "-i", + default="data/lang_char/text", + type=str, + help="The input text file", + ) + parser.add_argument( + "--output-dir", + "-o", + default="data/lang_char", + type=str, + help="The output directory", + ) + return parser + + +def get_norm_lines(lines: List[str]) -> List[str]: + def _text_norm(text: str) -> str: + # to cope with the protocol for transcription: + # When taking notes, the annotators adhere to the following guidelines: + # 1) If the audio contains pure music, the annotators mark the label + # "(music)" in the file name of its transcript. 2) If the utterance + # contains one or several sentences with background music or noise, the + # annotators mark the label "(music)" before each sentence in the transcript. + # 3) The annotators use {} symbols to enclose words they are uncertain + # about, for example, {梁佳佳},我是{}人. + + # here we manually fix some errors in the transcript + + return ( + text.strip() + .replace("(music)", "") + .replace("(music", "") + .replace("{", "") + .replace("}", "") + .replace("BB所以就指腹為親喇", "BB 所以就指腹為親喇") + .upper() + ) + + return [_text_norm(line) for line in lines] + + +def get_word_segments(lines: List[str]) -> List[str]: + # the current pycantonese segmenter does not handle the case when the input + # is code switching, so we need to handle it separately + + new_lines = [] + + for line in tqdm(lines, desc="Segmenting lines"): + try: + # code switching + if len(line.strip().split(" ")) > 1: + segments = [] + for segment in line.strip().split(" "): + if segment.strip() == "": + continue + try: + if not is_cjk(segment[0]): # en segment + segments.append(segment) + else: # zh segment + segments.extend(pycantonese.segment(segment)) + except Exception as e: + logging.error(f"Failed to process segment: {segment}") + raise e + new_lines.append(" ".join(segments) + "\n") + # not code switching + else: + new_lines.append(" ".join(pycantonese.segment(line)) + "\n") + except Exception as e: + logging.error(f"Failed to process line: {line}") + raise e + return new_lines + + +def get_words(lines: List[str]) -> List[str]: + words = set() + for line in tqdm(lines, desc="Getting words"): + words.update(line.strip().split(" ")) + return list(words) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + + input_file = Path(args.input_file) + output_dir = Path(args.output_dir) + + assert output_dir.is_dir(), f"{output_dir} does not exist" + assert input_file.is_file(), f"{input_file} does not exist" + + lines = input_file.read_text(encoding="utf-8").strip().split("\n") + + norm_lines = get_norm_lines(lines) + with open(output_dir / "text_norm", "w+", encoding="utf-8") as f: + f.writelines([line + "\n" for line in norm_lines]) + + text_words_segments = get_word_segments(norm_lines) + with open(output_dir / "text_words_segmentation", "w+", encoding="utf-8") as f: + f.writelines(text_words_segments) + + words = get_words(text_words_segments)[1:] # remove "\n" from words + with open(output_dir / "words_no_ids.txt", "w+", encoding="utf-8") as f: + f.writelines([word + "\n" for word in sorted(words)]) + + words = ( + ["", "!SIL", "", ""] + + sorted(words) + + ["#0", "", "<\s>"] + ) + + with open(output_dir / "words.txt", "w+", encoding="utf-8") as f: + f.writelines([f"{word} {i}\n" for i, word in enumerate(words)]) diff --git a/egs/mdcc/ASR/local/text2segments.py b/egs/mdcc/ASR/local/text2segments.py new file mode 100755 index 0000000000..8ce7ab7e58 --- /dev/null +++ b/egs/mdcc/ASR/local/text2segments.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# 2022 Xiaomi Corp. (authors: Weiji Zhuang) +# 2024 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input "text", which refers to the transcript file for +MDCC: + - text +and generates the output file text_word_segmentation which is implemented +with word segmenting: + - text_words_segmentation +""" + +import argparse +from typing import List + +import pycantonese +from tqdm.auto import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Cantonese Word Segmentation for text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-file", + "-i", + default="data/lang_char/text", + type=str, + help="the input text file for MDCC", + ) + parser.add_argument( + "--output-file", + "-o", + default="data/lang_char/text_words_segmentation", + type=str, + help="the text implemented with words segmenting for MDCC", + ) + + return parser + + +def get_word_segments(lines: List[str]) -> List[str]: + return [ + " ".join(pycantonese.segment(line)) + "\n" + for line in tqdm(lines, desc="Segmenting lines") + ] + + +def main(): + parser = get_parser() + args = parser.parse_args() + + input_file = args.input_file + output_file = args.output_file + + with open(input_file, "r", encoding="utf-8") as fr: + lines = fr.readlines() + + new_lines = get_word_segments(lines) + + with open(output_file, "w", encoding="utf-8") as fw: + fw.writelines(new_lines) + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/local/text2token.py b/egs/mdcc/ASR/local/text2token.py new file mode 120000 index 0000000000..81e459d69e --- /dev/null +++ b/egs/mdcc/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../aidatatang_200zh/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/mdcc/ASR/prepare.sh b/egs/mdcc/ASR/prepare.sh new file mode 100755 index 0000000000..f4d9bc47e1 --- /dev/null +++ b/egs/mdcc/ASR/prepare.sh @@ -0,0 +1,308 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 +perturb_speed=true + + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/mdcc +# |-- README.md +# |-- audio/ +# |-- clip_info_rthk.csv +# |-- cnt_asr_metadata_full.csv +# |-- cnt_asr_test_metadata.csv +# |-- cnt_asr_train_metadata.csv +# |-- cnt_asr_valid_metadata.csv +# |-- data_statistic.py +# |-- length +# |-- podcast_447_2021.csv +# |-- test.txt +# |-- transcription/ +# `-- words_length +# You can download them from: +# https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view?usp=drive_link +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "stage 0: Download data" + + # If you have pre-downloaded it to /path/to/mdcc, + # you can create a symlink + # + # ln -sfv /path/to/mdcc $dl_dir/mdcc + # + # The directory structure is + # mdcc/ + # |-- README.md + # |-- audio/ + # |-- clip_info_rthk.csv + # |-- cnt_asr_metadata_full.csv + # |-- cnt_asr_test_metadata.csv + # |-- cnt_asr_train_metadata.csv + # |-- cnt_asr_valid_metadata.csv + # |-- data_statistic.py + # |-- length + # |-- podcast_447_2021.csv + # |-- test.txt + # |-- transcription/ + # `-- words_length + + if [ ! -d $dl_dir/mdcc/audio ]; then + lhotse download mdcc $dl_dir + + # this will download and unzip dataset.zip to $dl_dir/ + + mv $dl_dir/dataset $dl_dir/mdcc + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare MDCC manifest" + # We assume that you have downloaded the MDCC corpus + # to $dl_dir/mdcc + if [ ! -f data/manifests/.mdcc_manifests.done ]; then + log "Might take 40 minutes to traverse the directory." + mkdir -p data/manifests + lhotse prepare mdcc $dl_dir/mdcc data/manifests + touch data/manifests/.mdcc_manifests.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for MDCC" + if [ ! -f data/fbank/.mdcc.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_mdcc.py --perturb-speed ${perturb_speed} + touch data/fbank/.mdcc.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi + +lang_char_dir=data/lang_char +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare char based lang" + mkdir -p $lang_char_dir + + # Prepare text. + # Note: in Linux, you can install jq with the following command: + # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + # 2. chmod +x ./jq + # 3. cp jq /usr/bin + if [ ! -f $lang_char_dir/text ]; then + gunzip -c data/manifests/mdcc_supervisions_train.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/train_text + + cat $lang_char_dir/train_text > $lang_char_dir/text + + gunzip -c data/manifests/mdcc_supervisions_test.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/valid_text + + cat $lang_char_dir/valid_text >> $lang_char_dir/text + + gunzip -c data/manifests/mdcc_supervisions_valid.jsonl.gz \ + |jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \ + > $lang_char_dir/test_text + + cat $lang_char_dir/test_text >> $lang_char_dir/text + fi + + if [ ! -f $lang_char_dir/text_words_segmentation ]; then + ./local/preprocess_mdcc.py --input-file $lang_char_dir/text \ + --output-dir $lang_char_dir + + mv $lang_char_dir/text $lang_char_dir/_text + cp $lang_char_dir/text_words_segmentation $lang_char_dir/text + fi + + if [ ! -f $lang_char_dir/tokens.txt ]; then + ./local/prepare_char.py --lang-dir $lang_char_dir + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare G" + + mkdir -p data/lm + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/text_words_segmentation \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt + fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_char_dir \ + --ngram-G ./data/lm/G_3_gram_char.fst.txt + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compile LG & HLG" + + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/mdcc-train-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/train_text \ + --output-file $dl_dir/lm/mdcc-train-word.txt + fi + + # training words + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + # valid words + if [ ! -f $dl_dir/lm/mdcc-valid-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/valid_text \ + --output-file $dl_dir/lm/mdcc-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + # test words + if [ ! -f $dl_dir/lm/mdcc-test-word.txt ]; then + ./local/text2segments.py --input-file $lang_char_dir/test_text \ + --output-file $dl_dir/lm/mdcc-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/mdcc-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi diff --git a/egs/mdcc/ASR/shared b/egs/mdcc/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/mdcc/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/__init__.py b/egs/mdcc/ASR/zipformer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/mdcc/ASR/zipformer/asr_datamodule.py b/egs/mdcc/ASR/zipformer/asr_datamodule.py new file mode 100644 index 0000000000..1f49b65206 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,382 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2024 Xiaomi Corporation (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class MdccAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "mdcc_cuts_train.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_test.jsonl.gz") diff --git a/egs/mdcc/ASR/zipformer/beam_search.py b/egs/mdcc/ASR/zipformer/beam_search.py new file mode 120000 index 0000000000..e24eca39f2 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decode.py b/egs/mdcc/ASR/zipformer/decode.py new file mode 100755 index 0000000000..ce104baf72 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decode.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Mingshuang Luo, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + mdcc = MdccAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + valid_cuts = mdcc.valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + test_cuts = mdcc.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = mdcc.test_dataloaders(test_cuts) + + test_sets = ["valid", "test"] + test_dls = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/decode_stream.py b/egs/mdcc/ASR/zipformer/decode_stream.py new file mode 120000 index 0000000000..b8d8ddfc4c --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/decoder.py b/egs/mdcc/ASR/zipformer/decoder.py new file mode 120000 index 0000000000..5a8018680d --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/encoder_interface.py b/egs/mdcc/ASR/zipformer/encoder_interface.py new file mode 120000 index 0000000000..c2eaca6712 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 0000000000..f9d7563520 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..652346001e --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx-streaming.py b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 0000000000..2962eb7847 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export-onnx.py b/egs/mdcc/ASR/zipformer/export-onnx.py new file mode 120000 index 0000000000..70a15683c2 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/export.py b/egs/mdcc/ASR/zipformer/export.py new file mode 120000 index 0000000000..dfc1bec080 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/joiner.py b/egs/mdcc/ASR/zipformer/joiner.py new file mode 120000 index 0000000000..5b8a36332e --- /dev/null +++ b/egs/mdcc/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/model.py b/egs/mdcc/ASR/zipformer/model.py new file mode 120000 index 0000000000..cd7e07d72b --- /dev/null +++ b/egs/mdcc/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_check.py b/egs/mdcc/ASR/zipformer/onnx_check.py new file mode 120000 index 0000000000..f3dd420046 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_decode.py b/egs/mdcc/ASR/zipformer/onnx_decode.py new file mode 100755 index 0000000000..1ed4a9fa11 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/onnx_decode.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. +""" + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from lhotse.cut import Cut +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: k2.SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + Mapping ids to tokens. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [[token_table[h] for h in hyp] for hyp in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: k2.SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + Mapping ids to tokens. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = k2.SymbolTable.from_file(args.tokens) + assert token_table[0] == "" + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + + mdcc = MdccAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + valid_cuts = mdcc.valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + test_cuts = mdcc.test_net_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = mdcc.test_dataloaders(test_cuts) + + test_sets = ["valid", "test"] + test_dl = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/optim.py b/egs/mdcc/ASR/zipformer/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling.py b/egs/mdcc/ASR/zipformer/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/mdcc/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/scaling_converter.py b/egs/mdcc/ASR/zipformer/scaling_converter.py new file mode 120000 index 0000000000..b0ecee05e1 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_beam_search.py b/egs/mdcc/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 0000000000..b1ed545579 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_decode.py b/egs/mdcc/ASR/zipformer/streaming_decode.py new file mode 100755 index 0000000000..dadb0b55fd --- /dev/null +++ b/egs/mdcc/ASR/zipformer/streaming_decode.py @@ -0,0 +1,881 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import MdccAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + encoder_out=encoder_out, + streams=decode_streams, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + The Lexicon. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + list(decode_streams[i].ground_truth.strip()), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + key = f"greedy_search_{key}" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_{key}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}_{key}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + mdcc = MdccAsrDataModule(args) + + valid_cuts = mdcc.valid_cuts() + test_cuts = mdcc.test_cuts() + + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/subsampling.py b/egs/mdcc/ASR/zipformer/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py new file mode 100755 index 0000000000..2fae668444 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/train.py @@ -0,0 +1,1345 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import MdccAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + mdcc = MdccAsrDataModule(args) + + train_cuts = mdcc.train_cuts() + valid_cuts = mdcc.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = mdcc.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) + + valid_dl = mdcc.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + MdccAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mdcc/ASR/zipformer/zipformer.py b/egs/mdcc/ASR/zipformer/zipformer.py new file mode 120000 index 0000000000..23011dda71 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py index 662ae01c51..489b38e657 100644 --- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -216,7 +216,7 @@ def train_dataloaders( logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/requirements.txt b/requirements.txt index e64afd1eec..6bafa6aca3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -kaldifst +kaldifst>1.7.0 kaldilm kaldialign num2words @@ -14,4 +14,7 @@ onnxruntime==1.16.3 # style check session: black==22.3.0 isort==5.10.1 -flake8==5.0.4 \ No newline at end of file +flake8==5.0.4 + +# cantonese word segment support +pycantonese==3.4.0 \ No newline at end of file