Skip to content

Commit

Permalink
Replace Clone() with View()
Browse files Browse the repository at this point in the history
  • Loading branch information
hiedean committed Nov 17, 2023
1 parent 1a6a41e commit 3467c5b
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-conformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ Ort::Value OnlineConformerTransducerModel::RunDecoder(
Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-lstm-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));

const float *p_logit = logit.GetTensorData<float>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));

float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-zipformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ Ort::Value OnlineZipformerTransducerModel::RunDecoder(
Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Ort::Value OnlineZipformer2TransducerModel::RunDecoder(
Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down

0 comments on commit 3467c5b

Please sign in to comment.