diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e9bd54cd3..7b9223c62 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -74,6 +74,7 @@ set(sources transpose.cc unbind.cc utils.cc + vad-model-config.cc wave-reader.cc ) diff --git a/sherpa-onnx/csrc/silero-vad-model-config.cc b/sherpa-onnx/csrc/silero-vad-model-config.cc index 9a0034874..1cf5df747 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.cc +++ b/sherpa-onnx/csrc/silero-vad-model-config.cc @@ -9,7 +9,7 @@ namespace sherpa_onnx { -void SilerVadModelConfig::Register(ParseOptions *po) { +void SileroVadModelConfig::Register(ParseOptions *po) { po->Register("silero-vad-model", &model, "Path to silero VAD ONNX model."); po->Register("silero-vad-prob", &prob, @@ -32,7 +32,7 @@ void SilerVadModelConfig::Register(ParseOptions *po) { "perfomance!"); } -bool SilerVadModelConfig::Validate() const { +bool SileroVadModelConfig::Validate() const { if (!FileExists(model)) { SHERPA_ONNX_LOGE("Silero vad model file %s does not exist", model.c_str()); return false; @@ -53,7 +53,7 @@ bool SilerVadModelConfig::Validate() const { return true; } -std::string SilerVadModelConfig::ToString() const { +std::string SileroVadModelConfig::ToString() const { std::ostringstream os; os << "SilerVadModelConfig("; diff --git a/sherpa-onnx/csrc/silero-vad-model-config.h b/sherpa-onnx/csrc/silero-vad-model-config.h index 1091c786b..3c9653dd7 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.h +++ b/sherpa-onnx/csrc/silero-vad-model-config.h @@ -10,7 +10,7 @@ namespace sherpa_onnx { -struct SilerVadModelConfig { +struct SileroVadModelConfig { std::string model; // threshold to classify a segment as speech @@ -25,6 +25,8 @@ struct SilerVadModelConfig { // 256, 512, 768 samples for 800 Hz int window_size = 1536; // in samples + SileroVadModelConfig() = default; + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/vad-model-config.cc b/sherpa-onnx/csrc/vad-model-config.cc new file mode 100644 index 000000000..39d45e5d6 --- /dev/null +++ b/sherpa-onnx/csrc/vad-model-config.cc @@ -0,0 +1,36 @@ +// sherpa-onnx/csrc/vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/vad-model-config.h" + +#include +#include + +namespace sherpa_onnx { + +void VadModelConfig::Register(ParseOptions *po) { + silero_vad.Register(po); + + po->Register("vad-num-threads", &num_threads, + "Number of threads to run the VAD model"); + + po->Register("vad-provider", &provider, + "Specify a provider to run the VAD model. Supported values: " + "cpu, cuda, coreml"); +} + +bool VadModelConfig::Validate() const { return silero_vad.Validate(); } + +std::string VadModelConfig::ToString() const { + std::ostringstream os; + + os << "VadModelConfig("; + os << "silero_vad=" << silero_vad.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/vad-model-config.h b/sherpa-onnx/csrc/vad-model-config.h new file mode 100644 index 000000000..3bfce0592 --- /dev/null +++ b/sherpa-onnx/csrc/vad-model-config.h @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/silero-vad-model-config.h" + +namespace sherpa_onnx { + +struct VadModelConfig { + SileroVadModelConfig silero_vad; + + int32_t num_threads = 1; + std::string provider = "cpu"; + + VadModelConfig() = default; + + VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t num_threads, + const std::string &provider) + : silero_vad(silero_vad), num_threads(num_threads), provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_