Skip to content

Commit

Permalink
fix the encoder states dimension mismatch
Browse files Browse the repository at this point in the history
- we use less features 80->56
- we pass `feature_dim` to `OnlineZipformer2TransducerModel::GetEncoderInitStates()` via new method `OnlineTransducerModel::SetFeatureDim(.)`
- the value 19 corresponding to 80-dim features was previously hard-coded, in this PR it is imported from the `FeatureExtractorConfig`
  • Loading branch information
KarelVesely84 committed Mar 13, 2024
1 parent 25cec5c commit 20f7aca
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 5 deletions.
8 changes: 4 additions & 4 deletions sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");

po->Regiseter("low-freq", &low_freq,
"Low cutoff frequency for mel bins");
po->Register("low-freq", &low_freq,
"Low cutoff frequency for mel bins");

po->Regiseter("high-freq", &high_freq,
"High cutoff frequency for mel bins (if <= 0, offset from Nyquist)");
po->Register("high-freq", &high_freq,
"High cutoff frequency for mel bins (if <= 0, offset from Nyquist)");
}

std::string FeatureExtractorConfig::ToString() const {
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {

model_->SetFeatureDim(config.feat_config.feature_dim);

if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/online-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class OnlineTransducerModel {
*/
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;

/** Set feature dim.
*
* This is used in `OnlineZipformer2TransducerModel`,
* to pass `feature_dim` for `GetEncoderInitStates()`.
*
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
* init state has the correct `embed_dim` of its output.
*/
virtual void SetFeatureDim(int32_t feature_dim) { }

/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() {
}

{
std::array<int64_t, 4> s{1, 128, 3, 19};
SHERPA_ONNX_CHECK_NE(feature_dim_, 0);
int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2;
std::array<int64_t, 4> s{1, 128, 3, embed_dim};

auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
ans.push_back(std::move(v));
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/online-zipformer2-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {

std::vector<Ort::Value> GetEncoderInitStates() override;

void SetFeatureDim(int32_t feature_dim) override { feature_dim_ = feature_dim; }

std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
Expand Down Expand Up @@ -101,6 +103,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {

int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
int32_t feature_dim_ = 0;
};

} // namespace sherpa_onnx
Expand Down

0 comments on commit 20f7aca

Please sign in to comment.