diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 0b57e3f46..c42184124 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -88,8 +88,8 @@ static std::vector ConvertTokensToIds( Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations, const std::string &language, - bool debug /*= false*/, bool is_piper /*= false*/) - : debug_(debug), is_piper_(is_piper) { + bool debug /*= false*/) + : debug_(debug) { InitLanguage(language); { @@ -108,9 +108,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, #if __ANDROID_API__ >= 9 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, const std::string &punctuations, - const std::string &language, bool debug /*= false*/, - bool is_piper /*= false*/) - : debug_(debug), is_piper_(is_piper) { + const std::string &language, bool debug /*= false*/ + ) + : debug_(debug) { InitLanguage(language); { @@ -132,16 +132,10 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, std::vector> Lexicon::ConvertTextToTokenIds( const std::string &text, const std::string & /*voice*/ /*= ""*/) const { switch (language_) { - case Language::kEnglish: - return ConvertTextToTokenIdsEnglish(text); - case Language::kGerman: - return ConvertTextToTokenIdsGerman(text); - case Language::kSpanish: - return ConvertTextToTokenIdsSpanish(text); - case Language::kFrench: - return ConvertTextToTokenIdsFrench(text); case Language::kChinese: return ConvertTextToTokenIdsChinese(text); + case Language::kNotChinese: + return ConvertTextToTokenIdsNotChinese(text); default: SHERPA_ONNX_LOGE("Unknown language: %d", static_cast(language_)); exit(-1); @@ -197,7 +191,8 @@ std::vector> Lexicon::ConvertTextToTokenIdsChinese( fprintf(stderr, "\n"); } - std::vector ans; + std::vector> ans; + std::vector this_sentence; int32_t blank = -1; if (token2id_.count(" ")) { @@ -212,15 +207,32 @@ std::vector> Lexicon::ConvertTextToTokenIdsChinese( } if (sil != -1) { - ans.push_back(sil); + this_sentence.push_back(sil); } for (const auto &w : words) { - if (punctuations_.count(w)) { - if (token2id_.count(w)) { - ans.push_back(token2id_.at(w)); - } else if (sil != -1) { - ans.push_back(sil); + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" || + w == "。" || w == ";" || w == "!" || w == "?" || w == ":" || + w == "”" || + // not sentence break + w == "," || w == "“" || w == "," || w == "、") { + if (punctuations_.count(w)) { + if (token2id_.count(w)) { + this_sentence.push_back(token2id_.at(w)); + } else if (sil != -1) { + this_sentence.push_back(sil); + } + } + + if (w != "," && w != "“" && w != "," && w != "、") { + if (eos != -1) { + this_sentence.push_back(eos); + } + ans.push_back(std::move(this_sentence)); + + if (sil != -1) { + this_sentence.push_back(sil); + } } continue; } @@ -231,24 +243,26 @@ std::vector> Lexicon::ConvertTextToTokenIdsChinese( } const auto &token_ids = word2ids_.at(w); - ans.insert(ans.end(), token_ids.begin(), token_ids.end()); + this_sentence.insert(this_sentence.end(), token_ids.begin(), + token_ids.end()); if (blank != -1) { - ans.push_back(blank); + this_sentence.push_back(blank); } } if (sil != -1) { - ans.push_back(sil); + this_sentence.push_back(sil); } if (eos != -1) { - ans.push_back(eos); + this_sentence.push_back(eos); } + ans.push_back(std::move(this_sentence)); - return {ans}; + return ans; } -std::vector> Lexicon::ConvertTextToTokenIdsEnglish( +std::vector> Lexicon::ConvertTextToTokenIdsNotChinese( const std::string &_text) const { std::string text(_text); ToLowerCase(&text); @@ -271,14 +285,22 @@ std::vector> Lexicon::ConvertTextToTokenIdsEnglish( int32_t blank = token2id_.at(" "); - std::vector ans; - if (is_piper_ && token2id_.count("^")) { - ans.push_back(token2id_.at("^")); // sos - } + std::vector> ans; + std::vector this_sentence; for (const auto &w : words) { - if (punctuations_.count(w)) { - ans.push_back(token2id_.at(w)); + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" || + // not sentence break + w == ",") { + if (punctuations_.count(w)) { + this_sentence.push_back(token2id_.at(w)); + } + + if (w != ",") { + this_sentence.push_back(blank); + ans.push_back(std::move(this_sentence)); + } + continue; } @@ -288,20 +310,21 @@ std::vector> Lexicon::ConvertTextToTokenIdsEnglish( } const auto &token_ids = word2ids_.at(w); - ans.insert(ans.end(), token_ids.begin(), token_ids.end()); - ans.push_back(blank); + this_sentence.insert(this_sentence.end(), token_ids.begin(), + token_ids.end()); + this_sentence.push_back(blank); } - if (!ans.empty()) { + if (!this_sentence.empty()) { // remove the last blank - ans.resize(ans.size() - 1); + this_sentence.resize(this_sentence.size() - 1); } - if (is_piper_ && token2id_.count("$")) { - ans.push_back(token2id_.at("$")); // eos + if (!this_sentence.empty()) { + ans.push_back(std::move(this_sentence)); } - return {ans}; + return ans; } void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } @@ -309,16 +332,10 @@ void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } void Lexicon::InitLanguage(const std::string &_lang) { std::string lang(_lang); ToLowerCase(&lang); - if (lang == "english") { - language_ = Language::kEnglish; - } else if (lang == "german") { - language_ = Language::kGerman; - } else if (lang == "spanish") { - language_ = Language::kSpanish; - } else if (lang == "french") { - language_ = Language::kFrench; - } else if (lang == "chinese") { + if (lang == "chinese") { language_ = Language::kChinese; + } else if (!lang.empty()) { + language_ = Language::kNotChinese; } else { SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 197b0afe6..97b0ff7ba 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -29,35 +29,19 @@ class Lexicon : public OfflineTtsFrontend { // Note: for models from piper, we won't use this class. Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations, const std::string &language, - bool debug = false, bool is_piper = false); + bool debug = false); #if __ANDROID_API__ >= 9 Lexicon(AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, const std::string &punctuations, - const std::string &language, bool debug = false, - bool is_piper = false); + const std::string &language, bool debug = false); #endif std::vector> ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: - std::vector> ConvertTextToTokenIdsGerman( - const std::string &text) const { - return ConvertTextToTokenIdsEnglish(text); - } - - std::vector> ConvertTextToTokenIdsSpanish( - const std::string &text) const { - return ConvertTextToTokenIdsEnglish(text); - } - - std::vector> ConvertTextToTokenIdsFrench( - const std::string &text) const { - return ConvertTextToTokenIdsEnglish(text); - } - - std::vector> ConvertTextToTokenIdsEnglish( + std::vector> ConvertTextToTokenIdsNotChinese( const std::string &text) const; std::vector> ConvertTextToTokenIdsChinese( @@ -70,10 +54,7 @@ class Lexicon : public OfflineTtsFrontend { private: enum class Language { - kEnglish, - kGerman, - kSpanish, - kFrench, + kNotChinese, kChinese, kUnknown, }; @@ -84,7 +65,6 @@ class Lexicon : public OfflineTtsFrontend { std::unordered_map token2id_; Language language_; bool debug_; - bool is_piper_; // for Chinese polyphones std::unique_ptr pattern_; diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index a0a7e163b..bb6555700 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -195,8 +195,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } else { frontend_ = std::make_unique( mgr, config_.model.vits.lexicon, config_.model.vits.tokens, - model_->Punctuations(), model_->Language(), config_.model.debug, - model_->IsPiper()); + model_->Punctuations(), model_->Language(), config_.model.debug); } } #endif @@ -208,8 +207,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } else { frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, - model_->Punctuations(), model_->Language(), config_.model.debug, - model_->IsPiper()); + model_->Punctuations(), model_->Language(), config_.model.debug); } }