diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index d5bebd58f..024f236ff 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -166,7 +166,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { } KeywordResult GetResult(OnlineStream *s) const override { - TransducerKeywordsResult decoder_result = s->GetKeywordResult(); + TransducerKeywordsResult decoder_result = s->GetKeywordResult(true); // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index e92c02531..6a74491ff 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -54,7 +54,21 @@ class OnlineStream::Impl { void SetKeywordResult(const TransducerKeywordsResult &r) { keyword_result_ = r; } - TransducerKeywordsResult &GetKeywordResult() { return keyword_result_; } + TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates) { + if (remove_duplicates) { + if (!prev_keyword_result_.timestamps.empty() && + !keyword_result_.timestamps.empty() && + keyword_result_.timestamps[0] <= + prev_keyword_result_.timestamps.back()) { + return empty_keyword_result_; + } else { + prev_keyword_result_ = keyword_result_; + } + return keyword_result_; + } else { + return keyword_result_; + } + } OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } @@ -98,7 +112,9 @@ class OnlineStream::Impl { int32_t start_frame_index_ = 0; // never reset int32_t segment_ = 0; OnlineTransducerDecoderResult result_; + TransducerKeywordsResult prev_keyword_result_; TransducerKeywordsResult keyword_result_; + TransducerKeywordsResult empty_keyword_result_; OnlineCtcDecoderResult ctc_result_; std::vector states_; // states for transducer or ctc models std::vector paraformer_feat_cache_; @@ -159,8 +175,9 @@ void OnlineStream::SetKeywordResult(const TransducerKeywordsResult &r) { impl_->SetKeywordResult(r); } -TransducerKeywordsResult &OnlineStream::GetKeywordResult() { - return impl_->GetKeywordResult(); +TransducerKeywordsResult &OnlineStream::GetKeywordResult( + bool remove_duplicates /*=false*/) { + return impl_->GetKeywordResult(remove_duplicates); } OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index df65ca669..3cfab81a6 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -79,7 +79,7 @@ class OnlineStream { OnlineTransducerDecoderResult &GetResult(); void SetKeywordResult(const TransducerKeywordsResult &r); - TransducerKeywordsResult &GetKeywordResult(); + TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates = false); void SetCtcResult(const OnlineCtcDecoderResult &r); OnlineCtcDecoderResult &GetCtcResult(); diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc index 076d43eef..f28239683 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc @@ -4,7 +4,6 @@ #include -#include // NOLINT #include #include #include @@ -18,8 +17,7 @@ typedef struct { std::unique_ptr online_stream; - float duration; - float elapsed_seconds; + std::string filename; } Stream; int main(int32_t argc, char *argv[]) { @@ -35,37 +33,12 @@ int main(int32_t argc, char *argv[]) { --joiner=/path/to/joiner.onnx \ --provider=cpu \ --num-threads=2 \ - --decoding-method=greedy_search \ + --keywords-file=keywords.txt \ /path/to/foo.wav [bar.wav foobar.wav ...] -(2) Streaming zipformer2 CTC - - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 - tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 - - ./bin/sherpa-onnx \ - --debug=1 \ - --zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ - --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ - ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ - ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ - ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav - -(3) Streaming paraformer - - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 - tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 - - ./bin/sherpa-onnx \ - --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ - --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \ - --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \ - ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav - Note: It supports decoding multiple files in batches Default value for num_threads is 2. -Valid values for decoding_method: greedy_search (default), modified_beam_search. Valid values for provider: cpu (default), cuda, coreml. foo.wav should be of single channel, 16-bit PCM encoded wave file; its sampling rate can be arbitrary and does not need to be 16kHz. @@ -97,9 +70,6 @@ for a list of pre-trained models to download. std::vector ss; - const auto begin = std::chrono::steady_clock::now(); - std::vector durations; - for (int32_t i = 1; i <= po.NumArgs(); ++i) { const std::string wav_filename = po.GetArg(i); int32_t sampling_rate = -1; @@ -113,8 +83,6 @@ for a list of pre-trained models to download. return -1; } - const float duration = samples.size() / static_cast(sampling_rate); - auto s = keywordspotter.CreateStream(); s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); @@ -125,7 +93,7 @@ for a list of pre-trained models to download. // Call InputFinished() to indicate that no audio samples are available s->InputFinished(); - ss.push_back({std::move(s), duration, 0}); + ss.push_back({std::move(s), wav_filename}); } std::vector ready_streams; @@ -135,37 +103,20 @@ for a list of pre-trained models to download. const auto p_ss = s.online_stream.get(); if (keywordspotter.IsReady(p_ss)) { ready_streams.push_back(p_ss); - } else if (s.elapsed_seconds == 0) { - const auto end = std::chrono::steady_clock::now(); - const float elapsed_seconds = - std::chrono::duration_cast(end - begin) - .count() / - 1000.; - s.elapsed_seconds = elapsed_seconds; + } + std::ostringstream os; + const auto r = keywordspotter.GetResult(p_ss); + if (!r.keyword.empty()) { + os << s.filename << "\n"; + os << r.AsJsonString() << "\n\n"; + std::cerr << os.str(); } } if (ready_streams.empty()) { break; } - keywordspotter.DecodeStreams(ready_streams.data(), ready_streams.size()); } - - std::ostringstream os; - for (int32_t i = 1; i <= po.NumArgs(); ++i) { - const auto &s = ss[i - 1]; - const float rtf = s.elapsed_seconds / s.duration; - - os << po.GetArg(i) << "\n"; - os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds - << ", Real time factor (RTF): " << rtf << "\n"; - const auto r = keywordspotter.GetResult(s.online_stream.get()); - os << r.keyword << "\n"; - os << r.AsJsonString() << "\n\n"; - } - - std::cerr << os.str(); - return 0; }