Skip to content

Commit

Permalink
fixed deocder method to take states of previous chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
sangeet2020 committed May 23, 2024
1 parent afb10d4 commit 7800cc0
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 49 deletions.
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();
Expand Down
28 changes: 15 additions & 13 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar

#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
Expand All @@ -24,7 +25,6 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"
Expand Down Expand Up @@ -80,6 +80,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}
Expand Down Expand Up @@ -120,27 +121,27 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec.size(), x_shape.data(),
x_shape.size());

std::array<int64_t, 1> processed_frames_shape{
static_cast<int64_t>(all_processed_frames.size())};

Ort::Value processed_frames = Ort::Value::CreateTensor(
memory_info, all_processed_frames.data(), all_processed_frames.size(),
processed_frames_shape.data(), processed_frames_shape.size());

auto states = model_->StackStates(states_vec);

auto [t, ns] = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames));
int32_t num_states = states.size();
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] encoder_out_length, int64 tensor, (batch_size,)

std::vector<Ort::Value> out_states;
out_states.reserve(num_states);

for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}

Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);

// defined in online-transducer-greedy-search-nemo-decoder.h
std::vector<OnlineTransducerDecoderResult> results = decoder_-> Decode(std::move(encoder_out), std::move(t[1]));
decoder_-> Decode(std::move(encoder_out), std::move(t[1]),
std::move(out_states), &results, ss, n);

std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(ns);
model_->UnStackStates(out_states);

for (int32_t i = 0; i != n; ++i) {
ss[i]->SetResult(results[i]);
Expand Down Expand Up @@ -187,6 +188,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}

}

private:
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}

void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
}

std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
return impl_->GetNeMoDecoderStates();
}
Expand Down
84 changes: 62 additions & 22 deletions sherpa-onnx/csrc/online-transducer-nemo-model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-transducer-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar

#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"

Expand Down Expand Up @@ -145,29 +146,51 @@ class OnlineTransducerNeMoModel::Impl {
return ans;
}

std::pair<std::vector<Ort::Value>, std::vector<Ort::Value>> RunEncoder(Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value /* processed_frames */) {
std::vector<Ort::Value> encoder_inputs;
encoder_inputs.reserve(1 + states.size());
std::vector<Ort::Value> RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];

encoder_inputs.push_back(std::move(features));
for (auto &v : states) {
encoder_inputs.push_back(std::move(v));
}
int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0];

std::array<int64_t, 1> length_shape{batch_size};

Ort::Value length = Ort::Value::CreateTensor<int64_t>(
allocator_, length_shape.data(), length_shape.size());

int64_t *p_length = length.GetTensorMutableData<int64_t>();

std::fill(p_length, p_length + batch_size, ChunkSize());

auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
// (B, T, C) -> (B, C, T)
features = Transpose12(allocator_, &features);

std::vector<Ort::Value> next_states;
next_states.reserve(states.size());
std::array<Ort::Value, 5> inputs = {
std::move(features), View(&length), std::move(cache_last_channel),
std::move(cache_last_time), std::move(cache_last_channel_len)};

auto out =
encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
// out[0]: logit
// out[1] logit_length
// out[2:] states_next
//
// we need to remove out[1]

std::vector<Ort::Value> ans;
ans.reserve(out.size() - 1);

for (int32_t i = 0; i != out.size(); ++i) {
if (i == 1) {
continue;
}

for (int32_t i = 1; i != static_cast<int32_t>(encoder_out.size()); ++i) {
next_states.push_back(std::move(encoder_out[i]));
ans.push_back(std::move(out[i]));
}
return {std::move(encoder_out), std::move(next_states)};

return ans;
}

std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
Expand Down Expand Up @@ -250,6 +273,20 @@ class OnlineTransducerNeMoModel::Impl {

std::string FeatureNormalizationMethod() const { return normalize_type_; }

// Return a vector containing 3 tensors
// - cache_last_channel
// - cache_last_time_
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(3);
ans.push_back(View(&cache_last_channel_));
ans.push_back(View(&cache_last_time_));
ans.push_back(View(&cache_last_channel_len_));

return ans;
}

private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
Expand Down Expand Up @@ -409,11 +446,10 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel(

OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default;

std::pair<std::vector<Ort::Value>, std::vector<Ort::Value>>
std::vector<Ort::Value>
OnlineTransducerNeMoModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value processed_frames) const {
return impl_->RunEncoder(std::move(features), std::move(states), std::move(processed_frames));
std::vector<Ort::Value> states) const {
return impl_->RunEncoder(std::move(features), std::move(states));
}

std::pair<Ort::Value, std::vector<Ort::Value>>
Expand Down Expand Up @@ -459,4 +495,8 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}

std::vector<Ort::Value> OnlineTransducerNeMoModel::GetInitStates() const {
return impl_->GetInitStates();
}

} // namespace sherpa_onnx
24 changes: 11 additions & 13 deletions sherpa-onnx/csrc/online-transducer-nemo-model.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// sherpa-onnx/csrc/online-transducer-nemo-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar

#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_

Expand Down Expand Up @@ -58,26 +60,22 @@ class OnlineTransducerNeMoModel {
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const;

// /** Get the initial encoder states.
// *
// * @return Return the initial encoder state.
// */
// std::vector<Ort::Value> GetEncoderInitStates() = 0;
// A list of 3 tensors:
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() const;

/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param states Encoder state of the previous chunk. It is changed in-place.
* @param processed_frames Processed frames before subsampling. It is a 1-D
* tensor with data type int64_t.
*
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a tuple containing:
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
* - next_states Encoder state for the next chunk.
*/
std::pair<std::vector<Ort::Value>, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) const; // NOLINT
std::vector<Ort::Value> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states) const; // NOLINT

/** Run the decoder network.
*
Expand Down

0 comments on commit 7800cc0

Please sign in to comment.