Skip to content

Commit

Permalink
Add streaming ASR support for HarmonyOS.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Nov 26, 2024
1 parent 298b6b6 commit 2b1b057
Show file tree
Hide file tree
Showing 35 changed files with 367 additions and 206 deletions.
28 changes: 27 additions & 1 deletion sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct SherpaOnnxDisplay {

#define SHERPA_ONNX_OR(x, y) (x ? x : y)

const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
static sherpa_onnx::OnlineRecognizerConfig GetOnlineRecognizerConfig(
const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config;

Expand Down Expand Up @@ -151,9 +151,21 @@ const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, "");

if (config->model_config.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", recognizer_config.ToString().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", recognizer_config.ToString().c_str());
#endif
}

return recognizer_config;
}

const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config =
GetOnlineRecognizerConfig(config);

if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
Expand Down Expand Up @@ -1876,6 +1888,20 @@ SherpaOnnxOfflineSpeakerDiarizationProcessWithCallbackNoArg(

#ifdef __OHOS__

const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizerOHOS(
const SherpaOnnxOnlineRecognizerConfig *config,
NativeResourceManager *mgr) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config =
GetOnlineRecognizerConfig(config);

SherpaOnnxOnlineRecognizer *recognizer = new SherpaOnnxOnlineRecognizer;

recognizer->impl =
std::make_unique<sherpa_onnx::OnlineRecognizer>(mgr, recognizer_config);

return recognizer;
}

const SherpaOnnxOfflineRecognizer *SherpaOnnxCreateOfflineRecognizerOHOS(
const SherpaOnnxOfflineRecognizerConfig *config,
NativeResourceManager *mgr) {
Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,13 @@ SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationDestroyResult(
// It is for HarmonyOS
typedef struct NativeResourceManager NativeResourceManager;

/// @param config Config for the recognizer.
/// @return Return a pointer to the recognizer. The user has to invoke
// SherpaOnnxDestroyOnlineRecognizer() to free it to avoid memory leak.
SHERPA_ONNX_API const SherpaOnnxOnlineRecognizer *
SherpaOnnxCreateOnlineRecognizerOHOS(
const SherpaOnnxOnlineRecognizerConfig *config, NativeResourceManager *mgr);

/// @param config Config for the recognizer.
/// @return Return a pointer to the recognizer. The user has to invoke
// SherpaOnnxDestroyOfflineRecognizer() to free it to avoid memory
Expand Down
6 changes: 3 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"

#include <string>
#include <strstream>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#elif __OHOS__
#include <strstream>
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

Expand Down
27 changes: 24 additions & 3 deletions sherpa-onnx/csrc/online-conformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include "android/asset_manager_jni.h"
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
Expand Down Expand Up @@ -50,9 +54,9 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
}
}

#if __ANDROID_API__ >= 9
template <typename Manager>
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
AAssetManager *mgr, const OnlineModelConfig &config)
Manager *mgr, const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_ERROR),
config_(config),
sess_opts_(GetSessionOptions(config)),
Expand All @@ -72,7 +76,6 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
InitJoiner(buf.data(), buf.size());
}
}
#endif

void OnlineConformerTransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
Expand All @@ -91,7 +94,11 @@ void OnlineConformerTransducerModel::InitEncoder(void *model_data,
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s", os.str().c_str());
#endif
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
Expand Down Expand Up @@ -121,7 +128,11 @@ void OnlineConformerTransducerModel::InitDecoder(void *model_data,
std::ostringstream os;
os << "---decoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s", os.str().c_str());
#endif
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
Expand Down Expand Up @@ -273,4 +284,14 @@ Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
return std::move(logit[0]);
}

#if __ANDROID_API__ >= 9
template OnlineConformerTransducerModel::OnlineConformerTransducerModel(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif

#if __OHOS__
template OnlineConformerTransducerModel::OnlineConformerTransducerModel(
NativeResourceManager *mgr, const OnlineModelConfig &config);
#endif

} // namespace sherpa_onnx
11 changes: 2 additions & 9 deletions sherpa-onnx/csrc/online-conformer-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
Expand All @@ -25,10 +20,8 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineConformerTransducerModel(const OnlineModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineConformerTransducerModel(AAssetManager *mgr,
const OnlineModelConfig &config);
#endif
template <typename Manager>
OnlineConformerTransducerModel(Manager *mgr, const OnlineModelConfig &config);

std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
Expand Down
23 changes: 20 additions & 3 deletions sherpa-onnx/csrc/online-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
#include <sstream>
#include <string>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
Expand All @@ -31,10 +40,9 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
}
}

#if __ANDROID_API__ >= 9

template <typename Manager>
std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
AAssetManager *mgr, const OnlineModelConfig &config) {
Manager *mgr, const OnlineModelConfig &config) {
if (!config.wenet_ctc.model.empty()) {
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
Expand All @@ -46,6 +54,15 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
exit(-1);
}
}

#if __ANDROID_API__ >= 9
template std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif

#if __OHOS__
template std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
NativeResourceManager *mgr, const OnlineModelConfig &config);
#endif

} // namespace sherpa_onnx
10 changes: 2 additions & 8 deletions sherpa-onnx/csrc/online-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"

Expand All @@ -25,10 +20,9 @@ class OnlineCtcModel {
static std::unique_ptr<OnlineCtcModel> Create(
const OnlineModelConfig &config);

#if __ANDROID_API__ >= 9
template <typename Manager>
static std::unique_ptr<OnlineCtcModel> Create(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif
Manager *mgr, const OnlineModelConfig &config);

// Return a list of tensors containing the initial states
virtual std::vector<Ort::Value> GetInitStates() const = 0;
Expand Down
23 changes: 20 additions & 3 deletions sherpa-onnx/csrc/online-lstm-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#include "android/asset_manager_jni.h"
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
Expand Down Expand Up @@ -48,9 +52,9 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
}
}

#if __ANDROID_API__ >= 9
template <typename Manager>
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
AAssetManager *mgr, const OnlineModelConfig &config)
Manager *mgr, const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_ERROR),
config_(config),
sess_opts_(GetSessionOptions(config)),
Expand All @@ -70,7 +74,6 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
InitJoiner(buf.data(), buf.size());
}
}
#endif

void OnlineLstmTransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
Expand All @@ -89,7 +92,11 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data,
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s", os.str().c_str());
#endif
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
Expand Down Expand Up @@ -261,4 +268,14 @@ Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
return std::move(logit[0]);
}

#if __ANDROID_API__ >= 9
template OnlineLstmTransducerModel::OnlineLstmTransducerModel(
AAssetManager *mgr, const OnlineModelConfig &config);
#endif

#if __OHOS__
template OnlineLstmTransducerModel::OnlineLstmTransducerModel(
NativeResourceManager *mgr, const OnlineModelConfig &config);
#endif

} // namespace sherpa_onnx
11 changes: 2 additions & 9 deletions sherpa-onnx/csrc/online-lstm-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
Expand All @@ -24,10 +19,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineLstmTransducerModel(const OnlineModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel(AAssetManager *mgr,
const OnlineModelConfig &config);
#endif
template <typename Manager>
OnlineLstmTransducerModel(Manager *mgr, const OnlineModelConfig &config);

std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
Expand Down
Loading

0 comments on commit 2b1b057

Please sign in to comment.