Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Dec 29, 2023
1 parent 539dd71 commit d6124dc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 64 deletions.
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
}

KeywordResult GetResult(OnlineStream *s) const override {
TransducerKeywordsResult decoder_result = s->GetKeywordResult();
TransducerKeywordsResult decoder_result = s->GetKeywordResult(true);

// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
Expand Down
23 changes: 20 additions & 3 deletions sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,21 @@ class OnlineStream::Impl {
void SetKeywordResult(const TransducerKeywordsResult &r) {
keyword_result_ = r;
}
TransducerKeywordsResult &GetKeywordResult() { return keyword_result_; }
TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates) {
if (remove_duplicates) {
if (!prev_keyword_result_.timestamps.empty() &&
!keyword_result_.timestamps.empty() &&
keyword_result_.timestamps[0] <=
prev_keyword_result_.timestamps.back()) {
return empty_keyword_result_;
} else {
prev_keyword_result_ = keyword_result_;
}
return keyword_result_;
} else {
return keyword_result_;
}
}

OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }

Expand Down Expand Up @@ -98,7 +112,9 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
int32_t segment_ = 0;
OnlineTransducerDecoderResult result_;
TransducerKeywordsResult prev_keyword_result_;
TransducerKeywordsResult keyword_result_;
TransducerKeywordsResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<float> paraformer_feat_cache_;
Expand Down Expand Up @@ -159,8 +175,9 @@ void OnlineStream::SetKeywordResult(const TransducerKeywordsResult &r) {
impl_->SetKeywordResult(r);
}

TransducerKeywordsResult &OnlineStream::GetKeywordResult() {
return impl_->GetKeywordResult();
TransducerKeywordsResult &OnlineStream::GetKeywordResult(
bool remove_duplicates /*=false*/) {
return impl_->GetKeywordResult(remove_duplicates);
}

OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class OnlineStream {
OnlineTransducerDecoderResult &GetResult();

void SetKeywordResult(const TransducerKeywordsResult &r);
TransducerKeywordsResult &GetKeywordResult();
TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates = false);

void SetCtcResult(const OnlineCtcDecoderResult &r);
OnlineCtcDecoderResult &GetCtcResult();
Expand Down
69 changes: 10 additions & 59 deletions sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <stdio.h>

#include <chrono> // NOLINT
#include <iomanip>
#include <iostream>
#include <string>
Expand All @@ -18,8 +17,7 @@

typedef struct {
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
float duration;
float elapsed_seconds;
std::string filename;
} Stream;

int main(int32_t argc, char *argv[]) {
Expand All @@ -35,37 +33,12 @@ int main(int32_t argc, char *argv[]) {
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=2 \
--decoding-method=greedy_search \
--keywords-file=keywords.txt \
/path/to/foo.wav [bar.wav foobar.wav ...]
(2) Streaming zipformer2 CTC
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
./bin/sherpa-onnx \
--debug=1 \
--zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \
./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav
(3) Streaming paraformer
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
./bin/sherpa-onnx \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Valid values for provider: cpu (default), cuda, coreml.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Expand Down Expand Up @@ -97,9 +70,6 @@ for a list of pre-trained models to download.

std::vector<Stream> ss;

const auto begin = std::chrono::steady_clock::now();
std::vector<float> durations;

for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
Expand All @@ -113,8 +83,6 @@ for a list of pre-trained models to download.
return -1;
}

const float duration = samples.size() / static_cast<float>(sampling_rate);

auto s = keywordspotter.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());

Expand All @@ -125,7 +93,7 @@ for a list of pre-trained models to download.

// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({std::move(s), duration, 0});
ss.push_back({std::move(s), wav_filename});
}

std::vector<sherpa_onnx::OnlineStream *> ready_streams;
Expand All @@ -135,37 +103,20 @@ for a list of pre-trained models to download.
const auto p_ss = s.online_stream.get();
if (keywordspotter.IsReady(p_ss)) {
ready_streams.push_back(p_ss);
} else if (s.elapsed_seconds == 0) {
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
s.elapsed_seconds = elapsed_seconds;
}
std::ostringstream os;
const auto r = keywordspotter.GetResult(p_ss);
if (!r.keyword.empty()) {
os << s.filename << "\n";
os << r.AsJsonString() << "\n\n";
std::cerr << os.str();
}
}

if (ready_streams.empty()) {
break;
}

keywordspotter.DecodeStreams(ready_streams.data(), ready_streams.size());
}

std::ostringstream os;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const auto &s = ss[i - 1];
const float rtf = s.elapsed_seconds / s.duration;

os << po.GetArg(i) << "\n";
os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
<< ", Real time factor (RTF): " << rtf << "\n";
const auto r = keywordspotter.GetResult(s.online_stream.get());
os << r.keyword << "\n";
os << r.AsJsonString() << "\n\n";
}

std::cerr << os.str();

return 0;
}

0 comments on commit d6124dc

Please sign in to comment.