From d9e25ac60b1718fc7dfb9de1814c9c8473c453d8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 10 Nov 2023 16:24:11 +0800 Subject: [PATCH] Support VITS TTS models from coqui-ai/TTS (#416) * Support VITS TTS models from coqui-ai/TTS * release v1.8.9 --- CMakeLists.txt | 2 +- sherpa-onnx/csrc/lexicon.cc | 29 ++++++++++++++++------ sherpa-onnx/csrc/offline-tts-vits-model.cc | 3 ++- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ca1af9e78..f821c5c84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.8.8") +set(SHERPA_ONNX_VERSION "1.8.9") # Disable warning about # diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index a0da2fa4a..9cdfd8f47 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -196,20 +196,27 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( std::vector ans; + int32_t blank = -1; + if (token2id_.count(" ")) { + blank = token2id_.at(" "); + } + int32_t sil = -1; int32_t eos = -1; if (token2id_.count("sil")) { sil = token2id_.at("sil"); eos = token2id_.at("eos"); - } else { - sil = 0; } - ans.push_back(sil); + if (sil != -1) { + ans.push_back(sil); + } for (const auto &w : words) { if (punctuations_.count(w)) { - ans.push_back(sil); + if (sil != -1) { + ans.push_back(sil); + } continue; } @@ -220,11 +227,19 @@ std::vector Lexicon::ConvertTextToTokenIdsChinese( const auto &token_ids = word2ids_.at(w); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); + if (blank != -1) { + ans.push_back(blank); + } + } + + if (sil != -1) { + ans.push_back(sil); } - ans.push_back(sil); + if (eos != -1) { ans.push_back(eos); } + return ans; } @@ -252,7 +267,7 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( int32_t blank = token2id_.at(" "); std::vector ans; - if (is_piper_) { + if (is_piper_ && token2id_.count("^")) { ans.push_back(token2id_.at("^")); // sos } @@ -277,7 +292,7 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( ans.resize(ans.size() - 1); } - if (is_piper_) { + if (is_piper_ && token2id_.count("$")) { ans.push_back(token2id_.at("$")); // eos } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index ab14b55de..dafe5052a 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -81,7 +81,8 @@ class OfflineTtsVitsModel::Impl { std::string comment; SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); - if (comment.find("piper") != std::string::npos) { + if (comment.find("piper") != std::string::npos || + comment.find("coqui") != std::string::npos) { is_piper_ = true; } }