From 83d0bc0e234c733e4ac64160d8b30abd20375113 Mon Sep 17 00:00:00 2001 From: hiedean Date: Sat, 18 Nov 2023 20:56:28 +0800 Subject: [PATCH] Replace Clone() with View() --- ...ffline-transducer-modified-beam-search-decoder.cc | 2 +- .../csrc/online-conformer-transducer-model.cc | 2 +- sherpa-onnx/csrc/online-lstm-transducer-model.cc | 2 +- sherpa-onnx/csrc/online-rnn-lm.cc | 8 ++++---- .../csrc/online-transducer-greedy-search-decoder.cc | 8 +++++--- ...online-transducer-modified-beam-search-decoder.cc | 2 +- sherpa-onnx/csrc/online-wenet-ctc-model.cc | 12 ++++++------ .../csrc/online-zipformer-transducer-model.cc | 2 +- .../csrc/online-zipformer2-transducer-model.cc | 2 +- 9 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index e845b31384..142acb4acb 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( // now cur_encoder_out is of shape (num_hyps, joiner_dim) Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); LogSoftmax(p_logit, vocab_size, num_hyps); diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 58cbce01cb..6e7659e79e 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -259,7 +259,7 @@ Ort::Value OnlineConformerTransducerModel::RunDecoder( Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), - std::move(decoder_out)}; + View(&decoder_out)}; auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), joiner_input.size(), joiner_output_names_ptr_.data(), diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 4a0e838da2..799867e48a 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -250,7 +250,7 @@ Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), - std::move(decoder_out)}; + View(&decoder_out)}; auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), joiner_input.size(), joiner_output_names_ptr_.data(), diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 29b150e45b..a26948d55e 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl { return {std::move(out[0]), std::move(next_states)}; } - std::pair> GetInitStates() const { + std::pair> GetInitStates() { std::vector ans; ans.reserve(init_states_.size()); - for (const auto &s : init_states_) { - ans.emplace_back(Clone(allocator_, &s)); + for (auto &s : init_states_) { + ans.emplace_back(View(&s)); } - return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; + return {std::move(View(&init_scores_.value)), std::move(ans)}; } private: diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index c2fc1103da..132aa87d25 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( } if (is_batch_decoder_out_cached) { auto &r = result->front(); - std::vector decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); + std::vector decoder_out_shape = + r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); decoder_out_shape[0] = batch_size; - decoder_out = Ort::Value::CreateTensor(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); + decoder_out = Ort::Value::CreateTensor(model_->Allocator(), + decoder_out_shape.data(), decoder_out_shape.size()); UseCachedDecoderOut(*result, &decoder_out); } else { Ort::Value decoder_input = model_->BuildDecoderInput(*result); @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value cur_encoder_out = GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); const float *p_logit = logit.GetTensorData(); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index a98f19dad4..a02e345036 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); LogSoftmax(p_logit, vocab_size, num_hyps); diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index 5d7e90964c..b340d30544 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -73,9 +73,9 @@ class OnlineWenetCtcModel::Impl { std::array inputs = {std::move(x), View(&offset), View(&required_cache_size_tensor_), - std::move(attn_cache), - std::move(conv_cache), - std::move(attn_mask)}; + View(&attn_cache), + View(&conv_cache), + View(&attn_mask)}; auto out = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl { // - attn_cache // - conv_cache // - offset - std::vector GetInitStates() const { + std::vector GetInitStates() { std::vector ans; ans.reserve(3); - ans.push_back(Clone(Allocator(), &attn_cache_)); - ans.push_back(Clone(Allocator(), &conv_cache_)); + ans.push_back(View(&attn_cache_)); + ans.push_back(View(&conv_cache_)); int64_t offset_shape = 1; diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 31234ae74a..5e8d297b28 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -465,7 +465,7 @@ Ort::Value OnlineZipformerTransducerModel::RunDecoder( Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), - std::move(decoder_out)}; + View(&decoder_out)}; auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), joiner_input.size(), joiner_output_names_ptr_.data(), diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index e818b0bc98..d8044fcbd8 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -454,7 +454,7 @@ Ort::Value OnlineZipformer2TransducerModel::RunDecoder( Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), - std::move(decoder_out)}; + View(&decoder_out)}; auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), joiner_input.size(), joiner_output_names_ptr_.data(),