Skip to content

Commit

Permalink
Merge branch 'main' into kvaishnavi/rotemb-in-gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Apr 2, 2024
2 parents 39765f1 + 6ad63e1 commit a6de684
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 26 deletions.
8 changes: 8 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ SessionInfo::SessionInfo(OrtSession& session) {
}
}

bool SessionInfo::HasInput(const std::string& name) const {
return inputs_.find(name) != inputs_.end();
}

bool SessionInfo::HasOutput(const std::string& name) const {
return outputs_.find(name) != outputs_.end();
}

ONNXTensorElementDataType SessionInfo::GetInputDataType(const std::string& name) const {
auto result = inputs_.find(name);
if (result == inputs_.end())
Expand Down
3 changes: 3 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct Tokenizer : std::enable_shared_from_this<Tokenizer> {
struct SessionInfo {
SessionInfo(OrtSession& session);

bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const;

ONNXTensorElementDataType GetInputDataType(const std::string& name) const;
ONNXTensorElementDataType GetOutputDataType(const std::string& name) const;

Expand Down
58 changes: 32 additions & 26 deletions src/models/position_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ namespace Generators {
PositionIDs::PositionIDs(const Model& model, State& state, RoamingArray<int32_t>& sequence_lengths_unk)
: model_{model},
state_{state} {
type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.position_ids);
has_position_ids_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids);
type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask);

if (type_ != Ort::TypeToTensorType<int32_t>::type && type_ != Ort::TypeToTensorType<int64_t>::type)
throw std::runtime_error("position_ids & attention_mask only support int32 or int64 types");

Expand All @@ -33,38 +35,42 @@ PositionIDs::PositionIDs(const Model& model, State& state, RoamingArray<int32_t>
void PositionIDs::Add() {
input_index_ = state_.inputs_.size();

state_.inputs_.push_back(position_ids_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str());
if (has_position_ids_) {
state_.inputs_.push_back(position_ids_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str());
}

state_.inputs_.push_back(attention_mask_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str());
}

void PositionIDs::Update(int current_length) {
// Reallocate position_ids for the 2nd and onward shape
if (position_ids_next_) {
position_ids_ = std::move(position_ids_next_);
position_ids_shape_[1] = 1;
state_.inputs_[input_index_] = position_ids_.get();
} else { // Just incrementing existing position IDs
switch (model_.device_type_) {
case DeviceType::CPU: {
if (type_ == Ort::TypeToTensorType<int32_t>::type)
UpdatePositionIDs<int32_t>();
else
UpdatePositionIDs<int64_t>();
break;
}
if (has_position_ids_) {
// Reallocate position_ids for the 2nd and onward shape
if (position_ids_next_) {
position_ids_ = std::move(position_ids_next_);
position_ids_shape_[1] = 1;
state_.inputs_[input_index_] = position_ids_.get();
} else { // Just incrementing existing position IDs
switch (model_.device_type_) {
case DeviceType::CPU: {
if (type_ == Ort::TypeToTensorType<int32_t>::type)
UpdatePositionIDs<int32_t>();
else
UpdatePositionIDs<int64_t>();
break;
}
#if USE_CUDA
case DeviceType::CUDA:
if (type_ == Ort::TypeToTensorType<int32_t>::type)
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int32_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
else
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
break;
case DeviceType::CUDA:
if (type_ == Ort::TypeToTensorType<int32_t>::type)
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int32_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
else
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
break;
#endif
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}
}
}

Expand Down Expand Up @@ -95,7 +101,7 @@ void PositionIDs::Update(int current_length) {
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}
attention_mask_ = std::move(next_attention_mask);
state_.inputs_[input_index_ + 1] = attention_mask_.get();
state_.inputs_[input_index_ + has_position_ids_] = attention_mask_.get();
}
}

Expand Down
1 change: 1 addition & 0 deletions src/models/position_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct PositionIDs {
State& state_;
size_t input_index_{~0U};
ONNXTensorElementDataType type_; // Common type for position_ids and attention_mask
bool has_position_ids_;

std::array<int64_t, 2> position_ids_shape_{}; // {params.batch_size*params.beam_size, params.sequence_length}
std::unique_ptr<OrtValue> position_ids_;
Expand Down

0 comments on commit a6de684

Please sign in to comment.