-
Notifications
You must be signed in to change notification settings - Fork 477
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support coqui-ai/TTS VITS models using Characters.
- Loading branch information
1 parent
23cf92d
commit 7501ae9
Showing
8 changed files
with
298 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
// sherpa-onnx/csrc/offline-tts-character-frontend.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#if __ANDROID_API__ >= 9 | ||
#include <strstream> | ||
|
||
#include "android/asset_manager.h" | ||
#include "android/asset_manager_jni.h" | ||
#endif | ||
#include <algorithm> | ||
#include <cctype> | ||
#include <codecvt> | ||
#include <fstream> | ||
#include <locale> | ||
#include <sstream> | ||
#include <utility> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h" | ||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) { | ||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; | ||
std::unordered_map<char32_t, int32_t> token2id; | ||
|
||
std::string line; | ||
|
||
std::string sym; | ||
std::u32string s; | ||
int32_t id; | ||
while (std::getline(is, line)) { | ||
std::istringstream iss(line); | ||
iss >> sym; | ||
if (iss.eof()) { | ||
id = atoi(sym.c_str()); | ||
sym = " "; | ||
} else { | ||
iss >> id; | ||
} | ||
|
||
// eat the trailing \r\n on windows | ||
iss >> std::ws; | ||
if (!iss.eof()) { | ||
SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); | ||
exit(-1); | ||
} | ||
|
||
// Form models from coqui-ai/TTS, we have saved the IDs of the following | ||
// symbols in OfflineTtsVitsModelMetaData, so it is safe to skip them here. | ||
if (sym == "<PAD>" || sym == "<EOS>" || sym == "<BOS>" || sym == "<BLNK>") { | ||
continue; | ||
} | ||
|
||
s = conv.from_bytes(sym); | ||
if (s.size() != 1) { | ||
SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", | ||
line.c_str(), static_cast<int32_t>(s.size())); | ||
exit(-1); | ||
} | ||
|
||
char32_t c = s[0]; | ||
|
||
if (token2id.count(c)) { | ||
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", | ||
sym.c_str(), line.c_str(), token2id.at(c)); | ||
exit(-1); | ||
} | ||
|
||
token2id.insert({c, id}); | ||
} | ||
|
||
return token2id; | ||
} | ||
|
||
OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( | ||
const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data) | ||
: meta_data_(meta_data) { | ||
std::ifstream is(tokens); | ||
token2id_ = ReadTokens(is); | ||
} | ||
|
||
#if __ANDROID_API__ >= 9 | ||
OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( | ||
AAssetManager *mgr, const std::string &tokens, | ||
const OfflineTtsVitsModelMetaData &meta_data) | ||
: meta_data_(meta_data) { | ||
auto buf = ReadFile(mgr, tokens); | ||
std::istrstream is(buf.data(), buf.size()); | ||
token2id_ = ReadTokens(is); | ||
} | ||
|
||
#endif | ||
|
||
std::vector<std::vector<int64_t>> | ||
OfflineTtsCharacterFrontend::ConvertTextToTokenIds( | ||
const std::string &text, const std::string &voice /*= ""*/) const { | ||
// see | ||
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 | ||
int32_t use_eos_bos = meta_data_.use_eos_bos; | ||
int32_t bos_id = meta_data_.bos_id; | ||
int32_t eos_id = meta_data_.eos_id; | ||
int32_t blank_id = meta_data_.blank_id; | ||
int32_t add_blank = meta_data_.add_blank; | ||
|
||
// Note: No need to convert text to lowercase since tokens.txt | ||
// is assumed to contain both lowercase and uppercase tokens. | ||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv; | ||
std::u32string s = conv.from_bytes(text); | ||
|
||
std::vector<std::vector<int64_t>> ans; | ||
|
||
std::vector<int64_t> this_sentence; | ||
if (add_blank) { | ||
if (use_eos_bos) { | ||
this_sentence.push_back(bos_id); | ||
} | ||
|
||
this_sentence.push_back(blank_id); | ||
|
||
for (char32_t c : s) { | ||
if (token2id_.count(c)) { | ||
this_sentence.push_back(token2id_.at(c)); | ||
this_sentence.push_back(blank_id); | ||
} else { | ||
SHERPA_ONNX_LOGE("Skip unknown character. Unicode codepoint: \\U+%04x.", | ||
static_cast<uint32_t>(c)); | ||
} | ||
|
||
if (c == '.' || c == ':' || c == '?' || c == '!') { | ||
// end of a sentence | ||
if (use_eos_bos) { | ||
this_sentence.push_back(eos_id); | ||
} | ||
|
||
ans.push_back(std::move(this_sentence)); | ||
|
||
// re-initialize this_sentence | ||
if (use_eos_bos) { | ||
this_sentence.push_back(bos_id); | ||
} | ||
this_sentence.push_back(blank_id); | ||
} | ||
} | ||
|
||
if (use_eos_bos) { | ||
this_sentence.push_back(eos_id); | ||
} | ||
|
||
if (this_sentence.size() > 1 + use_eos_bos) { | ||
ans.push_back(std::move(this_sentence)); | ||
} | ||
} else { | ||
// not adding blank | ||
if (use_eos_bos) { | ||
this_sentence.push_back(bos_id); | ||
} | ||
|
||
for (char32_t c : s) { | ||
if (token2id_.count(c)) { | ||
this_sentence.push_back(token2id_.at(c)); | ||
} | ||
|
||
if (c == '.' || c == ':' || c == '?' || c == '!') { | ||
// end of a sentence | ||
if (use_eos_bos) { | ||
this_sentence.push_back(eos_id); | ||
} | ||
|
||
ans.push_back(std::move(this_sentence)); | ||
|
||
// re-initialize this_sentence | ||
if (use_eos_bos) { | ||
this_sentence.push_back(bos_id); | ||
} | ||
} | ||
} | ||
|
||
if (this_sentence.size() > 1) { | ||
ans.push_back(std::move(this_sentence)); | ||
} | ||
} | ||
|
||
return ans; | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// sherpa-onnx/csrc/offline-tts-character-frontend.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ | ||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ | ||
#include <cstdint> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#if __ANDROID_API__ >= 9 | ||
#include "android/asset_manager.h" | ||
#include "android/asset_manager_jni.h" | ||
#endif | ||
|
||
#include "sherpa-onnx/csrc/offline-tts-frontend.h" | ||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { | ||
public: | ||
OfflineTtsCharacterFrontend(const std::string &tokens, | ||
const OfflineTtsVitsModelMetaData &meta_data); | ||
|
||
#if __ANDROID_API__ >= 9 | ||
OfflineTtsCharacterFrontend(AAssetManager *mgr, const std::string &tokens, | ||
const OfflineTtsVitsModelMetaData &meta_data); | ||
|
||
#endif | ||
/** Convert a string to token IDs. | ||
* | ||
* @param text The input text. | ||
* Example 1: "This is the first sample sentence; this is the | ||
* second one." Example 2: "这是第一句。这是第二句。" | ||
* @param voice Optional. It is for espeak-ng. | ||
* | ||
* @return Return a vector-of-vector of token IDs. Each subvector contains | ||
* a sentence that can be processed independently. | ||
* If a frontend does not support splitting the text into | ||
* sentences, the resulting vector contains only one subvector. | ||
*/ | ||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds( | ||
const std::string &text, const std::string &voice = "") const override; | ||
|
||
private: | ||
OfflineTtsVitsModelMetaData meta_data_; | ||
std::unordered_map<char32_t, int32_t> token2id_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.