Skip to content

Commit

Permalink
Fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jan 16, 2024
1 parent 46ea698 commit 3d85169
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 45 deletions.
4 changes: 2 additions & 2 deletions .github/scripts/test-kws.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ time $EXE \
--keywords-file=$repo/test_wavs/test_keywords.txt \
--max-active-paths=4 \
--num-threads=4 \
$repo/test_wavs/3.wav $reop/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav
$repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav

rm -rf $repo

Expand All @@ -63,6 +63,6 @@ time $EXE \
--keywords-file=$repo/test_wavs/test_keywords.txt \
--max-active-paths=4 \
--num-threads=4 \
$repo/test_wavs/0.wav $reop/test_wavs/1.wav
$repo/test_wavs/0.wav $repo/test_wavs/1.wav

rm -rf $repo
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ set(sources
stack.cc
symbol-table.cc
text-utils.cc
transducer-keywords-decoder.cc
transducer-keyword-decoder.cc
transpose.cc
unbind.cc
utils.cc
Expand Down
14 changes: 7 additions & 7 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transducer-keywords-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
#include "sherpa-onnx/csrc/utils.h"

namespace sherpa_onnx {

static KeywordResult Convert(const TransducerKeywordsResult &src,
static KeywordResult Convert(const TransducerKeywordResult &src,
const SymbolTable &sym_table, float frame_shift_ms,
int32_t subsampling_factor,
int32_t frames_since_start) {
Expand Down Expand Up @@ -74,7 +74,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

InitKeywords();

decoder_ = std::make_unique<TransducerKeywordsDecoder>(
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
Expand All @@ -91,7 +91,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

InitKeywords(mgr);

decoder_ = std::make_unique<TransducerKeywordsDecoder>(
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
Expand Down Expand Up @@ -188,7 +188,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

int32_t feature_dim = ss[0]->FeatureDim();

std::vector<TransducerKeywordsResult> results(n);
std::vector<TransducerKeywordResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
Expand Down Expand Up @@ -244,7 +244,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
}

KeywordResult GetResult(OnlineStream *s) const override {
TransducerKeywordsResult decoder_result = s->GetKeywordResult(true);
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);

// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
Expand Down Expand Up @@ -313,7 +313,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
std::vector<std::string> keywords_;
ContextGraphPtr keywords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<TransducerKeywordsDecoder> decoder_;
std::unique_ptr<TransducerKeywordDecoder> decoder_;
SymbolTable sym_;
int32_t unk_id_ = -1;
};
Expand Down
14 changes: 7 additions & 7 deletions sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class OnlineStream::Impl {

OnlineTransducerDecoderResult &GetResult() { return result_; }

void SetKeywordResult(const TransducerKeywordsResult &r) {
void SetKeywordResult(const TransducerKeywordResult &r) {
keyword_result_ = r;
}
TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates) {
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
if (remove_duplicates) {
if (!prev_keyword_result_.timestamps.empty() &&
!keyword_result_.timestamps.empty() &&
Expand Down Expand Up @@ -112,9 +112,9 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
int32_t segment_ = 0;
OnlineTransducerDecoderResult result_;
TransducerKeywordsResult prev_keyword_result_;
TransducerKeywordsResult keyword_result_;
TransducerKeywordsResult empty_keyword_result_;
TransducerKeywordResult prev_keyword_result_;
TransducerKeywordResult keyword_result_;
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<float> paraformer_feat_cache_;
Expand Down Expand Up @@ -171,11 +171,11 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}

void OnlineStream::SetKeywordResult(const TransducerKeywordsResult &r) {
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
impl_->SetKeywordResult(r);
}

TransducerKeywordsResult &OnlineStream::GetKeywordResult(
TransducerKeywordResult &OnlineStream::GetKeywordResult(
bool remove_duplicates /*=false*/) {
return impl_->GetKeywordResult(remove_duplicates);
}
Expand Down
8 changes: 4 additions & 4 deletions sherpa-onnx/csrc/online-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/transducer-keywords-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"

namespace sherpa_onnx {

class TransducerKeywordsResult;
class TransducerKeywordResult;
class OnlineStream {
public:
explicit OnlineStream(const FeatureExtractorConfig &config = {},
Expand Down Expand Up @@ -78,8 +78,8 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();

void SetKeywordResult(const TransducerKeywordsResult &r);
TransducerKeywordsResult &GetKeywordResult(bool remove_duplicates = false);
void SetKeywordResult(const TransducerKeywordResult &r);
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);

void SetCtcResult(const OnlineCtcDecoderResult &r);
OnlineCtcDecoderResult &GetCtcResult();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
encoder_out.GetTensorTypeAndShapeInfo().GetShape();

if (encoder_out_shape[0] != result->size()) {
fprintf(stderr,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
//
// Copyright (c) 2023-2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/transducer-keywords-decoder.h"

#include <algorithm>
#include <cmath>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"

namespace sherpa_onnx {

TransducerKeywordsResult TransducerKeywordsDecoder::GetEmptyResult() const {
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
TransducerKeywordsResult r;
TransducerKeywordResult r;
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = blank_id;

Expand All @@ -26,17 +25,17 @@ TransducerKeywordsResult TransducerKeywordsDecoder::GetEmptyResult() const {
return r;
}

void TransducerKeywordsDecoder::Decode(
void TransducerKeywordDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordsResult> *result) {
std::vector<TransducerKeywordResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();

if (encoder_out_shape[0] != result->size()) {
fprintf(stderr,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// Copyright (c) 2023-2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_

#include <string>
#include <utility>
Expand All @@ -14,7 +14,7 @@

namespace sherpa_onnx {

struct TransducerKeywordsResult {
struct TransducerKeywordResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;

Expand All @@ -34,20 +34,20 @@ struct TransducerKeywordsResult {
Hypotheses hyps;
};

class TransducerKeywordsDecoder {
class TransducerKeywordDecoder {
public:
TransducerKeywordsDecoder(OnlineTransducerModel *model,
int32_t max_active_paths,
int32_t num_trailing_blanks, int32_t unk_id)
TransducerKeywordDecoder(OnlineTransducerModel *model,
int32_t max_active_paths,
int32_t num_trailing_blanks, int32_t unk_id)
: model_(model),
max_active_paths_(max_active_paths),
num_trailing_blanks_(num_trailing_blanks),
unk_id_(unk_id) {}

TransducerKeywordsResult GetEmptyResult() const;
TransducerKeywordResult GetEmptyResult() const;

void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordsResult> *result);
std::vector<TransducerKeywordResult> *result);

private:
OnlineTransducerModel *model_; // Not owned
Expand All @@ -59,4 +59,4 @@ class TransducerKeywordsDecoder {

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORDS_DECODER_H_
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_

0 comments on commit 3d85169

Please sign in to comment.