Skip to content

Commit

Permalink
Add fine-tuned whisper model on aishell (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 31, 2024
1 parent 0b18ccf commit 2e8b321
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
20 changes: 17 additions & 3 deletions .github/workflows/export-whisper-to-onnx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos-latest]
os: [ubuntu-latest]
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"]
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"]
python-version: ["3.8"]

steps:
Expand Down Expand Up @@ -49,16 +49,27 @@ jobs:
elif [[ $model == distil-small.en ]]; then
wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
ls -lh
elif [[ $model == medium-aishell ]]; then
wget -q -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt
ls -lh
fi
python3 ./export-onnx.py --model ${{ matrix.model }}
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
#
if [[ $model == medium-aishell ]]; then
ls -lh *.onnx
rm -fv medium-aishell-encoder.onnx
rm -fv medium-aishell-decoder.onnx
fi
ls -lh
ls -lh ~/.cache/whisper || true
ls -lh distil*original-model.bin || true
rm -rf ~/.cache/whisper
rm -f distil*original-model.bin
rm -f medium-aishell.pt
src=sherpa-onnx-whisper-${{ matrix.model }}
Expand Down Expand Up @@ -132,7 +143,10 @@ jobs:
git config --global user.name "Fangjun Kuang"
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
rm -rf huggingface/*
if [[ $model != medium-aishell ]]; then
rm -rf huggingface/*
fi
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
mv $src.tar* ./huggingface
Expand Down
17 changes: 16 additions & 1 deletion scripts/whisper/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def get_args():
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"distil-medium.en", "distil-small.en", "distil-large-v2"
"distil-medium.en", "distil-small.en", "distil-large-v2",
# for fine-tuned models from icefall
"medium-aishell",
],
# fmt: on
)
Expand Down Expand Up @@ -340,6 +342,19 @@ def main():
"""
)
model = whisper.load_model(filename)
elif name == "medium-aishell":
filename = "./medium-aishell.pt"
if not Path(filename).is_file():
raise ValueError(
"""
Please go to https://huggingface.co/yuekai/icefall_asr_aishell_whisper/tree/main/exp_medium
to download whisper-medium-aishell1-epoch-10-avg-4.pt
You can use the following command to do that:
wget -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)
Expand Down
6 changes: 3 additions & 3 deletions scripts/whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def compute_features(filename: str) -> torch.Tensor:
mel = (log_spec + 4.0) / 4.0
# mel (T, 80)

# We pad 50 frames at the end so that it is able to detect eot
# You can use another value instead of 50.
mel = torch.nn.functional.pad(mel, (0, 0, 0, 1000), "constant", 0)
# We pad 1500 frames at the end so that it is able to detect eot
# You can use another value instead of 1500.
mel = torch.nn.functional.pad(mel, (0, 0, 0, 1500), "constant", 0)
# Note that if it throws for a multilingual model,
# please use a larger value, say 300

Expand Down

0 comments on commit 2e8b321

Please sign in to comment.