From 3d8516930f2d27ee2d8a3f336ec31da81f1308cd Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 16 Jan 2024 10:14:04 +0800 Subject: [PATCH] Fix code style --- .github/scripts/test-kws.sh | 4 ++-- sherpa-onnx/csrc/CMakeLists.txt | 2 +- .../csrc/keyword-spotter-transducer-impl.h | 14 ++++++------- sherpa-onnx/csrc/online-stream.cc | 14 ++++++------- sherpa-onnx/csrc/online-stream.h | 8 ++++---- ...transducer-modified-beam-search-decoder.cc | 8 ++++---- ...coder.cc => transducer-keyword-decoder.cc} | 19 +++++++++--------- ...decoder.h => transducer-keyword-decoder.h} | 20 +++++++++---------- 8 files changed, 44 insertions(+), 45 deletions(-) rename sherpa-onnx/csrc/{transducer-keywords-decoder.cc => transducer-keyword-decoder.cc} (92%) rename sherpa-onnx/csrc/{transducer-keywords-decoder.h => transducer-keyword-decoder.h} (68%) diff --git a/.github/scripts/test-kws.sh b/.github/scripts/test-kws.sh index 232deb6aa..710a193fc 100755 --- a/.github/scripts/test-kws.sh +++ b/.github/scripts/test-kws.sh @@ -36,7 +36,7 @@ time $EXE \ --keywords-file=$repo/test_wavs/test_keywords.txt \ --max-active-paths=4 \ --num-threads=4 \ - $repo/test_wavs/3.wav $reop/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav + $repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav rm -rf $repo @@ -63,6 +63,6 @@ time $EXE \ --keywords-file=$repo/test_wavs/test_keywords.txt \ --max-active-paths=4 \ --num-threads=4 \ - $repo/test_wavs/0.wav $reop/test_wavs/1.wav + $repo/test_wavs/0.wav $repo/test_wavs/1.wav rm -rf $repo diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 7fe5417f4..efb8e5d2b 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -89,7 +89,7 @@ set(sources stack.cc symbol-table.cc text-utils.cc - transducer-keywords-decoder.cc + transducer-keyword-decoder.cc transpose.cc unbind.cc utils.cc diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index fb1d02257..ce708ccd4 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -25,12 +25,12 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/symbol-table.h" -#include "sherpa-onnx/csrc/transducer-keywords-decoder.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" #include "sherpa-onnx/csrc/utils.h" namespace sherpa_onnx { -static KeywordResult Convert(const TransducerKeywordsResult &src, +static KeywordResult Convert(const TransducerKeywordResult &src, const SymbolTable &sym_table, float frame_shift_ms, int32_t subsampling_factor, int32_t frames_since_start) { @@ -74,7 +74,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { InitKeywords(); - decoder_ = std::make_unique( + decoder_ = std::make_unique( model_.get(), config_.max_active_paths, config_.num_trailing_blanks, unk_id_); } @@ -91,7 +91,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { InitKeywords(mgr); - decoder_ = std::make_unique( + decoder_ = std::make_unique( model_.get(), config_.max_active_paths, config_.num_trailing_blanks, unk_id_); } @@ -188,7 +188,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { int32_t feature_dim = ss[0]->FeatureDim(); - std::vector results(n); + std::vector results(n); std::vector features_vec(n * chunk_size * feature_dim); std::vector> states_vec(n); std::vector all_processed_frames(n); @@ -244,7 +244,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { } KeywordResult GetResult(OnlineStream *s) const override { - TransducerKeywordsResult decoder_result = s->GetKeywordResult(true); + TransducerKeywordResult decoder_result = s->GetKeywordResult(true); // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; @@ -313,7 +313,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector keywords_; ContextGraphPtr keywords_graph_; std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; SymbolTable sym_; int32_t unk_id_ = -1; }; diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 6a74491ff..aaddfb545 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -51,10 +51,10 @@ class OnlineStream::Impl { OnlineTransducerDecoderResult &GetResult() { return result_; } - void SetKeywordResult(const TransducerKeywordsResult &r) { + void SetKeywordResult(const TransducerKeywordResult &r) { keyword_result_ = r; } - TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates) { + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) { if (remove_duplicates) { if (!prev_keyword_result_.timestamps.empty() && !keyword_result_.timestamps.empty() && @@ -112,9 +112,9 @@ class OnlineStream::Impl { int32_t start_frame_index_ = 0; // never reset int32_t segment_ = 0; OnlineTransducerDecoderResult result_; - TransducerKeywordsResult prev_keyword_result_; - TransducerKeywordsResult keyword_result_; - TransducerKeywordsResult empty_keyword_result_; + TransducerKeywordResult prev_keyword_result_; + TransducerKeywordResult keyword_result_; + TransducerKeywordResult empty_keyword_result_; OnlineCtcDecoderResult ctc_result_; std::vector states_; // states for transducer or ctc models std::vector paraformer_feat_cache_; @@ -171,11 +171,11 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { return impl_->GetResult(); } -void OnlineStream::SetKeywordResult(const TransducerKeywordsResult &r) { +void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) { impl_->SetKeywordResult(r); } -TransducerKeywordsResult &OnlineStream::GetKeywordResult( +TransducerKeywordResult &OnlineStream::GetKeywordResult( bool remove_duplicates /*=false*/) { return impl_->GetKeywordResult(remove_duplicates); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 3cfab81a6..175a5f719 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -14,11 +14,11 @@ #include "sherpa-onnx/csrc/online-ctc-decoder.h" #include "sherpa-onnx/csrc/online-paraformer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" -#include "sherpa-onnx/csrc/transducer-keywords-decoder.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" namespace sherpa_onnx { -class TransducerKeywordsResult; +class TransducerKeywordResult; class OnlineStream { public: explicit OnlineStream(const FeatureExtractorConfig &config = {}, @@ -78,8 +78,8 @@ class OnlineStream { void SetResult(const OnlineTransducerDecoderResult &r); OnlineTransducerDecoderResult &GetResult(); - void SetKeywordResult(const TransducerKeywordsResult &r); - TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates = false); + void SetKeywordResult(const TransducerKeywordResult &r); + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false); void SetCtcResult(const OnlineCtcDecoderResult &r); OnlineCtcDecoderResult &GetCtcResult(); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 476428aa9..2694b4bd1 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( encoder_out.GetTensorTypeAndShapeInfo().GetShape(); if (encoder_out_shape[0] != result->size()) { - fprintf(stderr, - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", - static_cast(encoder_out_shape[0]), - static_cast(result->size())); + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); exit(-1); } diff --git a/sherpa-onnx/csrc/transducer-keywords-decoder.cc b/sherpa-onnx/csrc/transducer-keyword-decoder.cc similarity index 92% rename from sherpa-onnx/csrc/transducer-keywords-decoder.cc rename to sherpa-onnx/csrc/transducer-keyword-decoder.cc index 2bb33d33a..ef8314ed8 100644 --- a/sherpa-onnx/csrc/transducer-keywords-decoder.cc +++ b/sherpa-onnx/csrc/transducer-keyword-decoder.cc @@ -2,8 +2,6 @@ // // Copyright (c) 2023-2024 Xiaomi Corporation -#include "sherpa-onnx/csrc/transducer-keywords-decoder.h" - #include #include #include @@ -11,13 +9,14 @@ #include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" namespace sherpa_onnx { -TransducerKeywordsResult TransducerKeywordsDecoder::GetEmptyResult() const { +TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 - TransducerKeywordsResult r; + TransducerKeywordResult r; std::vector blanks(context_size, -1); blanks.back() = blank_id; @@ -26,17 +25,17 @@ TransducerKeywordsResult TransducerKeywordsDecoder::GetEmptyResult() const { return r; } -void TransducerKeywordsDecoder::Decode( +void TransducerKeywordDecoder::Decode( Ort::Value encoder_out, OnlineStream **ss, - std::vector *result) { + std::vector *result) { std::vector encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); if (encoder_out_shape[0] != result->size()) { - fprintf(stderr, - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", - static_cast(encoder_out_shape[0]), - static_cast(result->size())); + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); exit(-1); } diff --git a/sherpa-onnx/csrc/transducer-keywords-decoder.h b/sherpa-onnx/csrc/transducer-keyword-decoder.h similarity index 68% rename from sherpa-onnx/csrc/transducer-keywords-decoder.h rename to sherpa-onnx/csrc/transducer-keyword-decoder.h index 63bfa6176..a352abcc8 100644 --- a/sherpa-onnx/csrc/transducer-keywords-decoder.h +++ b/sherpa-onnx/csrc/transducer-keyword-decoder.h @@ -2,8 +2,8 @@ // // Copyright (c) 2023-2024 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_ -#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_ +#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ +#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ #include #include @@ -14,7 +14,7 @@ namespace sherpa_onnx { -struct TransducerKeywordsResult { +struct TransducerKeywordResult { /// Number of frames after subsampling we have decoded so far int32_t frame_offset = 0; @@ -34,20 +34,20 @@ struct TransducerKeywordsResult { Hypotheses hyps; }; -class TransducerKeywordsDecoder { +class TransducerKeywordDecoder { public: - TransducerKeywordsDecoder(OnlineTransducerModel *model, - int32_t max_active_paths, - int32_t num_trailing_blanks, int32_t unk_id) + TransducerKeywordDecoder(OnlineTransducerModel *model, + int32_t max_active_paths, + int32_t num_trailing_blanks, int32_t unk_id) : model_(model), max_active_paths_(max_active_paths), num_trailing_blanks_(num_trailing_blanks), unk_id_(unk_id) {} - TransducerKeywordsResult GetEmptyResult() const; + TransducerKeywordResult GetEmptyResult() const; void Decode(Ort::Value encoder_out, OnlineStream **ss, - std::vector *result); + std::vector *result); private: OnlineTransducerModel *model_; // Not owned @@ -59,4 +59,4 @@ class TransducerKeywordsDecoder { } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_ +#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_