From 6059f34ed1189506da9780a17735c104bca5ace4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Sep 2023 15:48:06 +0800 Subject: [PATCH 1/4] small fixes --- sherpa-onnx/csrc/offline-paraformer-decoder.h | 3 +- ...ffline-paraformer-greedy-search-decoder.cc | 31 ++++++++++--------- ...offline-paraformer-greedy-search-decoder.h | 3 +- sherpa-onnx/csrc/offline-paraformer-model.cc | 13 +++----- sherpa-onnx/csrc/offline-paraformer-model.h | 11 ++++--- .../csrc/offline-recognizer-paraformer-impl.h | 10 ++++-- 6 files changed, 41 insertions(+), 30 deletions(-) diff --git a/sherpa-onnx/csrc/offline-paraformer-decoder.h b/sherpa-onnx/csrc/offline-paraformer-decoder.h index 1b783e88d..46d5b0ad2 100644 --- a/sherpa-onnx/csrc/offline-paraformer-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-decoder.h @@ -28,7 +28,8 @@ class OfflineParaformerDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value log_probs, Ort::Value token_num) = 0; + Ort::Value log_probs, Ort::Value token_num, + Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 619b33495..919be7426 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -10,26 +10,29 @@ namespace sherpa_onnx { std::vector -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, - Ort::Value /*token_num*/) { - std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); +OfflineParaformerGreedySearchDecoder::Decode( + Ort::Value /*log_probs*/, Ort::Value token_num, + Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/ +) { + std::vector shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = shape[0]; - int32_t num_tokens = shape[1]; - int32_t vocab_size = shape[2]; + int32_t max_num_tokens = shape[1]; std::vector results(batch_size); - for (int32_t i = 0; i != batch_size; ++i) { - const float *p = - log_probs.GetTensorData() + i * num_tokens * vocab_size; - for (int32_t k = 0; k != num_tokens; ++k) { - auto max_idx = static_cast( - std::distance(p, std::max_element(p, p + vocab_size))); - if (max_idx == eos_id_) break; + if (!us_cif_peak) { + // when timestamp is enabled, the data type of token_num is int32_t + const int64_t *p_token = token_num.GetTensorData(); - results[i].tokens.push_back(max_idx); + for (int32_t i = 0; i != batch_size; ++i, p_token += max_num_tokens) { + for (int32_t k = 0; k != max_num_tokens; ++k) { + int32_t t = p_token[k]; + if (t == eos_id_) { + break; + } - p += vocab_size; + results[i].tokens.push_back(t); + } } } diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h index 1f48e8c84..eba3fc04b 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h @@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { : eos_id_(eos_id) {} std::vector Decode( - Ort::Value log_probs, Ort::Value /*token_num*/) override; + Ort::Value log_probs, Ort::Value token_num, + Ort::Value us_cif_peak = Ort::Value(nullptr)) override; private: int32_t eos_id_; diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index 614b2cc61..874374f18 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -36,16 +36,13 @@ class OfflineParaformerModel::Impl { } #endif - std::pair Forward(Ort::Value features, - Ort::Value features_length) { + std::vector Forward(Ort::Value features, + Ort::Value features_length) { std::array inputs = {std::move(features), std::move(features_length)}; - auto out = - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), - output_names_ptr_.data(), output_names_ptr_.size()); - - return {std::move(out[0]), std::move(out[1])}; + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); } int32_t VocabSize() const { return vocab_size_; } @@ -119,7 +116,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr, OfflineParaformerModel::~OfflineParaformerModel() = default; -std::pair OfflineParaformerModel::Forward( +std::vector OfflineParaformerModel::Forward( Ort::Value features, Ort::Value features_length) { return impl_->Forward(std::move(features), std::move(features_length)); } diff --git a/sherpa-onnx/csrc/offline-paraformer-model.h b/sherpa-onnx/csrc/offline-paraformer-model.h index 1fe7e84d5..d5c2329f6 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.h +++ b/sherpa-onnx/csrc/offline-paraformer-model.h @@ -5,7 +5,6 @@ #define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ #include -#include #include #if __ANDROID_API__ >= 9 @@ -35,13 +34,17 @@ class OfflineParaformerModel { * valid frames in `features` before padding. * Its dtype is int32_t. * - * @return Return a pair containing: + * @return Return a vector containing: * - log_probs: A 3-D tensor of shape (N, T', vocab_size) * - token_num: A 1-D tensor of shape (N, T') containing number * of valid tokens in each utterance. Its dtype is int64_t. + * If it is a model supporting timestamps, then there are additional two + * outputs: + * - us_alphas + * - us_cif_peak */ - std::pair Forward(Ort::Value features, - Ort::Value features_length); + std::vector Forward(Ort::Value features, + Ort::Value features_length); /** Return the vocabulary size of the model */ diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 9d7186124..ea4494851 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -184,7 +184,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { // i.e., -23.025850929940457f Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); - std::pair t{nullptr, nullptr}; + std::vector t; try { t = model_->Forward(std::move(x), std::move(x_length)); } catch (const Ort::Exception &ex) { @@ -193,7 +193,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { return; } - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + std::vector results; + if (t.size() == 2) { + results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + } else { + results = + decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3])); + } for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_); From 1b4fde6f7747fb00d6ab8e4a9412141f3993ae48 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Sep 2023 15:54:50 +0800 Subject: [PATCH 2/4] refactor --- ...ffline-paraformer-greedy-search-decoder.cc | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 919be7426..7cd5fa08f 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -11,28 +11,29 @@ namespace sherpa_onnx { std::vector OfflineParaformerGreedySearchDecoder::Decode( - Ort::Value /*log_probs*/, Ort::Value token_num, + Ort::Value log_probs, Ort::Value /*token_num*/, Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/ ) { - std::vector shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = shape[0]; - int32_t max_num_tokens = shape[1]; + int32_t num_tokens = shape[1]; + int32_t vocab_size = shape[2]; std::vector results(batch_size); - if (!us_cif_peak) { - // when timestamp is enabled, the data type of token_num is int32_t - const int64_t *p_token = token_num.GetTensorData(); + for (int32_t i = 0; i != batch_size; ++i) { + const float *p = + log_probs.GetTensorData() + i * num_tokens * vocab_size; + for (int32_t k = 0; k != num_tokens; ++k) { + auto max_idx = static_cast( + std::distance(p, std::max_element(p, p + vocab_size))); + if (max_idx == eos_id_) { + break; + } - for (int32_t i = 0; i != batch_size; ++i, p_token += max_num_tokens) { - for (int32_t k = 0; k != max_num_tokens; ++k) { - int32_t t = p_token[k]; - if (t == eos_id_) { - break; - } + results[i].tokens.push_back(max_idx); - results[i].tokens.push_back(t); - } + p += vocab_size; } } From b6d80570bc068324467380b3b60af46a392a1d2b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Sep 2023 16:20:39 +0800 Subject: [PATCH 3/4] add timestamps support --- .github/scripts/test-offline-transducer.sh | 27 ++++++++++++++++ sherpa-onnx/csrc/offline-paraformer-decoder.h | 5 +++ ...ffline-paraformer-greedy-search-decoder.cc | 31 +++++++++++++++++++ .../csrc/offline-recognizer-paraformer-impl.h | 1 + 4 files changed, 64 insertions(+) diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 9fc1107ab..8cb3c5e8d 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -123,3 +123,30 @@ time $EXE \ $repo/test_wavs/8k.wav rm -rf $repo + +log "------------------------------------------------------------" +log "Run Paraformer (Chinese) with timestamps" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +ls -lh *.onnx +popd + +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo diff --git a/sherpa-onnx/csrc/offline-paraformer-decoder.h b/sherpa-onnx/csrc/offline-paraformer-decoder.h index 46d5b0ad2..2effdfe1b 100644 --- a/sherpa-onnx/csrc/offline-paraformer-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-decoder.h @@ -14,6 +14,11 @@ namespace sherpa_onnx { struct OfflineParaformerDecoderResult { /// The decoded token IDs std::vector tokens; + + // it contains the start time of each token in seconds + // + // len(timestamps) == len(tokens) + std::vector timestamps; }; class OfflineParaformerDecoder { diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 7cd5fa08f..95b582ca0 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -7,6 +7,8 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { std::vector @@ -35,6 +37,35 @@ OfflineParaformerGreedySearchDecoder::Decode( p += vocab_size; } + + if (us_cif_peak) { + int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1]; + + const auto *peak = us_cif_peak.GetTensorData() + i * dim; + std::vector timestamps; + timestamps.reserve(results[i].tokens.size()); + + // 10.0: frameshift is 10 milliseconds + // 6: LfrWindowSize + // 3: us_cif_peak is upsampled by a factor of 3 + // 1000: milliseconds to seconds + float scale = 10.0 * 6 / 3 / 1000; + + for (int32_t k = 0; k != dim; ++k) { + if (peak[k] > 1 - 1e-4) { + timestamps.push_back(k * scale); + } + } + timestamps.pop_back(); + + if (timestamps.size() == results[i].tokens.size()) { + results[i].timestamps = std::move(timestamps); + } else { + SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i, + static_cast(results[i].tokens.size()), + static_cast(timestamps.size())); + } + } } return results; diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index ea4494851..3c96f03bc 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert( const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + r.timestamps = src.timestamps; std::string text; From 42f042d8f97163048c8a8a53ac176ffa328e8bee Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Sep 2023 17:44:34 +0800 Subject: [PATCH 4/4] Update C API to include timestamps for offline ASR --- sherpa-onnx/c-api/c-api.cc | 11 ++++ sherpa-onnx/c-api/c-api.h | 8 +++ ...ffline-paraformer-greedy-search-decoder.cc | 1 + sherpa-onnx/csrc/offline-paraformer-model.cc | 1 + swift-api-examples/SherpaOnnx.swift | 17 ++++++ .../decode-file-non-streaming.swift | 56 ++++++++++++++----- 6 files changed, 80 insertions(+), 14 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 6a989d542..520c47542 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( std::copy(text.begin(), text.end(), const_cast(r->text)); const_cast(r->text)[text.size()] = 0; + if (!result.timestamps.empty()) { + r->timestamps = new float[result.timestamps.size()]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + r->count = result.timestamps.size(); + } else { + r->timestamps = nullptr; + r->count = 0; + } + return r; } void DestroyOfflineRecognizerResult( const SherpaOnnxOfflineRecognizerResult *r) { delete[] r->text; + delete[] r->timestamps; delete r; } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index d669bce27..71aa56426 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { const char *text; + + // Pointer to continuous memory which holds timestamps + // + // It is NULL if the model does not support timestamps + float *timestamps; + + // number of entries in timestamps + int32_t count; // TODO(fangjun): Add more fields } SherpaOnnxOfflineRecognizerResult; diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 95b582ca0..c1d89a3ab 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" #include +#include #include #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index 874374f18..ce1851062 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -6,6 +6,7 @@ #include #include +#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 524499b2a..72c497cf8 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult { return String(cString: result.pointee.text) } + var count: Int32 { + return result.pointee.count + } + + var timestamps: [Float] { + if let p = result.pointee.timestamps { + var timestamps: [Float] = [] + for index in 0..!) { self.result = result } diff --git a/swift-api-examples/decode-file-non-streaming.swift b/swift-api-examples/decode-file-non-streaming.swift index a9485c5fd..6d0b4e8b5 100644 --- a/swift-api-examples/decode-file-non-streaming.swift +++ b/swift-api-examples/decode-file-non-streaming.swift @@ -13,21 +13,45 @@ extension AVAudioPCMBuffer { } func run() { - let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" - let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" - let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" - let whisperConfig = sherpaOnnxOfflineWhisperModelConfig( - encoder: encoder, - decoder: decoder - ) + var recognizer: SherpaOnnxOfflineRecognizer + var modelConfig: SherpaOnnxOfflineModelConfig + var modelType = "whisper" + // modelType = "paraformer" - let modelConfig = sherpaOnnxOfflineModelConfig( - tokens: tokens, - whisper: whisperConfig, - debug: 0, - modelType: "whisper" - ) + if modelType == "whisper" { + let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" + let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" + let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" + + let whisperConfig = sherpaOnnxOfflineWhisperModelConfig( + encoder: encoder, + decoder: decoder + ) + + modelConfig = sherpaOnnxOfflineModelConfig( + tokens: tokens, + whisper: whisperConfig, + debug: 0, + modelType: "whisper" + ) + } else if modelType == "paraformer" { + let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx" + let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt" + let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig( + model: model + ) + + modelConfig = sherpaOnnxOfflineModelConfig( + tokens: tokens, + paraformer: paraformerConfig, + debug: 0, + modelType: "paraformer" + ) + } else { + print("Please specify a supported modelType \(modelType)") + return + } let featConfig = sherpaOnnxFeatureConfig( sampleRate: 16000, @@ -38,7 +62,7 @@ func run() { modelConfig: modelConfig ) - let recognizer = SherpaOnnxOfflineRecognizer(config: &config) + recognizer = SherpaOnnxOfflineRecognizer(config: &config) let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" let fileURL: NSURL = NSURL(fileURLWithPath: filePath) @@ -55,6 +79,10 @@ func run() { let array: [Float]! = audioFileBuffer?.array() let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate)) print("\nresult is:\n\(result.text)") + if result.timestamps.count != 0 { + print("\ntimestamps is:\n\(result.timestamps)") + } + } @main