From 7501ae9eaeae8d1fa7c86289f9bcb5666dd3bcdd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 6 Dec 2023 13:57:21 +0800 Subject: [PATCH] Support coqui-ai/TTS VITS models using Characters. --- sherpa-onnx/csrc/CMakeLists.txt | 1 + .../csrc/offline-tts-character-frontend.cc | 189 ++++++++++++++++++ .../csrc/offline-tts-character-frontend.h | 54 +++++ sherpa-onnx/csrc/offline-tts-vits-impl.h | 32 ++- .../csrc/offline-tts-vits-model-config.cc | 14 +- .../csrc/offline-tts-vits-model-metadata.h | 15 +- sherpa-onnx/csrc/offline-tts-vits-model.cc | 18 +- sherpa-onnx/csrc/piper-phonemize-lexicon.cc | 1 - 8 files changed, 298 insertions(+), 26 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-tts-character-frontend.cc create mode 100644 sherpa-onnx/csrc/offline-tts-character-frontend.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index c2cd2bca5..e9717996f 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -41,6 +41,7 @@ set(sources offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + offline-tts-character-frontend.cc offline-wenet-ctc-model-config.cc offline-wenet-ctc-model.cc offline-whisper-greedy-search-decoder.cc diff --git a/sherpa-onnx/csrc/offline-tts-character-frontend.cc b/sherpa-onnx/csrc/offline-tts-character-frontend.cc new file mode 100644 index 000000000..befcb7bb8 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-character-frontend.cc @@ -0,0 +1,189 @@ +// sherpa-onnx/csrc/offline-tts-character-frontend.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-tts-character-frontend.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static std::unordered_map ReadTokens(std::istream &is) { + std::wstring_convert, char32_t> conv; + std::unordered_map token2id; + + std::string line; + + std::string sym; + std::u32string s; + int32_t id; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); + exit(-1); + } + + // Form models from coqui-ai/TTS, we have saved the IDs of the following + // symbols in OfflineTtsVitsModelMetaData, so it is safe to skip them here. + if (sym == "" || sym == "" || sym == "" || sym == "") { + continue; + } + + s = conv.from_bytes(sym); + if (s.size() != 1) { + SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", + line.c_str(), static_cast(s.size())); + exit(-1); + } + + char32_t c = s[0]; + + if (token2id.count(c)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(c)); + exit(-1); + } + + token2id.insert({c, id}); + } + + return token2id; +} + +OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); +} + +#if __ANDROID_API__ >= 9 +OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + AAssetManager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + token2id_ = ReadTokens(is); +} + +#endif + +std::vector> +OfflineTtsCharacterFrontend::ConvertTextToTokenIds( + const std::string &text, const std::string &voice /*= ""*/) const { + // see + // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 + int32_t use_eos_bos = meta_data_.use_eos_bos; + int32_t bos_id = meta_data_.bos_id; + int32_t eos_id = meta_data_.eos_id; + int32_t blank_id = meta_data_.blank_id; + int32_t add_blank = meta_data_.add_blank; + + // Note: No need to convert text to lowercase since tokens.txt + // is assumed to contain both lowercase and uppercase tokens. + std::wstring_convert, char32_t> conv; + std::u32string s = conv.from_bytes(text); + + std::vector> ans; + + std::vector this_sentence; + if (add_blank) { + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + + this_sentence.push_back(blank_id); + + for (char32_t c : s) { + if (token2id_.count(c)) { + this_sentence.push_back(token2id_.at(c)); + this_sentence.push_back(blank_id); + } else { + SHERPA_ONNX_LOGE("Skip unknown character. Unicode codepoint: \\U+%04x.", + static_cast(c)); + } + + if (c == '.' || c == ':' || c == '?' || c == '!') { + // end of a sentence + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + ans.push_back(std::move(this_sentence)); + + // re-initialize this_sentence + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + this_sentence.push_back(blank_id); + } + } + + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + if (this_sentence.size() > 1 + use_eos_bos) { + ans.push_back(std::move(this_sentence)); + } + } else { + // not adding blank + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + + for (char32_t c : s) { + if (token2id_.count(c)) { + this_sentence.push_back(token2id_.at(c)); + } + + if (c == '.' || c == ':' || c == '?' || c == '!') { + // end of a sentence + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + ans.push_back(std::move(this_sentence)); + + // re-initialize this_sentence + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + } + } + + if (this_sentence.size() > 1) { + ans.push_back(std::move(this_sentence)); + } + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-character-frontend.h b/sherpa-onnx/csrc/offline-tts-character-frontend.h new file mode 100644 index 000000000..d56ea3125 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-character-frontend.h @@ -0,0 +1,54 @@ +// sherpa-onnx/csrc/offline-tts-character-frontend.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/offline-tts-frontend.h" +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" + +namespace sherpa_onnx { + +class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { + public: + OfflineTtsCharacterFrontend(const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + +#if __ANDROID_API__ >= 9 + OfflineTtsCharacterFrontend(AAssetManager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + +#endif + /** Convert a string to token IDs. + * + * @param text The input text. + * Example 1: "This is the first sample sentence; this is the + * second one." Example 2: "这是第一句。这是第二句。" + * @param voice Optional. It is for espeak-ng. + * + * @return Return a vector-of-vector of token IDs. Each subvector contains + * a sentence that can be processed independently. + * If a frontend does not support splitting the text into + * sentences, the resulting vector contains only one subvector. + */ + std::vector> ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const override; + + private: + OfflineTtsVitsModelMetaData meta_data_; + std::unordered_map token2id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index f1c043204..4ac12bab6 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -18,6 +18,7 @@ #include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-tts-character-frontend.h" #include "sherpa-onnx/csrc/offline-tts-frontend.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/offline-tts-vits-model.h" @@ -116,7 +117,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { return {}; } - if (meta_data.add_blank && config_.model.vits.data_dir.empty()) { + // TODO(fangjun): add blank inside the frontend, not here + if (meta_data.add_blank && config_.model.vits.data_dir.empty() && + meta_data.frontend != "characters") { for (auto &k : x) { k = AddBlank(k); } @@ -195,12 +198,22 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { void InitFrontend(AAssetManager *mgr) { const auto &meta_data = model_->GetMetaData(); - if ((meta_data.is_piper || meta_data.is_coqui) && - !config_.model.vits.data_dir.empty()) { + if (meta_data.frontend == "characters") { + frontend_ = std::make_unique( + mgr, config_.model.vits.tokens, meta_data); + } else if ((meta_data.is_piper || meta_data.is_coqui) && + !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( mgr, config_.model.vits.tokens, config_.model.vits.data_dir, meta_data); } else { + if (config_.model.vits.lexicon.empty()) { + SHERPA_ONNX_LOGE( + "Not a model using characters as modeling unit. Please provide " + "--vits-lexicon if you leave --vits-data-dir empty"); + exit(-1); + } + frontend_ = std::make_unique( mgr, config_.model.vits.lexicon, config_.model.vits.tokens, meta_data.punctuations, meta_data.language, config_.model.debug); @@ -211,12 +224,21 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { void InitFrontend() { const auto &meta_data = model_->GetMetaData(); - if ((meta_data.is_piper || meta_data.is_coqui) && - !config_.model.vits.data_dir.empty()) { + if (meta_data.frontend == "characters") { + frontend_ = std::make_unique( + config_.model.vits.tokens, meta_data); + } else if ((meta_data.is_piper || meta_data.is_coqui) && + !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( config_.model.vits.tokens, config_.model.vits.data_dir, model_->GetMetaData()); } else { + if (config_.model.vits.lexicon.empty()) { + SHERPA_ONNX_LOGE( + "Not a model using characters as modeling unit. Please provide " + "--vits-lexicon if you leave --vits-data-dir empty"); + exit(-1); + } frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, meta_data.punctuations, meta_data.language, config_.model.debug); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index 22ccec354..3d35726fe 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -44,19 +44,7 @@ bool OfflineTtsVitsModelConfig::Validate() const { return false; } - if (data_dir.empty()) { - if (lexicon.empty()) { - SHERPA_ONNX_LOGE( - "Please provide --vits-lexicon if you leave --vits-data-dir empty"); - return false; - } - - if (!FileExists(lexicon)) { - SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str()); - return false; - } - - } else { + if (!data_dir.empty()) { if (!FileExists(data_dir + "/phontab")) { SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test", data_dir.c_str()); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h index 9356519aa..60e375540 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h @@ -10,15 +10,14 @@ namespace sherpa_onnx { +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. struct OfflineTtsVitsModelMetaData { - int32_t sample_rate; + int32_t sample_rate = 0; int32_t add_blank = 0; int32_t num_speakers = 0; - std::string punctuations; - std::string language; - std::string voice; - bool is_piper = false; bool is_coqui = false; @@ -27,6 +26,12 @@ struct OfflineTtsVitsModelMetaData { int32_t bos_id = 0; int32_t eos_id = 0; int32_t use_eos_bos = 0; + int32_t pad_id = 0; + + std::string punctuations; + std::string language; + std::string voice; + std::string frontend; // characters }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index b0604a6b5..d3672ed6b 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -87,13 +87,18 @@ class OfflineTtsVitsModel::Impl { SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, "punctuation", ""); SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", ""); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.frontend, "frontend", + ""); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos, "use_eos_bos", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.pad_id, "pad_id", 0); std::string comment; SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); @@ -142,16 +147,25 @@ class OfflineTtsVitsModel::Impl { Ort::Value sid_tensor = Ort::Value::CreateTensor(memory_info, &sid, 1, &sid_shape, 1); + int64_t lang_id_shape = 1; + int64_t lang_id = 0; + Ort::Value lang_id_tensor = + Ort::Value::CreateTensor(memory_info, &lang_id, 1, &lang_id_shape, 1); + std::vector inputs; - inputs.reserve(4); + inputs.reserve(5); inputs.push_back(std::move(x)); inputs.push_back(std::move(x_length)); inputs.push_back(std::move(scales_tensor)); - if (input_names_.size() == 4 && input_names_.back() == "sid") { + if (input_names_.size() >= 4 && input_names_[3] == "sid") { inputs.push_back(std::move(sid_tensor)); } + if (input_names_.size() >= 5 && input_names_[4] == "langid") { + inputs.push_back(std::move(lang_id_tensor)); + } + auto out = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), output_names_ptr_.data(), output_names_ptr_.size()); diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc index 20512e6cb..476e872d5 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc @@ -123,7 +123,6 @@ static std::vector CoquiPhonemesToIds( int32_t blank_id = meta_data.blank_id; int32_t add_blank = meta_data.add_blank; int32_t comma_id = token2id.at(','); - SHERPA_ONNX_LOGE("comma id: %d", comma_id); std::vector ans; if (add_blank) {