Skip to content

Commit

Permalink
Support distil-whisper (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 6, 2023
1 parent 86baf43 commit a65cdc3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
35 changes: 32 additions & 3 deletions .github/workflows/export-whisper-to-onnx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,49 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
python-version: ["3.8"]

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
shell: bash
run: |
python3 -m pip install openai-whisper torch onnxruntime onnx
python3 -m pip install torch==1.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install openai-whisper==20230314 onnxruntime onnx
- name: export ${{ matrix.model }}
shell: bash
run: |
cd scripts/whisper
model=${{ matrix.model }}
echo "model: $model"
if [[ $model == distil-medium.en ]]; then
wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
ls -lh
fi
python3 ./export-onnx.py --model ${{ matrix.model }}
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
ls -lh
ls -lh ~/.cache/whisper
if [[ $model != distil-medium.en ]]; then
ls -lh ~/.cache/whisper
fi
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
model=${{ matrix.model }}
cd scripts/whisper
git config --global user.email "[email protected]"
Expand All @@ -54,6 +71,18 @@ jobs:
cp *tokens.txt ./huggingface
cd huggingface
if [[ $model == distil-medium.en ]]; then
mkdir test_wavs
cd test_wavs
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
git add .
cd ..
fi
git status
ls -lh
git lfs track "*.onnx"
Expand Down
23 changes: 21 additions & 2 deletions scripts/whisper/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def get_args():
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2"],
"large", "large-v1", "large-v2",
"distil-medium.en",
],
# fmt: on
)
return parser.parse_args()
Expand Down Expand Up @@ -257,10 +259,27 @@ def convert_tokens(name, model):
def main():
args = get_args()
name = args.model
print(args)
print(name)

opset_version = 13

model = whisper.load_model(name)
if name == "distil-medium.en":
filename = "./distil-medium-en-original-model.bin"
if not Path(filename):
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-medium.en
to download original-model.bin
You can use the following command to do that:
wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)

print(
f"number of model parameters: {name}",
sum(p.numel() for p in model.parameters()),
Expand Down

0 comments on commit a65cdc3

Please sign in to comment.