Skip to content

Commit

Permalink
Add C++ runtime support for MeloTTS English model
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Nov 3, 2024
1 parent bde6287 commit cbef4ff
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 18 deletions.
67 changes: 53 additions & 14 deletions sherpa-onnx/csrc/melo-tts-lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl {
}
}

Impl(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
{
std::ifstream is(tokens);
InitTokens(is);
}

{
std::ifstream is(lexicon);
InitLexicon(is);
}
}

std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
Expand All @@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl {
s = std::regex_replace(s, punct_re4, "!");

std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);

if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());

std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
if (jieba_) {
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);

if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());

std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}

SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
} else {
words = SplitUtf8(text);

if (debug_) {
fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str());
fprintf(stderr, "Input text in bytes (lowercase):");
for (uint8_t c : text) {
fprintf(stderr, " %02x", c);
}
fprintf(stderr, "\n");
fprintf(stderr, "After splitting to words:");
for (const auto &w : words) {
fprintf(stderr, " %s", w.c_str());
}
fprintf(stderr, "\n");
}
}

std::vector<TokenIDs> ans;
Expand Down Expand Up @@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl {
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}

// For Chinese+English MeloTTS
word2ids_[""] = word2ids_[""];
word2ids_[""] = word2ids_[""];
}
Expand Down Expand Up @@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}

MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, meta_data, debug)) {}

std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/melo-tts-lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);

MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);

std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.is_melo_tts && meta_data.language == "English") {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
model_->GetMetaData(), config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl {
}

Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
if (meta_data_.num_speakers == 1) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/onnx-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
// For other versions, we may need to change it
#if ORT_API_VERSION >= 12
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
return v ? v.get() : "";
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
std::string ans = v ? v : "";
allocator->Free(allocator, v);
return ans;
#endif
Expand Down

0 comments on commit cbef4ff

Please sign in to comment.