Skip to content

Commit

Permalink
All model tensor inputs/outputs are now auto detected (#126)
Browse files Browse the repository at this point in the history
Removed templates from Model PositionIDs/InputIDs handlers

Remove logits type and kv type from genai-config.json
Config file will output a warning that the logits_type and kv_type are
deprecated. We'll remove all traces once we update builder.py

Also detects requesting CUDA device type when not built for CUDA and
report an early error
  • Loading branch information
RyanUnderhill authored Feb 27, 2024
1 parent b2400cb commit 538cad3
Show file tree
Hide file tree
Showing 22 changed files with 224 additions and 139 deletions.
4 changes: 0 additions & 4 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ struct Model_Element : JSON::Element {
void OnString(std::string_view name, std::string_view value) override {
if (name == "type") {
v_.type = value;
} else if (name == "logits_type") {
v_.logits_type = TranslateTensorType(value);
} else if (name == "kv_type") {
v_.kv_type = TranslateTensorType(value);
} else
throw JSON::unknown_value_error{};
}
Expand Down
3 changes: 0 additions & 3 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ struct Config {
struct Model {
std::string type;

ONNXTensorElementDataType logits_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; // float16/float32 are the valid types
ONNXTensorElementDataType kv_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; // float16/float32 are the valid types

int pad_token_id{}; // The id of the padding token.
int eos_token_id{}; // The id of the end-of-stream token.
int bos_token_id{}; // The id of the beginning-of-stream token.
Expand Down
6 changes: 4 additions & 2 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,18 @@ void GeneratorParams::SetInputSequences(const TokenSequences& sequences) {

ProviderOptions GetDefaultProviderOptions([[maybe_unused]] DeviceType device_type) {
ProviderOptions options;
#if USE_CUDA
if (device_type == DeviceType::CUDA) {
#if USE_CUDA
cudaStream_t cuda_stream;
cudaStreamCreate(&cuda_stream);

auto& cuda_options = options.emplace<OrtCUDAProviderOptions>();
cuda_options.has_user_compute_stream = true;
cuda_options.user_compute_stream = cuda_stream;
}
#else
throw std::runtime_error("Trying to use cuda with the non cuda version of onnxruntime-genai");
#endif
}

return options;
}
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <set>
#include <stdexcept>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <variant>
#include <vector>
Expand Down
1 change: 0 additions & 1 deletion src/models/gpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace Generators {
Gpt_Model::Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options)
: Model{std::move(config), provider_options} {
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get());

InitDeviceAllocator(*session_decoder_);
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ struct Gpt_State : State {
const Gpt_Model& model_;
bool first_run_{true};

InputIDs<int32_t> input_ids_{model_, *this};
InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
KV_Cache_Combined kv_cache_{model_, *this};
PositionIDs<int32_t> position_ids_;
PositionIDs position_ids_;
};
} // namespace Generators
36 changes: 16 additions & 20 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,48 @@

namespace Generators {

template <typename T>
InputIDs<T>::InputIDs(const Model& model, State& state)
InputIDs::InputIDs(const Model& model, State& state)
: model_{model},
state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
shape_ = {state_.search_params_.batch_size, state_.search_params_.sequence_length};
type_ = model_.session_info_->GetInputDataType(name_);

// If 64-bit, convert from 32-bit to 64-bit
if constexpr (std::is_same_v<T, int64_t>) {
value_ = OrtValue::CreateTensor<int64_t>(model.allocator_cpu_, shape_);
auto* p_data = value_->GetTensorMutableData<T>();
if (type_ == Ort::TypeToTensorType<int64_t>::type) {
value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_);
auto* p_data = value_->GetTensorMutableData<int64_t>();
for (auto v : state_.search_params_.input_ids) {
*p_data++ = v;
}
} else {
static_assert(std::is_same_v<T, int32_t>);
value_ = OrtValue::CreateTensor<T>(model.allocator_cpu_.GetInfo(), std::span<T>(const_cast<T*>(state_.search_params_.input_ids.data()), shape_[0] * shape_[1]), shape_);
if (type_ != Ort::TypeToTensorType<int32_t>::type)
throw std::runtime_error("InputIDs must be int64 or int32");
value_ = OrtValue::CreateTensor<int32_t>(model.allocator_cpu_.GetInfo(), std::span<int32_t>(const_cast<int32_t*>(state_.search_params_.input_ids.data()), shape_[0] * shape_[1]), shape_);
}

value_ = model_.ExpandInputs(value_, state_.search_params_.num_beams);
shape_[0] *= state_.search_params_.num_beams;
}

template <typename T>
void InputIDs<T>::Add() {
void InputIDs::Add() {
input_index_ = state_.inputs_.size();

state_.inputs_.push_back(value_.get());
state_.input_names_.push_back(name_);
}

template <typename T>
void InputIDs<T>::Update(RoamingArray<int32_t> next_tokens_unk) {
void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
// Resize input_ids shape once if it doesn't match the decoder shape
if (shape_[1] != 1) {
shape_[1] = 1;
value_ = OrtValue::CreateTensor<T>(*model_.allocator_device_, shape_);
value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.inputs_[input_index_] = value_.get();
}

auto* data = value_->GetTensorMutableData<T>();
// Update input_ids with next tokens, converting from 32-bit to 64-bit
if constexpr (std::is_same_v<T, int64_t>) {
if (type_ == Ort::TypeToTensorType<int64_t>::type) {
auto* data = value_->GetTensorMutableData<int64_t>();
#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA) {
auto next_tokens = next_tokens_unk.GetGPU();
Expand All @@ -61,17 +60,14 @@ void InputIDs<T>::Update(RoamingArray<int32_t> next_tokens_unk) {
}
}
} else {
static_assert(std::is_same_v<T, int32_t>);
auto* data = value_->GetTensorMutableData<int32_t>();
#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA)
cudaMemcpyAsync(data, next_tokens_unk.GetGPU().data(), shape_[0] * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
cudaMemcpyAsync(data, next_tokens_unk.GetGPU().data(), shape_[0] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_);
else
#endif
memcpy(data, next_tokens_unk.GetCPU().data(), shape_[0] * sizeof(T));
memcpy(data, next_tokens_unk.GetCPU().data(), shape_[0] * sizeof(int32_t));
}
}

template struct InputIDs<int32_t>;
template struct InputIDs<int64_t>;

} // namespace Generators
2 changes: 1 addition & 1 deletion src/models/input_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

namespace Generators {

template <typename T>
struct InputIDs {
InputIDs(const Model& model, State& state);

Expand All @@ -18,6 +17,7 @@ struct InputIDs {
size_t input_index_{~0U};

std::array<int64_t, 2> shape_{};
ONNXTensorElementDataType type_;
std::unique_ptr<OrtValue> value_;
};

Expand Down
54 changes: 35 additions & 19 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@ KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state)
: model_{model},
state_{state},
layer_count_{model.config_->model.decoder.num_hidden_layers},
shape_{2, static_cast<int64_t>(state_.search_params_.batch_size) * state_.search_params_.num_beams, model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size},
empty_past_{OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type)} {
shape_{2, static_cast<int64_t>(state_.search_params_.batch_size) * state_.search_params_.num_beams, model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} {
pasts_.resize(layer_count_);
presents_.reserve(layer_count_);

shape_[3] = state_.search_params_.sequence_length;
for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model.allocator_device_, shape_, model_.config_->model.kv_type));

char string[64];
snprintf(string, std::size(string), model.config_->model.decoder.inputs.past_names.c_str(), i);
input_name_strings_.emplace_back(string);

snprintf(string, std::size(string), model.config_->model.decoder.outputs.present_names.c_str(), i);
output_name_strings_.emplace_back(string);
}

// Derive the KV data type from the KV input 0
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
shape_[3] = state_.search_params_.sequence_length;

for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model.allocator_device_, shape_, type_));
}
}

void KV_Cache_Combined::Add() {
Expand Down Expand Up @@ -51,7 +57,7 @@ void KV_Cache_Combined::Update(std::span<const int32_t> beam_indices, int curren

shape_[3] = current_length;
for (int i = 0; i < layer_count_; i++) {
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type);
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.inputs_[input_index_ + i] = pasts_[i].get();
state_.outputs_[output_index_ + i] = presents_[i].get();
}
Expand Down Expand Up @@ -100,7 +106,7 @@ void KV_Cache_Combined::PickPastState(std::span<const int32_t> beam_indices, int
}

void KV_Cache_Combined::PickPastState(std::span<const int32_t> beam_indices, int index) {
if (model_.config_->model.kv_type == Ort::TypeToTensorType<float>::type) {
if (type_ == Ort::TypeToTensorType<float>::type) {
PickPastState<float>(beam_indices, index);
} else {
PickPastState<Ort::Float16_t>(beam_indices, index);
Expand All @@ -111,17 +117,11 @@ KV_Cache::KV_Cache(const Model& model, State& state)
: model_{model},
state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
shape_{static_cast<int64_t>(state_.search_params_.batch_size) * state_.search_params_.num_beams, model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size},
empty_past_{OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type)} {
shape_{static_cast<int64_t>(state_.search_params_.batch_size) * state_.search_params_.num_beams, model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} {
pasts_.resize(layer_count_ * 2);
presents_.reserve(layer_count_ * 2);

shape_[2] = state_.search_params_.sequence_length; // Set this after empty_past_ has been created with 0 for this field

for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type));
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type));

char string[64];
snprintf(string, std::size(string), model.config_->model.decoder.inputs.past_key_names.c_str(), i);
input_name_strings_.emplace_back(string);
Expand All @@ -133,6 +133,17 @@ KV_Cache::KV_Cache(const Model& model, State& state)
snprintf(string, std::size(string), model.config_->model.decoder.outputs.present_value_names.c_str(), i);
output_name_strings_.emplace_back(string);
}

// Derive the KV data type from the KV input 0
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
shape_[2] = state_.search_params_.sequence_length; // Set this after empty_past_ has been created with 0 for this field

for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
}
}

void KV_Cache::AddEncoder() {
Expand Down Expand Up @@ -168,7 +179,7 @@ void KV_Cache::Update(std::span<const int32_t> beam_indices, int current_length)

shape_[2] = current_length;
for (int i = 0; i < layer_count_ * 2; i++) {
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type);
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.outputs_[output_index_ + i] = presents_[i].get();
}
}
Expand Down Expand Up @@ -207,7 +218,7 @@ void KV_Cache::PickPastState(std::span<const int32_t> beam_indices, int index) {
}

void KV_Cache::PickPastState(std::span<const int32_t> beam_indices, int index) {
if (model_.config_->model.kv_type == Ort::TypeToTensorType<float>::type) {
if (type_ == Ort::TypeToTensorType<float>::type) {
PickPastState<float>(beam_indices, index);
} else {
PickPastState<Ort::Float16_t>(beam_indices, index);
Expand All @@ -222,9 +233,6 @@ Cross_Cache::Cross_Cache(const Model& model, State& state)
values_.reserve(layer_count_ * 2);

for (int i = 0; i < layer_count_; ++i) {
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type));
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, model_.config_->model.kv_type));

char string[64];
snprintf(string, std::size(string), model.config_->model.decoder.inputs.cross_past_key_names.c_str(), i);
input_name_strings_.emplace_back(string);
Expand All @@ -236,6 +244,14 @@ Cross_Cache::Cross_Cache(const Model& model, State& state)
snprintf(string, std::size(string), model.config_->model.decoder.outputs.cross_present_value_names.c_str(), i);
output_name_strings_.emplace_back(string);
}

// Derive the KV data type from the KV input 0
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

for (int i = 0; i < layer_count_; ++i) {
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
}
}

void Cross_Cache::AddOutputs() {
Expand Down
3 changes: 3 additions & 0 deletions src/models/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct KV_Cache_Combined {
size_t input_index_{~0U}, output_index_{~0U};

std::array<int64_t, 5> shape_;
ONNXTensorElementDataType type_;

std::unique_ptr<OrtValue> empty_past_;
std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
Expand All @@ -42,6 +43,7 @@ struct KV_Cache {
size_t input_index_{~0U}, output_index_{~0U};

std::array<int64_t, 4> shape_;
ONNXTensorElementDataType type_;

std::unique_ptr<OrtValue> empty_past_;
std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
Expand All @@ -61,6 +63,7 @@ struct Cross_Cache {
int layer_count_;

std::array<int64_t, 4> shape_;
ONNXTensorElementDataType type_;

std::vector<std::unique_ptr<OrtValue>> values_;
std::vector<std::string> input_name_strings_, output_name_strings_;
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ struct Llama_State : State {
const Llama_Model& model_;
bool first_run_{true};

InputIDs<int64_t> input_ids_{model_, *this};
InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
KV_Cache kv_cache_{model_, *this};
PositionIDs<int64_t> position_ids_;
PositionIDs position_ids_;
};

} // namespace Generators
31 changes: 16 additions & 15 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,44 @@ namespace Generators {

Logits::Logits(const Model& model, State& state)
: model_{model},
state_{state} {
logits_shape_ = {state_.search_params_.batch_size * state_.search_params_.num_beams, state_.search_params_.sequence_length, state_.search_params_.vocab_size};
logits_ = OrtValue::CreateTensor(*model.allocator_device_, logits_shape_, model_.config_->model.logits_type);
state_{state},
shape_{static_cast<int64_t>(state_.search_params_.batch_size) * state_.search_params_.num_beams, state_.search_params_.sequence_length, state_.search_params_.vocab_size},
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
value_ = OrtValue::CreateTensor(*model.allocator_device_, shape_, type_);

if (model_.device_type_ == DeviceType::CPU && model_.config_->model.logits_type != Ort::TypeToTensorType<float>::type)
if (model_.device_type_ == DeviceType::CPU && type_ != Ort::TypeToTensorType<float>::type)
throw std::runtime_error("Model logits_type can only be float32 on CPU");
}

RoamingArray<float> Logits::Get() {
auto element_count = logits_->GetTensorTypeAndShapeInfo()->GetElementCount();
size_t element_count = shape_[0] * shape_[1] * shape_[2];

#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA) {
if (model_.config_->model.logits_type == Ort::TypeToTensorType<Ort::Float16_t>::type) {
ConvertFp16ToFp32(*model_.allocator_device_, model_.cuda_stream_, *logits_, logits32_);
return gpu_span<float>{logits32_->GetTensorMutableData<float>(), element_count};
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type) {
ConvertFp16ToFp32(*model_.allocator_device_, model_.cuda_stream_, *value_, value32_);
return gpu_span<float>{value32_->GetTensorMutableData<float>(), element_count};
}
return gpu_span<float>{logits_->GetTensorMutableData<float>(), element_count};
return gpu_span<float>{value_->GetTensorMutableData<float>(), element_count};
}
#endif

return cpu_span<float>{logits_->GetTensorMutableData<float>(), element_count};
return cpu_span<float>{value_->GetTensorMutableData<float>(), element_count};
}

void Logits::Add() {
output_index_ = state_.outputs_.size();

state_.output_names_.push_back(model_.config_->model.decoder.outputs.logits.c_str());
state_.outputs_.push_back(logits_.get());
state_.outputs_.push_back(value_.get());
}

void Logits::Update() {
// Resize the logits shape once if it doesn't match the decoder shape
if (logits_shape_[1] != 1) {
logits_shape_[1] = 1;
logits_ = OrtValue::CreateTensor(*model_.allocator_device_, logits_shape_, model_.config_->model.logits_type);
state_.outputs_[output_index_] = logits_.get();
if (shape_[1] != 1) {
shape_[1] = 1;
value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.outputs_[output_index_] = value_.get();
}
}

Expand Down
Loading

0 comments on commit 538cad3

Please sign in to comment.