Skip to content

Commit

Permalink
Android JNI support for speaker diarization (#1421)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 12, 2024
1 parent 5e273c5 commit 94b26ff
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 2 deletions.
14 changes: 14 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create(
return nullptr;
}

#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config);
}

SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}
#endif

} // namespace sherpa_onnx
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <functional>
#include <memory>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace sherpa_onnx {

Expand All @@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl {
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif

virtual ~OfflineSpeakerDiarizationImpl() = default;

virtual int32_t SampleRate() const = 0;
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h"
Expand Down Expand Up @@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Init();
}

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarizationPyannoteImpl(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: config_(config),
segmentation_model_(mgr, config_.segmentation),
embedding_extractor_(mgr, config_.embedding),
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}
#endif

int32_t SampleRate() const override {
const auto &meta_data = segmentation_model_.GetModelMetaData();

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization::OfflineSpeakerDiarization(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {}
#endif

OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;

int32_t OfflineSpeakerDiarization::SampleRate() const {
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <memory>
#include <string>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
Expand Down Expand Up @@ -57,6 +62,11 @@ class OfflineSpeakerDiarization {
explicit OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config);

#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization(AAssetManager *mgr,
const OfflineSpeakerDiarizationConfig &config);
#endif

~OfflineSpeakerDiarization();

// Expected sample rate of the input audio samples
Expand Down
18 changes: 18 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
Init(buf.data(), buf.size());
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.pyannote.model);
Init(buf.data(), buf.size());
}
#endif

const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const {
return meta_data_;
Expand Down Expand Up @@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel::
const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif

OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default;

Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include <memory>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
Expand All @@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel {
explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config);

#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config);
#endif

~OfflineSpeakerSegmentationPyannoteModel();

const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ to download models for offline ASR.
}

while (!vad->Empty()) {
auto &segment = vad->Front();
const auto &segment = vad->Front();
auto s = recognizer.CreateStream();
s->AcceptWaveform(sample_rate, segment.samples.data(),
segment.samples.size());
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/audio-tagging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto kws = new sherpa_onnx::KeywordSpotter(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/offline-punctuation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::OfflinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::OfflineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
Expand Down
19 changes: 18 additions & 1 deletion sherpa-onnx/jni/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
return 0;
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto sd = new sherpa_onnx::OfflineSpeakerDiarization(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);

return (jlong)sd;
}

SHERPA_ONNX_EXTERN_C
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/speaker-embedding-extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/spoken-language-identification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/jni/voice-activity-detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto model = new sherpa_onnx::VoiceActivityDetector(
#if __ANDROID_API__ >= 9
mgr,
Expand Down

0 comments on commit 94b26ff

Please sign in to comment.