From 86baf43c6bc65e6eaa4263eb71d90d2364ec7deb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 6 Nov 2023 10:38:40 +0800 Subject: [PATCH] support reading rule FST for Android TTS (#410) --- .github/workflows/apk-tts.yaml | 21 +++++-- .github/workflows/apk.yaml | 10 ++++ .../com/k2fsa/sherpa/onnx/MainActivity.kt | 9 ++- .../main/java/com/k2fsa/sherpa/onnx/Tts.kt | 8 ++- cmake/kaldifst.cmake | 16 ++--- scripts/apk/build-apk-tts.sh.in | 8 ++- scripts/apk/generate-tts-apk-script.py | 59 ++++++++++++++++++- sherpa-onnx/csrc/lexicon.cc | 14 ++++- sherpa-onnx/csrc/offline-tts-vits-impl.h | 16 ++++- sherpa-onnx/jni/jni.cc | 7 +++ 10 files changed, 143 insertions(+), 25 deletions(-) diff --git a/.github/workflows/apk-tts.yaml b/.github/workflows/apk-tts.yaml index 68ca8d76d..ffd7ab1d5 100644 --- a/.github/workflows/apk-tts.yaml +++ b/.github/workflows/apk-tts.yaml @@ -34,6 +34,11 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-android + - name: Display NDK HOME shell: bash run: | @@ -61,6 +66,10 @@ jobs: - name: build APK shell: bash run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + cmake --version + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME ./build-apk-tts.sh @@ -70,12 +79,14 @@ jobs: ls -lh ./apks/ du -h -d1 . - # - uses: actions/upload-artifact@v3 - # with: - # name: tts-apk - # path: ./apks/*.apk + - uses: actions/upload-artifact@v3 + if: false + with: + name: tts-apk + path: ./apks/*.apk - name: Publish to huggingface + if: true env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v2 @@ -92,7 +103,9 @@ jobs: git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface cd huggingface + git fetch git pull + git merge -m "merge remote" --ff origin main mkdir -p tts cp -v ../apks/*.apk ./tts/ diff --git a/.github/workflows/apk.yaml b/.github/workflows/apk.yaml index a9d01a107..e8eae56e1 100644 --- a/.github/workflows/apk.yaml +++ b/.github/workflows/apk.yaml @@ -28,6 +28,12 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-android + - name: Display NDK HOME shell: bash run: | @@ -37,6 +43,10 @@ jobs: - name: build APK shell: bash run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + cmake --version + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME ./build-apk-vad.sh ./build-apk-two-pass.sh diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index e88500214..72bba99e1 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -101,12 +101,14 @@ class MainActivity : AppCompatActivity() { fun initTts() { var modelDir :String? var modelName :String? + var ruleFsts: String? // The purpose of such a design is to make the CI test easier // Please see // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py modelDir = null modelName = null + ruleFsts = null // Example 1: // modelDir = "vits-vctk" @@ -116,7 +118,12 @@ class MainActivity : AppCompatActivity() { // modelDir = "vits-piper-en_US-lessac-medium" // modelName = "en_US-lessac-medium.onnx" - val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!)!! + // Example 3: + // modelDir = "vits-zh-aishell3" + // modelName = "vits-aishell3.onnx" + // ruleFsts = "vits-zh-aishell3/rule.fst" + + val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!, ruleFsts = ruleFsts ?: "")!! tts = OfflineTts(assetManager = application.assets, config = config) } } diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt index eed9a5934..cf6b1e254 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt @@ -21,6 +21,7 @@ data class OfflineTtsModelConfig( data class OfflineTtsConfig( var model: OfflineTtsModelConfig, + var ruleFsts: String = "", ) class GeneratedAudio( @@ -116,7 +117,7 @@ class OfflineTts( // please refer to // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html // to download models -fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? { +fun getOfflineTtsConfig(modelDir: String, modelName: String, ruleFsts: String): OfflineTtsConfig? { return OfflineTtsConfig( model = OfflineTtsModelConfig( vits = OfflineTtsVitsModelConfig( @@ -125,8 +126,9 @@ fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? tokens = "$modelDir/tokens.txt" ), numThreads = 2, - debug = false, + debug = true, provider = "cpu", - ) + ), + ruleFsts=ruleFsts, ) } diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake index 3038f3bd8..19f3aa4b4 100644 --- a/cmake/kaldifst.cmake +++ b/cmake/kaldifst.cmake @@ -1,18 +1,18 @@ function(download_kaldifst) include(FetchContent) - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.8.tar.gz") - set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.8.tar.gz") - set(kaldifst_HASH "SHA256=94613923568ef9a240ba1059b8b9dfe3082daad794934635d99e66248a6687b5") + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.9.tar.gz") + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.9.tar.gz") + set(kaldifst_HASH "SHA256=8c653021491dca54c38ab659565edfab391418a79ae87099257863cd5664dd39") # If you don't have access to the Internet, # please pre-download kaldifst set(possible_file_locations - $ENV{HOME}/Downloads/kaldifst-1.7.8.tar.gz - ${PROJECT_SOURCE_DIR}/kaldifst-1.7.8.tar.gz - ${PROJECT_BINARY_DIR}/kaldifst-1.7.8.tar.gz - /tmp/kaldifst-1.7.8.tar.gz - /star-fj/fangjun/download/github/kaldifst-1.7.8.tar.gz + $ENV{HOME}/Downloads/kaldifst-1.7.9.tar.gz + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.9.tar.gz + ${PROJECT_BINARY_DIR}/kaldifst-1.7.9.tar.gz + /tmp/kaldifst-1.7.9.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.9.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/scripts/apk/build-apk-tts.sh.in b/scripts/apk/build-apk-tts.sh.in index 05bfa2fd7..e6c873575 100644 --- a/scripts/apk/build-apk-tts.sh.in +++ b/scripts/apk/build-apk-tts.sh.in @@ -8,7 +8,7 @@ # Inside the $ANDROID_NDK directory, you can find a binary ndk-build # and some other files like the file "build/cmake/android.toolchain.cmake" -set -e +set -ex log() { # This function is from espnet @@ -43,6 +43,7 @@ wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/$model_name wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true +wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/rule.fst 2>/dev/null || true popd # Now we are at the project root directory @@ -51,6 +52,11 @@ git checkout . pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt +{% if tts_model.rule_fsts %} + rule_fsts={{ tts_model.rule_fsts }} + sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./MainActivity.kt +{% endif %} + git diff popd diff --git a/scripts/apk/generate-tts-apk-script.py b/scripts/apk/generate-tts-apk-script.py index 073f16187..43df6535a 100755 --- a/scripts/apk/generate-tts-apk-script.py +++ b/scripts/apk/generate-tts-apk-script.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 +import argparse from dataclasses import dataclass +from typing import List, Optional import jinja2 -from typing import List -import argparse def get_args(): @@ -29,12 +29,65 @@ class TtsModel: model_dir: str model_name: str lang: str # en, zh, fr, de, etc. + rule_fsts: Optional[List[str]] = (None,) def get_all_models() -> List[TtsModel]: return [ + # Chinese + TtsModel( + model_dir="vits-zh-aishell3", + model_name="vits-aishell3.onnx", + lang="zh", + rule_fsts="vits-zh-aishell3/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-doom", + model_name="doom.onnx", + lang="zh", + rule_fsts="vits-zh-hf-doom/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-echo", + model_name="echo.onnx", + lang="zh", + rule_fsts="vits-zh-hf-echo/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-zenyatta", + model_name="zenyatta.onnx", + lang="zh", + rule_fsts="vits-zh-hf-zenyatta/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-abyssinvoker", + model_name="abyssinvoker.onnx", + lang="zh", + rule_fsts="vits-zh-hf-abyssinvoker/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-keqing", + model_name="keqing.onnx", + lang="zh", + rule_fsts="vits-zh-hf-keqing/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-eula", + model_name="eula.onnx", + lang="zh", + rule_fsts="vits-zh-hf-eula/rule.fst", + ), + TtsModel( + model_dir="vits-zh-hf-bronya", + model_name="bronya.onnx", + lang="zh", + rule_fsts="vits-zh-hf-bronya/rule.fst", + ), TtsModel( - model_dir="vits-zh-aishell3", model_name="vits-aishell3.onnx", lang="zh" + model_dir="vits-zh-hf-theresa", + model_name="theresa.onnx", + lang="zh", + rule_fsts="vits-zh-hf-theresa/rule.fst", ), # English (US) # fmt: off diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index fce7bc8de..a0da2fa4a 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -196,8 +196,14 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( std::vector ans; - auto sil = token2id_.at("sil"); - auto eos = token2id_.at("eos"); + int32_t sil = -1; + int32_t eos = -1; + if (token2id_.count("sil")) { + sil = token2id_.at("sil"); + eos = token2id_.at("eos"); + } else { + sil = 0; + } ans.push_back(sil); @@ -216,7 +222,9 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( ans.insert(ans.end(), token_ids.begin(), token_ids.end()); } ans.push_back(sil); - ans.push_back(eos); + if (eos != -1) { + ans.push_back(eos); + } return ans; } diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index d93f53400..da5435af7 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -10,15 +10,17 @@ #include #if __ANDROID_API__ >= 9 +#include + #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif - #include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/offline-tts-vits-model.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -52,7 +54,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { model_->Punctuations(), model_->Language(), config.model.debug, model_->IsPiper()) { if (!config.rule_fsts.empty()) { - SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android"); + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + tn_list_.push_back(std::make_unique(is)); + } } } #endif diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 299f7d751..0a039f649 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -566,6 +566,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { ans.model.provider = p; env->ReleaseStringUTFChars(s, p); + // for ruleFsts + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fsts = p; + env->ReleaseStringUTFChars(s, p); + return ans; }