Skip to content

Commit

Permalink
Remove the 30-second constraint from whisper. (k2-fsa#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 7, 2023
1 parent eaa1b3e commit 6a5fe3f
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 78 deletions.
36 changes: 10 additions & 26 deletions .github/scripts/test-offline-whisper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ which $EXE
names=(
tiny.en
base.en
# small.en
# medium.en
small.en
medium.en
tiny
base
small
medium
)

for name in ${names[@]}; do
Expand All @@ -33,8 +37,8 @@ for name in ${names[@]}; do
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
git lfs pull --include "*.ort"
ls -lh *.{onnx,ort}
# git lfs pull --include "*.ort"
ls -lh *.onnx
popd

log "test fp32 onnx"
Expand All @@ -43,6 +47,7 @@ for name in ${names[@]}; do
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.onnx \
--whisper-decoder=$repo/${name}-decoder.onnx \
--whisper-tail-paddings=500 \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
Expand All @@ -54,28 +59,7 @@ for name in ${names[@]}; do
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

log "test fp32 ort"

time $EXE \
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.ort \
--whisper-decoder=$repo/${name}-decoder.ort \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

log "test int8 ort"

time $EXE \
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.int8.ort \
--whisper-decoder=$repo/${name}-decoder.int8.ort \
--whisper-tail-paddings=500 \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
Expand Down
60 changes: 40 additions & 20 deletions .github/workflows/export-whisper-to-onnx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos-latest]
os: [ubuntu-latest]
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
python-version: ["3.8"]

Expand Down Expand Up @@ -44,49 +44,69 @@ jobs:
ls -lh
fi
python3 ./export-onnx.py --model ${{ matrix.model }}
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
ls -lh
if [[ $model != distil-medium.en ]]; then
ls -lh ~/.cache/whisper
fi
src=sherpa-onnx-whisper-${{ matrix.model }}
mkdir $src
cp *.onnx $src/
cp *tokens.txt $src
cd $src
mkdir -p test_wavs
cd test_wavs
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
cd ../..
mv $src ../..
cd ../..
echo "--------------------"
ls -lh
ls -lh $src
echo "--------------------"
tar cjvf ./$src.tar.bz2 $src
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models

- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
model=${{ matrix.model }}
cd scripts/whisper
src=sherpa-onnx-whisper-${{ matrix.model }}
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
rm -rf huggingface/*
cp *.onnx ./huggingface
cp *.ort ./huggingface
cp *tokens.txt ./huggingface
cp -av $src/* ./huggingface/
cd huggingface
if [[ $model == distil-medium.en ]]; then
mkdir test_wavs
cd test_wavs
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
git add .
cd ..
fi
git status
ls -lh
git lfs track "*.onnx"
git lfs track "*.ort"
# git lfs track "*.ort"
git add .
git commit -m "upload ${{ matrix.model }}"
git push https://csukuangfj:[email protected]/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
20 changes: 10 additions & 10 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ jobs:
name: release-static
path: build/bin/*

- name: Test offline Whisper
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-whisper.sh
- name: Test online CTC
shell: bash
run: |
Expand Down Expand Up @@ -139,16 +149,6 @@ jobs:
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-whisper.sh
- name: Test offline transducer
shell: bash
run: |
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ jobs:
.github/scripts/test-online-paraformer.sh
- name: Test offline Whisper for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-whisper.sh
# - name: Test offline Whisper for windows x86
# shell: bash
# run: |
# export PATH=$PWD/build/bin/Release:$PATH
# export EXE=sherpa-onnx-offline.exe
#
# .github/scripts/test-offline-whisper.sh

- name: Test offline CTC for windows x86
shell: bash
Expand Down
53 changes: 46 additions & 7 deletions scripts/whisper/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Thanks to https://github.com/TadaoYamaoka
for making the onnx export script public.
Note that we have removed the 30 seconds constraint from whisper. You can
use any T <= 30.
"""

import argparse
Expand All @@ -17,6 +20,7 @@

import onnx
import torch
import torch.nn.functional as F
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import Tensor, nn

Expand Down Expand Up @@ -65,6 +69,39 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
onnx.save(model, filename)


def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)

if False:
# This branch contains the original code
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
else:
# This branch contains the actual changes
assert (
x.shape[2] == self.positional_embedding.shape[1]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
assert (
x.shape[1] == self.positional_embedding.shape[0]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)

for block in self.blocks:
x = block(x)

x = self.ln_post(x)
return x


AudioEncoder.forward = modified_audio_encoder_forward


class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
super().__init__()
Expand Down Expand Up @@ -279,6 +316,7 @@ def main():
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)

print(
f"number of model parameters: {name}",
Expand Down Expand Up @@ -311,19 +349,20 @@ def main():
assert mel.shape == (batch_size, 80, 30 * 100)

encoder = AudioEncoderTensorCache(model.encoder, model.decoder)

n_layer_cross_k, n_layer_cross_v = encoder(mel)
assert n_layer_cross_k.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_k.shape
), (n_layer_cross_k.shape, model.dims)
assert n_layer_cross_v.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), n_layer_cross_v.shape
), (n_layer_cross_v.shape, model.dims)

encoder_filename = f"{name}-encoder.onnx"
torch.onnx.export(
Expand All @@ -334,9 +373,9 @@ def main():
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)

Expand Down Expand Up @@ -461,8 +500,8 @@ def main():
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio"},
"n_layer_cross_v": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)

Expand Down
15 changes: 14 additions & 1 deletion scripts/whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,21 @@ def compute_features(filename: str) -> torch.Tensor:
log_spec = torch.clamp(features, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
mel = (log_spec + 4.0) / 4.0
# mel (T, 80)

# We pad 50 frames at the end so that it is able to detect eot
# You can use another value instead of 50.
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)
# Note that if it throws for a multilingual model,
# please use a larger value, say 300

target = 3000
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
if mel.shape[0] > target:
mel = mel[:target]

# We don't need to pad it to 30 seconds now!
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)

mel = mel.t().unsqueeze(0)

return mel
Expand Down
24 changes: 22 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,35 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {

NormalizeFeatures(f.data(), num_frames, feat_dim);

std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
// note that 50 is an experience value.
// see also ../../scripts/whisper/test.py
//
// You can replace 50 by other values, say, 100.
//
// Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token.
int32_t tail_padding_frames = 50;
if (model_->IsMultiLingual()) {
// 300 is an experience value. If it throws, please use a larger value.
tail_padding_frames = 300;
}

if (config_.model_config.whisper.tail_paddings > 0) {
tail_padding_frames = config_.model_config.whisper.tail_paddings;
}

int32_t actual_frames =
std::min(num_frames + tail_padding_frames, max_num_frames);

std::array<int64_t, 3> shape{1, actual_frames, feat_dim};

Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.begin(), f.end(), p_mel);

memset(p_mel + f.size(), 0,
(max_num_frames - num_frames) * feat_dim * sizeof(float));
(actual_frames - num_frames) * feat_dim * sizeof(float));
mel = Transpose12(model_->Allocator(), &mel);

try {
Expand Down
Loading

0 comments on commit 6a5fe3f

Please sign in to comment.