diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 176ff5c80..406dd69b8 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -7,7 +7,7 @@ namespace Generators { Logits::Logits(const Model& model, State& state) : model_{model}, state_{state}, - shape_{state_.search_params_.batch_size * state_.search_params_.num_beams, state_.search_params_.sequence_length, state_.search_params_.vocab_size}, + shape_{static_cast(state_.search_params_.batch_size) * state_.search_params_.num_beams, state_.search_params_.sequence_length, state_.search_params_.vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { value_ = OrtValue::CreateTensor(*model.allocator_device_, shape_, type_);