From 2e8b3212108045f6ddfc8629fc1db16f689df9b0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 31 Jan 2024 17:23:42 +0800 Subject: [PATCH] Add fine-tuned whisper model on aishell (#565) See also https://github.com/k2-fsa/icefall/pull/1466 --- .github/workflows/export-whisper-to-onnx.yaml | 20 ++++++++++++++++--- scripts/whisper/export-onnx.py | 17 +++++++++++++++- scripts/whisper/test.py | 6 +++--- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index b755b386f..19237cba8 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -15,9 +15,9 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-latest] + os: [ubuntu-latest] # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"] - model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"] + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"] python-version: ["3.8"] steps: @@ -49,9 +49,19 @@ jobs: elif [[ $model == distil-small.en ]]; then wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin ls -lh + elif [[ $model == medium-aishell ]]; then + wget -q -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt + ls -lh fi python3 ./export-onnx.py --model ${{ matrix.model }} # python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ + # + if [[ $model == medium-aishell ]]; then + ls -lh *.onnx + rm -fv medium-aishell-encoder.onnx + rm -fv medium-aishell-decoder.onnx + fi + ls -lh @@ -59,6 +69,7 @@ jobs: ls -lh distil*original-model.bin || true rm -rf ~/.cache/whisper rm -f distil*original-model.bin + rm -f medium-aishell.pt src=sherpa-onnx-whisper-${{ matrix.model }} @@ -132,7 +143,10 @@ jobs: git config --global user.name "Fangjun Kuang" GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface - rm -rf huggingface/* + + if [[ $model != medium-aishell ]]; then + rm -rf huggingface/* + fi if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then mv $src.tar* ./huggingface diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 4c85834dc..1bfe03d0f 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -44,7 +44,9 @@ def get_args(): "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2", - "distil-medium.en", "distil-small.en", "distil-large-v2" + "distil-medium.en", "distil-small.en", "distil-large-v2", + # for fine-tuned models from icefall + "medium-aishell", ], # fmt: on ) @@ -340,6 +342,19 @@ def main(): """ ) model = whisper.load_model(filename) + elif name == "medium-aishell": + filename = "./medium-aishell.pt" + if not Path(filename).is_file(): + raise ValueError( + """ + Please go to https://huggingface.co/yuekai/icefall_asr_aishell_whisper/tree/main/exp_medium + to download whisper-medium-aishell1-epoch-10-avg-4.pt + You can use the following command to do that: + + wget -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt + """ + ) + model = whisper.load_model(filename) else: model = whisper.load_model(name) print(model.dims) diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 6941d347a..014a19e6a 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -257,9 +257,9 @@ def compute_features(filename: str) -> torch.Tensor: mel = (log_spec + 4.0) / 4.0 # mel (T, 80) - # We pad 50 frames at the end so that it is able to detect eot - # You can use another value instead of 50. - mel = torch.nn.functional.pad(mel, (0, 0, 0, 1000), "constant", 0) + # We pad 1500 frames at the end so that it is able to detect eot + # You can use another value instead of 1500. + mel = torch.nn.functional.pad(mel, (0, 0, 0, 1500), "constant", 0) # Note that if it throws for a multilingual model, # please use a larger value, say 300