From 2b1b05752b9be78ae051b661ac4edf880484771e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Nov 2024 18:34:27 +0800 Subject: [PATCH] Add streaming ASR support for HarmonyOS. --- sherpa-onnx/c-api/c-api.cc | 28 +++++++++++++++- sherpa-onnx/c-api/c-api.h | 7 ++++ sherpa-onnx/csrc/offline-recognizer-impl.cc | 6 ++-- .../csrc/online-conformer-transducer-model.cc | 27 +++++++++++++-- .../csrc/online-conformer-transducer-model.h | 11 ++----- sherpa-onnx/csrc/online-ctc-model.cc | 23 +++++++++++-- sherpa-onnx/csrc/online-ctc-model.h | 10 ++---- .../csrc/online-lstm-transducer-model.cc | 23 +++++++++++-- .../csrc/online-lstm-transducer-model.h | 11 ++----- sherpa-onnx/csrc/online-nemo-ctc-model.cc | 30 +++++++++++++---- sherpa-onnx/csrc/online-nemo-ctc-model.h | 10 ++---- sherpa-onnx/csrc/online-paraformer-model.cc | 30 +++++++++++++---- sherpa-onnx/csrc/online-paraformer-model.h | 10 ++---- sherpa-onnx/csrc/online-recognizer-ctc-impl.h | 5 ++- sherpa-onnx/csrc/online-recognizer-impl.cc | 33 ++++++++++++++----- sherpa-onnx/csrc/online-recognizer-impl.h | 14 +++----- .../csrc/online-recognizer-paraformer-impl.h | 6 ++-- .../csrc/online-recognizer-transducer-impl.h | 15 +++------ .../online-recognizer-transducer-nemo-impl.h | 10 ++---- sherpa-onnx/csrc/online-recognizer.cc | 24 ++++++++++++-- sherpa-onnx/csrc/online-recognizer.h | 10 ++---- sherpa-onnx/csrc/online-transducer-model.cc | 21 ++++++++++-- sherpa-onnx/csrc/online-transducer-model.h | 10 ++---- .../csrc/online-transducer-nemo-model.cc | 30 +++++++++++++---- .../csrc/online-transducer-nemo-model.h | 11 ++----- sherpa-onnx/csrc/online-wenet-ctc-model.cc | 30 +++++++++++++---- sherpa-onnx/csrc/online-wenet-ctc-model.h | 10 ++---- .../csrc/online-zipformer-transducer-model.cc | 23 +++++++++++-- .../csrc/online-zipformer-transducer-model.h | 11 ++----- .../csrc/online-zipformer2-ctc-model.cc | 28 ++++++++++++---- .../csrc/online-zipformer2-ctc-model.h | 10 ++---- .../online-zipformer2-transducer-model.cc | 23 +++++++++++-- .../csrc/online-zipformer2-transducer-model.h | 10 ++---- sherpa-onnx/csrc/symbol-table.cc | 6 ++-- sherpa-onnx/python/csrc/vad-model.cc | 7 ++-- 35 files changed, 367 insertions(+), 206 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index c9dafe84b..fbc3a010a 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -56,7 +56,7 @@ struct SherpaOnnxDisplay { #define SHERPA_ONNX_OR(x, y) (x ? x : y) -const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( +static sherpa_onnx::OnlineRecognizerConfig GetOnlineRecognizerConfig( const SherpaOnnxOnlineRecognizerConfig *config) { sherpa_onnx::OnlineRecognizerConfig recognizer_config; @@ -151,9 +151,21 @@ const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); if (config->model_config.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", recognizer_config.ToString().c_str()); +#else SHERPA_ONNX_LOGE("%s\n", recognizer_config.ToString().c_str()); +#endif } + return recognizer_config; +} + +const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( + const SherpaOnnxOnlineRecognizerConfig *config) { + sherpa_onnx::OnlineRecognizerConfig recognizer_config = + GetOnlineRecognizerConfig(config); + if (!recognizer_config.Validate()) { SHERPA_ONNX_LOGE("Errors in config!"); return nullptr; @@ -1876,6 +1888,20 @@ SherpaOnnxOfflineSpeakerDiarizationProcessWithCallbackNoArg( #ifdef __OHOS__ +const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizerOHOS( + const SherpaOnnxOnlineRecognizerConfig *config, + NativeResourceManager *mgr) { + sherpa_onnx::OnlineRecognizerConfig recognizer_config = + GetOnlineRecognizerConfig(config); + + SherpaOnnxOnlineRecognizer *recognizer = new SherpaOnnxOnlineRecognizer; + + recognizer->impl = + std::make_unique(mgr, recognizer_config); + + return recognizer; +} + const SherpaOnnxOfflineRecognizer *SherpaOnnxCreateOfflineRecognizerOHOS( const SherpaOnnxOfflineRecognizerConfig *config, NativeResourceManager *mgr) { diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 31e86b8e5..1feaab306 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1527,6 +1527,13 @@ SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationDestroyResult( // It is for HarmonyOS typedef struct NativeResourceManager NativeResourceManager; +/// @param config Config for the recognizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// SherpaOnnxDestroyOnlineRecognizer() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaOnnxOnlineRecognizer * +SherpaOnnxCreateOnlineRecognizerOHOS( + const SherpaOnnxOnlineRecognizerConfig *config, NativeResourceManager *mgr); + /// @param config Config for the recognizer. /// @return Return a pointer to the recognizer. The user has to invoke // SherpaOnnxDestroyOfflineRecognizer() to free it to avoid memory diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 2a7a8dab9..b3789849c 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -5,17 +5,17 @@ #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include +#include #include #include #if __ANDROID_API__ >= 9 -#include #include "android/asset_manager.h" #include "android/asset_manager_jni.h" -#elif __OHOS__ -#include +#endif +#if __OHOS__ #include "rawfile/raw_file_manager.h" #endif diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 2bceffc7d..519d1a935 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -17,6 +17,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -50,9 +54,9 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( } } -#if __ANDROID_API__ >= 9 +template OnlineConformerTransducerModel::OnlineConformerTransducerModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_ERROR), config_(config), sess_opts_(GetSessionOptions(config)), @@ -72,7 +76,6 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( InitJoiner(buf.data(), buf.size()); } } -#endif void OnlineConformerTransducerModel::InitEncoder(void *model_data, size_t model_data_length) { @@ -91,7 +94,11 @@ void OnlineConformerTransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -121,7 +128,11 @@ void OnlineConformerTransducerModel::InitDecoder(void *model_data, std::ostringstream os; os << "---decoder---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -273,4 +284,14 @@ Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } +#if __ANDROID_API__ >= 9 +template OnlineConformerTransducerModel::OnlineConformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineConformerTransducerModel::OnlineConformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.h b/sherpa-onnx/csrc/online-conformer-transducer-model.h index bcf9e6eda..5c901b87a 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.h @@ -10,11 +10,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -25,10 +20,8 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel { public: explicit OnlineConformerTransducerModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineConformerTransducerModel(AAssetManager *mgr, - const OnlineModelConfig &config); -#endif + template + OnlineConformerTransducerModel(Manager *mgr, const OnlineModelConfig &config); std::vector StackStates( const std::vector> &states) const override; diff --git a/sherpa-onnx/csrc/online-ctc-model.cc b/sherpa-onnx/csrc/online-ctc-model.cc index a3a071a72..649c6b507 100644 --- a/sherpa-onnx/csrc/online-ctc-model.cc +++ b/sherpa-onnx/csrc/online-ctc-model.cc @@ -9,6 +9,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-nemo-ctc-model.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" @@ -31,10 +40,9 @@ std::unique_ptr OnlineCtcModel::Create( } } -#if __ANDROID_API__ >= 9 - +template std::unique_ptr OnlineCtcModel::Create( - AAssetManager *mgr, const OnlineModelConfig &config) { + Manager *mgr, const OnlineModelConfig &config) { if (!config.wenet_ctc.model.empty()) { return std::make_unique(mgr, config); } else if (!config.zipformer2_ctc.model.empty()) { @@ -46,6 +54,15 @@ std::unique_ptr OnlineCtcModel::Create( exit(-1); } } + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineCtcModel::Create( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OnlineCtcModel::Create( + NativeResourceManager *mgr, const OnlineModelConfig &config); #endif } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-model.h b/sherpa-onnx/csrc/online-ctc-model.h index 17721752d..bd01bc543 100644 --- a/sherpa-onnx/csrc/online-ctc-model.h +++ b/sherpa-onnx/csrc/online-ctc-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" @@ -25,10 +20,9 @@ class OnlineCtcModel { static std::unique_ptr Create( const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 + template static std::unique_ptr Create( - AAssetManager *mgr, const OnlineModelConfig &config); -#endif + Manager *mgr, const OnlineModelConfig &config); // Return a list of tensors containing the initial states virtual std::vector GetInitStates() const = 0; diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index b9ef9ca5b..91b499fdf 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -16,6 +16,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -48,9 +52,9 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( } } -#if __ANDROID_API__ >= 9 +template OnlineLstmTransducerModel::OnlineLstmTransducerModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_ERROR), config_(config), sess_opts_(GetSessionOptions(config)), @@ -70,7 +74,6 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( InitJoiner(buf.data(), buf.size()); } } -#endif void OnlineLstmTransducerModel::InitEncoder(void *model_data, size_t model_data_length) { @@ -89,7 +92,11 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -261,4 +268,14 @@ Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } +#if __ANDROID_API__ >= 9 +template OnlineLstmTransducerModel::OnlineLstmTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineLstmTransducerModel::OnlineLstmTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index 24119f240..64ade36d0 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -24,10 +19,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { public: explicit OnlineLstmTransducerModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineLstmTransducerModel(AAssetManager *mgr, - const OnlineModelConfig &config); -#endif + template + OnlineLstmTransducerModel(Manager *mgr, const OnlineModelConfig &config); std::vector StackStates( const std::vector> &states) const override; diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.cc b/sherpa-onnx/csrc/online-nemo-ctc-model.cc index 172ee69f4..716c7ee7e 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model.cc +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.cc @@ -13,6 +13,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -36,8 +40,8 @@ class OnlineNeMoCtcModel::Impl { } } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OnlineModelConfig &config) + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -47,7 +51,6 @@ class OnlineNeMoCtcModel::Impl { Init(buf.data(), buf.size()); } } -#endif std::vector Forward(Ort::Value x, std::vector states) { @@ -202,7 +205,11 @@ class OnlineNeMoCtcModel::Impl { if (config_.debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -286,11 +293,10 @@ class OnlineNeMoCtcModel::Impl { OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 -OnlineNeMoCtcModel::OnlineNeMoCtcModel(AAssetManager *mgr, +template +OnlineNeMoCtcModel::OnlineNeMoCtcModel(Manager *mgr, const OnlineModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default; @@ -323,4 +329,14 @@ std::vector> OnlineNeMoCtcModel::UnStackStates( return impl_->UnStackStates(std::move(states)); } +#if __ANDROID_API__ >= 9 +template OnlineNeMoCtcModel::OnlineNeMoCtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineNeMoCtcModel::OnlineNeMoCtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.h b/sherpa-onnx/csrc/online-nemo-ctc-model.h index c8dd182e8..4e5f820b6 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model.h +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-ctc-model.h" #include "sherpa-onnx/csrc/online-model-config.h" @@ -23,9 +18,8 @@ class OnlineNeMoCtcModel : public OnlineCtcModel { public: explicit OnlineNeMoCtcModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); -#endif + template + OnlineNeMoCtcModel(Manager *mgr, const OnlineModelConfig &config); ~OnlineNeMoCtcModel() override; diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc index d7d2e436d..b21bb9bcc 100644 --- a/sherpa-onnx/csrc/online-paraformer-model.cc +++ b/sherpa-onnx/csrc/online-paraformer-model.cc @@ -13,6 +13,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" @@ -38,8 +42,8 @@ class OnlineParaformerModel::Impl { } } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OnlineModelConfig &config) + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -54,7 +58,6 @@ class OnlineParaformerModel::Impl { InitDecoder(buf.data(), buf.size()); } } -#endif std::vector ForwardEncoder(Ort::Value features, Ort::Value features_length) { @@ -123,7 +126,11 @@ class OnlineParaformerModel::Impl { if (config_.debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -191,11 +198,10 @@ class OnlineParaformerModel::Impl { OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 -OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr, +template +OnlineParaformerModel::OnlineParaformerModel(Manager *mgr, const OnlineModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OnlineParaformerModel::~OnlineParaformerModel() = default; @@ -246,4 +252,14 @@ OrtAllocator *OnlineParaformerModel::Allocator() const { return impl_->Allocator(); } +#if __ANDROID_API__ >= 9 +template OnlineParaformerModel::OnlineParaformerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineParaformerModel::OnlineParaformerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-paraformer-model.h b/sherpa-onnx/csrc/online-paraformer-model.h index 3c018a72d..cbf2d7157 100644 --- a/sherpa-onnx/csrc/online-paraformer-model.h +++ b/sherpa-onnx/csrc/online-paraformer-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" @@ -22,9 +17,8 @@ class OnlineParaformerModel { public: explicit OnlineParaformerModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config); -#endif + template + OnlineParaformerModel(Manager *mgr, const OnlineModelConfig &config); ~OnlineParaformerModel(); diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 76452138d..3560b1ab7 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -88,8 +88,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { InitDecoder(); } -#if __ANDROID_API__ >= 9 - explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, + template + explicit OnlineRecognizerCtcImpl(Manager *mgr, const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(mgr, config), config_(config), @@ -104,7 +104,6 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { InitDecoder(); } -#endif std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index cf59cb539..27168b0f6 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -4,15 +4,18 @@ #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include #include #if __ANDROID_API__ >= 9 -#include - #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "fst/extensions/far/far.h" #include "kaldifst/csrc/kaldi-fst-io.h" #include "sherpa-onnx/csrc/macros.h" @@ -61,9 +64,9 @@ std::unique_ptr OnlineRecognizerImpl::Create( exit(-1); } -#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineRecognizerImpl::Create( - AAssetManager *mgr, const OnlineRecognizerConfig &config) { + Manager *mgr, const OnlineRecognizerConfig &config) { if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); @@ -97,7 +100,6 @@ std::unique_ptr OnlineRecognizerImpl::Create( SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } -#endif OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) : config_(config) { @@ -143,8 +145,8 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) } } -#if __ANDROID_API__ >= 9 -OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr, +template +OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, const OnlineRecognizerConfig &config) : config_(config) { if (!config.rule_fsts.empty()) { @@ -189,7 +191,6 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr, } // for (const auto &f : files) } // if (!config.rule_fars.empty()) } -#endif std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( std::string text) const { @@ -202,4 +203,20 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( return text; } +#if __ANDROID_API__ >= 9 +template OnlineRecognizerImpl::OnlineRecognizerImpl( + AAssetManager *mgr, const OnlineRecognizerConfig &config); + +template std::unique_ptr OnlineRecognizerImpl::Create( + AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OnlineRecognizerImpl::OnlineRecognizerImpl( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); + +template std::unique_ptr OnlineRecognizerImpl::Create( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index 8b569f3af..b7bda7862 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer.h" @@ -28,13 +23,12 @@ class OnlineRecognizerImpl { static std::unique_ptr Create( const OnlineRecognizerConfig &config); -#if __ANDROID_API__ >= 9 - OnlineRecognizerImpl(AAssetManager *mgr, - const OnlineRecognizerConfig &config); + template + OnlineRecognizerImpl(Manager *mgr, const OnlineRecognizerConfig &config); + template static std::unique_ptr Create( - AAssetManager *mgr, const OnlineRecognizerConfig &config); -#endif + Manager *mgr, const OnlineRecognizerConfig &config); virtual ~OnlineRecognizerImpl() = default; diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index 8ef66ea74..1e02fe519 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -120,8 +120,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { config_.feat_config.normalize_samples = false; } -#if __ANDROID_API__ >= 9 - explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, + template + explicit OnlineRecognizerParaformerImpl(Manager *mgr, const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(mgr, config), config_(config), @@ -138,7 +138,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { // [-32768, 32767], so we set normalize_samples to false config_.feat_config.normalize_samples = false; } -#endif + OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) = delete; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 475a90185..2eac3cf84 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -14,11 +14,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-lm.h" @@ -130,8 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } } -#if __ANDROID_API__ >= 9 - explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, + template + explicit OnlineRecognizerTransducerImpl(Manager *mgr, const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(mgr, config), config_(config), @@ -178,7 +173,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { exit(-1); } } -#endif std::unique_ptr CreateStream() const override { auto stream = @@ -429,8 +423,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { hotwords_, config_.hotwords_score, boost_scores_); } -#if __ANDROID_API__ >= 9 - void InitHotwords(AAssetManager *mgr) { + template + void InitHotwords(Manager *mgr) { // each line in hotwords_file contains space-separated words auto buf = ReadFile(mgr, config_.hotwords_file); @@ -452,7 +446,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { hotwords_graph_ = std::make_shared( hotwords_, config_.hotwords_score, boost_scores_); } -#endif void InitHotwordsFromBufStr() { // each line in hotwords_file contains space-separated words diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 0a09fdb01..a3f2756c8 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -16,11 +16,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h" #include "sherpa-onnx/csrc/online-recognizer.h" @@ -65,9 +60,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { PostInit(); } -#if __ANDROID_API__ >= 9 + template explicit OnlineRecognizerTransducerNeMoImpl( - AAssetManager *mgr, const OnlineRecognizerConfig &config) + Manager *mgr, const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(mgr, config), config_(config), symbol_table_(mgr, config.model_config.tokens), @@ -85,7 +80,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { PostInit(); } -#endif std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index ba7cba8ec..4ccc939da 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -13,6 +13,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -197,11 +206,10 @@ std::string OnlineRecognizerConfig::ToString() const { OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) : impl_(OnlineRecognizerImpl::Create(config)) {} -#if __ANDROID_API__ >= 9 -OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr, +template +OnlineRecognizer::OnlineRecognizer(Manager *mgr, const OnlineRecognizerConfig &config) : impl_(OnlineRecognizerImpl::Create(mgr, config)) {} -#endif OnlineRecognizer::~OnlineRecognizer() = default; @@ -238,4 +246,14 @@ bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const { void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); } +#if __ANDROID_API__ >= 9 +template OnlineRecognizer::OnlineRecognizer( + AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OnlineRecognizer::OnlineRecognizer( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 45e0f4237..8854fbd22 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" @@ -149,9 +144,8 @@ class OnlineRecognizer { public: explicit OnlineRecognizer(const OnlineRecognizerConfig &config); -#if __ANDROID_API__ >= 9 - OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config); -#endif + template + OnlineRecognizer(Manager *mgr, const OnlineRecognizerConfig &config); ~OnlineRecognizer(); diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 51a9aef3c..a38b5893f 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -9,6 +9,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include #include #include @@ -49,7 +53,11 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; @@ -155,9 +163,9 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( return decoder_input; } -#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineTransducerModel::Create( - AAssetManager *mgr, const OnlineModelConfig &config) { + Manager *mgr, const OnlineModelConfig &config) { if (!config.model_type.empty()) { const auto &model_type = config.model_type; if (model_type == "conformer") { @@ -195,6 +203,15 @@ std::unique_ptr OnlineTransducerModel::Create( // unreachable code return nullptr; } + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineTransducerModel::Create( + Manager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OnlineTransducerModel::Create( + NativeResourceManager *mgr, const OnlineModelConfig &config); #endif } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index 3dacb4a50..f6404eccd 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/online-model-config.h" @@ -30,10 +25,9 @@ class OnlineTransducerModel { static std::unique_ptr Create( const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 + template static std::unique_ptr Create( - AAssetManager *mgr, const OnlineModelConfig &config); -#endif + Manager *mgr, const OnlineModelConfig &config); /** Stack a list of individual states into a batch. * diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 264593a1c..73c23fe31 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -20,6 +20,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" @@ -54,8 +58,8 @@ class OnlineTransducerNeMoModel::Impl { } } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OnlineModelConfig &config) + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -75,7 +79,6 @@ class OnlineTransducerNeMoModel::Impl { InitJoiner(buf.data(), buf.size()); } } -#endif std::vector RunEncoder(Ort::Value features, std::vector states) { @@ -302,7 +305,11 @@ class OnlineTransducerNeMoModel::Impl { std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -460,11 +467,10 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( const OnlineModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 +template OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; @@ -528,4 +534,14 @@ std::vector> OnlineTransducerNeMoModel::UnStackStates( return impl_->UnStackStates(std::move(states)); } +#if __ANDROID_API__ >= 9 +template OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index e12814cc0..98390bb15 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -11,11 +11,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" @@ -28,10 +23,8 @@ class OnlineTransducerNeMoModel { public: explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineTransducerNeMoModel(AAssetManager *mgr, - const OnlineModelConfig &config); -#endif + template + OnlineTransducerNeMoModel(Manager *mgr, const OnlineModelConfig &config); ~OnlineTransducerNeMoModel(); // A list of 3 tensors: diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index cce322aa4..bf468484c 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -13,6 +13,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" @@ -33,8 +37,8 @@ class OnlineWenetCtcModel::Impl { } } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OnlineModelConfig &config) + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -44,7 +48,6 @@ class OnlineWenetCtcModel::Impl { Init(buf.data(), buf.size()); } } -#endif std::vector Forward(Ort::Value x, std::vector states) { @@ -139,7 +142,11 @@ class OnlineWenetCtcModel::Impl { if (config_.debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -212,11 +219,10 @@ class OnlineWenetCtcModel::Impl { OnlineWenetCtcModel::OnlineWenetCtcModel(const OnlineModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 -OnlineWenetCtcModel::OnlineWenetCtcModel(AAssetManager *mgr, +template +OnlineWenetCtcModel::OnlineWenetCtcModel(Manager *mgr, const OnlineModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OnlineWenetCtcModel::~OnlineWenetCtcModel() = default; @@ -258,4 +264,14 @@ std::vector> OnlineWenetCtcModel::UnStackStates( return ans; } +#if __ANDROID_API__ >= 9 +template OnlineWenetCtcModel::OnlineWenetCtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineWenetCtcModel::OnlineWenetCtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.h b/sherpa-onnx/csrc/online-wenet-ctc-model.h index 1be1034cc..28458b68b 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.h +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-ctc-model.h" #include "sherpa-onnx/csrc/online-model-config.h" @@ -23,9 +18,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel { public: explicit OnlineWenetCtcModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineWenetCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); -#endif + template + OnlineWenetCtcModel(Manager *mgr, const OnlineModelConfig &config); ~OnlineWenetCtcModel() override; diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 36e2d9dbd..437600262 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -17,6 +17,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -50,9 +54,9 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( } } -#if __ANDROID_API__ >= 9 +template OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_ERROR), config_(config), sess_opts_(GetSessionOptions(config)), @@ -72,7 +76,6 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( InitJoiner(buf.data(), buf.size()); } } -#endif void OnlineZipformerTransducerModel::InitEncoder(void *model_data, size_t model_data_length) { @@ -91,7 +94,11 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -480,4 +487,14 @@ Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } +#if __ANDROID_API__ >= 9 +template OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h index b2b7da040..9e4368a69 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -24,10 +19,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { public: explicit OnlineZipformerTransducerModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineZipformerTransducerModel(AAssetManager *mgr, - const OnlineModelConfig &config); -#endif + template + OnlineZipformerTransducerModel(Manager *mgr, const OnlineModelConfig &config); std::vector StackStates( const std::vector> &states) const override; diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc index 8f0708ad1..298b90522 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -15,6 +15,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -37,8 +41,8 @@ class OnlineZipformer2CtcModel::Impl { } } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OnlineModelConfig &config) + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -48,7 +52,6 @@ class OnlineZipformer2CtcModel::Impl { Init(buf.data(), buf.size()); } } -#endif std::vector Forward(Ort::Value features, std::vector states) { @@ -255,7 +258,11 @@ class OnlineZipformer2CtcModel::Impl { std::ostringstream os; os << "---zipformer2_ctc---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -415,11 +422,10 @@ OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( const OnlineModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 +template OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default; @@ -458,4 +464,14 @@ std::vector> OnlineZipformer2CtcModel::UnStackStates( return impl_->UnStackStates(std::move(states)); } +#if __ANDROID_API__ >= 9 +template OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.h b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h index 11b59e2bb..32ddf2122 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-ctc-model.h" #include "sherpa-onnx/csrc/online-model-config.h" @@ -23,9 +18,8 @@ class OnlineZipformer2CtcModel : public OnlineCtcModel { public: explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineZipformer2CtcModel(AAssetManager *mgr, const OnlineModelConfig &config); -#endif + template + OnlineZipformer2CtcModel(Manager *mgr, const OnlineModelConfig &config); ~OnlineZipformer2CtcModel() override; diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index 03c68474c..85a32ec3c 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -19,6 +19,10 @@ #include "android/asset_manager_jni.h" #endif +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -54,9 +58,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( } } -#if __ANDROID_API__ >= 9 +template OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( - AAssetManager *mgr, const OnlineModelConfig &config) + Manager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_ERROR), config_(config), encoder_sess_opts_(GetSessionOptions(config)), @@ -78,7 +82,6 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( InitJoiner(buf.data(), buf.size()); } } -#endif void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, size_t model_data_length) { @@ -97,7 +100,11 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -474,4 +481,14 @@ Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } +#if __ANDROID_API__ >= 9 +template OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index aa0f46f81..93b124f61 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -24,10 +19,9 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { public: explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config); -#if __ANDROID_API__ >= 9 - OnlineZipformer2TransducerModel(AAssetManager *mgr, + template + OnlineZipformer2TransducerModel(Manager *mgr, const OnlineModelConfig &config); -#endif std::vector StackStates( const std::vector> &states) const override; diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 723fec68b..3455d2007 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -8,16 +8,16 @@ #include #include #include +#include #include #if __ANDROID_API__ >= 9 -#include #include "android/asset_manager.h" #include "android/asset_manager_jni.h" -#elif __OHOS__ -#include +#endif +#if __OHOS__ #include "rawfile/raw_file_manager.h" #endif diff --git a/sherpa-onnx/python/csrc/vad-model.cc b/sherpa-onnx/python/csrc/vad-model.cc index f304fd0ae..743628d38 100644 --- a/sherpa-onnx/python/csrc/vad-model.cc +++ b/sherpa-onnx/python/csrc/vad-model.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/python/csrc/vad-model.h" +#include #include #include "sherpa-onnx/csrc/vad-model.h" @@ -13,8 +14,10 @@ namespace sherpa_onnx { void PybindVadModel(py::module *m) { using PyClass = VadModel; py::class_(*m, "VadModel") - .def_static("create", &PyClass::Create, py::arg("config"), - py::call_guard()) + .def_static("create", + (std::unique_ptr(*)(const VadModelConfig &))( + &PyClass::Create), + py::arg("config"), py::call_guard()) .def("reset", &PyClass::Reset, py::call_guard()) .def( "is_speech",