From 3c9d563366afffcaf7cc598be890970aa35a3de6 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 7 Oct 2024 10:40:00 -0700 Subject: [PATCH 01/18] Updates for decoder-pipeline to work with other split models --- src/config.cpp | 7 + src/config.h | 1 + src/models/decoder_only_pipeline.cpp | 10 +- src/models/decoder_only_pipeline.h | 1 + src/models/kv_cache.cpp | 214 +++++++++++++++++++++++++++ src/models/kv_cache.h | 26 ++++ 6 files changed, 257 insertions(+), 2 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 0c6de4d69..cd9f30377 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -330,6 +330,13 @@ struct Decoder_Element : JSON::Element { throw JSON::unknown_value_error{}; } + void OnBool(std::string_view name, bool value) override { + if (name == "sliding_window_key_value_cache") { + v_.sliding_window_key_value_cache = value; + } else + throw JSON::unknown_value_error{}; + } + private: Config::Model::Decoder& v_; SessionOptions_Element session_options_{v_.session_options}; diff --git a/src/config.h b/src/config.h index cd7cffd8f..c8fbe403c 100644 --- a/src/config.h +++ b/src/config.h @@ -106,6 +106,7 @@ struct Config { int num_key_value_heads{}; int num_hidden_layers{}; int head_size{}; + bool sliding_window_key_value_cache{false}; struct Inputs { std::string input_ids{Defaults::InputIdsName}; diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index d21f77600..bc1f3190d 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -96,8 +96,13 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode position_inputs_.Add(); logits_.Add(); if (KV_Cache::IsCacheNeeded(model)) { - kv_cache_ = std::make_unique(*this); - kv_cache_->Add(); + if (model.config_->model.decoder.sliding_window_key_value_cache) { + sliding_window_key_value_cache_ = std::make_unique(*this); + sliding_window_key_value_cache_->Add(); + } else { + kv_cache_ = std::make_unique(*this); + kv_cache_->Add(); + } } extra_inputs_.Add(); @@ -243,6 +248,7 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan& next_tok size_t new_length = input_ids_.GetShape()[1]; position_inputs_.Update(next_tokens, total_length, static_cast(new_length)); if (kv_cache_) kv_cache_->Update(beam_indices, total_length); + if (sliding_window_key_value_cache_) sliding_window_key_value_cache_->Update(beam_indices, total_length); logits_.Update(next_tokens, new_length); } diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index b778ca92b..3cbb81ab8 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -71,6 +71,7 @@ struct DecoderOnlyPipelineState : State { InputIDs input_ids_{*this}; Logits logits_{*this}; std::unique_ptr kv_cache_; + std::unique_ptr sliding_window_key_value_cache_; PositionInputs position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 5cb66ade3..34ed17db3 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -17,6 +17,10 @@ std::string ComposeKeyValueName(const std::string& template_string, int index) { return std::string(key_value_name); } +int64_t ElementCountFromShape(const std::array& shape) { + return std::accumulate(shape.begin(), shape.end(), int64_t{1}, std::multiplies()); +} + } // namespace KV_Cache_Combined::KV_Cache_Combined(State& state) @@ -435,4 +439,214 @@ void Cross_Cache::AddInputs() { } } +SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) + : state_{state}, + layer_count_{model_.config_->model.decoder.num_hidden_layers}, + window_size_{128}, + key_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, model_.config_->model.context_length - window_size_}, + key_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, window_size_}, + value_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.context_length - window_size_, model_.config_->model.decoder.head_size}, + value_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, + window_size_, model_.config_->model.decoder.head_size} { + for (int i = 0; i < layer_count_; ++i) { + input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, i)); + input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, i)); + + output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_key_names, i)); + output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_value_names, i)); + } + + type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); + + for (int i = 0; i < layer_count_; ++i) { + key_caches_in_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_)); + std::fill_n(key_caches_in_[i]->GetTensorMutableData(), + ElementCountFromShape(key_cache_shape_in_), static_cast(0)); + + value_caches_in_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_)); + std::fill_n(value_caches_in_[i]->GetTensorMutableData(), + ElementCountFromShape(value_cache_shape_in_), static_cast(0)); + + key_caches_out_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_)); + value_caches_out_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_)); + } +} + +void SlidingWindowKeyValueCache::Add() { + input_index_ = state_.inputs_.size(); + output_index_ = state_.outputs_.size(); + + for (size_t i = 0; i < layer_count_; ++i) { + state_.inputs_.push_back(key_caches_in_[i].get()); + state_.input_names_.push_back(input_name_strings_[2 * i].c_str()); + + state_.inputs_.push_back(value_caches_in_[i].get()); + state_.input_names_.push_back(input_name_strings_[2 * i + 1].c_str()); + + state_.outputs_.push_back(key_caches_out_[i].get()); + state_.output_names_.push_back(output_name_strings_[2 * i].c_str()); + + state_.outputs_.push_back(value_caches_out_[i].get()); + state_.output_names_.push_back(output_name_strings_[2 * i + 1].c_str()); + } +} + +void SlidingWindowKeyValueCache::Slide() { + for (size_t i = 0; i < layer_count_; ++i) { + uint8_t* key_cache_in_data = key_caches_in_[i]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[i]->GetTensorMutableData(); + + int64_t num_key_cache_chunks = key_cache_shape_in_[0] * key_cache_shape_in_[2]; + for (int64_t j = 0; j < num_key_cache_chunks; ++j) { + { + cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in_[3], + key_cache_shape_in_[3] - window_size_); + cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in_[3] + window_size_, + key_cache_shape_in_[3] - window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + { + cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in_[3] + key_cache_shape_in_[3] - window_size_, + window_size_); + cpu_span key_cache_src(key_cache_out_data + j * key_cache_shape_out_[3], + window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + } + + uint8_t* value_cache_in_data = value_caches_in_[i]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[i]->GetTensorMutableData(); + + for (int64_t j = 0; j < value_cache_shape_in_[0]; ++j) { + { + cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]), + (value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]); + cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + (window_size_ * value_cache_shape_in_[3]), + (value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + { + cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + ((value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]), + window_size_ * value_cache_shape_in_[3]); + cpu_span value_cache_src(value_cache_out_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]), + window_size_ * value_cache_shape_out_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + } + } +} + +void SlidingWindowKeyValueCache::Update(std::span beam_indices, int current_length) { + if (window_size_ == 1) { + Slide(); + return; + } + + // No sliding needed. But we need to concatenate the last window_size_ elements to the end of the cache + + // key_caches_in_ = Concat(key_caches_in_[:, :, :, 1:], key_caches_out_) + // [num_key_value_heads, 1, head_size, context_length-1] = [num_key_value_heads, 1, head_size, context_length - window_size_ - 1] + + // [num_key_value_heads, 1, head_size, window_size_] + // value_cache = Concat(value_caches_in_[:, :, 1:, :], value_caches_out_) + // [num_key_value_heads, 1, context_length - 1, head_size] = [num_key_value_heads, 1, context_length - window_size_ - 1, head_size] + + // [num_key_value_heads, 1, window_size_, head_size] + + int updated_window_size = 1; + auto updated_key_cache_shape_in = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, + model_.config_->model.context_length - updated_window_size}; + + auto updated_value_cache_shape_in = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.context_length - updated_window_size, + model_.config_->model.decoder.head_size}; + + auto updated_key_cache_shape_out = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, + updated_window_size}; + + auto updated_value_cache_shape_out = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + updated_window_size, + model_.config_->model.decoder.head_size}; + + for (size_t i = 0; i < layer_count_; ++i) { + std::unique_ptr key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_); + + uint8_t* key_cache_data = key_cache->GetTensorMutableData(); + uint8_t* key_cache_in_data = key_caches_in_[i]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[i]->GetTensorMutableData(); + + int64_t num_key_cache_chunks = updated_key_cache_shape_in[0] * updated_key_cache_shape_in[2]; + for (int64_t j = 0; j < num_key_cache_chunks; ++j) { + { + cpu_span key_cache_dst(key_cache_data + j * updated_key_cache_shape_in[3], + updated_key_cache_shape_in[3] - updated_window_size); + cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in_[3] + updated_window_size, + key_cache_shape_in_[3] - updated_window_size); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + { + cpu_span key_cache_dst(key_cache_data + j * updated_key_cache_shape_in[3] + + key_cache_shape_in_[3] - updated_window_size, + window_size_); + cpu_span key_cache_src(key_cache_out_data + j * key_cache_shape_out_[3], + window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + } + + key_caches_in_[i] = std::move(key_cache); + key_caches_out_[i] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_); + + std::unique_ptr value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_); + + uint8_t* value_cache_data = value_cache->GetTensorMutableData(); + uint8_t* value_cache_in_data = value_caches_in_[i]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[i]->GetTensorMutableData(); + + for (int64_t j = 0; j < updated_value_cache_shape_in[0]; ++j) { + { + cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]), + (value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]); + cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]) + + (updated_window_size * value_cache_shape_out_[3]), + (value_cache_shape_in_[2] - updated_window_size) * value_cache_shape_in_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + { + cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]) + + ((value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]), + window_size_ * value_cache_shape_out_[3]); + cpu_span value_cache_src(value_cache_out_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]), + window_size_ * value_cache_shape_out_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + } + + value_caches_in_[i] = std::move(value_cache); + value_caches_out_[i] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); + } + + window_size_ = 1; + key_cache_shape_in_ = updated_key_cache_shape_in; + value_cache_shape_in_ = updated_value_cache_shape_in; + key_cache_shape_out_ = updated_key_cache_shape_out; + value_cache_shape_out_ = updated_value_cache_shape_out; + + for (size_t i = 0; i < layer_count_; ++i) { + state_.inputs_[input_index_ + 2 * i] = key_caches_in_[i].get(); + state_.inputs_[input_index_ + 2 * i + 1] = value_caches_in_[i].get(); + state_.outputs_[output_index_ + 2 * i] = key_caches_out_[i].get(); + state_.outputs_[output_index_ + 2 * i + 1] = value_caches_out_[i].get(); + } +} + } // namespace Generators diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 39cbf2e7c..f96037d7d 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -89,4 +89,30 @@ struct Cross_Cache { std::vector> values_; std::vector input_name_strings_, output_name_strings_; }; + +struct SlidingWindowKeyValueCache { + SlidingWindowKeyValueCache(State& state); + + void Add(); + void Update(std::span beam_indices, int current_length); + + private: + void Slide(); + void Shift(); + void Concat(); + + State& state_; + const Model& model_{state_.model_}; + int layer_count_; + int window_size_; + size_t input_index_{~0U}, output_index_{~0U}; + + std::array key_cache_shape_in_, key_cache_shape_out_; + std::array value_cache_shape_in_, value_cache_shape_out_; + ONNXTensorElementDataType type_; + + std::vector> key_caches_in_, value_caches_in_; + std::vector> key_caches_out_, value_caches_out_; + std::vector input_name_strings_, output_name_strings_; +}; } // namespace Generators From 4212d19b79a1f9564d85dd961c67a1d79b3f562a Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 11 Nov 2024 13:38:54 -0800 Subject: [PATCH 02/18] Sync changes with main --- src/models/kv_cache.cpp | 2 +- src/models/kv_cache.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 34ed17db3..f05fa4bd9 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -545,7 +545,7 @@ void SlidingWindowKeyValueCache::Slide() { } } -void SlidingWindowKeyValueCache::Update(std::span beam_indices, int current_length) { +void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { if (window_size_ == 1) { Slide(); return; diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index f96037d7d..1aadf2389 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -94,7 +94,7 @@ struct SlidingWindowKeyValueCache { SlidingWindowKeyValueCache(State& state); void Add(); - void Update(std::span beam_indices, int current_length); + void Update(DeviceSpan beam_indices, int current_length); private: void Slide(); From 5e4ab3debacef1d6807da34370353de2aaf122ab Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 13 Nov 2024 14:12:26 -0800 Subject: [PATCH 03/18] Allow adjustments to the sliding window kv cache --- src/config.cpp | 28 +++++++++--- src/config.h | 7 ++- src/models/decoder_only_pipeline.cpp | 2 +- src/models/kv_cache.cpp | 68 +++++++++++++++------------- src/models/kv_cache.h | 2 - 5 files changed, 65 insertions(+), 42 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index cd9f30377..564da98cd 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -286,6 +286,22 @@ struct Pipeline_Element : JSON::Element { PipelineModelObject_Element object_{v_}; }; +struct SlidingWindowKeyValueCache_Element : JSON::Element { + explicit SlidingWindowKeyValueCache_Element(std::optional& v) : v_{v} {} + + void OnNumber(std::string_view name, double value) override { + if (name == "window_size") { + v_->window_size = static_cast(value); + } else if (name == "pad_value") { + v_->pad_value = static_cast(value); + } else + throw JSON::unknown_value_error{}; + } + + private: + std::optional& v_; +}; + struct Decoder_Element : JSON::Element { explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {} @@ -321,6 +337,10 @@ struct Decoder_Element : JSON::Element { if (name == "outputs") { return outputs_; } + if (name == "sliding_window_key_value_cache") { + v_.sliding_window_key_value_cache = Config::Model::Decoder::SlidingWindowKeyValueCache{}; + return sliding_window_key_value_cache_; + } throw JSON::unknown_value_error{}; } @@ -330,19 +350,13 @@ struct Decoder_Element : JSON::Element { throw JSON::unknown_value_error{}; } - void OnBool(std::string_view name, bool value) override { - if (name == "sliding_window_key_value_cache") { - v_.sliding_window_key_value_cache = value; - } else - throw JSON::unknown_value_error{}; - } - private: Config::Model::Decoder& v_; SessionOptions_Element session_options_{v_.session_options}; Inputs_Element inputs_{v_.inputs}; Outputs_Element outputs_{v_.outputs}; Pipeline_Element pipeline_{v_.pipeline}; + SlidingWindowKeyValueCache_Element sliding_window_key_value_cache_{v_.sliding_window_key_value_cache}; }; struct VisionInputs_Element : JSON::Element { diff --git a/src/config.h b/src/config.h index c8fbe403c..931e3f6bd 100644 --- a/src/config.h +++ b/src/config.h @@ -106,7 +106,12 @@ struct Config { int num_key_value_heads{}; int num_hidden_layers{}; int head_size{}; - bool sliding_window_key_value_cache{false}; + + struct SlidingWindowKeyValueCache { + int window_size{128}; + int pad_value{}; + }; + std::optional sliding_window_key_value_cache; struct Inputs { std::string input_ids{Defaults::InputIdsName}; diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index bc1f3190d..0b7d59da6 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -96,7 +96,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode position_inputs_.Add(); logits_.Add(); if (KV_Cache::IsCacheNeeded(model)) { - if (model.config_->model.decoder.sliding_window_key_value_cache) { + if (model.config_->model.decoder.sliding_window_key_value_cache.has_value()) { sliding_window_key_value_cache_ = std::make_unique(*this); sliding_window_key_value_cache_->Add(); } else { diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index f05fa4bd9..7430876ad 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -442,7 +442,7 @@ void Cross_Cache::AddInputs() { SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, - window_size_{128}, + window_size_{model_.config_->model.decoder.sliding_window_key_value_cache->window_size}, key_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, model_.config_->model.decoder.head_size, model_.config_->model.context_length - window_size_}, key_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, @@ -460,17 +460,23 @@ SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) } type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); + if (type_ != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { + throw std::runtime_error("Expected input data type to be uint8_t for SlidingWindowKeyValueCache. Actual: " + + std::to_string(type_)); + } for (int i = 0; i < layer_count_; ++i) { key_caches_in_.push_back( OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_)); std::fill_n(key_caches_in_[i]->GetTensorMutableData(), - ElementCountFromShape(key_cache_shape_in_), static_cast(0)); + ElementCountFromShape(key_cache_shape_in_), + static_cast(model_.config_->model.decoder.sliding_window_key_value_cache->pad_value)); value_caches_in_.push_back( OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_)); std::fill_n(value_caches_in_[i]->GetTensorMutableData(), - ElementCountFromShape(value_cache_shape_in_), static_cast(0)); + ElementCountFromShape(value_cache_shape_in_), + static_cast(model_.config_->model.decoder.sliding_window_key_value_cache->pad_value)); key_caches_out_.push_back( OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_)); @@ -483,25 +489,25 @@ void SlidingWindowKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); - for (size_t i = 0; i < layer_count_; ++i) { - state_.inputs_.push_back(key_caches_in_[i].get()); - state_.input_names_.push_back(input_name_strings_[2 * i].c_str()); + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + state_.inputs_.push_back(key_caches_in_[layer_idx].get()); + state_.input_names_.push_back(input_name_strings_[2 * layer_idx].c_str()); - state_.inputs_.push_back(value_caches_in_[i].get()); - state_.input_names_.push_back(input_name_strings_[2 * i + 1].c_str()); + state_.inputs_.push_back(value_caches_in_[layer_idx].get()); + state_.input_names_.push_back(input_name_strings_[2 * layer_idx + 1].c_str()); - state_.outputs_.push_back(key_caches_out_[i].get()); - state_.output_names_.push_back(output_name_strings_[2 * i].c_str()); + state_.outputs_.push_back(key_caches_out_[layer_idx].get()); + state_.output_names_.push_back(output_name_strings_[2 * layer_idx].c_str()); - state_.outputs_.push_back(value_caches_out_[i].get()); - state_.output_names_.push_back(output_name_strings_[2 * i + 1].c_str()); + state_.outputs_.push_back(value_caches_out_[layer_idx].get()); + state_.output_names_.push_back(output_name_strings_[2 * layer_idx + 1].c_str()); } } void SlidingWindowKeyValueCache::Slide() { - for (size_t i = 0; i < layer_count_; ++i) { - uint8_t* key_cache_in_data = key_caches_in_[i]->GetTensorMutableData(); - uint8_t* key_cache_out_data = key_caches_out_[i]->GetTensorMutableData(); + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[layer_idx]->GetTensorMutableData(); int64_t num_key_cache_chunks = key_cache_shape_in_[0] * key_cache_shape_in_[2]; for (int64_t j = 0; j < num_key_cache_chunks; ++j) { @@ -521,8 +527,8 @@ void SlidingWindowKeyValueCache::Slide() { } } - uint8_t* value_cache_in_data = value_caches_in_[i]->GetTensorMutableData(); - uint8_t* value_cache_out_data = value_caches_out_[i]->GetTensorMutableData(); + uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[layer_idx]->GetTensorMutableData(); for (int64_t j = 0; j < value_cache_shape_in_[0]; ++j) { { @@ -577,12 +583,12 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu updated_window_size, model_.config_->model.decoder.head_size}; - for (size_t i = 0; i < layer_count_; ++i) { + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { std::unique_ptr key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_); uint8_t* key_cache_data = key_cache->GetTensorMutableData(); - uint8_t* key_cache_in_data = key_caches_in_[i]->GetTensorMutableData(); - uint8_t* key_cache_out_data = key_caches_out_[i]->GetTensorMutableData(); + uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[layer_idx]->GetTensorMutableData(); int64_t num_key_cache_chunks = updated_key_cache_shape_in[0] * updated_key_cache_shape_in[2]; for (int64_t j = 0; j < num_key_cache_chunks; ++j) { @@ -603,14 +609,14 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu } } - key_caches_in_[i] = std::move(key_cache); - key_caches_out_[i] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_); + key_caches_in_[layer_idx] = std::move(key_cache); + key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_); std::unique_ptr value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_); uint8_t* value_cache_data = value_cache->GetTensorMutableData(); - uint8_t* value_cache_in_data = value_caches_in_[i]->GetTensorMutableData(); - uint8_t* value_cache_out_data = value_caches_out_[i]->GetTensorMutableData(); + uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[layer_idx]->GetTensorMutableData(); for (int64_t j = 0; j < updated_value_cache_shape_in[0]; ++j) { { @@ -631,8 +637,8 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu } } - value_caches_in_[i] = std::move(value_cache); - value_caches_out_[i] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); + value_caches_in_[layer_idx] = std::move(value_cache); + value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); } window_size_ = 1; @@ -641,11 +647,11 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu key_cache_shape_out_ = updated_key_cache_shape_out; value_cache_shape_out_ = updated_value_cache_shape_out; - for (size_t i = 0; i < layer_count_; ++i) { - state_.inputs_[input_index_ + 2 * i] = key_caches_in_[i].get(); - state_.inputs_[input_index_ + 2 * i + 1] = value_caches_in_[i].get(); - state_.outputs_[output_index_ + 2 * i] = key_caches_out_[i].get(); - state_.outputs_[output_index_ + 2 * i + 1] = value_caches_out_[i].get(); + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + state_.inputs_[input_index_ + 2 * layer_idx] = key_caches_in_[layer_idx].get(); + state_.inputs_[input_index_ + 2 * layer_idx + 1] = value_caches_in_[layer_idx].get(); + state_.outputs_[output_index_ + 2 * layer_idx] = key_caches_out_[layer_idx].get(); + state_.outputs_[output_index_ + 2 * layer_idx + 1] = value_caches_out_[layer_idx].get(); } } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 1aadf2389..bb48c99fd 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -98,8 +98,6 @@ struct SlidingWindowKeyValueCache { private: void Slide(); - void Shift(); - void Concat(); State& state_; const Model& model_{state_.model_}; From 0a0f98d6d7991bacd18d37265690a9593169a415 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 6 Nov 2024 12:35:53 -0800 Subject: [PATCH 04/18] enable setting default ORT logging level to verbose with ORTGENAI_ORT_VERBOSE_LOGGING environment variable --- src/generators.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index eff98f5ca..b0a65356a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -3,6 +3,7 @@ #include "generators.h" #include "sequences.h" +#include "models/env_utils.h" #include "models/model.h" #include "models/decoder_only.h" #include "search.h" @@ -45,8 +46,14 @@ void OnCudaError(cudaError_t error) { assert(false); } static bool _ = (Ort::InitApi(), false); +static OrtLoggingLevel GetDefaultOrtLoggingLevel() { + bool ort_verbose_logging = false; + GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging); + return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; +} + OrtGlobals::OrtGlobals() - : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} { + : env_{OrtEnv::Create(GetDefaultOrtLoggingLevel())} { auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1); Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()}; env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config); From 7a331ef161a2ddc8383a94a08647e46973c94ddc Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:58:42 -0800 Subject: [PATCH 05/18] hack to run with qnn shared memory allocator --- src/generators.cpp | 2 ++ src/generators.h | 1 + src/models/decoder_only_pipeline.cpp | 2 +- src/models/model.cpp | 34 +++++++++++++++++++++++----- src/models/model.h | 12 ++++------ src/models/position_inputs.cpp | 3 ++- 6 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index b0a65356a..09da6d187 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -177,6 +177,8 @@ std::string to_string(DeviceType device_type) { return "DirectML"; case DeviceType::WEBGPU: return "WebGpu"; + case DeviceType::QNN_WITH_SHARED_MEMORY: + return "QnnWithSharedMemory"; } throw std::runtime_error("Unknown device type"); } diff --git a/src/generators.h b/src/generators.h index 0b6dc7cfc..2ccc1f0c0 100644 --- a/src/generators.h +++ b/src/generators.h @@ -60,6 +60,7 @@ enum struct DeviceType { CUDA, DML, WEBGPU, + QNN_WITH_SHARED_MEMORY, }; std::string to_string(DeviceType device_type); diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 0b7d59da6..b5fa8b2fe 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -58,7 +58,7 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const { } bool IntermediatePipelineState::SupportsPrimaryDevice() const { - if (model_.device_type_ == DeviceType::CPU) { + if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN_WITH_SHARED_MEMORY) { return true; } else if (model_.device_type_ == DeviceType::CUDA) { if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) { diff --git a/src/models/model.cpp b/src/models/model.cpp index 4430fdb41..a465e4471 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -288,29 +288,41 @@ Model::Model(std::unique_ptr config) : config_{std::move(config)} { Model::~Model() = default; -void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) { +void Model::InitDeviceAllocator(OrtSession& session) { allocator_device_ = &allocator_cpu_; #if USE_CUDA if (device_type_ == DeviceType::CUDA) { allocator_device_ = GetCudaAllocator(session); + allocator_kvcache_ = allocator_device_; } #endif + #if USE_DML if (device_type_ == DeviceType::DML) { memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - dml_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_device_ = dml_owned_allocator_.get(); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_device_ = owned_allocator_device_.get(); + allocator_kvcache_ = allocator_device_; } #endif - allocator_kvcache_ = allocator_device_; + #if USE_WEBGPU if (device_type_ == DeviceType::WEBGPU) { // for webgpu we only use device memory for kv_cache memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - webgpu_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_kvcache_ = webgpu_owned_allocator_.get(); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_kvcache_ = owned_allocator_device_.get(); } #endif + + if (device_type_ == DeviceType::QNN_WITH_SHARED_MEMORY) { + memory_info_device_ = OrtMemoryInfo::Create("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, + OrtMemType::OrtMemTypeDefault); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_device_ = owned_allocator_device_.get(); + allocator_kvcache_ = allocator_device_; + } + session_info_ = std::make_unique(session); captured_graph_pool_ = std::make_shared(config_.get(), session_info_.get(), allocator_device_); } @@ -500,6 +512,15 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ for (auto& option : provider_options.options) { opts.emplace(option.first, option.second); } + + // TODO set device_type_ in a less hacky way. + // now, all QNN EP enable_htp_shared_memory_allocator option values had better be consistent... + // on the other hand, not sure if is_primary_session_options is the right thing to check here. + if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator"); + opt_it != opts.end() && opt_it->second == "1") { + device_type_ = DeviceType::QNN_WITH_SHARED_MEMORY; + } + session_options.AppendExecutionProvider("QNN", opts); } else if (provider_options.name == "webgpu") { device_type_ = DeviceType::WEBGPU; @@ -685,6 +706,7 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, switch (device_type_) { case DeviceType::WEBGPU: case DeviceType::DML: + case DeviceType::QNN_WITH_SHARED_MEMORY: // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs case DeviceType::CPU: for (int i = 0; i < batch_size; i++) { diff --git a/src/models/model.h b/src/models/model.h index bec481f03..9fdd7ac11 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -181,15 +181,11 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::unique_ptr dml_execution_context_; std::unique_ptr dml_readback_heap_; ComPtr dml_device_; - std::unique_ptr dml_owned_allocator_; -#endif -#if USE_WEBGPU - std::unique_ptr webgpu_owned_allocator_; - std::unique_ptr webgpu_io_binding_; -#endif -#if USE_DML || USE_WEBGPU - std::unique_ptr memory_info_device_; #endif + + std::unique_ptr owned_allocator_device_{}; // nullptr if n/a + std::unique_ptr memory_info_device_{}; // nullptr if n/a + std::shared_ptr captured_graph_pool_; std::map> pipeline_session_options_; }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index a6180e4ee..dd3c77011 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -243,7 +243,8 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { switch (model_.device_type_) { case DeviceType::WEBGPU: - case DeviceType::CPU: { + case DeviceType::CPU: + case DeviceType::QNN_WITH_SHARED_MEMORY: { type_ == Ort::TypeToTensorType ? UpdateAttentionMaskImpl(total_length) : UpdateAttentionMaskImpl(total_length); break; From aca6622b38f7a7b79565f12cc1500433533733c8 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 14 Nov 2024 13:10:32 -0800 Subject: [PATCH 06/18] Make kv cache updates parallel --- src/models/kv_cache.cpp | 14 ++++++++++---- src/models/threadpool.cpp | 20 ++++++++++++++++++++ src/models/threadpool.h | 20 ++++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 src/models/threadpool.cpp create mode 100644 src/models/threadpool.h diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 7430876ad..c00acc398 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -1,6 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "../generators.h" #include "model.h" #include "kv_cache.h" +#include "threadpool.h" namespace Generators { @@ -505,7 +509,8 @@ void SlidingWindowKeyValueCache::Add() { } void SlidingWindowKeyValueCache::Slide() { - for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + ThreadPool thread_pool{static_cast(layer_count_)}; + thread_pool.Compute([&](size_t layer_idx) { uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); uint8_t* key_cache_out_data = key_caches_out_[layer_idx]->GetTensorMutableData(); @@ -548,7 +553,7 @@ void SlidingWindowKeyValueCache::Slide() { std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); } } - } + }); } void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { @@ -583,7 +588,8 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu updated_window_size, model_.config_->model.decoder.head_size}; - for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + ThreadPool thread_pool{static_cast(layer_count_)}; + thread_pool.Compute([&](size_t layer_idx) { std::unique_ptr key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_); uint8_t* key_cache_data = key_cache->GetTensorMutableData(); @@ -639,7 +645,7 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu value_caches_in_[layer_idx] = std::move(value_cache); value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); - } + }); window_size_ = 1; key_cache_shape_in_ = updated_key_cache_shape_in; diff --git a/src/models/threadpool.cpp b/src/models/threadpool.cpp new file mode 100644 index 000000000..941bac1d8 --- /dev/null +++ b/src/models/threadpool.cpp @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "threadpool.h" + +namespace Generators { + +ThreadPool::ThreadPool(size_t num_threads) : num_threads_{num_threads} {} + +void ThreadPool::Compute(const std::function& func) { + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back([&, i] { func(i); }); + } + + for (auto& thread : threads_) { + thread.join(); + } +} + +} // namespace Generators \ No newline at end of file diff --git a/src/models/threadpool.h b/src/models/threadpool.h new file mode 100644 index 000000000..81fe98ae7 --- /dev/null +++ b/src/models/threadpool.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +namespace Generators { + +struct ThreadPool { + ThreadPool(size_t num_threads); + + void Compute(const std::function& func); + + private: + size_t num_threads_; + std::vector threads_; +}; + +} // namespace Generators From ad737df7649ec38f6c74f106150b5ba76ec07eb4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 2 Dec 2024 14:48:48 -0800 Subject: [PATCH 07/18] Support num tokens > 128 --- src/generators.cpp | 14 ++- src/generators.h | 4 +- src/models/debugging.cpp | 16 +++- src/models/decoder_only_pipeline.cpp | 41 ++++++--- src/models/decoder_only_pipeline.h | 7 +- src/models/input_ids.cpp | 67 ++++++++++++++ src/models/input_ids.h | 40 ++++++++- src/models/kv_cache.cpp | 13 +-- src/models/kv_cache.h | 9 +- src/models/position_inputs.cpp | 128 ++++++++++++++++++++++++++- src/models/position_inputs.h | 55 ++++++++++-- src/ort_genai.h | 4 +- src/ort_genai_c.cpp | 6 +- src/ort_genai_c.h | 2 +- 14 files changed, 356 insertions(+), 50 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 09da6d187..0f4ad776e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -275,15 +275,21 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } } -DeviceSpan Generator::AllocateInputIdsOnDevice(const cpu_span input_ids) { - auto input_ids_device = state_->params_->p_device->Allocate(input_ids.size()); +DeviceSpan Generator::AllocateInputIdsOnDevice(cpu_span input_ids) { + size_t input_ids_size = input_ids.size(); + if (model_->config_->model.decoder.sliding_window_key_value_cache.has_value()) { + const auto window_size = model_->config_->model.decoder.sliding_window_key_value_cache->window_size; + input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size; + } + auto input_ids_device = state_->params_->p_device->Allocate(input_ids_size); auto cpu_span = input_ids_device.CpuSpan(); - std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin()); + std::fill_n(cpu_span.begin(), input_ids_size, model_->config_->model.pad_token_id); + std::copy_backward(input_ids.begin(), input_ids.end(), cpu_span.end()); input_ids_device.CopyCpuToDevice(); return input_ids_device; } -void Generator::AppendTokens(const cpu_span input_ids) { +void Generator::AppendTokens(cpu_span input_ids) { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (input_ids.size() == 0) throw std::runtime_error("input_ids is empty"); diff --git a/src/generators.h b/src/generators.h index 2ccc1f0c0..95739bb4e 100644 --- a/src/generators.h +++ b/src/generators.h @@ -112,7 +112,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - void AppendTokens(const cpu_span input_ids); + void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length DeviceSpan GetLogits(); @@ -128,7 +128,7 @@ struct Generator : LeakChecked { bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio private: - DeviceSpan AllocateInputIdsOnDevice(const cpu_span input_ids); + DeviceSpan AllocateInputIdsOnDevice(cpu_span input_ids); void ComputeLogits(DeviceSpan next_tokens); enum Action { standard, // Default, set in any other case generated, // Set after GenerateNextToken diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index 597ded8de..00420d613 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -43,11 +43,19 @@ void DumpSpan(std::ostream& stream, std::span values) { for (auto v : values) stream << v << ' '; } else { - for (size_t i = 0; i < c_value_count / 2; i++) - stream << values[i] << ' '; + for (size_t i = 0; i < c_value_count / 2; i++) { + if constexpr (std::is_same::value || std::is_same::value) + stream << static_cast(values[i]) << ' '; + else + stream << values[i] << ' '; + } stream << "... "; - for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) - stream << values[i] << ' '; + for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) { + if constexpr (std::is_same::value || std::is_same::value) + stream << static_cast(values[i]) << ' '; + else + stream << values[i] << ' '; + } } } diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index b5fa8b2fe..21ca2bb24 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -91,9 +91,10 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode const GeneratorParams& params) : State{params, model}, model_{model}, - position_inputs_{model, *this, sequence_lengths} { - input_ids_.Add(); - position_inputs_.Add(); + input_ids_{CreateInputIDs(*this)}, + position_inputs_{CreatePositionInputs(*this, sequence_lengths)} { + input_ids_->Add(); + position_inputs_->Add(); logits_.Add(); if (KV_Cache::IsCacheNeeded(model)) { if (model.config_->model.decoder.sliding_window_key_value_cache.has_value()) { @@ -111,10 +112,8 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode } } -DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, - DeviceSpan next_indices) { - UpdateInputsOutputs(next_tokens, next_indices, total_length); - +void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices) { for (auto& pipeline_state : pipeline_states_) { if (first_run_ && !model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_prompt) { continue; @@ -223,6 +222,28 @@ DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices) { + UpdateInputsOutputs(next_tokens, next_indices, total_length); + + size_t num_chunks{1}; + if (first_run_ && sliding_window_key_value_cache_) { + int window_size = model_.config_->model.decoder.sliding_window_key_value_cache->window_size; + num_chunks = (next_tokens.size() + window_size - 1) / window_size; + } + + for (size_t i = 0; i < num_chunks; ++i) { + RunPipeline(total_length, next_tokens, next_indices); + + if (sliding_window_key_value_cache_ && i < num_chunks - 1) { + sliding_window_key_value_cache_->Slide(); + input_ids_->Update(next_tokens); + size_t new_length = input_ids_->GetShape()[1]; + position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); + } + } // Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier. if (!first_run_) { @@ -244,9 +265,9 @@ DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan beam_indices, int total_length) { - input_ids_.Update(next_tokens); - size_t new_length = input_ids_.GetShape()[1]; - position_inputs_.Update(next_tokens, total_length, static_cast(new_length)); + input_ids_->Update(next_tokens); + size_t new_length = input_ids_->GetShape()[1]; + position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); if (kv_cache_) kv_cache_->Update(beam_indices, total_length); if (sliding_window_key_value_cache_) sliding_window_key_value_cache_->Update(beam_indices, total_length); logits_.Update(next_tokens, new_length); diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index 3cbb81ab8..0a7125c80 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -58,6 +58,9 @@ struct DecoderOnlyPipelineState : State { OrtValue* GetOutput(const char* name) override; + void RunPipeline(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices); + private: void UpdateInputsOutputs(DeviceSpan& next_tokens, DeviceSpan next_indices, int total_length); @@ -68,11 +71,11 @@ struct DecoderOnlyPipelineState : State { // Stores all the outputs from the previous pipeline state(s) std::unordered_map> ortvalue_store_; - InputIDs input_ids_{*this}; + std::unique_ptr input_ids_; Logits logits_{*this}; std::unique_ptr kv_cache_; std::unique_ptr sliding_window_key_value_cache_; - PositionInputs position_inputs_; + std::unique_ptr position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 71f051bfc..6c9521025 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -191,4 +191,71 @@ void InputIDs::Update(DeviceSpan& new_tokens) { is_prompt_ = false; } +SlidingWindowInputIDs::SlidingWindowInputIDs(State& state) : state_{state} { + name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); + + if (!model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + throw std::runtime_error("Sliding a window over input_ids requires sliding_window_key_value_cache to be set in the config."); + } + + if (state_.params_->BatchBeamSize() != 1) { + throw std::runtime_error("Batch beam size must be 1 for sliding a window over input_ids."); + } + + window_size_ = model_.config_->model.decoder.sliding_window_key_value_cache->window_size; + shape_ = {1, model_.config_->model.decoder.sliding_window_key_value_cache->window_size}; + type_ = model_.session_info_->GetInputDataType(name_); + + if (type_ != Ort::TypeToTensorType) { + throw std::runtime_error("SlidingWindowInputIDs only supports int32_t input_ids."); + } +} + +void SlidingWindowInputIDs::Add() { + input_index_ = state_.inputs_.size(); + + state_.inputs_.push_back(value_.get()); + state_.input_names_.push_back(name_); +} + +void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { + if (window_index_ == 0) { + num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_; + + value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2 + // window_index = 0, value_ -> [0, a, b] + std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData()); + } else if (window_index_ < num_windows_) { + // next_tokens -> [a, b, c, d, e] + // window_size = 3, num_windows = 2 + // window_index = 1, value_ -> [c, d, e] + std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData()); + } else { + // All prompt token windows have been processed. Now we process the tokens generated by the model. + // next_tokens -> [f] + assert(new_tokens.size() == 1); + if (shape_[1] != 1) { + shape_[1] = 1; + value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } + + value_->GetTensorMutableData()[0] = new_tokens.Span()[0]; + } + + state_.inputs_[input_index_] = value_.get(); + window_index_++; +} + +std::unique_ptr CreateInputIDs(State& state) { + if (state.model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + return std::make_unique(state); + } else { + return std::make_unique(state); + } +} + } // namespace Generators diff --git a/src/models/input_ids.h b/src/models/input_ids.h index dd364212f..57f91438d 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -4,19 +4,26 @@ namespace Generators { -struct InputIDs { +struct InputIDsInterface { + virtual ~InputIDsInterface() = default; + virtual void Add() = 0; + virtual std::array GetShape() const = 0; + virtual void Update(DeviceSpan& next_tokens) = 0; +}; + +struct InputIDs : InputIDsInterface { InputIDs(State& state); InputIDs(const InputIDs&) = delete; InputIDs& operator=(const InputIDs&) = delete; // Register input_ids as ORT session input. // Called only once during initialization of state. - void Add(); + void Add() override; // Resize input_ids based on size of next_tokens. // Update value with next_tokens. - void Update(DeviceSpan& next_tokens); + void Update(DeviceSpan& next_tokens) override; - auto& GetShape() const { return shape_; } + std::array GetShape() const override { return shape_; } const char* name_; OrtValue* Get() { return value_.get(); } @@ -45,4 +52,29 @@ struct InputIDs { std::unique_ptr past_sequence_length_; }; +struct SlidingWindowInputIDs : public InputIDsInterface { + SlidingWindowInputIDs(State& state); + SlidingWindowInputIDs(const SlidingWindowInputIDs&) = delete; + SlidingWindowInputIDs& operator=(const SlidingWindowInputIDs&) = delete; + + void Add() override; + void Update(DeviceSpan& next_tokens) override; + std::array GetShape() const override { return shape_; } + + private: + State& state_; + const Model& model_{state_.model_}; + size_t input_index_{~0U}; + size_t window_size_{0}; + size_t num_windows_{1}; + size_t window_index_{0}; + const char* name_; + std::array shape_{}; + ONNXTensorElementDataType type_; + + std::unique_ptr value_; +}; + +std::unique_ptr CreateInputIDs(State& state); + } // namespace Generators diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index c00acc398..13fc6e0c2 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -557,7 +557,10 @@ void SlidingWindowKeyValueCache::Slide() { } void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { - if (window_size_ == 1) { + if (is_first_update_) { + is_first_update_ = false; + return; + } else if (window_size_ == 1) { Slide(); return; } @@ -628,17 +631,17 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu { cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]), (value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]); - cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]) + - (updated_window_size * value_cache_shape_out_[3]), + cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + (updated_window_size * value_cache_shape_in_[3]), (value_cache_shape_in_[2] - updated_window_size) * value_cache_shape_in_[3]); std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); } { cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]) + ((value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]), - window_size_ * value_cache_shape_out_[3]); + value_cache_shape_out_[2] * value_cache_shape_out_[3]); cpu_span value_cache_src(value_cache_out_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]), - window_size_ * value_cache_shape_out_[3]); + value_cache_shape_out_[2] * value_cache_shape_out_[3]); std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); } } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index bb48c99fd..1fa271cb1 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -95,14 +95,13 @@ struct SlidingWindowKeyValueCache { void Add(); void Update(DeviceSpan beam_indices, int current_length); - - private: void Slide(); + private: State& state_; const Model& model_{state_.model_}; - int layer_count_; - int window_size_; + int layer_count_{0}; + int window_size_{0}; size_t input_index_{~0U}, output_index_{~0U}; std::array key_cache_shape_in_, key_cache_shape_out_; @@ -112,5 +111,7 @@ struct SlidingWindowKeyValueCache { std::vector> key_caches_in_, value_caches_in_; std::vector> key_caches_out_, value_caches_out_; std::vector input_name_strings_, output_name_strings_; + + bool is_first_update_{true}; }; } // namespace Generators diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index dd3c77011..dae0a393a 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -68,7 +68,7 @@ void PositionInputs::Add() { } } -void PositionInputs::Update(const DeviceSpan& next_tokens, int total_length, int new_length) { +void PositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update if (is_first_update_) { @@ -295,7 +295,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } template -void PositionInputs::CreateAndInitializePositionIDs(const DeviceSpan& next_tokens, std::array shape) { +void PositionInputs::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -325,7 +325,7 @@ void PositionInputs::CreateAndInitializePositionIDs(const DeviceSpan& n } template -void PositionInputs::CreateAndInitializeAttentionMask(const DeviceSpan& next_tokens, std::array shape) { +void PositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -459,4 +459,126 @@ void PositionInputs::RewindMask(size_t index) { } #endif +SlidingWindowPositionInputs::SlidingWindowPositionInputs(State& state) + : state_{state} { + has_posid_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids); + has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); + + if (!model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + throw std::runtime_error("Sliding a window over input_ids requires sliding_window_key_value_cache to be set in the config."); + } + + window_size_ = model_.config_->model.decoder.sliding_window_key_value_cache->window_size; + + if (has_posid_input_) { + position_ids_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.position_ids); + if (position_ids_type_ != Ort::TypeToTensorType) + throw std::runtime_error("SlidingWindowPositionInputs only supports int64_t position_ids"); + + position_ids_shape_ = {1, model_.config_->model.decoder.sliding_window_key_value_cache->window_size}; + } + + if (has_mask_input_) { + auto attention_mask_type = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); + if (attention_mask_type != Ort::TypeToTensorType) + throw std::runtime_error("SlidingWindowPositionInputs only supports float attention_mask"); + + attention_mask_shape_ = {1, model_.config_->model.context_length}; + } +} + +void SlidingWindowPositionInputs::Add() { + if (has_posid_input_) { + position_ids_index_ = state_.inputs_.size(); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); + state_.inputs_.push_back(position_ids_.get()); + } + + if (has_mask_input_) { + attention_mask_index_ = state_.inputs_.size(); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); + state_.inputs_.push_back(attention_mask_.get()); + } +} + +void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { + if (window_index_ == 0) { + num_windows_ = (next_tokens.size() + window_size_ - 1) / window_size_; + if (has_posid_input_) { + position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, position_ids_type_); + + auto* position_ids_data = position_ids_->GetTensorMutableData(); + for (int i = 0, j = 0; i < position_ids_shape_[1]; i++) { + if (next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id) { + position_ids_data[i] = 0; + } else { + position_ids_data[i] = j++; + } + } + } + + if (has_mask_input_) { + attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, Ort::TypeToTensorType); + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill(attention_mask_data, attention_mask_data + attention_mask_shape_[1], 0.0f); + for (size_t i = 0; i < window_size_; i++) { + attention_mask_data[attention_mask_shape_[1] - window_size_ + i] = next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id ? 0.0f : 1.0f; + } + for (size_t i = 0; i < window_size_; i++) { + if (attention_mask_data[attention_mask_shape_[1] - window_size_ + i] == 1.0f) { + attention_mask_backward_offset_ = attention_mask_shape_[1] - window_size_ + i - 1; + break; + } + } + } + } else if (window_index_ < num_windows_) { + if (has_posid_input_) { + auto* position_ids_data = position_ids_->GetTensorMutableData(); + const auto last_position = position_ids_data[window_size_ - 1]; + std::iota(position_ids_data, position_ids_data + window_size_, last_position + 1); + } + + if (has_mask_input_) { + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill_n(attention_mask_data + attention_mask_backward_offset_ - window_size_ + 1, window_size_, 1.0f); + + attention_mask_backward_offset_ -= window_size_; + } + } else { + if (has_posid_input_) { + const auto last_position = position_ids_->GetTensorData()[position_ids_shape_[1] - 1]; + if (position_ids_shape_[1] != 1) { + position_ids_shape_[1] = 1; + position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, position_ids_type_); + } + position_ids_->GetTensorMutableData()[0] = last_position + 1; + } + + if (has_mask_input_) { + attention_mask_->GetTensorMutableData()[attention_mask_backward_offset_] = 1.0f; + if (attention_mask_backward_offset_ > 0) { + attention_mask_backward_offset_ -= 1; + } + } + } + + if (has_posid_input_) { + state_.inputs_[position_ids_index_] = position_ids_.get(); + } + + if (has_mask_input_) { + state_.inputs_[attention_mask_index_] = attention_mask_.get(); + } + + window_index_++; +} + +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths) { + if (state.model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + return std::make_unique(state); + } else { + return std::make_unique(state.model_, state, sequence_lengths); + } +} + } // namespace Generators diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 64a38779c..fa3f2b176 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -9,13 +9,20 @@ namespace Generators { -struct PositionInputs { +struct PositionInputsInterface { + virtual ~PositionInputsInterface() = default; + virtual void Add() = 0; + virtual void Update(DeviceSpan next_tokens, int total_length, int new_length) = 0; + virtual void RewindTo(size_t index) = 0; +}; + +struct PositionInputs : PositionInputsInterface { PositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk); - void Add(); - void Update(const DeviceSpan& next_tokens, int total_length, int new_length); + void Add() override; + void Update(DeviceSpan next_tokens, int total_length, int new_length) override; - void RewindTo(size_t index); + void RewindTo(size_t index) override; private: void AddAttentionMask(); @@ -30,9 +37,9 @@ struct PositionInputs { template void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); template - void CreateAndInitializePositionIDs(const DeviceSpan& next_tokens, std::array shape); + void CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape); template - void CreateAndInitializeAttentionMask(const DeviceSpan& next_tokens, std::array shape); + void CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape); template void UpdatePositionIDsImpl(int total_length, int new_kv_length); @@ -86,4 +93,40 @@ struct PositionInputs { #endif }; +struct SlidingWindowPositionInputs : PositionInputsInterface { + SlidingWindowPositionInputs(State& state); + SlidingWindowPositionInputs(const SlidingWindowPositionInputs&) = delete; + SlidingWindowPositionInputs& operator=(const SlidingWindowPositionInputs&) = delete; + + void Add() override; + void Update(DeviceSpan next_tokens, int total_length, int new_length) override; + void RewindTo(size_t index) override { + throw std::runtime_error("SlidingWindowPositionInputs does not support RewindTo."); + }; + + private: + State& state_; + const Model& model_{state_.model_}; + + bool has_mask_input_{}; + bool has_posid_input_{}; + + std::array position_ids_shape_{}; + ONNXTensorElementDataType position_ids_type_{}; + std::unique_ptr position_ids_; + std::array attention_mask_shape_{}; + ONNXTensorElementDataType attention_mask_type_{}; + std::unique_ptr attention_mask_; + size_t attention_mask_backward_offset_{~0U}; + + size_t attention_mask_index_{~0U}; + size_t position_ids_index_{~0U}; + + size_t window_size_{0}; + size_t num_windows_{1}; + size_t window_index_{0}; +}; + +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths); + } // namespace Generators diff --git a/src/ort_genai.h b/src/ort_genai.h index 3ae36f7b0..bad2159f4 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -158,7 +158,7 @@ struct OgaSequences : OgaAbstract { std::span Get(size_t index) const { return {SequenceData(index), SequenceCount(index)}; } - void Append(const std::span& sequence) { + void Append(std::span sequence) { OgaCheckResult(OgaAppendTokenSequence(sequence.data(), sequence.size(), this)); } void Append(const std::vector& sequence) { @@ -275,7 +275,7 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_AppendTokenSequences(this, &sequences)); } - void AppendTokens(int32_t* input_ids, size_t input_ids_count) { + void AppendTokens(const int32_t* input_ids, size_t input_ids_count) { OgaCheckResult(OgaGenerator_AppendTokens(this, input_ids, input_ids_count)); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index af2ee4b11..609c20f60 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -292,15 +292,15 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerator* oga_gene } auto input_ids = Generators::PadInputs(span_sequences, generator.model_->config_->model.pad_token_id); - generator.AppendTokens(input_ids); + generator.AppendTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); return nullptr; OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count) { +OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); - generator.AppendTokens(Generators::cpu_span(input_ids, input_ids_count)); + generator.AppendTokens(Generators::cpu_span(input_ids, input_ids_count)); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 1d38eb689..faea86f97 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -313,7 +313,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerato * \param[in] input_ids_count The number of input ids to add (batch_size * sequence_length). * \return OgaResult containing the error message if the setting of the input ids failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count); /* * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. From 5ef96580624c5c5d4698a3186027be81fe83d1c3 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 10 Dec 2024 13:11:06 -0800 Subject: [PATCH 08/18] Documentation and create kv cache interface --- src/config.cpp | 14 ++--- src/config.h | 4 +- src/generators.cpp | 6 ++- src/models/decoder_only_pipeline.cpp | 35 +++++++------ src/models/decoder_only_pipeline.h | 3 +- src/models/input_ids.cpp | 20 ++++---- src/models/kv_cache.cpp | 28 +++++++--- src/models/kv_cache.h | 45 +++++++++++----- src/models/position_inputs.cpp | 77 ++++++++++++++++++---------- 9 files changed, 148 insertions(+), 84 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 564da98cd..a66f96ce3 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -286,8 +286,8 @@ struct Pipeline_Element : JSON::Element { PipelineModelObject_Element object_{v_}; }; -struct SlidingWindowKeyValueCache_Element : JSON::Element { - explicit SlidingWindowKeyValueCache_Element(std::optional& v) : v_{v} {} +struct SlidingWindow_Element : JSON::Element { + explicit SlidingWindow_Element(std::optional& v) : v_{v} {} void OnNumber(std::string_view name, double value) override { if (name == "window_size") { @@ -299,7 +299,7 @@ struct SlidingWindowKeyValueCache_Element : JSON::Element { } private: - std::optional& v_; + std::optional& v_; }; struct Decoder_Element : JSON::Element { @@ -337,9 +337,9 @@ struct Decoder_Element : JSON::Element { if (name == "outputs") { return outputs_; } - if (name == "sliding_window_key_value_cache") { - v_.sliding_window_key_value_cache = Config::Model::Decoder::SlidingWindowKeyValueCache{}; - return sliding_window_key_value_cache_; + if (name == "sliding_window") { + v_.sliding_window = Config::Model::Decoder::SlidingWindow{}; + return sliding_window_; } throw JSON::unknown_value_error{}; } @@ -356,7 +356,7 @@ struct Decoder_Element : JSON::Element { Inputs_Element inputs_{v_.inputs}; Outputs_Element outputs_{v_.outputs}; Pipeline_Element pipeline_{v_.pipeline}; - SlidingWindowKeyValueCache_Element sliding_window_key_value_cache_{v_.sliding_window_key_value_cache}; + SlidingWindow_Element sliding_window_{v_.sliding_window}; }; struct VisionInputs_Element : JSON::Element { diff --git a/src/config.h b/src/config.h index 931e3f6bd..1d0a2907b 100644 --- a/src/config.h +++ b/src/config.h @@ -107,11 +107,11 @@ struct Config { int num_hidden_layers{}; int head_size{}; - struct SlidingWindowKeyValueCache { + struct SlidingWindow { int window_size{128}; int pad_value{}; }; - std::optional sliding_window_key_value_cache; + std::optional sliding_window; struct Inputs { std::string input_ids{Defaults::InputIdsName}; diff --git a/src/generators.cpp b/src/generators.cpp index 0f4ad776e..02b79fc73 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -277,8 +277,10 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ DeviceSpan Generator::AllocateInputIdsOnDevice(cpu_span input_ids) { size_t input_ids_size = input_ids.size(); - if (model_->config_->model.decoder.sliding_window_key_value_cache.has_value()) { - const auto window_size = model_->config_->model.decoder.sliding_window_key_value_cache->window_size; + if (model_->config_->model.decoder.sliding_window.has_value()) { + // If the model has a sliding window, pad the input_ids to the next multiple of the window size + // so that the input_ids can be divided into window size chunks. + const auto window_size = model_->config_->model.decoder.sliding_window->window_size; input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size; } auto input_ids_device = state_->params_->p_device->Allocate(input_ids_size); diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 21ca2bb24..410a3b875 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -92,19 +92,23 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode : State{params, model}, model_{model}, input_ids_{CreateInputIDs(*this)}, + key_value_cache_{CreateKeyValueCache(*this)}, position_inputs_{CreatePositionInputs(*this, sequence_lengths)} { input_ids_->Add(); position_inputs_->Add(); logits_.Add(); - if (KV_Cache::IsCacheNeeded(model)) { - if (model.config_->model.decoder.sliding_window_key_value_cache.has_value()) { - sliding_window_key_value_cache_ = std::make_unique(*this); - sliding_window_key_value_cache_->Add(); - } else { - kv_cache_ = std::make_unique(*this); - kv_cache_->Add(); - } + if (key_value_cache_) { + key_value_cache_->Add(); } + // if (KV_Cache::IsCacheNeeded(model)) { + // if (model.config_->model.decoder.sliding_window.has_value()) { + // sliding_window_key_value_cache_ = std::make_unique(*this); + // sliding_window_key_value_cache_->Add(); + // } else { + // kv_cache_ = std::make_unique(*this); + // kv_cache_->Add(); + // } + // } extra_inputs_.Add(); for ([[maybe_unused]] const auto& pipeline_model : model_.config_->model.decoder.pipeline) { @@ -229,19 +233,19 @@ DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpanmodel.decoder.sliding_window_key_value_cache->window_size; + if (first_run_ && model_.config_->model.decoder.sliding_window.has_value()) { + int window_size = model_.config_->model.decoder.sliding_window->window_size; num_chunks = (next_tokens.size() + window_size - 1) / window_size; } for (size_t i = 0; i < num_chunks; ++i) { RunPipeline(total_length, next_tokens, next_indices); - if (sliding_window_key_value_cache_ && i < num_chunks - 1) { - sliding_window_key_value_cache_->Slide(); + if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) { + // Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask input_ids_->Update(next_tokens); - size_t new_length = input_ids_->GetShape()[1]; - position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); + key_value_cache_->Update(next_indices, total_length); + position_inputs_->Update(next_tokens, total_length, static_cast(input_ids_->GetShape()[1])); } } @@ -268,8 +272,7 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan& next_tok input_ids_->Update(next_tokens); size_t new_length = input_ids_->GetShape()[1]; position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); - if (kv_cache_) kv_cache_->Update(beam_indices, total_length); - if (sliding_window_key_value_cache_) sliding_window_key_value_cache_->Update(beam_indices, total_length); + if (key_value_cache_) key_value_cache_->Update(beam_indices, total_length); logits_.Update(next_tokens, new_length); } diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index 0a7125c80..726828587 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -73,8 +73,7 @@ struct DecoderOnlyPipelineState : State { std::unique_ptr input_ids_; Logits logits_{*this}; - std::unique_ptr kv_cache_; - std::unique_ptr sliding_window_key_value_cache_; + std::unique_ptr key_value_cache_; std::unique_ptr position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 6c9521025..2f2975d7d 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -194,16 +194,16 @@ void InputIDs::Update(DeviceSpan& new_tokens) { SlidingWindowInputIDs::SlidingWindowInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - if (!model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { - throw std::runtime_error("Sliding a window over input_ids requires sliding_window_key_value_cache to be set in the config."); + if (!model_.config_->model.decoder.sliding_window.has_value()) { + throw std::runtime_error("Sliding a window over input_ids requires sliding_window to be set in the genai_config.json."); } if (state_.params_->BatchBeamSize() != 1) { throw std::runtime_error("Batch beam size must be 1 for sliding a window over input_ids."); } - window_size_ = model_.config_->model.decoder.sliding_window_key_value_cache->window_size; - shape_ = {1, model_.config_->model.decoder.sliding_window_key_value_cache->window_size}; + window_size_ = model_.config_->model.decoder.sliding_window->window_size; + shape_ = {1, model_.config_->model.decoder.sliding_window->window_size}; type_ = model_.session_info_->GetInputDataType(name_); if (type_ != Ort::TypeToTensorType) { @@ -222,11 +222,11 @@ void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { if (window_index_ == 0) { num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_; - value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); // next_tokens will always be padded so that it's size is a multiple of window_size_ // next_tokens -> [0, a, b, c, d, e] - // window_size = 3, num_windows = 2 + // window_size = 3, num_windows = 2, pad_token = 0 // window_index = 0, value_ -> [0, a, b] std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData()); } else if (window_index_ < num_windows_) { @@ -235,15 +235,15 @@ void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { // window_index = 1, value_ -> [c, d, e] std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData()); } else { - // All prompt token windows have been processed. Now we process the tokens generated by the model. + // All prompt token chunks have been processed. Now we process the tokens generated by the model. // next_tokens -> [f] assert(new_tokens.size() == 1); if (shape_[1] != 1) { shape_[1] = 1; - value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); } - value_->GetTensorMutableData()[0] = new_tokens.Span()[0]; + value_->GetTensorMutableData()[0] = new_tokens.Span().front(); } state_.inputs_[input_index_] = value_.get(); @@ -251,7 +251,7 @@ void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { } std::unique_ptr CreateInputIDs(State& state) { - if (state.model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + if (state.model_.config_->model.decoder.sliding_window.has_value()) { return std::make_unique(state); } else { return std::make_unique(state); diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 13fc6e0c2..009b7b1b6 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -188,7 +188,7 @@ void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices, int inde } } -bool KV_Cache::IsCacheNeeded(const Model& model) { +bool KeyValueCacheInterface::IsCacheNeeded(const Model& model) { return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); } @@ -446,7 +446,7 @@ void Cross_Cache::AddInputs() { SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, - window_size_{model_.config_->model.decoder.sliding_window_key_value_cache->window_size}, + window_size_{model_.config_->model.decoder.sliding_window->window_size}, key_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, model_.config_->model.decoder.head_size, model_.config_->model.context_length - window_size_}, key_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, @@ -474,13 +474,13 @@ SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_)); std::fill_n(key_caches_in_[i]->GetTensorMutableData(), ElementCountFromShape(key_cache_shape_in_), - static_cast(model_.config_->model.decoder.sliding_window_key_value_cache->pad_value)); + static_cast(model_.config_->model.decoder.sliding_window->pad_value)); value_caches_in_.push_back( OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_)); std::fill_n(value_caches_in_[i]->GetTensorMutableData(), ElementCountFromShape(value_cache_shape_in_), - static_cast(model_.config_->model.decoder.sliding_window_key_value_cache->pad_value)); + static_cast(model_.config_->model.decoder.sliding_window->pad_value)); key_caches_out_.push_back( OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_)); @@ -558,14 +558,18 @@ void SlidingWindowKeyValueCache::Slide() { void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { if (is_first_update_) { + num_windows_ = (current_length + window_size_ - 1) / window_size_; is_first_update_ = false; + window_index_++; return; - } else if (window_size_ == 1) { + } else if (window_size_ == 1 || window_index_ < num_windows_) { Slide(); + window_index_++; return; } - // No sliding needed. But we need to concatenate the last window_size_ elements to the end of the cache + // Transition from prompt processing to token generation. + // Concatenate the last window_size_ elements to the end of the cache // key_caches_in_ = Concat(key_caches_in_[:, :, :, 1:], key_caches_out_) // [num_key_value_heads, 1, head_size, context_length-1] = [num_key_value_heads, 1, head_size, context_length - window_size_ - 1] + @@ -664,4 +668,16 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu } } +std::unique_ptr CreateKeyValueCache(State& state) { + if (!KeyValueCacheInterface::IsCacheNeeded(state.model_)) { + return nullptr; + } + + if (state.model_.config_->model.decoder.sliding_window) { + return std::make_unique(state); + } else { + return std::make_unique(state); + } +} + } // namespace Generators diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 1fa271cb1..7b6693f63 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -4,6 +4,16 @@ namespace Generators { +struct KeyValueCacheInterface { + virtual ~KeyValueCacheInterface() = default; + virtual void Add() = 0; + virtual void AddEncoder() = 0; + virtual void Update(DeviceSpan beam_indices, int total_length) = 0; + virtual void RewindTo(size_t index) = 0; + + static bool IsCacheNeeded(const Model& model); +}; + struct KV_Cache_Combined { KV_Cache_Combined(State& state); @@ -34,23 +44,22 @@ struct KV_Cache_Combined { std::vector input_name_strings_, output_name_strings_; }; -struct KV_Cache { +struct KV_Cache : KeyValueCacheInterface { KV_Cache(State& state); - static bool IsCacheNeeded(const Model& model); - - void AddEncoder(); // If model has an initial encoder step, this is used + void AddEncoder() override; // If model has an initial encoder step, this is used // Register input_ids as ORT session input. // Called only once during initialization of state. - void Add(); + void Add() override; // Move present to past. Prepare present output for next generation iteration. - void Update(DeviceSpan beam_indices, int total_length); - void RewindTo(size_t index); + void Update(DeviceSpan beam_indices, int total_length) override; + void RewindTo(size_t index) override; + + private: template void PickPastState(DeviceSpan beam_indices, int index); void PickPastState(DeviceSpan beam_indices, int index); - private: template void RewindPastTensorsTo(size_t index); @@ -90,18 +99,27 @@ struct Cross_Cache { std::vector input_name_strings_, output_name_strings_; }; -struct SlidingWindowKeyValueCache { +struct SlidingWindowKeyValueCache : KeyValueCacheInterface { SlidingWindowKeyValueCache(State& state); - void Add(); - void Update(DeviceSpan beam_indices, int current_length); - void Slide(); + void Add() override; + void AddEncoder() override { + throw std::runtime_error("SlidingWindowKeyValueCache does not support AddEncoder."); + }; + void Update(DeviceSpan beam_indices, int current_length) override; + void RewindTo(size_t index) override { + throw std::runtime_error("SlidingWindowKeyValueCache does not support RewindTo."); + } private: + void Slide(); + State& state_; const Model& model_{state_.model_}; int layer_count_{0}; int window_size_{0}; + size_t num_windows_{1}; + size_t window_index_{0}; size_t input_index_{~0U}, output_index_{~0U}; std::array key_cache_shape_in_, key_cache_shape_out_; @@ -114,4 +132,7 @@ struct SlidingWindowKeyValueCache { bool is_first_update_{true}; }; + +std::unique_ptr CreateKeyValueCache(State& state); + } // namespace Generators diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index dae0a393a..c1dcb3379 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -464,24 +464,25 @@ SlidingWindowPositionInputs::SlidingWindowPositionInputs(State& state) has_posid_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids); has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); - if (!model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { - throw std::runtime_error("Sliding a window over input_ids requires sliding_window_key_value_cache to be set in the config."); + if (has_posid_input_ || has_mask_input_) { + if (!model_.config_->model.decoder.sliding_window.has_value()) { + throw std::runtime_error("Sliding a window over position_ids and attention_mask requires sliding_window to be set in the genai_config.json."); + } + window_size_ = model_.config_->model.decoder.sliding_window->window_size; } - window_size_ = model_.config_->model.decoder.sliding_window_key_value_cache->window_size; - if (has_posid_input_) { position_ids_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.position_ids); - if (position_ids_type_ != Ort::TypeToTensorType) - throw std::runtime_error("SlidingWindowPositionInputs only supports int64_t position_ids"); + if (position_ids_type_ != Ort::TypeToTensorType) + throw std::runtime_error("SlidingWindowPositionInputs only supports int32_t position_ids"); - position_ids_shape_ = {1, model_.config_->model.decoder.sliding_window_key_value_cache->window_size}; + position_ids_shape_ = {1, model_.config_->model.decoder.sliding_window->window_size}; } if (has_mask_input_) { - auto attention_mask_type = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); - if (attention_mask_type != Ort::TypeToTensorType) - throw std::runtime_error("SlidingWindowPositionInputs only supports float attention_mask"); + attention_mask_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); + if (attention_mask_type_ != Ort::TypeToTensorType) + throw std::runtime_error("SlidingWindowPositionInputs only supports int32_t attention_mask"); attention_mask_shape_ = {1, model_.config_->model.context_length}; } @@ -505,11 +506,15 @@ void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int to if (window_index_ == 0) { num_windows_ = (next_tokens.size() + window_size_ - 1) / window_size_; if (has_posid_input_) { - position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, position_ids_type_); + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, position_ids_type_); - auto* position_ids_data = position_ids_->GetTensorMutableData(); + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 0, position_ids_ -> [0, 0, 1] + auto* position_ids_data = position_ids_->GetTensorMutableData(); for (int i = 0, j = 0; i < position_ids_shape_[1]; i++) { - if (next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id) { + if (next_tokens.Span()[i] == model_.config_->model.pad_token_id) { position_ids_data[i] = 0; } else { position_ids_data[i] = j++; @@ -518,14 +523,19 @@ void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int to } if (has_mask_input_) { - attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, Ort::TypeToTensorType); - auto* attention_mask_data = attention_mask_->GetTensorMutableData(); - std::fill(attention_mask_data, attention_mask_data + attention_mask_shape_[1], 0.0f); + attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, attention_mask_shape_, attention_mask_type_); + + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 0, attention_mask_ -> ([0] * context_length - window_size_) + [0, 1, 1] + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill(attention_mask_data, attention_mask_data + attention_mask_shape_[1], 0); for (size_t i = 0; i < window_size_; i++) { - attention_mask_data[attention_mask_shape_[1] - window_size_ + i] = next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id ? 0.0f : 1.0f; + attention_mask_data[attention_mask_shape_[1] - window_size_ + i] = next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id ? 0 : 1; } for (size_t i = 0; i < window_size_; i++) { - if (attention_mask_data[attention_mask_shape_[1] - window_size_ + i] == 1.0f) { + if (attention_mask_data[attention_mask_shape_[1] - window_size_ + i] == 1) { attention_mask_backward_offset_ = attention_mask_shape_[1] - window_size_ + i - 1; break; } @@ -533,29 +543,42 @@ void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int to } } else if (window_index_ < num_windows_) { if (has_posid_input_) { - auto* position_ids_data = position_ids_->GetTensorMutableData(); + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 1, position_ids_ -> [2, 3, 4] + + auto* position_ids_data = position_ids_->GetTensorMutableData(); const auto last_position = position_ids_data[window_size_ - 1]; std::iota(position_ids_data, position_ids_data + window_size_, last_position + 1); } if (has_mask_input_) { - auto* attention_mask_data = attention_mask_->GetTensorMutableData(); - std::fill_n(attention_mask_data + attention_mask_backward_offset_ - window_size_ + 1, window_size_, 1.0f); - + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 1, attention_mask_ -> ([0] * context_length - (2 * window_size_)) + [0, 1, 1, 1, 1, 1] + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill_n(attention_mask_data + attention_mask_backward_offset_ - window_size_ + 1, window_size_, 1); attention_mask_backward_offset_ -= window_size_; } } else { + // All prompt token chunks have been processed. Now we process the tokens generated by the model. if (has_posid_input_) { - const auto last_position = position_ids_->GetTensorData()[position_ids_shape_[1] - 1]; + // next_tokens -> [f] + // position_ids_ -> [5] + const auto last_position = position_ids_->GetTensorData()[position_ids_shape_[1] - 1]; if (position_ids_shape_[1] != 1) { position_ids_shape_[1] = 1; - position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, position_ids_type_); + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, position_ids_type_); } - position_ids_->GetTensorMutableData()[0] = last_position + 1; + position_ids_->GetTensorMutableData()[0] = last_position + 1; } if (has_mask_input_) { - attention_mask_->GetTensorMutableData()[attention_mask_backward_offset_] = 1.0f; + // next_tokens -> [f] + // attention_mask_ -> ([0] * context_length - (2 * window_size_) - 1) + [0, 1, 1, 1, 1, 1, 1] + attention_mask_->GetTensorMutableData()[attention_mask_backward_offset_] = 1; if (attention_mask_backward_offset_ > 0) { attention_mask_backward_offset_ -= 1; } @@ -574,7 +597,7 @@ void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int to } std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths) { - if (state.model_.config_->model.decoder.sliding_window_key_value_cache.has_value()) { + if (state.model_.config_->model.decoder.sliding_window.has_value()) { return std::make_unique(state); } else { return std::make_unique(state.model_, state, sequence_lengths); From a68827d6e914bdfb0d41df6b066a8d5340f30ed2 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 10 Dec 2024 22:03:08 +0000 Subject: [PATCH 09/18] Always assign allocator_kv_cache_ --- src/models/model.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/model.cpp b/src/models/model.cpp index a465e4471..85fffbec2 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -290,6 +290,7 @@ Model::~Model() = default; void Model::InitDeviceAllocator(OrtSession& session) { allocator_device_ = &allocator_cpu_; + allocator_kvcache_ = &allocator_cpu_; #if USE_CUDA if (device_type_ == DeviceType::CUDA) { allocator_device_ = GetCudaAllocator(session); From 2d2c3fbdc3c0431a413fbf01b41c99b73ec541bf Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 10 Dec 2024 14:51:04 -0800 Subject: [PATCH 10/18] Avoid using front() --- src/models/decoder_only_pipeline.cpp | 9 --------- src/models/input_ids.cpp | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 410a3b875..e7f7218e6 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -100,15 +100,6 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode if (key_value_cache_) { key_value_cache_->Add(); } - // if (KV_Cache::IsCacheNeeded(model)) { - // if (model.config_->model.decoder.sliding_window.has_value()) { - // sliding_window_key_value_cache_ = std::make_unique(*this); - // sliding_window_key_value_cache_->Add(); - // } else { - // kv_cache_ = std::make_unique(*this); - // kv_cache_->Add(); - // } - // } extra_inputs_.Add(); for ([[maybe_unused]] const auto& pipeline_model : model_.config_->model.decoder.pipeline) { diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 2f2975d7d..afd307d73 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -243,7 +243,7 @@ void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); } - value_->GetTensorMutableData()[0] = new_tokens.Span().front(); + value_->GetTensorMutableData()[0] = new_tokens.Span()[0]; } state_.inputs_[input_index_] = value_.get(); From 3acbfc0bf22956b57d9a6c89d01f3aacaaffa1d0 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 10 Dec 2024 15:07:47 -0800 Subject: [PATCH 11/18] link against pthreads --- cmake/cxx_standard.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmake/cxx_standard.cmake b/cmake/cxx_standard.cmake index 9ae00e5cb..713cb9600 100644 --- a/cmake/cxx_standard.cmake +++ b/cmake/cxx_standard.cmake @@ -12,3 +12,7 @@ else () message("Test is using C++20") set(CMAKE_CXX_STANDARD 20) endif () + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") +endif () From c5ee9c06d3da176a3cf354f2bfe5a0a1efa4996c Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 11 Dec 2024 13:48:22 -0800 Subject: [PATCH 12/18] Address pull-request review comments --- src/models/debugging.cpp | 21 ++++----- src/models/decoder_only.h | 6 +-- src/models/decoder_only_pipeline.h | 6 +-- src/models/gpt.h | 6 +-- src/models/input_ids.cpp | 20 ++++---- src/models/input_ids.h | 28 ++++++------ src/models/kv_cache.cpp | 52 ++++++++++----------- src/models/kv_cache.h | 32 ++++++------- src/models/multi_modal_vision_model.h | 8 ++-- src/models/position_inputs.cpp | 66 +++++++++++++-------------- src/models/position_inputs.h | 26 +++++------ src/models/threadpool.cpp | 2 + src/models/whisper.h | 4 +- 13 files changed, 137 insertions(+), 140 deletions(-) diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index 00420d613..11056bfde 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -39,23 +39,18 @@ std::ostream& operator<<(std::ostream& stream, Ort::BFloat16_t v) { template void DumpSpan(std::ostream& stream, std::span values) { + // If type is uint8_t or int8_t cast to int so it displays as an int vs a char + using DisplayType = std::conditional_t || std::is_same_v, int, T>; + if (values.size() <= c_value_count) { for (auto v : values) - stream << v << ' '; + stream << static_cast(v) << ' '; } else { - for (size_t i = 0; i < c_value_count / 2; i++) { - if constexpr (std::is_same::value || std::is_same::value) - stream << static_cast(values[i]) << ' '; - else - stream << values[i] << ' '; - } + for (size_t i = 0; i < c_value_count / 2; i++) + stream << static_cast(values[i]) << ' '; stream << "... "; - for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) { - if constexpr (std::is_same::value || std::is_same::value) - stream << static_cast(values[i]) << ' '; - else - stream << values[i] << ' '; - } + for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) + stream << static_cast(values[i]) << ' '; } } diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index a25e37df8..355f89a6d 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -29,10 +29,10 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - InputIDs input_ids_{*this}; + InputIDsDefault input_ids_{*this}; Logits logits_{*this}; - KV_Cache kv_cache_{*this}; - PositionInputs position_inputs_; + KeyValueCacheDefault kv_cache_{*this}; + PositionInputsDefault position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index 726828587..f9e9c4d29 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -71,10 +71,10 @@ struct DecoderOnlyPipelineState : State { // Stores all the outputs from the previous pipeline state(s) std::unordered_map> ortvalue_store_; - std::unique_ptr input_ids_; + std::unique_ptr input_ids_; Logits logits_{*this}; - std::unique_ptr key_value_cache_; - std::unique_ptr position_inputs_; + std::unique_ptr key_value_cache_; + std::unique_ptr position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/gpt.h b/src/models/gpt.h index bbd250b24..4ee52102b 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -27,10 +27,10 @@ struct Gpt_State : State { const Gpt_Model& model_; - InputIDs input_ids_{*this}; + InputIDsDefault input_ids_{*this}; Logits logits_{*this}; - KV_Cache_Combined kv_cache_{*this}; - PositionInputs position_inputs_; + KeyValueCacheDefault_Combined kv_cache_{*this}; + PositionInputsDefault position_inputs_; ExtraInputs extra_inputs_{*this}; }; } // namespace Generators diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index afd307d73..546efd97a 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -5,7 +5,7 @@ namespace Generators { -InputIDs::InputIDs(State& state) +InputIDsDefault::InputIDsDefault(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); shape_ = {state_.params_->BatchBeamSize(), 0}; @@ -41,7 +41,7 @@ InputIDs::InputIDs(State& state) } } -void InputIDs::Add() { +void InputIDsDefault::Add() { input_index_ = state_.inputs_.size(); state_.inputs_.push_back(value_.get()); @@ -55,7 +55,7 @@ void InputIDs::Add() { } } -void InputIDs::Update(DeviceSpan& new_tokens) { +void InputIDsDefault::Update(DeviceSpan& new_tokens) { const auto get_unpadded_sequence_length = [](std::span input_ids, int32_t pad_token_id) { int32_t seq_length = 0; @@ -191,7 +191,7 @@ void InputIDs::Update(DeviceSpan& new_tokens) { is_prompt_ = false; } -SlidingWindowInputIDs::SlidingWindowInputIDs(State& state) : state_{state} { +WindowedInputIDs::WindowedInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); if (!model_.config_->model.decoder.sliding_window.has_value()) { @@ -207,18 +207,18 @@ SlidingWindowInputIDs::SlidingWindowInputIDs(State& state) : state_{state} { type_ = model_.session_info_->GetInputDataType(name_); if (type_ != Ort::TypeToTensorType) { - throw std::runtime_error("SlidingWindowInputIDs only supports int32_t input_ids."); + throw std::runtime_error("WindowedInputIDs only supports int32_t input_ids."); } } -void SlidingWindowInputIDs::Add() { +void WindowedInputIDs::Add() { input_index_ = state_.inputs_.size(); state_.inputs_.push_back(value_.get()); state_.input_names_.push_back(name_); } -void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { +void WindowedInputIDs::Update(DeviceSpan& new_tokens) { if (window_index_ == 0) { num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_; @@ -250,11 +250,11 @@ void SlidingWindowInputIDs::Update(DeviceSpan& new_tokens) { window_index_++; } -std::unique_ptr CreateInputIDs(State& state) { +std::unique_ptr CreateInputIDs(State& state) { if (state.model_.config_->model.decoder.sliding_window.has_value()) { - return std::make_unique(state); + return std::make_unique(state); } else { - return std::make_unique(state); + return std::make_unique(state); } } diff --git a/src/models/input_ids.h b/src/models/input_ids.h index 57f91438d..ee16e34d8 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -4,17 +4,17 @@ namespace Generators { -struct InputIDsInterface { - virtual ~InputIDsInterface() = default; +struct InputIDs { + virtual ~InputIDs() = default; virtual void Add() = 0; virtual std::array GetShape() const = 0; virtual void Update(DeviceSpan& next_tokens) = 0; }; -struct InputIDs : InputIDsInterface { - InputIDs(State& state); - InputIDs(const InputIDs&) = delete; - InputIDs& operator=(const InputIDs&) = delete; +struct InputIDsDefault : InputIDs { + InputIDsDefault(State& state); + InputIDsDefault(const InputIDsDefault&) = delete; + InputIDsDefault& operator=(const InputIDsDefault&) = delete; // Register input_ids as ORT session input. // Called only once during initialization of state. @@ -52,10 +52,10 @@ struct InputIDs : InputIDsInterface { std::unique_ptr past_sequence_length_; }; -struct SlidingWindowInputIDs : public InputIDsInterface { - SlidingWindowInputIDs(State& state); - SlidingWindowInputIDs(const SlidingWindowInputIDs&) = delete; - SlidingWindowInputIDs& operator=(const SlidingWindowInputIDs&) = delete; +struct WindowedInputIDs : public InputIDs { + WindowedInputIDs(State& state); + WindowedInputIDs(const WindowedInputIDs&) = delete; + WindowedInputIDs& operator=(const WindowedInputIDs&) = delete; void Add() override; void Update(DeviceSpan& next_tokens) override; @@ -65,9 +65,9 @@ struct SlidingWindowInputIDs : public InputIDsInterface { State& state_; const Model& model_{state_.model_}; size_t input_index_{~0U}; - size_t window_size_{0}; - size_t num_windows_{1}; - size_t window_index_{0}; + size_t window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; const char* name_; std::array shape_{}; ONNXTensorElementDataType type_; @@ -75,6 +75,6 @@ struct SlidingWindowInputIDs : public InputIDsInterface { std::unique_ptr value_; }; -std::unique_ptr CreateInputIDs(State& state); +std::unique_ptr CreateInputIDs(State& state); } // namespace Generators diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 009b7b1b6..734d07f8e 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -27,7 +27,7 @@ int64_t ElementCountFromShape(const std::array& shape) { } // namespace -KV_Cache_Combined::KV_Cache_Combined(State& state) +KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, shape_{2, state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} { @@ -50,7 +50,7 @@ KV_Cache_Combined::KV_Cache_Combined(State& state) } } -void KV_Cache_Combined::Add() { +void KeyValueCacheDefault_Combined::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -62,7 +62,7 @@ void KV_Cache_Combined::Add() { } } -void KV_Cache_Combined::Update(DeviceSpan beam_indices, int total_length) { +void KeyValueCacheDefault_Combined::Update(DeviceSpan beam_indices, int total_length) { assert(state_.params_->search.num_beams == 1 || !beam_indices.empty()); // We require beam_indices if we're a beam search if (!is_first_update_) { @@ -85,7 +85,7 @@ void KV_Cache_Combined::Update(DeviceSpan beam_indices, int total_lengt is_first_update_ = false; } -void KV_Cache_Combined::RewindTo(size_t index) { +void KeyValueCacheDefault_Combined::RewindTo(size_t index) { if (shape_[3] <= static_cast(index)) { throw std::runtime_error("Requested length of rewind is greater than the current length."); } @@ -104,7 +104,7 @@ void KV_Cache_Combined::RewindTo(size_t index) { } template -void KV_Cache_Combined::RewindPastTensorsTo(size_t index) { +void KeyValueCacheDefault_Combined::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[3] >= static_cast(index)); std::array new_shape = shape_; new_shape[3] = static_cast(index); @@ -139,7 +139,7 @@ void KV_Cache_Combined::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices_device, int index) { +void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.CopyDeviceToCpu(); auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4]; auto past_key_size = shape_[1] * block_size_per_beam; @@ -180,7 +180,7 @@ void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices_device, i pasts_[index] = std::move(past); } -void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices, int index) { +void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -188,11 +188,11 @@ void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices, int inde } } -bool KeyValueCacheInterface::IsCacheNeeded(const Model& model) { +bool KeyValueCache::IsCacheNeeded(const Model& model) { return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); } -KV_Cache::KV_Cache(State& state) +KeyValueCacheDefault::KeyValueCacheDefault(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")}, @@ -257,7 +257,7 @@ KV_Cache::KV_Cache(State& state) } } -void KV_Cache::AddEncoder() { +void KeyValueCacheDefault::AddEncoder() { // We don't set the input_index_ & output_index_ because the encoder step only runs once, there's no update for (int i = 0; i < layer_count_ * 2; ++i) { @@ -266,7 +266,7 @@ void KV_Cache::AddEncoder() { } } -void KV_Cache::Add() { +void KeyValueCacheDefault::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -285,7 +285,7 @@ void KV_Cache::Add() { } } -void KV_Cache::Update(DeviceSpan beam_indices, int total_length) { +void KeyValueCacheDefault::Update(DeviceSpan beam_indices, int total_length) { // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; @@ -310,7 +310,7 @@ void KV_Cache::Update(DeviceSpan beam_indices, int total_length) { is_first_update_ = false; } -void KV_Cache::RewindTo(size_t index) { +void KeyValueCacheDefault::RewindTo(size_t index) { if (past_present_share_buffer_) { return; } else if (shape_[2] <= static_cast(index)) { @@ -331,7 +331,7 @@ void KV_Cache::RewindTo(size_t index) { } template -void KV_Cache::RewindPastTensorsTo(size_t index) { +void KeyValueCacheDefault::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[2] >= static_cast(index) && !past_present_share_buffer_); std::array new_shape = shape_; new_shape[2] = static_cast(index); @@ -366,7 +366,7 @@ void KV_Cache::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KV_Cache::PickPastState(DeviceSpan beam_indices_device, int index) { +void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.Span(); auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3]; auto element_count = shape_[0] * block_size_per_beam; @@ -398,7 +398,7 @@ void KV_Cache::PickPastState(DeviceSpan beam_indices_device, int index) pasts_[index] = std::move(past_value); } -void KV_Cache::PickPastState(DeviceSpan beam_indices, int index) { +void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -443,7 +443,7 @@ void Cross_Cache::AddInputs() { } } -SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) +WindowedKeyValueCache::WindowedKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, window_size_{model_.config_->model.decoder.sliding_window->window_size}, @@ -464,8 +464,8 @@ SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) } type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); - if (type_ != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { - throw std::runtime_error("Expected input data type to be uint8_t for SlidingWindowKeyValueCache. Actual: " + + if (type_ != Ort::TypeToTensorType) { + throw std::runtime_error("Expected input data type to be uint8_t for WindowedKeyValueCache. Actual: " + std::to_string(type_)); } @@ -489,7 +489,7 @@ SlidingWindowKeyValueCache::SlidingWindowKeyValueCache(State& state) } } -void SlidingWindowKeyValueCache::Add() { +void WindowedKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -508,7 +508,7 @@ void SlidingWindowKeyValueCache::Add() { } } -void SlidingWindowKeyValueCache::Slide() { +void WindowedKeyValueCache::Slide() { ThreadPool thread_pool{static_cast(layer_count_)}; thread_pool.Compute([&](size_t layer_idx) { uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); @@ -556,7 +556,7 @@ void SlidingWindowKeyValueCache::Slide() { }); } -void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { +void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { if (is_first_update_) { num_windows_ = (current_length + window_size_ - 1) / window_size_; is_first_update_ = false; @@ -668,15 +668,15 @@ void SlidingWindowKeyValueCache::Update(DeviceSpan beam_indices, int cu } } -std::unique_ptr CreateKeyValueCache(State& state) { - if (!KeyValueCacheInterface::IsCacheNeeded(state.model_)) { +std::unique_ptr CreateKeyValueCache(State& state) { + if (!KeyValueCache::IsCacheNeeded(state.model_)) { return nullptr; } if (state.model_.config_->model.decoder.sliding_window) { - return std::make_unique(state); + return std::make_unique(state); } else { - return std::make_unique(state); + return std::make_unique(state); } } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 7b6693f63..5699e2e49 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -4,8 +4,8 @@ namespace Generators { -struct KeyValueCacheInterface { - virtual ~KeyValueCacheInterface() = default; +struct KeyValueCache { + virtual ~KeyValueCache() = default; virtual void Add() = 0; virtual void AddEncoder() = 0; virtual void Update(DeviceSpan beam_indices, int total_length) = 0; @@ -14,8 +14,8 @@ struct KeyValueCacheInterface { static bool IsCacheNeeded(const Model& model); }; -struct KV_Cache_Combined { - KV_Cache_Combined(State& state); +struct KeyValueCacheDefault_Combined { + KeyValueCacheDefault_Combined(State& state); void Add(); // Add to state inputs/outputs void Update(DeviceSpan beam_indices, int total_length); @@ -44,8 +44,8 @@ struct KV_Cache_Combined { std::vector input_name_strings_, output_name_strings_; }; -struct KV_Cache : KeyValueCacheInterface { - KV_Cache(State& state); +struct KeyValueCacheDefault : KeyValueCache { + KeyValueCacheDefault(State& state); void AddEncoder() override; // If model has an initial encoder step, this is used // Register input_ids as ORT session input. @@ -80,7 +80,7 @@ struct KV_Cache : KeyValueCacheInterface { std::vector sb_kv_caches_; }; -// Very similar to the KV_Cache, but is only created once at the encoder step, then used without modification for every decoder step +// Very similar to the KeyValueCacheDefault, but is only created once at the encoder step, then used without modification for every decoder step struct Cross_Cache { Cross_Cache(State& state); @@ -99,16 +99,16 @@ struct Cross_Cache { std::vector input_name_strings_, output_name_strings_; }; -struct SlidingWindowKeyValueCache : KeyValueCacheInterface { - SlidingWindowKeyValueCache(State& state); +struct WindowedKeyValueCache : KeyValueCache { + WindowedKeyValueCache(State& state); void Add() override; void AddEncoder() override { - throw std::runtime_error("SlidingWindowKeyValueCache does not support AddEncoder."); + throw std::runtime_error("WindowedKeyValueCache does not support AddEncoder."); }; void Update(DeviceSpan beam_indices, int current_length) override; void RewindTo(size_t index) override { - throw std::runtime_error("SlidingWindowKeyValueCache does not support RewindTo."); + throw std::runtime_error("WindowedKeyValueCache does not support RewindTo."); } private: @@ -116,10 +116,10 @@ struct SlidingWindowKeyValueCache : KeyValueCacheInterface { State& state_; const Model& model_{state_.model_}; - int layer_count_{0}; - int window_size_{0}; - size_t num_windows_{1}; - size_t window_index_{0}; + int layer_count_{}; + int window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; size_t input_index_{~0U}, output_index_{~0U}; std::array key_cache_shape_in_, key_cache_shape_out_; @@ -133,6 +133,6 @@ struct SlidingWindowKeyValueCache : KeyValueCacheInterface { bool is_first_update_{true}; }; -std::unique_ptr CreateKeyValueCache(State& state); +std::unique_ptr CreateKeyValueCache(State& state); } // namespace Generators diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index aa94da9d1..2fde849d4 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,7 +42,7 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; int64_t num_image_tokens_; - InputIDs input_ids_{*this}; // Model input + InputIDsDefault input_ids_{*this}; // Model input ImageFeatures image_features_{*this, ImageFeatures::Mode::Input, // Optional model input model_.config_->model.embedding.inputs.image_features, num_image_tokens_}; @@ -89,9 +89,9 @@ struct DecoderState : State { const CapturedGraphInfo* captured_graph_info_; Embeddings inputs_embeds_{*this, Embeddings::Mode::Input, // Model input model_.config_->model.decoder.inputs.embeddings}; - PositionInputs position_inputs_; // Model input - KV_Cache kv_cache_{*this}; // Model input - Logits logits_{*this}; // Model output + PositionInputsDefault position_inputs_; // Model input + KeyValueCacheDefault kv_cache_{*this}; // Model input + Logits logits_{*this}; // Model output }; struct MultiModalPipelineState : State { diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index c1dcb3379..4e29f0b73 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -9,7 +9,7 @@ namespace Generators { -PositionInputs::PositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk) +PositionInputsDefault::PositionInputsDefault(const Model& model, State& state, DeviceSpan sequence_lengths_unk) : model_{model}, state_{state} { has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); @@ -59,7 +59,7 @@ PositionInputs::PositionInputs(const Model& model, State& state, DeviceSpan next_tokens, int total_length, int new_length) { +void PositionInputsDefault::Update(DeviceSpan next_tokens, int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update if (is_first_update_) { @@ -96,7 +96,7 @@ void PositionInputs::Update(DeviceSpan next_tokens, int total_length, i is_first_update_ = false; } -void PositionInputs::RewindTo(size_t index) { +void PositionInputsDefault::RewindTo(size_t index) { // Reset the state of the position inputs if (index == 0) { is_first_update_ = true; @@ -108,18 +108,18 @@ void PositionInputs::RewindTo(size_t index) { RewindMask(index); #endif } else - throw std::runtime_error("PositionInputs::RewindTo - Unsupported batch size"); + throw std::runtime_error("PositionInputsDefault::RewindTo - Unsupported batch size"); } } -void PositionInputs::AddAttentionMask() { +void PositionInputsDefault::AddAttentionMask() { mask_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(attention_mask_.get()); state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); } -void PositionInputs::AddPositionIDs() { +void PositionInputsDefault::AddPositionIDs() { posid_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(position_ids_.get()); @@ -127,7 +127,7 @@ void PositionInputs::AddPositionIDs() { } #if USE_CUDA || USE_DML -void PositionInputs::CopyNextPositionIDsToCurrent() { +void PositionInputsDefault::CopyNextPositionIDsToCurrent() { #if USE_CUDA assert(model_.device_type_ == DeviceType::CUDA); cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), @@ -149,7 +149,7 @@ void PositionInputs::CopyNextPositionIDsToCurrent() { } #endif -void PositionInputs::CreateNextPositionIDsTensor() { +void PositionInputsDefault::CreateNextPositionIDsTensor() { if (!sb_position_ids_) { if (position_ids_shape_[1] == 1 && position_ids_next_) { position_ids_ = std::move(position_ids_next_); @@ -167,11 +167,11 @@ void PositionInputs::CreateNextPositionIDsTensor() { } } -void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { +void PositionInputsDefault::UpdatePositionIDs(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); + throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - DML does not support continuous decoding."); // Reallocate position_ids when new_kv_length changes if (position_ids_shape_[1] != new_kv_length) { @@ -206,7 +206,7 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { } } -void PositionInputs::CreateNextAttentionMaskTensor(int total_length) { +void PositionInputsDefault::CreateNextAttentionMaskTensor(int total_length) { if (!sb_attention_mask_) { attention_mask_shape_[1] = total_length; attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); @@ -232,11 +232,11 @@ void PositionInputs::CreateNextAttentionMaskTensor(int total_length) { } } -void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { +void PositionInputsDefault::UpdateAttentionMask(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); + throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - DML does not support continuous decoding."); CreateNextAttentionMaskTensor(total_length); state_.inputs_[mask_input_index_] = attention_mask_.get(); @@ -281,7 +281,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } #endif default: - throw std::runtime_error("PositionInputs::Update - Unsupported device type"); + throw std::runtime_error("PositionInputsDefault::Update - Unsupported device type"); } #if USE_DML if (model_.device_type_ != DeviceType::DML) { @@ -295,7 +295,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } template -void PositionInputs::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { +void PositionInputsDefault::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -325,7 +325,7 @@ void PositionInputs::CreateAndInitializePositionIDs(DeviceSpan next_tok } template -void PositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { +void PositionInputsDefault::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -349,14 +349,14 @@ void PositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_t } template -void PositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { +void PositionInputsDefault::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { sequence_lengths_unk[i] = 0; } } template -void PositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) { +void PositionInputsDefault::UpdatePositionIDsImpl(int total_length, int new_kv_length) { auto* data = position_ids_->GetTensorMutableData(); if (position_ids_shape_[0] == 1) { // For batch size == 1 we calculate position ids with total length and new kv length for continuous decoding @@ -370,7 +370,7 @@ void PositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) } #if USE_DML -void PositionInputs::UpdatePositionIDsImplDML() { +void PositionInputsDefault::UpdatePositionIDsImplDML() { ComPtr target_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); @@ -389,7 +389,7 @@ void PositionInputs::UpdatePositionIDsImplDML() { #endif template -void PositionInputs::UpdateAttentionMaskImpl(int total_length) { +void PositionInputsDefault::UpdateAttentionMaskImpl(int total_length) { auto* data = attention_mask_next_->GetTensorMutableData(); auto* old_data = attention_mask_->GetTensorData(); if (attention_mask_shape_[0] == 1) { @@ -408,7 +408,7 @@ void PositionInputs::UpdateAttentionMaskImpl(int total_length) { } #if USE_DML -void PositionInputs::UpdateAttentionMaskImplDML(int total_length) { +void PositionInputsDefault::UpdateAttentionMaskImplDML(int total_length) { ComPtr attention_mask_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); ComPtr attention_mask_next_resource; @@ -443,7 +443,7 @@ void PositionInputs::UpdateAttentionMaskImplDML(int total_length) { #endif #if USE_CUDA -void PositionInputs::RewindMask(size_t index) { +void PositionInputsDefault::RewindMask(size_t index) { if (sb_attention_mask_ && !is_first_mask_update_) { int past_length = static_cast(index); int max_length = static_cast(state_.params_->search.max_length); @@ -459,7 +459,7 @@ void PositionInputs::RewindMask(size_t index) { } #endif -SlidingWindowPositionInputs::SlidingWindowPositionInputs(State& state) +WindowedPositionInputs::WindowedPositionInputs(State& state) : state_{state} { has_posid_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids); has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); @@ -474,7 +474,7 @@ SlidingWindowPositionInputs::SlidingWindowPositionInputs(State& state) if (has_posid_input_) { position_ids_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.position_ids); if (position_ids_type_ != Ort::TypeToTensorType) - throw std::runtime_error("SlidingWindowPositionInputs only supports int32_t position_ids"); + throw std::runtime_error("WindowedPositionInputs only supports int32_t position_ids"); position_ids_shape_ = {1, model_.config_->model.decoder.sliding_window->window_size}; } @@ -482,13 +482,13 @@ SlidingWindowPositionInputs::SlidingWindowPositionInputs(State& state) if (has_mask_input_) { attention_mask_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); if (attention_mask_type_ != Ort::TypeToTensorType) - throw std::runtime_error("SlidingWindowPositionInputs only supports int32_t attention_mask"); + throw std::runtime_error("WindowedPositionInputs only supports int32_t attention_mask"); attention_mask_shape_ = {1, model_.config_->model.context_length}; } } -void SlidingWindowPositionInputs::Add() { +void WindowedPositionInputs::Add() { if (has_posid_input_) { position_ids_index_ = state_.inputs_.size(); state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); @@ -502,7 +502,7 @@ void SlidingWindowPositionInputs::Add() { } } -void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { +void WindowedPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { if (window_index_ == 0) { num_windows_ = (next_tokens.size() + window_size_ - 1) / window_size_; if (has_posid_input_) { @@ -596,11 +596,11 @@ void SlidingWindowPositionInputs::Update(DeviceSpan next_tokens, int to window_index_++; } -std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths) { +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths) { if (state.model_.config_->model.decoder.sliding_window.has_value()) { - return std::make_unique(state); + return std::make_unique(state); } else { - return std::make_unique(state.model_, state, sequence_lengths); + return std::make_unique(state.model_, state, sequence_lengths); } } diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index fa3f2b176..564da5033 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -9,15 +9,15 @@ namespace Generators { -struct PositionInputsInterface { - virtual ~PositionInputsInterface() = default; +struct PositionInputs { + virtual ~PositionInputs() = default; virtual void Add() = 0; virtual void Update(DeviceSpan next_tokens, int total_length, int new_length) = 0; virtual void RewindTo(size_t index) = 0; }; -struct PositionInputs : PositionInputsInterface { - PositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk); +struct PositionInputsDefault : PositionInputs { + PositionInputsDefault(const Model& model, State& state, DeviceSpan sequence_lengths_unk); void Add() override; void Update(DeviceSpan next_tokens, int total_length, int new_length) override; @@ -93,15 +93,15 @@ struct PositionInputs : PositionInputsInterface { #endif }; -struct SlidingWindowPositionInputs : PositionInputsInterface { - SlidingWindowPositionInputs(State& state); - SlidingWindowPositionInputs(const SlidingWindowPositionInputs&) = delete; - SlidingWindowPositionInputs& operator=(const SlidingWindowPositionInputs&) = delete; +struct WindowedPositionInputs : PositionInputs { + WindowedPositionInputs(State& state); + WindowedPositionInputs(const WindowedPositionInputs&) = delete; + WindowedPositionInputs& operator=(const WindowedPositionInputs&) = delete; void Add() override; void Update(DeviceSpan next_tokens, int total_length, int new_length) override; void RewindTo(size_t index) override { - throw std::runtime_error("SlidingWindowPositionInputs does not support RewindTo."); + throw std::runtime_error("WindowedPositionInputs does not support RewindTo."); }; private: @@ -122,11 +122,11 @@ struct SlidingWindowPositionInputs : PositionInputsInterface { size_t attention_mask_index_{~0U}; size_t position_ids_index_{~0U}; - size_t window_size_{0}; - size_t num_windows_{1}; - size_t window_index_{0}; + size_t window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; }; -std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths); +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths); } // namespace Generators diff --git a/src/models/threadpool.cpp b/src/models/threadpool.cpp index 941bac1d8..1ac56f08a 100644 --- a/src/models/threadpool.cpp +++ b/src/models/threadpool.cpp @@ -15,6 +15,8 @@ void ThreadPool::Compute(const std::function& func) { for (auto& thread : threads_) { thread.join(); } + + threads_.clear(); } } // namespace Generators \ No newline at end of file diff --git a/src/models/whisper.h b/src/models/whisper.h index ab7e508d6..0528f4b25 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -34,9 +34,9 @@ struct Whisper_State : State { Decoder, } run_state_{RunState::Encoder_Decoder_Init}; - InputIDs decoder_input_ids_{*this}; + InputIDsDefault decoder_input_ids_{*this}; Logits logits_{*this}; - KV_Cache kv_cache_{*this}; + KeyValueCacheDefault kv_cache_{*this}; Cross_Cache cross_cache_{*this}; std::unique_ptr encoder_input_ids_; std::unique_ptr encoder_hidden_states_; From 0ccc6683be8650ecf6b2358cf43d78a84be5f2ac Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 12 Dec 2024 23:51:28 +0000 Subject: [PATCH 13/18] Address pull-request review comments --- CMakeLists.txt | 3 +++ cmake/cxx_standard.cmake | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 60076364d..4b9f83ec1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ if(ENABLE_TESTS) endif() endif() +find_package(Threads REQUIRED) if(WIN32) add_library(onnxruntime-genai SHARED ${generator_srcs} "${GENERATORS_ROOT}/dll/onnxruntime-genai.rc") @@ -104,6 +105,8 @@ target_include_directories(onnxruntime-genai-static PUBLIC ${onnxruntime_extensi target_link_libraries(onnxruntime-genai PRIVATE onnxruntime_extensions) target_link_libraries(onnxruntime-genai-static PUBLIC onnxruntime_extensions) target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR}) +target_link_libraries(onnxruntime-genai PRIVATE Threads::Threads) +target_link_libraries(onnxruntime-genai-static PUBLIC Threads::Threads) # we keep the shared libraries disconnected on Android as they will come from separate AARs and we don't want to force # the ORT version to match in both. diff --git a/cmake/cxx_standard.cmake b/cmake/cxx_standard.cmake index 713cb9600..9ae00e5cb 100644 --- a/cmake/cxx_standard.cmake +++ b/cmake/cxx_standard.cmake @@ -12,7 +12,3 @@ else () message("Test is using C++20") set(CMAKE_CXX_STANDARD 20) endif () - -if (CMAKE_SYSTEM_NAME STREQUAL "Linux") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") -endif () From 11dbed2f7db6b07a45227455c84aeee4ad118c12 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 16 Dec 2024 09:37:52 -0800 Subject: [PATCH 14/18] Throw meaningful exception when user tries continous decoding --- src/generators.cpp | 6 ++++++ src/models/position_inputs.cpp | 2 -- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 02b79fc73..74572506d 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -300,6 +300,12 @@ void Generator::AppendTokens(cpu_span input_ids) { if (search_->GetSequenceLength() != 0 && state_->params_->search.batch_size > 1) throw std::runtime_error("AppendTokens can only be called once for batch_size > 1. To call AppendTokens again, use RewindToLength(0)"); + constexpr std::array devices_supporting_continuous_decoding{DeviceType::CPU, DeviceType::CUDA}; + if (search_->GetSequenceLength() != 0 && + std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(), + [this](DeviceType device_type) { return device_type == state_->params_->device_type; })) + throw std::runtime_error("Continuous decoding is not supported on the selected device type: " + to_string(state_->params_->device_type)); + if (last_action_ == Action::generated) { ComputeLogits(search_->GetNextTokens()); } diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 4e29f0b73..1b46e497c 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -170,8 +170,6 @@ void PositionInputsDefault::CreateNextPositionIDsTensor() { void PositionInputsDefault::UpdatePositionIDs(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); - if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - DML does not support continuous decoding."); // Reallocate position_ids when new_kv_length changes if (position_ids_shape_[1] != new_kv_length) { From 43af9aa290c85303f3aab4535b5ce393e0bc2774 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 17 Dec 2024 10:14:03 -0800 Subject: [PATCH 15/18] Address pull request review comments --- src/config.h | 6 +++--- src/generators.cpp | 10 +++++----- src/generators.h | 2 +- src/models/decoder_only_pipeline.cpp | 2 +- src/models/input_ids.cpp | 8 ++++---- src/models/input_ids.h | 7 +++++++ src/models/kv_cache.cpp | 14 +++++++++----- src/models/kv_cache.h | 15 ++++++++------- src/models/model.cpp | 6 +++--- src/models/position_inputs.cpp | 4 ++-- src/models/position_inputs.h | 9 +++++++++ src/ort_genai_c.cpp | 2 +- 12 files changed, 53 insertions(+), 32 deletions(-) diff --git a/src/config.h b/src/config.h index 1d0a2907b..5ba21683c 100644 --- a/src/config.h +++ b/src/config.h @@ -107,9 +107,9 @@ struct Config { int num_hidden_layers{}; int head_size{}; - struct SlidingWindow { - int window_size{128}; - int pad_value{}; + struct SlidingWindow { // Sliding window parameters for models that process input prompt in chunks + int window_size{}; // The size of the window to slide over the input prompt + int pad_value{}; // The key-value cache padding value to use for the sliding window for inactive tokens }; std::optional sliding_window; diff --git a/src/generators.cpp b/src/generators.cpp index 74572506d..7746a7ce2 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -177,7 +177,7 @@ std::string to_string(DeviceType device_type) { return "DirectML"; case DeviceType::WEBGPU: return "WebGpu"; - case DeviceType::QNN_WITH_SHARED_MEMORY: + case DeviceType::QNN: return "QnnWithSharedMemory"; } throw std::runtime_error("Unknown device type"); @@ -276,16 +276,16 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } DeviceSpan Generator::AllocateInputIdsOnDevice(cpu_span input_ids) { - size_t input_ids_size = input_ids.size(); + size_t padded_input_ids_size = input_ids.size(); if (model_->config_->model.decoder.sliding_window.has_value()) { // If the model has a sliding window, pad the input_ids to the next multiple of the window size // so that the input_ids can be divided into window size chunks. const auto window_size = model_->config_->model.decoder.sliding_window->window_size; - input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size; + padded_input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size; } - auto input_ids_device = state_->params_->p_device->Allocate(input_ids_size); + auto input_ids_device = state_->params_->p_device->Allocate(padded_input_ids_size); auto cpu_span = input_ids_device.CpuSpan(); - std::fill_n(cpu_span.begin(), input_ids_size, model_->config_->model.pad_token_id); + std::fill_n(cpu_span.begin(), padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id); std::copy_backward(input_ids.begin(), input_ids.end(), cpu_span.end()); input_ids_device.CopyCpuToDevice(); return input_ids_device; diff --git a/src/generators.h b/src/generators.h index 95739bb4e..1d2f5b33e 100644 --- a/src/generators.h +++ b/src/generators.h @@ -60,7 +60,7 @@ enum struct DeviceType { CUDA, DML, WEBGPU, - QNN_WITH_SHARED_MEMORY, + QNN, }; std::string to_string(DeviceType device_type); diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index e7f7218e6..b18fea179 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -58,7 +58,7 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const { } bool IntermediatePipelineState::SupportsPrimaryDevice() const { - if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN_WITH_SHARED_MEMORY) { + if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN) { return true; } else if (model_.device_type_ == DeviceType::CUDA) { if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) { diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 546efd97a..2b74cc491 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -224,19 +224,19 @@ void WindowedInputIDs::Update(DeviceSpan& new_tokens) { value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); - // next_tokens will always be padded so that it's size is a multiple of window_size_ - // next_tokens -> [0, a, b, c, d, e] + // new_tokens will always be padded so that it's size is a multiple of window_size_ + // new_tokens -> [0, a, b, c, d, e] // window_size = 3, num_windows = 2, pad_token = 0 // window_index = 0, value_ -> [0, a, b] std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData()); } else if (window_index_ < num_windows_) { - // next_tokens -> [a, b, c, d, e] + // new_tokens -> [a, b, c, d, e] // window_size = 3, num_windows = 2 // window_index = 1, value_ -> [c, d, e] std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData()); } else { // All prompt token chunks have been processed. Now we process the tokens generated by the model. - // next_tokens -> [f] + // new_tokens -> [f] assert(new_tokens.size() == 1); if (shape_[1] != 1) { shape_[1] = 1; diff --git a/src/models/input_ids.h b/src/models/input_ids.h index ee16e34d8..b4710b32a 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -52,6 +52,13 @@ struct InputIDsDefault : InputIDs { std::unique_ptr past_sequence_length_; }; +// Certain models can only process a fixed number of tokens at a time. +// For example, given a prompt with 120 tokens, and a model that can only process 20 tokens at a time, +// this class will split the prompt into 6 windows of 20 tokens each. +// At each update step, the next window of tokens is processed. +// This is done until all windows have been processed before switching to the model-generated tokens +// which are processed one token at a time. +// In contrast, InputIDsDefault processes all prompt tokens at once. struct WindowedInputIDs : public InputIDs { WindowedInputIDs(State& state); WindowedInputIDs(const WindowedInputIDs&) = delete; diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 734d07f8e..e1660a55d 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -25,6 +25,10 @@ int64_t ElementCountFromShape(const std::array& shape) { return std::accumulate(shape.begin(), shape.end(), int64_t{1}, std::multiplies()); } +bool IsCacheNeeded(const Model& model) { + return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); +} + } // namespace KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state) @@ -188,10 +192,6 @@ void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indic } } -bool KeyValueCache::IsCacheNeeded(const Model& model) { - return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); -} - KeyValueCacheDefault::KeyValueCacheDefault(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, @@ -455,6 +455,10 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state) model_.config_->model.context_length - window_size_, model_.config_->model.decoder.head_size}, value_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, window_size_, model_.config_->model.decoder.head_size} { + if (layer_count_ == 0) { + throw std::runtime_error("Expected there to be at least 1 layer in the model. Actual: " + + std::to_string(layer_count_) + ". Please check the num_hidden_layers attribute in the model configuration."); + } for (int i = 0; i < layer_count_; ++i) { input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, i)); input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, i)); @@ -669,7 +673,7 @@ void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current } std::unique_ptr CreateKeyValueCache(State& state) { - if (!KeyValueCache::IsCacheNeeded(state.model_)) { + if (!IsCacheNeeded(state.model_)) { return nullptr; } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 5699e2e49..2c9c8e46a 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -10,22 +10,23 @@ struct KeyValueCache { virtual void AddEncoder() = 0; virtual void Update(DeviceSpan beam_indices, int total_length) = 0; virtual void RewindTo(size_t index) = 0; - - static bool IsCacheNeeded(const Model& model); }; -struct KeyValueCacheDefault_Combined { +struct KeyValueCacheDefault_Combined : KeyValueCache { KeyValueCacheDefault_Combined(State& state); - void Add(); // Add to state inputs/outputs - void Update(DeviceSpan beam_indices, int total_length); - void RewindTo(size_t index); + void Add() override; // Add to state inputs/outputs + void AddEncoder() override { + throw std::runtime_error("KeyValueCacheDefault_Combined does not support AddEncoder."); + }; + void Update(DeviceSpan beam_indices, int total_length) override; + void RewindTo(size_t index) override; + private: template void PickPastState(DeviceSpan beam_indices, int index); void PickPastState(DeviceSpan beam_indices, int index); - private: template void RewindPastTensorsTo(size_t index); diff --git a/src/models/model.cpp b/src/models/model.cpp index 85fffbec2..47f6dc4bd 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -316,7 +316,7 @@ void Model::InitDeviceAllocator(OrtSession& session) { } #endif - if (device_type_ == DeviceType::QNN_WITH_SHARED_MEMORY) { + if (device_type_ == DeviceType::QNN) { memory_info_device_ = OrtMemoryInfo::Create("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); @@ -519,7 +519,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ // on the other hand, not sure if is_primary_session_options is the right thing to check here. if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator"); opt_it != opts.end() && opt_it->second == "1") { - device_type_ = DeviceType::QNN_WITH_SHARED_MEMORY; + device_type_ = DeviceType::QNN; } session_options.AppendExecutionProvider("QNN", opts); @@ -707,7 +707,7 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, switch (device_type_) { case DeviceType::WEBGPU: case DeviceType::DML: - case DeviceType::QNN_WITH_SHARED_MEMORY: + case DeviceType::QNN: // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs case DeviceType::CPU: for (int i = 0; i < batch_size; i++) { diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 1b46e497c..46e4c3114 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -242,7 +242,7 @@ void PositionInputsDefault::UpdateAttentionMask(int total_length, int new_kv_len switch (model_.device_type_) { case DeviceType::WEBGPU: case DeviceType::CPU: - case DeviceType::QNN_WITH_SHARED_MEMORY: { + case DeviceType::QNN: { type_ == Ort::TypeToTensorType ? UpdateAttentionMaskImpl(total_length) : UpdateAttentionMaskImpl(total_length); break; @@ -528,7 +528,7 @@ void WindowedPositionInputs::Update(DeviceSpan next_tokens, int total_l // window_size = 3, num_windows = 2, pad_token = 0 // window_index = 0, attention_mask_ -> ([0] * context_length - window_size_) + [0, 1, 1] auto* attention_mask_data = attention_mask_->GetTensorMutableData(); - std::fill(attention_mask_data, attention_mask_data + attention_mask_shape_[1], 0); + std::fill_n(attention_mask_data, attention_mask_shape_[1] - window_size_, 0); for (size_t i = 0; i < window_size_; i++) { attention_mask_data[attention_mask_shape_[1] - window_size_ + i] = next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id ? 0 : 1; } diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 564da5033..c598e4310 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -93,6 +93,15 @@ struct PositionInputsDefault : PositionInputs { #endif }; +// Certain models can only process a fixed number of tokens at a time. +// For example, given a prompt with 120 tokens, and a model that can only process 20 tokens at a time, +// this class will split the position ids into 6 windows of 20 tokens each. +// At each update step, the next window of position ids is prepared. +// This is done until all windows have been processed before switching to the model-generation phase +// where position ids are prepared one id at a time. +// This class will also prepare the attention mask for each iteration. The attention mask buffer is allocated just +// once and reused for each iteration by setting the mask to 1 for current window tokens and previously active window tokens +// In contrast, PositionInputsDefault processes all position ids at once. struct WindowedPositionInputs : PositionInputs { WindowedPositionInputs(State& state); WindowedPositionInputs(const WindowedPositionInputs&) = delete; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 609c20f60..3f2c9d750 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -292,7 +292,7 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerator* oga_gene } auto input_ids = Generators::PadInputs(span_sequences, generator.model_->config_->model.pad_token_id); - generator.AppendTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); + generator.AppendTokens(input_ids); return nullptr; OGA_CATCH } From 2a00a26bb85ba12b6e900be02b94ca4a35d7584b Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 18 Dec 2024 10:21:12 -0800 Subject: [PATCH 16/18] Rename InputIDsDefault with DefaultInputIDs --- src/models/decoder_only.h | 2 +- src/models/gpt.h | 2 +- src/models/input_ids.cpp | 8 ++++---- src/models/input_ids.h | 10 +++++----- src/models/multi_modal_vision_model.h | 2 +- src/models/whisper.h | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 1ea2deb2b..d919c4cb3 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -29,7 +29,7 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - InputIDsDefault input_ids_{*this}; + DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; KeyValueCacheDefault kv_cache_{*this}; PositionInputsDefault position_inputs_; diff --git a/src/models/gpt.h b/src/models/gpt.h index 4ee52102b..627739c4d 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -27,7 +27,7 @@ struct Gpt_State : State { const Gpt_Model& model_; - InputIDsDefault input_ids_{*this}; + DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; KeyValueCacheDefault_Combined kv_cache_{*this}; PositionInputsDefault position_inputs_; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 2b74cc491..f99907b59 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -5,7 +5,7 @@ namespace Generators { -InputIDsDefault::InputIDsDefault(State& state) +DefaultInputIDs::DefaultInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); shape_ = {state_.params_->BatchBeamSize(), 0}; @@ -41,7 +41,7 @@ InputIDsDefault::InputIDsDefault(State& state) } } -void InputIDsDefault::Add() { +void DefaultInputIDs::Add() { input_index_ = state_.inputs_.size(); state_.inputs_.push_back(value_.get()); @@ -55,7 +55,7 @@ void InputIDsDefault::Add() { } } -void InputIDsDefault::Update(DeviceSpan& new_tokens) { +void DefaultInputIDs::Update(DeviceSpan& new_tokens) { const auto get_unpadded_sequence_length = [](std::span input_ids, int32_t pad_token_id) { int32_t seq_length = 0; @@ -254,7 +254,7 @@ std::unique_ptr CreateInputIDs(State& state) { if (state.model_.config_->model.decoder.sliding_window.has_value()) { return std::make_unique(state); } else { - return std::make_unique(state); + return std::make_unique(state); } } diff --git a/src/models/input_ids.h b/src/models/input_ids.h index b4710b32a..d7a229911 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -11,10 +11,10 @@ struct InputIDs { virtual void Update(DeviceSpan& next_tokens) = 0; }; -struct InputIDsDefault : InputIDs { - InputIDsDefault(State& state); - InputIDsDefault(const InputIDsDefault&) = delete; - InputIDsDefault& operator=(const InputIDsDefault&) = delete; +struct DefaultInputIDs : InputIDs { + DefaultInputIDs(State& state); + DefaultInputIDs(const DefaultInputIDs&) = delete; + DefaultInputIDs& operator=(const DefaultInputIDs&) = delete; // Register input_ids as ORT session input. // Called only once during initialization of state. @@ -58,7 +58,7 @@ struct InputIDsDefault : InputIDs { // At each update step, the next window of tokens is processed. // This is done until all windows have been processed before switching to the model-generated tokens // which are processed one token at a time. -// In contrast, InputIDsDefault processes all prompt tokens at once. +// In contrast, DefaultInputIDs processes all prompt tokens at once. struct WindowedInputIDs : public InputIDs { WindowedInputIDs(State& state); WindowedInputIDs(const WindowedInputIDs&) = delete; diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 2fde849d4..92690049b 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,7 +42,7 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; int64_t num_image_tokens_; - InputIDsDefault input_ids_{*this}; // Model input + DefaultInputIDs input_ids_{*this}; // Model input ImageFeatures image_features_{*this, ImageFeatures::Mode::Input, // Optional model input model_.config_->model.embedding.inputs.image_features, num_image_tokens_}; diff --git a/src/models/whisper.h b/src/models/whisper.h index 0528f4b25..b286ae0c1 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -34,7 +34,7 @@ struct Whisper_State : State { Decoder, } run_state_{RunState::Encoder_Decoder_Init}; - InputIDsDefault decoder_input_ids_{*this}; + DefaultInputIDs decoder_input_ids_{*this}; Logits logits_{*this}; KeyValueCacheDefault kv_cache_{*this}; Cross_Cache cross_cache_{*this}; From ae436ef122d5edf3ade81b19bb1c8573feeb24da Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 18 Dec 2024 10:38:12 -0800 Subject: [PATCH 17/18] More merge conflicts --- src/config.cpp | 6 +++--- src/generators.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 27b11d9a4..5e6ad7c03 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -274,11 +274,11 @@ struct Pipeline_Element : JSON::Element { struct SlidingWindow_Element : JSON::Element { explicit SlidingWindow_Element(std::optional& v) : v_{v} {} - void OnNumber(std::string_view name, double value) override { + void OnValue(std::string_view name, JSON::Value value) override { if (name == "window_size") { - v_->window_size = static_cast(value); + v_->window_size = static_cast(JSON::Get(value)); } else if (name == "pad_value") { - v_->pad_value = static_cast(value); + v_->pad_value = static_cast(JSON::Get(value)); } else throw JSON::unknown_value_error{}; } diff --git a/src/generators.h b/src/generators.h index 8fc57bd45..28a71e580 100644 --- a/src/generators.h +++ b/src/generators.h @@ -129,7 +129,7 @@ struct Generator : LeakChecked { private: DeviceSpan AllocateInputIdsOnDevice(cpu_span input_ids); - void AuxAppendTokens(const cpu_span input_ids); + void AuxAppendTokens(cpu_span input_ids); void ComputeLogits(DeviceSpan next_tokens); enum Action { standard, // Default, set in any other case generated, // Set after GenerateNextToken From 5a8fae9b6bff94a5b0e5784c0d91228e2bc594d5 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 18 Dec 2024 10:55:49 -0800 Subject: [PATCH 18/18] Rename structs --- src/models/decoder_only.h | 4 +-- src/models/gpt.h | 4 +-- src/models/kv_cache.cpp | 38 ++++++++++---------- src/models/kv_cache.h | 16 ++++----- src/models/multi_modal_vision_model.h | 4 +-- src/models/position_inputs.cpp | 50 +++++++++++++-------------- src/models/position_inputs.h | 6 ++-- src/models/whisper.h | 4 +-- 8 files changed, 63 insertions(+), 63 deletions(-) diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index d919c4cb3..27b006a0a 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -31,8 +31,8 @@ struct DecoderOnly_State : State { DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; - KeyValueCacheDefault kv_cache_{*this}; - PositionInputsDefault position_inputs_; + DefaultKeyValueCache kv_cache_{*this}; + DefaultPositionInputs position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/gpt.h b/src/models/gpt.h index 627739c4d..f50e51406 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -29,8 +29,8 @@ struct Gpt_State : State { DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; - KeyValueCacheDefault_Combined kv_cache_{*this}; - PositionInputsDefault position_inputs_; + CombinedKeyValueCache kv_cache_{*this}; + DefaultPositionInputs position_inputs_; ExtraInputs extra_inputs_{*this}; }; } // namespace Generators diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index e1660a55d..0a47355a9 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -31,7 +31,7 @@ bool IsCacheNeeded(const Model& model) { } // namespace -KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state) +CombinedKeyValueCache::CombinedKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, shape_{2, state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} { @@ -54,7 +54,7 @@ KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state) } } -void KeyValueCacheDefault_Combined::Add() { +void CombinedKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -66,7 +66,7 @@ void KeyValueCacheDefault_Combined::Add() { } } -void KeyValueCacheDefault_Combined::Update(DeviceSpan beam_indices, int total_length) { +void CombinedKeyValueCache::Update(DeviceSpan beam_indices, int total_length) { assert(state_.params_->search.num_beams == 1 || !beam_indices.empty()); // We require beam_indices if we're a beam search if (!is_first_update_) { @@ -89,7 +89,7 @@ void KeyValueCacheDefault_Combined::Update(DeviceSpan beam_indices, int is_first_update_ = false; } -void KeyValueCacheDefault_Combined::RewindTo(size_t index) { +void CombinedKeyValueCache::RewindTo(size_t index) { if (shape_[3] <= static_cast(index)) { throw std::runtime_error("Requested length of rewind is greater than the current length."); } @@ -108,7 +108,7 @@ void KeyValueCacheDefault_Combined::RewindTo(size_t index) { } template -void KeyValueCacheDefault_Combined::RewindPastTensorsTo(size_t index) { +void CombinedKeyValueCache::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[3] >= static_cast(index)); std::array new_shape = shape_; new_shape[3] = static_cast(index); @@ -143,7 +143,7 @@ void KeyValueCacheDefault_Combined::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indices_device, int index) { +void CombinedKeyValueCache::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.CopyDeviceToCpu(); auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4]; auto past_key_size = shape_[1] * block_size_per_beam; @@ -184,7 +184,7 @@ void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indic pasts_[index] = std::move(past); } -void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indices, int index) { +void CombinedKeyValueCache::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -192,7 +192,7 @@ void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan beam_indic } } -KeyValueCacheDefault::KeyValueCacheDefault(State& state) +DefaultKeyValueCache::DefaultKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")}, @@ -257,7 +257,7 @@ KeyValueCacheDefault::KeyValueCacheDefault(State& state) } } -void KeyValueCacheDefault::AddEncoder() { +void DefaultKeyValueCache::AddEncoder() { // We don't set the input_index_ & output_index_ because the encoder step only runs once, there's no update for (int i = 0; i < layer_count_ * 2; ++i) { @@ -266,7 +266,7 @@ void KeyValueCacheDefault::AddEncoder() { } } -void KeyValueCacheDefault::Add() { +void DefaultKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -285,7 +285,7 @@ void KeyValueCacheDefault::Add() { } } -void KeyValueCacheDefault::Update(DeviceSpan beam_indices, int total_length) { +void DefaultKeyValueCache::Update(DeviceSpan beam_indices, int total_length) { // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; @@ -310,7 +310,7 @@ void KeyValueCacheDefault::Update(DeviceSpan beam_indices, int total_le is_first_update_ = false; } -void KeyValueCacheDefault::RewindTo(size_t index) { +void DefaultKeyValueCache::RewindTo(size_t index) { if (past_present_share_buffer_) { return; } else if (shape_[2] <= static_cast(index)) { @@ -331,7 +331,7 @@ void KeyValueCacheDefault::RewindTo(size_t index) { } template -void KeyValueCacheDefault::RewindPastTensorsTo(size_t index) { +void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[2] >= static_cast(index) && !past_present_share_buffer_); std::array new_shape = shape_; new_shape[2] = static_cast(index); @@ -366,7 +366,7 @@ void KeyValueCacheDefault::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices_device, int index) { +void DefaultKeyValueCache::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.Span(); auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3]; auto element_count = shape_[0] * block_size_per_beam; @@ -398,7 +398,7 @@ void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices_device pasts_[index] = std::move(past_value); } -void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices, int index) { +void DefaultKeyValueCache::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -406,7 +406,7 @@ void KeyValueCacheDefault::PickPastState(DeviceSpan beam_indices, int i } } -Cross_Cache::Cross_Cache(State& state) +CrossCache::CrossCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 1500, model_.config_->model.decoder.head_size} { @@ -429,14 +429,14 @@ Cross_Cache::Cross_Cache(State& state) } } -void Cross_Cache::AddOutputs() { +void CrossCache::AddOutputs() { for (int i = 0; i < layer_count_ * 2; ++i) { state_.outputs_.push_back(values_[i].get()); state_.output_names_.push_back(output_name_strings_[i].c_str()); } } -void Cross_Cache::AddInputs() { +void CrossCache::AddInputs() { for (int i = 0; i < layer_count_ * 2; ++i) { state_.inputs_.push_back(values_[i].get()); state_.input_names_.push_back(input_name_strings_[i].c_str()); @@ -680,7 +680,7 @@ std::unique_ptr CreateKeyValueCache(State& state) { if (state.model_.config_->model.decoder.sliding_window) { return std::make_unique(state); } else { - return std::make_unique(state); + return std::make_unique(state); } } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 2c9c8e46a..0e871d938 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -12,12 +12,12 @@ struct KeyValueCache { virtual void RewindTo(size_t index) = 0; }; -struct KeyValueCacheDefault_Combined : KeyValueCache { - KeyValueCacheDefault_Combined(State& state); +struct CombinedKeyValueCache : KeyValueCache { + CombinedKeyValueCache(State& state); void Add() override; // Add to state inputs/outputs void AddEncoder() override { - throw std::runtime_error("KeyValueCacheDefault_Combined does not support AddEncoder."); + throw std::runtime_error("CombinedKeyValueCache does not support AddEncoder."); }; void Update(DeviceSpan beam_indices, int total_length) override; void RewindTo(size_t index) override; @@ -45,8 +45,8 @@ struct KeyValueCacheDefault_Combined : KeyValueCache { std::vector input_name_strings_, output_name_strings_; }; -struct KeyValueCacheDefault : KeyValueCache { - KeyValueCacheDefault(State& state); +struct DefaultKeyValueCache : KeyValueCache { + DefaultKeyValueCache(State& state); void AddEncoder() override; // If model has an initial encoder step, this is used // Register input_ids as ORT session input. @@ -81,9 +81,9 @@ struct KeyValueCacheDefault : KeyValueCache { std::vector sb_kv_caches_; }; -// Very similar to the KeyValueCacheDefault, but is only created once at the encoder step, then used without modification for every decoder step -struct Cross_Cache { - Cross_Cache(State& state); +// Very similar to the DefaultKeyValueCache, but is only created once at the encoder step, then used without modification for every decoder step +struct CrossCache { + CrossCache(State& state); void AddOutputs(); void AddInputs(); diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 92690049b..3cfa1bdfc 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -89,8 +89,8 @@ struct DecoderState : State { const CapturedGraphInfo* captured_graph_info_; Embeddings inputs_embeds_{*this, Embeddings::Mode::Input, // Model input model_.config_->model.decoder.inputs.embeddings}; - PositionInputsDefault position_inputs_; // Model input - KeyValueCacheDefault kv_cache_{*this}; // Model input + DefaultPositionInputs position_inputs_; // Model input + DefaultKeyValueCache kv_cache_{*this}; // Model input Logits logits_{*this}; // Model output }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 46e4c3114..fde1ed7a9 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -9,7 +9,7 @@ namespace Generators { -PositionInputsDefault::PositionInputsDefault(const Model& model, State& state, DeviceSpan sequence_lengths_unk) +DefaultPositionInputs::DefaultPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk) : model_{model}, state_{state} { has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); @@ -59,7 +59,7 @@ PositionInputsDefault::PositionInputsDefault(const Model& model, State& state, D } } -void PositionInputsDefault::Add() { +void DefaultPositionInputs::Add() { if (has_posid_input_) { AddPositionIDs(); } @@ -68,7 +68,7 @@ void PositionInputsDefault::Add() { } } -void PositionInputsDefault::Update(DeviceSpan next_tokens, int total_length, int new_length) { +void DefaultPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update if (is_first_update_) { @@ -96,7 +96,7 @@ void PositionInputsDefault::Update(DeviceSpan next_tokens, int total_le is_first_update_ = false; } -void PositionInputsDefault::RewindTo(size_t index) { +void DefaultPositionInputs::RewindTo(size_t index) { // Reset the state of the position inputs if (index == 0) { is_first_update_ = true; @@ -108,18 +108,18 @@ void PositionInputsDefault::RewindTo(size_t index) { RewindMask(index); #endif } else - throw std::runtime_error("PositionInputsDefault::RewindTo - Unsupported batch size"); + throw std::runtime_error("DefaultPositionInputs::RewindTo - Unsupported batch size"); } } -void PositionInputsDefault::AddAttentionMask() { +void DefaultPositionInputs::AddAttentionMask() { mask_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(attention_mask_.get()); state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); } -void PositionInputsDefault::AddPositionIDs() { +void DefaultPositionInputs::AddPositionIDs() { posid_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(position_ids_.get()); @@ -127,7 +127,7 @@ void PositionInputsDefault::AddPositionIDs() { } #if USE_CUDA || USE_DML -void PositionInputsDefault::CopyNextPositionIDsToCurrent() { +void DefaultPositionInputs::CopyNextPositionIDsToCurrent() { #if USE_CUDA assert(model_.device_type_ == DeviceType::CUDA); cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), @@ -149,7 +149,7 @@ void PositionInputsDefault::CopyNextPositionIDsToCurrent() { } #endif -void PositionInputsDefault::CreateNextPositionIDsTensor() { +void DefaultPositionInputs::CreateNextPositionIDsTensor() { if (!sb_position_ids_) { if (position_ids_shape_[1] == 1 && position_ids_next_) { position_ids_ = std::move(position_ids_next_); @@ -167,9 +167,9 @@ void PositionInputsDefault::CreateNextPositionIDsTensor() { } } -void PositionInputsDefault::UpdatePositionIDs(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); // Reallocate position_ids when new_kv_length changes if (position_ids_shape_[1] != new_kv_length) { @@ -204,7 +204,7 @@ void PositionInputsDefault::UpdatePositionIDs(int total_length, int new_kv_lengt } } -void PositionInputsDefault::CreateNextAttentionMaskTensor(int total_length) { +void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) { if (!sb_attention_mask_) { attention_mask_shape_[1] = total_length; attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); @@ -230,11 +230,11 @@ void PositionInputsDefault::CreateNextAttentionMaskTensor(int total_length) { } } -void PositionInputsDefault::UpdateAttentionMask(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputsDefault::UpdatePositionIDs - DML does not support continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); CreateNextAttentionMaskTensor(total_length); state_.inputs_[mask_input_index_] = attention_mask_.get(); @@ -279,7 +279,7 @@ void PositionInputsDefault::UpdateAttentionMask(int total_length, int new_kv_len } #endif default: - throw std::runtime_error("PositionInputsDefault::Update - Unsupported device type"); + throw std::runtime_error("DefaultPositionInputs::Update - Unsupported device type"); } #if USE_DML if (model_.device_type_ != DeviceType::DML) { @@ -293,7 +293,7 @@ void PositionInputsDefault::UpdateAttentionMask(int total_length, int new_kv_len } template -void PositionInputsDefault::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { +void DefaultPositionInputs::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -323,7 +323,7 @@ void PositionInputsDefault::CreateAndInitializePositionIDs(DeviceSpan n } template -void PositionInputsDefault::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { +void DefaultPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -347,14 +347,14 @@ void PositionInputsDefault::CreateAndInitializeAttentionMask(DeviceSpan } template -void PositionInputsDefault::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { +void DefaultPositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { sequence_lengths_unk[i] = 0; } } template -void PositionInputsDefault::UpdatePositionIDsImpl(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) { auto* data = position_ids_->GetTensorMutableData(); if (position_ids_shape_[0] == 1) { // For batch size == 1 we calculate position ids with total length and new kv length for continuous decoding @@ -368,7 +368,7 @@ void PositionInputsDefault::UpdatePositionIDsImpl(int total_length, int new_kv_l } #if USE_DML -void PositionInputsDefault::UpdatePositionIDsImplDML() { +void DefaultPositionInputs::UpdatePositionIDsImplDML() { ComPtr target_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); @@ -387,7 +387,7 @@ void PositionInputsDefault::UpdatePositionIDsImplDML() { #endif template -void PositionInputsDefault::UpdateAttentionMaskImpl(int total_length) { +void DefaultPositionInputs::UpdateAttentionMaskImpl(int total_length) { auto* data = attention_mask_next_->GetTensorMutableData(); auto* old_data = attention_mask_->GetTensorData(); if (attention_mask_shape_[0] == 1) { @@ -406,7 +406,7 @@ void PositionInputsDefault::UpdateAttentionMaskImpl(int total_length) { } #if USE_DML -void PositionInputsDefault::UpdateAttentionMaskImplDML(int total_length) { +void DefaultPositionInputs::UpdateAttentionMaskImplDML(int total_length) { ComPtr attention_mask_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); ComPtr attention_mask_next_resource; @@ -441,7 +441,7 @@ void PositionInputsDefault::UpdateAttentionMaskImplDML(int total_length) { #endif #if USE_CUDA -void PositionInputsDefault::RewindMask(size_t index) { +void DefaultPositionInputs::RewindMask(size_t index) { if (sb_attention_mask_ && !is_first_mask_update_) { int past_length = static_cast(index); int max_length = static_cast(state_.params_->search.max_length); @@ -598,7 +598,7 @@ std::unique_ptr CreatePositionInputs(State& state, DeviceSpanmodel.decoder.sliding_window.has_value()) { return std::make_unique(state); } else { - return std::make_unique(state.model_, state, sequence_lengths); + return std::make_unique(state.model_, state, sequence_lengths); } } diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index c598e4310..4365e0ee4 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -16,8 +16,8 @@ struct PositionInputs { virtual void RewindTo(size_t index) = 0; }; -struct PositionInputsDefault : PositionInputs { - PositionInputsDefault(const Model& model, State& state, DeviceSpan sequence_lengths_unk); +struct DefaultPositionInputs : PositionInputs { + DefaultPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk); void Add() override; void Update(DeviceSpan next_tokens, int total_length, int new_length) override; @@ -101,7 +101,7 @@ struct PositionInputsDefault : PositionInputs { // where position ids are prepared one id at a time. // This class will also prepare the attention mask for each iteration. The attention mask buffer is allocated just // once and reused for each iteration by setting the mask to 1 for current window tokens and previously active window tokens -// In contrast, PositionInputsDefault processes all position ids at once. +// In contrast, DefaultPositionInputs processes all position ids at once. struct WindowedPositionInputs : PositionInputs { WindowedPositionInputs(State& state); WindowedPositionInputs(const WindowedPositionInputs&) = delete; diff --git a/src/models/whisper.h b/src/models/whisper.h index b286ae0c1..34eecd0ff 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -36,8 +36,8 @@ struct Whisper_State : State { DefaultInputIDs decoder_input_ids_{*this}; Logits logits_{*this}; - KeyValueCacheDefault kv_cache_{*this}; - Cross_Cache cross_cache_{*this}; + DefaultKeyValueCache kv_cache_{*this}; + CrossCache cross_cache_{*this}; std::unique_ptr encoder_input_ids_; std::unique_ptr encoder_hidden_states_;