Skip to content

Commit

Permalink
Support coqui-ai/TTS VITS models using Characters.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 6, 2023
1 parent 23cf92d commit 7501ae9
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 26 deletions.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-tts-character-frontend.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
Expand Down
189 changes: 189 additions & 0 deletions sherpa-onnx/csrc/offline-tts-character-frontend.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// sherpa-onnx/csrc/offline-tts-character-frontend.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <algorithm>
#include <cctype>
#include <codecvt>
#include <fstream>
#include <locale>
#include <sstream>
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
std::unordered_map<char32_t, int32_t> token2id;

std::string line;

std::string sym;
std::u32string s;
int32_t id;
while (std::getline(is, line)) {
std::istringstream iss(line);
iss >> sym;
if (iss.eof()) {
id = atoi(sym.c_str());
sym = " ";
} else {
iss >> id;
}

// eat the trailing \r\n on windows
iss >> std::ws;
if (!iss.eof()) {
SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str());
exit(-1);
}

// Form models from coqui-ai/TTS, we have saved the IDs of the following
// symbols in OfflineTtsVitsModelMetaData, so it is safe to skip them here.
if (sym == "<PAD>" || sym == "<EOS>" || sym == "<BOS>" || sym == "<BLNK>") {
continue;
}

s = conv.from_bytes(sym);
if (s.size() != 1) {
SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d",
line.c_str(), static_cast<int32_t>(s.size()));
exit(-1);
}

char32_t c = s[0];

if (token2id.count(c)) {
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
sym.c_str(), line.c_str(), token2id.at(c));
exit(-1);
}

token2id.insert({c, id});
}

return token2id;
}

OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data)
: meta_data_(meta_data) {
std::ifstream is(tokens);
token2id_ = ReadTokens(is);
}

#if __ANDROID_API__ >= 9
OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
AAssetManager *mgr, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data)
: meta_data_(meta_data) {
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
token2id_ = ReadTokens(is);
}

#endif

std::vector<std::vector<int64_t>>
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
const std::string &text, const std::string &voice /*= ""*/) const {
// see
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
int32_t use_eos_bos = meta_data_.use_eos_bos;
int32_t bos_id = meta_data_.bos_id;
int32_t eos_id = meta_data_.eos_id;
int32_t blank_id = meta_data_.blank_id;
int32_t add_blank = meta_data_.add_blank;

// Note: No need to convert text to lowercase since tokens.txt
// is assumed to contain both lowercase and uppercase tokens.
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
std::u32string s = conv.from_bytes(text);

std::vector<std::vector<int64_t>> ans;

std::vector<int64_t> this_sentence;
if (add_blank) {
if (use_eos_bos) {
this_sentence.push_back(bos_id);
}

this_sentence.push_back(blank_id);

for (char32_t c : s) {
if (token2id_.count(c)) {
this_sentence.push_back(token2id_.at(c));
this_sentence.push_back(blank_id);
} else {
SHERPA_ONNX_LOGE("Skip unknown character. Unicode codepoint: \\U+%04x.",
static_cast<uint32_t>(c));
}

if (c == '.' || c == ':' || c == '?' || c == '!') {
// end of a sentence
if (use_eos_bos) {
this_sentence.push_back(eos_id);
}

ans.push_back(std::move(this_sentence));

// re-initialize this_sentence
if (use_eos_bos) {
this_sentence.push_back(bos_id);
}
this_sentence.push_back(blank_id);
}
}

if (use_eos_bos) {
this_sentence.push_back(eos_id);
}

if (this_sentence.size() > 1 + use_eos_bos) {
ans.push_back(std::move(this_sentence));
}
} else {
// not adding blank
if (use_eos_bos) {
this_sentence.push_back(bos_id);
}

for (char32_t c : s) {
if (token2id_.count(c)) {
this_sentence.push_back(token2id_.at(c));
}

if (c == '.' || c == ':' || c == '?' || c == '!') {
// end of a sentence
if (use_eos_bos) {
this_sentence.push_back(eos_id);
}

ans.push_back(std::move(this_sentence));

// re-initialize this_sentence
if (use_eos_bos) {
this_sentence.push_back(bos_id);
}
}
}

if (this_sentence.size() > 1) {
ans.push_back(std::move(this_sentence));
}
}

return ans;
}

} // namespace sherpa_onnx
54 changes: 54 additions & 0 deletions sherpa-onnx/csrc/offline-tts-character-frontend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// sherpa-onnx/csrc/offline-tts-character-frontend.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"

namespace sherpa_onnx {

class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
public:
OfflineTtsCharacterFrontend(const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data);

#if __ANDROID_API__ >= 9
OfflineTtsCharacterFrontend(AAssetManager *mgr, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data);

#endif
/** Convert a string to token IDs.
*
* @param text The input text.
* Example 1: "This is the first sample sentence; this is the
* second one." Example 2: "这是第一句。这是第二句。"
* @param voice Optional. It is for espeak-ng.
*
* @return Return a vector-of-vector of token IDs. Each subvector contains
* a sentence that can be processed independently.
* If a frontend does not support splitting the text into
* sentences, the resulting vector contains only one subvector.
*/
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;

private:
OfflineTtsVitsModelMetaData meta_data_;
std::unordered_map<char32_t, int32_t> token2id_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
32 changes: 27 additions & 5 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
Expand Down Expand Up @@ -116,7 +117,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
return {};
}

if (meta_data.add_blank && config_.model.vits.data_dir.empty()) {
// TODO(fangjun): add blank inside the frontend, not here
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
meta_data.frontend != "characters") {
for (auto &k : x) {
k = AddBlank(k);
}
Expand Down Expand Up @@ -195,12 +198,22 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
void InitFrontend(AAssetManager *mgr) {
const auto &meta_data = model_->GetMetaData();

if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
if (meta_data.frontend == "characters") {
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
mgr, config_.model.vits.tokens, meta_data);
} else if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
mgr, config_.model.vits.tokens, config_.model.vits.data_dir,
meta_data);
} else {
if (config_.model.vits.lexicon.empty()) {
SHERPA_ONNX_LOGE(
"Not a model using characters as modeling unit. Please provide "
"--vits-lexicon if you leave --vits-data-dir empty");
exit(-1);
}

frontend_ = std::make_unique<Lexicon>(
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
meta_data.punctuations, meta_data.language, config_.model.debug);
Expand All @@ -211,12 +224,21 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
void InitFrontend() {
const auto &meta_data = model_->GetMetaData();

if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
if (meta_data.frontend == "characters") {
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
config_.model.vits.tokens, meta_data);
} else if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
config_.model.vits.tokens, config_.model.vits.data_dir,
model_->GetMetaData());
} else {
if (config_.model.vits.lexicon.empty()) {
SHERPA_ONNX_LOGE(
"Not a model using characters as modeling unit. Please provide "
"--vits-lexicon if you leave --vits-data-dir empty");
exit(-1);
}
frontend_ = std::make_unique<Lexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
meta_data.punctuations, meta_data.language, config_.model.debug);
Expand Down
14 changes: 1 addition & 13 deletions sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
return false;
}

if (data_dir.empty()) {
if (lexicon.empty()) {
SHERPA_ONNX_LOGE(
"Please provide --vits-lexicon if you leave --vits-data-dir empty");
return false;
}

if (!FileExists(lexicon)) {
SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str());
return false;
}

} else {
if (!data_dir.empty()) {
if (!FileExists(data_dir + "/phontab")) {
SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test",
data_dir.c_str());
Expand Down
15 changes: 10 additions & 5 deletions sherpa-onnx/csrc/offline-tts-vits-model-metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

namespace sherpa_onnx {

// If you are not sure what each field means, please
// have a look of the Python file in the model directory that
// you have downloaded.
struct OfflineTtsVitsModelMetaData {
int32_t sample_rate;
int32_t sample_rate = 0;
int32_t add_blank = 0;
int32_t num_speakers = 0;

std::string punctuations;
std::string language;
std::string voice;

bool is_piper = false;
bool is_coqui = false;

Expand All @@ -27,6 +26,12 @@ struct OfflineTtsVitsModelMetaData {
int32_t bos_id = 0;
int32_t eos_id = 0;
int32_t use_eos_bos = 0;
int32_t pad_id = 0;

std::string punctuations;
std::string language;
std::string voice;
std::string frontend; // characters
};

} // namespace sherpa_onnx
Expand Down
Loading

0 comments on commit 7501ae9

Please sign in to comment.