From 117cd7bb8c262580718d87e17f01510c7ead3f92 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 12 Jul 2024 23:47:39 +0800 Subject: [PATCH] Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114) --- .github/workflows/export-whisper-to-onnx.yaml | 63 +++++++++---------- CHANGELOG.md | 3 +- CMakeLists.txt | 2 +- cmake/kaldi-native-fbank.cmake | 16 ++--- .../non-streaming-asr/pubspec.yaml | 2 +- dart-api-examples/streaming-asr/pubspec.yaml | 2 +- dart-api-examples/tts/pubspec.yaml | 2 +- dart-api-examples/vad/pubspec.yaml | 2 +- flutter-examples/streaming_asr/pubspec.yaml | 4 +- flutter-examples/tts/pubspec.yaml | 2 +- flutter/sherpa_onnx/pubspec.yaml | 12 ++-- .../ios/sherpa_onnx_ios.podspec | 2 +- .../macos/sherpa_onnx_macos.podspec | 2 +- nodejs-addon-examples/package.json | 2 +- scripts/dart/sherpa-onnx-pubspec.yaml | 2 +- scripts/whisper/.gitignore | 6 ++ scripts/whisper/export-onnx.py | 48 ++++++++++++-- scripts/whisper/test.py | 37 +++++++---- .../csrc/offline-recognizer-whisper-impl.h | 4 +- sherpa-onnx/csrc/offline-stream.cc | 12 ++-- sherpa-onnx/csrc/offline-stream.h | 5 +- sherpa-onnx/csrc/offline-whisper-model.cc | 6 ++ sherpa-onnx/csrc/offline-whisper-model.h | 1 + 23 files changed, 152 insertions(+), 85 deletions(-) diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index 2dd6aa283..070bc2333 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -15,9 +15,9 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest] - # model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"] - model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"] + os: [macos-latest] + model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell", "large", "large-v1", "large-v2", "distil-large-v2"] + # model: ["large", "large-v1", "large-v2", "large-v3", "distil-large-v2"] python-version: ["3.8"] steps: @@ -32,7 +32,7 @@ jobs: shell: bash run: | python3 -m pip install torch==1.13.0 torchaudio==0.13.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html - python3 -m pip install openai-whisper==20230314 onnxruntime onnx + python3 -m pip install openai-whisper==20231117 onnxruntime onnx soundfile librosa - name: export ${{ matrix.model }} shell: bash @@ -62,7 +62,6 @@ jobs: rm -fv medium-aishell-decoder.onnx fi - ls -lh ls -lh ~/.cache/whisper || true @@ -74,7 +73,8 @@ jobs: src=sherpa-onnx-whisper-${{ matrix.model }} cd .. - mv whisper $src + mkdir $src + mv -v whisper/$model* $src/ echo "------------------------------" @@ -97,19 +97,16 @@ jobs: ls -lh $src echo "--------------------" - if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then - #tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2. - tar cvjf $src.tar.bz2 $src - split -b 1G $src.tar.bz2 $src.tar.bz2. - rm $src.tar.bz2 - # cat $src.tar.gz.* | tar xjf - + if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then + echo "Don't release model to github for large models. $model" else tar cvjf $src.tar.bz2 $src fi - ls -lh + ls -lh - name: Release + if: matrix.model != 'large' && matrix.model != 'large-v1' && matrix.model != 'large-v2' && matrix.model != 'large-v3' && matrix.model != 'distil-large-v2' uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -119,19 +116,6 @@ jobs: repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: asr-models - - name: Test ${{ matrix.model }} - shell: bash - run: | - python3 -m pip install kaldi-native-fbank - git checkout . - model=${{ matrix.model }} - src=sherpa-onnx-whisper-$model - python3 scripts/whisper/test.py \ - --encoder $src/$model-encoder.int8.onnx \ - --decoder $src/$model-decoder.int8.onnx \ - --tokens $src/$model-tokens.txt \ - $src/test_wavs/0.wav - - name: Publish ${{ matrix.model }} to huggingface shell: bash env: @@ -144,27 +128,36 @@ jobs: export GIT_CLONE_PROTECTION_ACTIVE=false - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface + export GIT_LFS_SKIP_SMUDGE=1 + + git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface if [[ $model != medium-aishell ]]; then rm -rf huggingface/* fi - if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then - mv $src.tar* ./huggingface - else - cp -v $src/*.onnx ./huggingface - cp -v $src/*tokens* ./huggingface - cp -av $src/test_wavs ./huggingface - fi + cp -av $src/* ./huggingface/ cd huggingface git status ls -lh - git lfs track "*gz*" git lfs track "*onnx*" + git lfs track "*weights*" git add . git commit -m "upload ${{ matrix.model }}" git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main + + - name: Test ${{ matrix.model }} + shell: bash + run: | + python3 -m pip install kaldi-native-fbank + git checkout . + model=${{ matrix.model }} + src=sherpa-onnx-whisper-$model + time python3 scripts/whisper/test.py \ + --encoder $src/$model-encoder.onnx \ + --decoder $src/$model-decoder.onnx \ + --tokens $src/$model-tokens.txt \ + $src/test_wavs/0.wav diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bcffaf1c..90e1a2721 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 1.10.14 (to-be-released) +## 1.10.14 +* Support whisper large v3 * Update onnxruntime from v1.18.0 to v1.18.1 * Fix invalid utf8 sequence from Whisper for Dart API. diff --git a/CMakeLists.txt b/CMakeLists.txt index 203b8a569..b6c2a2ff8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(sherpa-onnx) # ./nodejs-addon-examples # ./dart-api-examples/ # ./CHANGELOG.md -set(SHERPA_ONNX_VERSION "1.10.13") +set(SHERPA_ONNX_VERSION "1.10.14") # Disable warning about # diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index ec2add8b1..2d87b6a8b 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,9 +1,9 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz") - set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.3.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=335fe1daf1b9bfb2a7b6bf03b64c4c4686c39077c57fb8058c02611981676638") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") + set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.19.3.tar.gz - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.19.3.tar.gz - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.19.3.tar.gz - /tmp/kaldi-native-fbank-1.19.3.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.19.3.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz + /tmp/kaldi-native-fbank-1.20.0.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/dart-api-examples/non-streaming-asr/pubspec.yaml b/dart-api-examples/non-streaming-asr/pubspec.yaml index db5fd01cc..81476e150 100644 --- a/dart-api-examples/non-streaming-asr/pubspec.yaml +++ b/dart-api-examples/non-streaming-asr/pubspec.yaml @@ -10,7 +10,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/streaming-asr/pubspec.yaml b/dart-api-examples/streaming-asr/pubspec.yaml index 338788695..cf9be993f 100644 --- a/dart-api-examples/streaming-asr/pubspec.yaml +++ b/dart-api-examples/streaming-asr/pubspec.yaml @@ -11,7 +11,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/tts/pubspec.yaml b/dart-api-examples/tts/pubspec.yaml index c915a167e..cde8ede5b 100644 --- a/dart-api-examples/tts/pubspec.yaml +++ b/dart-api-examples/tts/pubspec.yaml @@ -8,7 +8,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/vad/pubspec.yaml b/dart-api-examples/vad/pubspec.yaml index f82700f29..4bf5a3344 100644 --- a/dart-api-examples/vad/pubspec.yaml +++ b/dart-api-examples/vad/pubspec.yaml @@ -9,7 +9,7 @@ environment: sdk: ^3.4.0 dependencies: - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 path: ^1.9.0 args: ^2.5.0 diff --git a/flutter-examples/streaming_asr/pubspec.yaml b/flutter-examples/streaming_asr/pubspec.yaml index db647fe24..596c40b72 100644 --- a/flutter-examples/streaming_asr/pubspec.yaml +++ b/flutter-examples/streaming_asr/pubspec.yaml @@ -5,7 +5,7 @@ description: > publish_to: 'none' -version: 1.10.13 +version: 1.10.14 topics: - speech-recognition @@ -30,7 +30,7 @@ dependencies: record: ^5.1.0 url_launcher: ^6.2.6 - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 # sherpa_onnx: # path: ../../flutter/sherpa_onnx diff --git a/flutter-examples/tts/pubspec.yaml b/flutter-examples/tts/pubspec.yaml index b14acada9..4ddaafd6e 100644 --- a/flutter-examples/tts/pubspec.yaml +++ b/flutter-examples/tts/pubspec.yaml @@ -17,7 +17,7 @@ dependencies: cupertino_icons: ^1.0.6 path_provider: ^2.1.3 path: ^1.9.0 - sherpa_onnx: ^1.10.13 + sherpa_onnx: ^1.10.14 url_launcher: ^6.2.6 audioplayers: ^5.0.0 diff --git a/flutter/sherpa_onnx/pubspec.yaml b/flutter/sherpa_onnx/pubspec.yaml index 4a8f7f04d..cbaf7c8e5 100644 --- a/flutter/sherpa_onnx/pubspec.yaml +++ b/flutter/sherpa_onnx/pubspec.yaml @@ -17,7 +17,7 @@ topics: - voice-activity-detection # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec -version: 1.10.13 +version: 1.10.14 homepage: https://github.com/k2-fsa/sherpa-onnx @@ -30,19 +30,19 @@ dependencies: flutter: sdk: flutter - sherpa_onnx_android: ^1.10.13 + sherpa_onnx_android: ^1.10.14 # path: ../sherpa_onnx_android - sherpa_onnx_macos: ^1.10.13 + sherpa_onnx_macos: ^1.10.14 # path: ../sherpa_onnx_macos - sherpa_onnx_linux: ^1.10.13 + sherpa_onnx_linux: ^1.10.14 # path: ../sherpa_onnx_linux # - sherpa_onnx_windows: ^1.10.13 + sherpa_onnx_windows: ^1.10.14 # path: ../sherpa_onnx_windows - sherpa_onnx_ios: ^1.10.13 + sherpa_onnx_ios: ^1.10.14 # sherpa_onnx_ios: # path: ../sherpa_onnx_ios diff --git a/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec b/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec index 3555e88aa..2527ff042 100644 --- a/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec +++ b/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec @@ -7,7 +7,7 @@ # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c Pod::Spec.new do |s| s.name = 'sherpa_onnx_ios' - s.version = '1.10.13' + s.version = '1.10.14' s.summary = 'A new Flutter FFI plugin project.' s.description = <<-DESC A new Flutter FFI plugin project. diff --git a/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec b/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec index 9df7691fd..60bbdc791 100644 --- a/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec +++ b/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec @@ -4,7 +4,7 @@ # Pod::Spec.new do |s| s.name = 'sherpa_onnx_macos' - s.version = '1.10.13' + s.version = '1.10.14' s.summary = 'sherpa-onnx Flutter FFI plugin project.' s.description = <<-DESC sherpa-onnx Flutter FFI plugin project. diff --git a/nodejs-addon-examples/package.json b/nodejs-addon-examples/package.json index 807926201..3877faf5e 100644 --- a/nodejs-addon-examples/package.json +++ b/nodejs-addon-examples/package.json @@ -1,5 +1,5 @@ { "dependencies": { - "sherpa-onnx-node": "^1.10.13" + "sherpa-onnx-node": "^1.10.14" } } diff --git a/scripts/dart/sherpa-onnx-pubspec.yaml b/scripts/dart/sherpa-onnx-pubspec.yaml index f2a7dffda..73ff2f14a 100644 --- a/scripts/dart/sherpa-onnx-pubspec.yaml +++ b/scripts/dart/sherpa-onnx-pubspec.yaml @@ -17,7 +17,7 @@ topics: - voice-activity-detection # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec -version: 1.10.13 +version: 1.10.14 homepage: https://github.com/k2-fsa/sherpa-onnx diff --git a/scripts/whisper/.gitignore b/scripts/whisper/.gitignore index fbe9a87e7..98383be8e 100644 --- a/scripts/whisper/.gitignore +++ b/scripts/whisper/.gitignore @@ -2,3 +2,9 @@ *.config *.ort *-tokens.txt +*.bias +*.weights +*.weight +*.*embedding +_Const* +onnx__* diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 1bfe03d0f..382e9a381 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -32,6 +32,9 @@ TextDecoder, ) +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def get_args(): parser = argparse.ArgumentParser() @@ -43,8 +46,9 @@ def get_args(): choices=[ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", - "large", "large-v1", "large-v2", + "large", "large-v1", "large-v2", "large-v3", "distil-medium.en", "distil-small.en", "distil-large-v2", + # "distil-large-v3", # distil-large-v3 is not supported! # for fine-tuned models from icefall "medium-aishell", ], @@ -63,12 +67,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): Key-value pairs. """ model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = str(value) - onnx.save(model, filename) + if "large" in filename: + external_filename = filename.split(".onnx")[0] + onnx.save( + model, + filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_filename + ".weights", + ) + else: + onnx.save(model, filename) def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor): @@ -376,7 +394,9 @@ def main(): # write tokens - tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) + tokenizer = whisper.tokenizer.get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) model.eval() print(model.dims) @@ -384,10 +404,15 @@ def main(): audio = whisper.pad_or_trim(audio) assert audio.shape == (16000 * 30,), audio.shape - # make log-Mel spectrogram and move to the same device as the model - mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) + if args.model in ("large", "large-v3"): + n_mels = 128 + else: + n_mels = 80 + mel = ( + whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0) + ) batch_size = 1 - assert mel.shape == (batch_size, 80, 30 * 100) + assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape encoder = AudioEncoderTensorCache(model.encoder, model.decoder) @@ -546,6 +571,17 @@ def main(): }, ) + if "large" in args.model: + decoder_external_filename = decoder_filename.split(".onnx")[0] + decoder_model = onnx.load(decoder_filename) + onnx.save( + decoder_model, + decoder_filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=decoder_external_filename + ".weights", + ) + if "large" in args.model: # it causes errors for large models, so skip it. return diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 014a19e6a..160933043 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -9,9 +9,10 @@ from typing import Tuple import kaldi_native_fbank as knf +import numpy as np import onnxruntime as ort +import soundfile as sf import torch -import torchaudio def get_args(): @@ -98,7 +99,6 @@ def init_encoder(self, encoder: str): self.blank = int(meta["blank_id"]) self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) - self.sot_sequence.append(self.no_timestamps) self.all_language_tokens = list( @@ -226,7 +226,18 @@ def load_tokens(filename): return tokens -def compute_features(filename: str) -> torch.Tensor: +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_features(filename: str, dim: int = 80) -> torch.Tensor: """ Args: filename: @@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor: Returns: Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. """ - wave, sample_rate = torchaudio.load(filename) - audio = wave[0].contiguous() # only use the first channel + wave, sample_rate = load_audio(filename) if sample_rate != 16000: - audio = torchaudio.functional.resample( - audio, orig_freq=sample_rate, new_freq=16000 - ) + import librosa + + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 features = [] - online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) - online_whisper_fbank.accept_waveform(16000, audio.numpy()) + opts = knf.WhisperFeatureOptions() + opts.dim = dim + online_whisper_fbank = knf.OnlineWhisperFbank(opts) + online_whisper_fbank.accept_waveform(16000, wave) online_whisper_fbank.input_finished() for i in range(online_whisper_fbank.num_frames_ready): f = online_whisper_fbank.get_frame(i) @@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor: def main(): args = get_args() - mel = compute_features(args.sound_file) model = OnnxModel(args.encoder, args.decoder) + dim = 80 if "large-v3" not in args.encoder else 128 + mel = compute_features(args.sound_file, dim=dim) n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) @@ -313,6 +327,7 @@ def main(): n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() + print(model.sot_sequence) tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) offset = torch.zeros(1, dtype=torch.int64) logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index 358917608..e56f07550 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -88,7 +88,9 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { } std::unique_ptr CreateStream() const override { - return std::make_unique(WhisperTag{}); + WhisperTag tag; + tag.dim = model_->FeatureDim(); + return std::make_unique(tag); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 79bdb5c54..c7c1dc0c2 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -97,12 +97,16 @@ class OfflineStream::Impl { } } - explicit Impl(WhisperTag /*tag*/) { + explicit Impl(WhisperTag tag) { config_.normalize_samples = true; opts_.frame_opts.samp_freq = 16000; - opts_.mel_opts.num_bins = 80; // not used - whisper_fbank_ = - std::make_unique(opts_.frame_opts); + opts_.mel_opts.num_bins = tag.dim; + + knf::WhisperFeatureOptions whisper_opts; + whisper_opts.frame_opts = opts_.frame_opts; + whisper_opts.dim = tag.dim; + + whisper_fbank_ = std::make_unique(whisper_opts); config_.sampling_rate = opts_.frame_opts.samp_freq; } diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 9df46d04e..e3c346fc4 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -35,7 +35,10 @@ struct OfflineRecognitionResult { std::string AsJsonString() const; }; -struct WhisperTag {}; +struct WhisperTag { + int32_t dim = 80; +}; + struct CEDTag {}; class OfflineStream { diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index f73234d95..7812e1d09 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -217,6 +217,8 @@ class OfflineWhisperModel::Impl { int32_t VocabSize() const { return n_vocab_; } + int32_t FeatureDim() const { return n_mels_; } + int32_t Translate() const { return translate_; } bool IsMultiLingual() const { return is_multilingual_; } @@ -242,6 +244,7 @@ class OfflineWhisperModel::Impl { } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels"); SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); @@ -316,6 +319,7 @@ class OfflineWhisperModel::Impl { std::unordered_map id2lang_; // model meta data + int32_t n_mels_ = 80; int32_t n_text_layer_ = 0; int32_t n_text_ctx_ = 0; int32_t n_text_state_ = 0; @@ -414,6 +418,8 @@ int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } +int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); } + int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } bool OfflineWhisperModel::IsMultiLingual() const { diff --git a/sherpa-onnx/csrc/offline-whisper-model.h b/sherpa-onnx/csrc/offline-whisper-model.h index 892af24af..866714bc5 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.h +++ b/sherpa-onnx/csrc/offline-whisper-model.h @@ -102,6 +102,7 @@ class OfflineWhisperModel { int32_t SOT() const; int32_t TextCtx() const; int32_t VocabSize() const; + int32_t FeatureDim() const; int32_t Translate() const; bool IsMultiLingual() const;