Skip to content

Commit

Permalink
Support playing as it is generating for Android (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 9, 2023
1 parent cae0231 commit 0f053d8
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager
import android.media.MediaPlayer
import android.media.*
import android.net.Uri
import android.os.Bundle
import android.util.Log
Expand All @@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() {
private lateinit var generate: Button
private lateinit var play: Button

// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
private lateinit var track: AudioTrack

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
Expand All @@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() {
initTts()
Log.i(TAG, "Finish initializing TTS")

Log.i(TAG, "Start to initialize AudioTrack")
initAudioTrack()
Log.i(TAG, "Finish initializing AudioTrack")

text = findViewById(R.id.text)
sid = findViewById(R.id.sid)
speed = findViewById(R.id.speed)
Expand All @@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() {
play.isEnabled = false
}

private fun initAudioTrack() {
val sampleRate = tts.sampleRate()
val bufLength = (sampleRate * 0.1).toInt()
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")

val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()

val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()

track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track.play()
}

// this function is called from C++
private fun callback(samples: FloatArray) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
}

private fun onClickGenerate() {
val sidInt = sid.text.toString().toIntOrNull()
if (sidInt == null || sidInt < 0) {
Expand Down Expand Up @@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() {
return
}

play.isEnabled = false
val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat)
track.pause()
track.flush()
track.play()

val filename = application.filesDir.absolutePath + "/generated.wav"
val ok = audio.samples.size > 0 && audio.save(filename)
if (ok) {
play.isEnabled = true
// Play automatically after generation
onClickPlay()
}
play.isEnabled = false
Thread {
val audio = tts.generateWithCallback(
text = textStr,
sid = sidInt,
speed = speedFloat,
callback = this::callback
)

val filename = application.filesDir.absolutePath + "/generated.wav"
val ok = audio.samples.size > 0 && audio.save(filename)
if (ok) {
runOnUiThread {
play.isEnabled = true
track.stop()
}
}
}.start()
}

private fun onClickPlay() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class OfflineTts(
}
}

fun sampleRate() = getSampleRate(ptr)

fun generate(
text: String,
sid: Int = 0,
Expand All @@ -66,6 +68,19 @@ class OfflineTts(
)
}

fun generateWithCallback(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
): GeneratedAudio {
var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
)
}

fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
if (assetManager != null) {
Expand Down Expand Up @@ -97,6 +112,7 @@ class OfflineTts(
): Long

private external fun delete(ptr: Long)
private external fun getSampleRate(ptr: Long): Int

// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
Expand All @@ -109,6 +125,14 @@ class OfflineTts(
speed: Float = 1.0f
): Array<Any>

external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
): Array<Any>

companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
Expand Down
6 changes: 5 additions & 1 deletion kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

fun callback(samples: FloatArray): Unit {
println("callback got called with ${samples.size} samples");
}

fun main() {
testTts()
testAsr()
Expand All @@ -22,7 +26,7 @@ fun testTts() {
)
)
val tts = OfflineTts(config=config)
val audio = tts.generate(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”")
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
audio.save(filename="test-en.wav")
}

Expand Down
102 changes: 51 additions & 51 deletions scripts/apk/generate-tts-apk-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]:
lang="zh",
rule_fsts="vits-zh-aishell3/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-doom",
model_name="doom.onnx",
lang="zh",
rule_fsts="vits-zh-hf-doom/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-echo",
model_name="echo.onnx",
lang="zh",
rule_fsts="vits-zh-hf-echo/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-zenyatta",
model_name="zenyatta.onnx",
lang="zh",
rule_fsts="vits-zh-hf-zenyatta/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-abyssinvoker",
model_name="abyssinvoker.onnx",
lang="zh",
rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-keqing",
model_name="keqing.onnx",
lang="zh",
rule_fsts="vits-zh-hf-keqing/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-eula",
model_name="eula.onnx",
lang="zh",
rule_fsts="vits-zh-hf-eula/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-bronya",
model_name="bronya.onnx",
lang="zh",
rule_fsts="vits-zh-hf-bronya/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-theresa",
model_name="theresa.onnx",
lang="zh",
rule_fsts="vits-zh-hf-theresa/rule.fst",
),
# TtsModel(
# model_dir="vits-zh-hf-doom",
# model_name="doom.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-doom/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-echo",
# model_name="echo.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-echo/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-zenyatta",
# model_name="zenyatta.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-zenyatta/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-abyssinvoker",
# model_name="abyssinvoker.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-keqing",
# model_name="keqing.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-keqing/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-eula",
# model_name="eula.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-eula/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-bronya",
# model_name="bronya.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-bronya/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-theresa",
# model_name="theresa.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-theresa/rule.fst",
# ),
# English (US)
TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
# TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
# fmt: on
]

Expand All @@ -238,8 +238,8 @@ def main():
template = environment.from_string(s)
d = dict()

# all_model_list = get_vits_models()
all_model_list = get_piper_models()
all_model_list = get_vits_models()
all_model_list += get_piper_models()
all_model_list += get_coqui_models()

num_models = len(all_model_list)
Expand Down
57 changes: 53 additions & 4 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT

#include <fstream>
#include <functional>
#include <strstream>
#include <utility>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <fstream>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
Expand Down Expand Up @@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts {
explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
: tts_(config) {}

GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const {
return tts_.Generate(text, sid, speed);
GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0,
std::function<void(const float *, int32_t)> callback = nullptr) const {
return tts_.Generate(text, sid, speed, callback);
}

int32_t SampleRate() const { return tts_.SampleRate(); }

private:
OfflineTts tts_;
};
Expand Down Expand Up @@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
->SampleRate();
}

// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
Expand Down Expand Up @@ -663,6 +675,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
jfloat speed, jobject callback) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);

std::function<void(const float *, int32_t)> callback_wrapper =
[env, callback](const float *samples, int32_t n) {
jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");

jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
env->CallVoidMethod(callback, mid, samples_arr);
};

auto audio =
reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);

jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());

jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);

env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));

env->ReleaseStringUTFChars(text, p_text);

return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
Expand Down

0 comments on commit 0f053d8

Please sign in to comment.