From 4f04fb843b32160d8adb622ce19c0d2600a00c80 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Wed, 13 Mar 2024 16:35:48 +0100 Subject: [PATCH] surface dithering constant, 0.0 disables dithering - currently, dithering is not yet implemented in https://github.com/csukuangfj/kaldi-native-fbank - i can port it there from kaldi --- cmake/kaldi-native-fbank.cmake | 1 + sherpa-onnx/csrc/features.cc | 14 +++++++++++--- sherpa-onnx/csrc/features.h | 8 ++++++++ sherpa-onnx/csrc/keyword-spotter-transducer-impl.h | 4 ++++ .../csrc/online-recognizer-transducer-impl.h | 7 ++++--- .../csrc/online-zipformer2-transducer-model.h | 4 +++- sherpa-onnx/python/csrc/features.cc | 6 ++++-- .../python/sherpa_onnx/online_recognizer.py | 7 +++++++ 8 files changed, 42 insertions(+), 9 deletions(-) diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index de2b854333..040d180a8e 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,6 +1,7 @@ function(download_kaldi_native_fbank) include(FetchContent) + # TODO: update is required, so that dithering works... (it was missing) set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz") set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz") set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e") diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 90fc55772e..7e510361a8 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -30,7 +30,14 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { "Low cutoff frequency for mel bins"); po->Register("high-freq", &high_freq, - "High cutoff frequency for mel bins (if <= 0, offset from Nyquist)"); + "High cutoff frequency for mel bins " + "(if <= 0, offset from Nyquist)"); + + po->Register("dither", &dither, + "Dithering constant (0.0 means no dither). " + "By default the audio samples are in range [-1,+1], " + "so 0.00003 is a good value, " + "equivalent to the default 1.0 from kaldi"); } std::string FeatureExtractorConfig::ToString() const { @@ -40,7 +47,8 @@ std::string FeatureExtractorConfig::ToString() const { os << "sampling_rate=" << sampling_rate << ", "; os << "feature_dim=" << feature_dim << ", "; os << "low_freq=" << low_freq << ", "; - os << "high_freq=" << high_freq << ")"; + os << "high_freq=" << high_freq << ", "; + os << "dither=" << dither << ")"; return os.str(); } @@ -48,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const { class FeatureExtractor::Impl { public: explicit Impl(const FeatureExtractorConfig &config) : config_(config) { - opts_.frame_opts.dither = 0; + opts_.frame_opts.dither = config.dither; opts_.frame_opts.snip_edges = config.snip_edges; opts_.frame_opts.samp_freq = config.sampling_rate; opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index d03fbaa0ca..68f4348ed0 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -34,6 +34,14 @@ struct FeatureExtractorConfig { // https://github.com/k2-fsa/sherpa-onnx/issues/514 float high_freq = -400.0f; + // dithering constant, useful for signals with hard-zeroes in non-speech parts + // this prevents large negative values in log-mel filterbanks + // + // In k2, audio samples are in range [-1..+1], in kaldi the range was [-32k..+32k], + // so the value 0.00003 is equivalent to kaldi default 1.0 + // + float dither = 0.0f; // dithering disabled by default + // Set internally by some models, e.g., paraformer sets it to false. // This parameter is not exposed to users from the commandline // If true, the feature extractor expects inputs to be normalized to diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index ef22a99848..a8a8242ccb 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { unk_id_ = sym_[""]; } + model_->SetFeatureDim(config.feat_config.feature_dim); + InitKeywords(); decoder_ = std::make_unique( @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { unk_id_ = sym_[""]; } + model_->SetFeatureDim(config.feat_config.feature_dim); + InitKeywords(mgr); decoder_ = std::make_unique( diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 3fc018daa4..0fa3acac4e 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -86,13 +86,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { - - model_->SetFeatureDim(config.feat_config.feature_dim); - if (sym_.contains("")) { unk_id_ = sym_[""]; } + model_->SetFeatureDim(config.feat_config.feature_dim); + if (config.decoding_method == "modified_beam_search") { if (!config_.hotwords_file.empty()) { InitHotwords(); @@ -126,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { unk_id_ = sym_[""]; } + model_->SetFeatureDim(config.feat_config.feature_dim); + if (config.decoding_method == "modified_beam_search") { #if 0 // TODO(fangjun): Implement it diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index 4c10cba7dc..acad451702 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -37,7 +37,9 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { std::vector GetEncoderInitStates() override; - void SetFeatureDim(int32_t feature_dim) override { feature_dim_ = feature_dim; } + void SetFeatureDim(int32_t feature_dim) override { + feature_dim_ = feature_dim; + } std::pair> RunEncoder( Ort::Value features, std::vector states, diff --git a/sherpa-onnx/python/csrc/features.cc b/sherpa-onnx/python/csrc/features.cc index 106a8e0d37..333c6b6758 100644 --- a/sherpa-onnx/python/csrc/features.cc +++ b/sherpa-onnx/python/csrc/features.cc @@ -11,15 +11,17 @@ namespace sherpa_onnx { static void PybindFeatureExtractorConfig(py::module *m) { using PyClass = FeatureExtractorConfig; py::class_(*m, "FeatureExtractorConfig") - .def(py::init(), + .def(py::init(), py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, py::arg("low_freq") = 20.0f, - py::arg("high_freq") = -400.0f) + py::arg("high_freq") = -400.0f, + py::arg("dither") = 0.0f) .def_readwrite("sampling_rate", &PyClass::sampling_rate) .def_readwrite("feature_dim", &PyClass::feature_dim) .def_readwrite("low_freq", &PyClass::low_freq) .def_readwrite("high_freq", &PyClass::high_freq) + .def_readwrite("dither", &PyClass::high_freq) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index a5350d90b6..1050433999 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -43,6 +43,7 @@ def from_transducer( feature_dim: int = 80, low_freq: float = 20.0, high_freq: float = -400.0, + dither: float = 0.0, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, @@ -87,6 +88,11 @@ def from_transducer( high_freq: High cutoff frequency for mel bins in feature extraction (if <= 0, offset from Nyquist) + dither: + Dithering constant (0.0 means no dither). + By default the audio samples are in range [-1,+1], + so dithering constant 0.00003 is a good value, + equivalent to the default 1.0 from kaldi enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. @@ -149,6 +155,7 @@ def from_transducer( feature_dim=feature_dim, low_freq=low_freq, high_freq=high_freq, + dither=dither, ) endpoint_config = EndpointConfig(