From c3a50f218af0ada36662fd24bac72f02a5af16f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 15 Sep 2023 21:31:20 +0800 Subject: [PATCH] add example for silero vad --- sherpa-onnx/csrc/session.cc | 4 + sherpa-onnx/csrc/session.h | 3 + .../csrc/sherpa-onnx-vad-microphone.cc | 48 +++- sherpa-onnx/csrc/silero-vad-model-config.cc | 20 +- sherpa-onnx/csrc/silero-vad-model-config.h | 10 +- sherpa-onnx/csrc/silero-vad-model.cc | 221 +++++++++++++++++- sherpa-onnx/csrc/vad-model-config.cc | 6 +- sherpa-onnx/csrc/vad-model-config.h | 10 +- 8 files changed, 298 insertions(+), 24 deletions(-) diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 80c9471d3..fe747740f 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -76,4 +76,8 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } +Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 42c93e0be..f0f25b236 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -10,6 +10,7 @@ #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/vad-model-config.h" namespace sherpa_onnx { @@ -20,6 +21,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); + +Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc index 4e0fba704..e7bc92931 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc @@ -12,10 +12,7 @@ #include #include "portaudio.h" // NOLINT -#include "sherpa-onnx/csrc/display.h" #include "sherpa-onnx/csrc/microphone.h" -#include "sherpa-onnx/csrc/parse-options.h" -#include "sherpa-onnx/csrc/vad-model-config.h" #include "sherpa-onnx/csrc/vad-model.h" bool stop = false; @@ -28,12 +25,20 @@ static int32_t RecordCallback(const void *input_buffer, const PaStreamCallbackTimeInfo * /*time_info*/, PaStreamCallbackFlags /*status_flags*/, void *user_data) { + int32_t window_size = *reinterpret_cast(user_data); + std::lock_guard lock(mutex); - queue.emplace( + std::vector samples( reinterpret_cast(input_buffer), reinterpret_cast(input_buffer) + frames_per_buffer); + if (!queue.empty() && queue.back().size() < window_size) { + queue.back().insert(queue.back().end(), samples.begin(), samples.end()); + } else { + queue.push(std::move(samples)); + } + return stop ? paComplete : paContinue; } @@ -109,13 +114,16 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx 0, // frames per buffer paClipOff, // we won't output out of range samples // so don't bother clipping them - RecordCallback, nullptr); + RecordCallback, &config.silero_vad.window_size); if (err != paNoError) { fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); exit(EXIT_FAILURE); } err = Pa_StartStream(stream); + + auto vad_model = sherpa_onnx::VadModel::Create(config); + fprintf(stderr, "Started\n"); if (err != paNoError) { @@ -123,16 +131,40 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx exit(EXIT_FAILURE); } + int32_t speech_count = 0; + int32_t non_speech_count = 0; while (!stop) { { std::lock_guard lock(mutex); - while (!queue.empty()) { - fprintf(stderr, "%d\n", (int)queue.size()); + while (!queue.empty() && + queue.front().size() >= config.silero_vad.window_size) { + bool is_speech = + vad_model->IsSpeech(queue.front().data(), queue.front().size()); + queue.pop(); + + if (is_speech) { + speech_count += 1; + non_speech_count = 0; + } else { + speech_count = 0; + non_speech_count += 1; + } + + if (speech_count == 1) { + static int32_t k = 0; + ++k; + fprintf(stderr, "Detected speech: %d\n", k); + } + + if (non_speech_count == 1) { + static int32_t k = 0; + ++k; + fprintf(stderr, "Detected non-speech: %d\n", k); + } } } Pa_Sleep(100); // sleep for 100ms - stop = true; } err = Pa_CloseStream(stream); diff --git a/sherpa-onnx/csrc/silero-vad-model-config.cc b/sherpa-onnx/csrc/silero-vad-model-config.cc index 7a5633836..8419265fe 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.cc +++ b/sherpa-onnx/csrc/silero-vad-model-config.cc @@ -12,17 +12,22 @@ namespace sherpa_onnx { void SileroVadModelConfig::Register(ParseOptions *po) { po->Register("silero-vad-model", &model, "Path to silero VAD ONNX model."); - po->Register("silero-vad-prob", &prob, + po->Register("silero-vad-threshold", &threshold, "Speech threshold. Silero VAD outputs speech probabilities for " "each audio chunk, probabilities ABOVE this value are " "considered as SPEECH. It is better to tune this parameter for " "each dataset separately, but lazy " "0.5 is pretty good for most datasets."); + po->Register( "silero-vad-min-silence-duration", &min_silence_duration, "In seconds. In the end of each speech chunk wait for " "--silero-vad-min-silence-duration seconds before separating it"); + po->Register("silero-vad-min-speech-duration", &min_speech_duration, + "In seconds. In the end of each silence chunk wait for " + "--silero-vad-min-speech-duration seconds before separating it"); + po->Register( "silero-vad-window-size", &window_size, "In samples. Audio chunks of --silero-vad-window-size samples are fed " @@ -43,15 +48,17 @@ bool SileroVadModelConfig::Validate() const { return false; } - if (prob < 0.01) { + if (threshold < 0.01) { SHERPA_ONNX_LOGE( - "Please use a larger value for --silero-vad-prob. Given: %f", prob); + "Please use a larger value for --silero-vad-threshold. Given: %f", + threshold); return false; } - if (prob >= 1) { + if (threshold >= 1) { SHERPA_ONNX_LOGE( - "Please use a smaller value for --silero-vad-prob. Given: %f", prob); + "Please use a smaller value for --silero-vad-threshold. Given: %f", + threshold); return false; } @@ -63,8 +70,9 @@ std::string SileroVadModelConfig::ToString() const { os << "SilerVadModelConfig("; os << "model=\"" << model << "\", "; - os << "prob=" << prob << ", "; + os << "threshold=" << threshold << ", "; os << "min_silence_duration=" << min_silence_duration << ", "; + os << "min_speech_duration=" << min_speech_duration << ", "; os << "window_size=" << window_size << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/silero-vad-model-config.h b/sherpa-onnx/csrc/silero-vad-model-config.h index 131554aba..9757820a6 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.h +++ b/sherpa-onnx/csrc/silero-vad-model-config.h @@ -17,16 +17,18 @@ struct SileroVadModelConfig { // // The predicted probability of a segment is larger than this // value, then it is classified as speech. - float prob = 0.5; + float threshold = 0.5; - float min_silence_duration = 0.1; // in seconds + float min_silence_duration = 0.5; // in seconds + + float min_speech_duration = 0.25; // in seconds // 512, 1024, 1536 samples for 16000 Hz // 256, 512, 768 samples for 800 Hz - int window_size = 1536; // in samples + int window_size = 512; // in samples // support only 16000 and 8000 - int32_t sample_rate = 16000; + int32_t sample_rate = 16000; // not exposed to users SileroVadModelConfig() = default; diff --git a/sherpa-onnx/csrc/silero-vad-model.cc b/sherpa-onnx/csrc/silero-vad-model.cc index 33d32d12e..a9c8a33b0 100644 --- a/sherpa-onnx/csrc/silero-vad-model.cc +++ b/sherpa-onnx/csrc/silero-vad-model.cc @@ -4,18 +4,233 @@ #include "sherpa-onnx/csrc/silero-vad-model.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + namespace sherpa_onnx { class SileroVadModel::Impl { public: - Impl(const VadModelConfig &config) : config_(config) {} + Impl(const VadModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.silero_vad.model); + Init(buf.data(), buf.size()); + + sample_rate_ = config.silero_vad.sample_rate; + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + void Reset() { + // 2 - number of LSTM layer + // 1 - batch size + // 64 - hidden dim + std::array shape{2, 1, 64}; + + Ort::Value h = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + + Ort::Value c = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + + Fill(&h, 0); + Fill(&c, 0); + + states_.clear(); + + states_.reserve(2); + states_.push_back(std::move(h)); + states_.push_back(std::move(c)); + + triggered_ = false; + current_sample_ = 0; + temp_start_ = 0; + temp_end_ = 0; + } + + bool IsSpeech(const float *samples, int32_t n) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape = {1, n}; + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(samples), n, + x_shape.data(), x_shape.size()); + + int64_t sr_shape = 1; + Ort::Value sr = + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); + + std::array inputs = {std::move(x), std::move(sr), + std::move(states_[0]), + std::move(states_[1])}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + states_[0] = std::move(out[1]); + states_[1] = std::move(out[2]); + + float prob = out[0].GetTensorData()[0]; + + float threshold = config_.silero_vad.threshold; + + current_sample_ += config_.silero_vad.window_size; + + if (prob > threshold && temp_end_ != 0) { + temp_end_ = 0; + } + + if (prob > threshold && temp_start_ == 0) { + // start speaking, but we constraint that it must satisfy + // min_speech_duration + temp_start_ = current_sample_; + return false; + } + + if (prob > threshold && temp_start_ != 0 && !triggered_) { + if (current_sample_ - temp_start_ < min_speech_samples_) { + return false; + } + + triggered_ = true; + + return true; + } - void Reset() {} + if ((prob < threshold) && !triggered_) { + // silence + return false; + } - bool IsSpeech(const float *samples, int32_t n) { return true; } + if ((prob > threshold - 0.15) && triggered_) { + // speaking + return true; + } + + if ((prob > threshold) && !triggered_) { + // start speaking + triggered_ = true; + + return true; + } + + if ((prob < threshold) && triggered_) { + // stop to speak + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + if (current_sample_ - temp_end_ < min_silence_samples_) { + // continue speaking + return true; + } + // stopped speaking + temp_start_ = 0; + temp_end_ = 0; + triggered_ = false; + return false; + } + + return false; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + Check(); + + Reset(); + } + + void Check() { + if (input_names_.size() != 4) { + SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", + static_cast(input_names_.size())); + exit(-1); + } + + if (input_names_[0] != "input") { + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input", + input_names_[0].c_str()); + exit(-1); + } + + if (input_names_[1] != "sr") { + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: sr", input_names_[1].c_str()); + exit(-1); + } + + if (input_names_[2] != "h") { + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: h", input_names_[2].c_str()); + exit(-1); + } + + if (input_names_[3] != "c") { + SHERPA_ONNX_LOGE("Input[3]: %s. Expected: c", input_names_[3].c_str()); + exit(-1); + } + + // Now for outputs + if (output_names_.size() != 3) { + SHERPA_ONNX_LOGE("Expect 3 outputs. Given: %d", + static_cast(output_names_.size())); + exit(-1); + } + + if (output_names_[0] != "output") { + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output", + output_names_[0].c_str()); + exit(-1); + } + + if (output_names_[1] != "hn") { + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: sr", output_names_[1].c_str()); + exit(-1); + } + + if (output_names_[2] != "cn") { + SHERPA_ONNX_LOGE("Output[2]: %s. Expected: sr", output_names_[2].c_str()); + exit(-1); + } + } private: VadModelConfig config_; + + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector states_; + int64_t sample_rate_; + int32_t min_silence_samples_; + int32_t min_speech_samples_; + + bool triggered_ = false; + int32_t current_sample_ = 0; + int32_t temp_start_ = 0; + int32_t temp_end_ = 0; }; SileroVadModel::SileroVadModel(const VadModelConfig &config) diff --git a/sherpa-onnx/csrc/vad-model-config.cc b/sherpa-onnx/csrc/vad-model-config.cc index 39d45e5d6..3187285fe 100644 --- a/sherpa-onnx/csrc/vad-model-config.cc +++ b/sherpa-onnx/csrc/vad-model-config.cc @@ -18,6 +18,9 @@ void VadModelConfig::Register(ParseOptions *po) { po->Register("vad-provider", &provider, "Specify a provider to run the VAD model. Supported values: " "cpu, cuda, coreml"); + + po->Register("vad-debug", &debug, + "true to display debug information when loading vad models"); } bool VadModelConfig::Validate() const { return silero_vad.Validate(); } @@ -28,7 +31,8 @@ std::string VadModelConfig::ToString() const { os << "VadModelConfig("; os << "silero_vad=" << silero_vad.ToString() << ", "; os << "num_threads=" << num_threads << ", "; - os << "provider=\"" << provider << "\")"; + os << "provider=\"" << provider << "\", "; + os << "debug=" << (debug ? "True" : "False") << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/vad-model-config.h b/sherpa-onnx/csrc/vad-model-config.h index 3bfce0592..e7ba58c30 100644 --- a/sherpa-onnx/csrc/vad-model-config.h +++ b/sherpa-onnx/csrc/vad-model-config.h @@ -17,11 +17,17 @@ struct VadModelConfig { int32_t num_threads = 1; std::string provider = "cpu"; + // true to show debug information when loading models + bool debug = false; + VadModelConfig() = default; VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t num_threads, - const std::string &provider) - : silero_vad(silero_vad), num_threads(num_threads), provider(provider) {} + const std::string &provider, bool debug) + : silero_vad(silero_vad), + num_threads(num_threads), + provider(provider), + debug(debug) {} void Register(ParseOptions *po); bool Validate() const;