diff --git a/CMakeLists.txt b/CMakeLists.txt index 60076364d..4b9f83ec1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ if(ENABLE_TESTS) endif() endif() +find_package(Threads REQUIRED) if(WIN32) add_library(onnxruntime-genai SHARED ${generator_srcs} "${GENERATORS_ROOT}/dll/onnxruntime-genai.rc") @@ -104,6 +105,8 @@ target_include_directories(onnxruntime-genai-static PUBLIC ${onnxruntime_extensi target_link_libraries(onnxruntime-genai PRIVATE onnxruntime_extensions) target_link_libraries(onnxruntime-genai-static PUBLIC onnxruntime_extensions) target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR}) +target_link_libraries(onnxruntime-genai PRIVATE Threads::Threads) +target_link_libraries(onnxruntime-genai-static PUBLIC Threads::Threads) # we keep the shared libraries disconnected on Android as they will come from separate AARs and we don't want to force # the ORT version to match in both. diff --git a/src/config.cpp b/src/config.cpp index 459da0d68..5e6ad7c03 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -271,6 +271,22 @@ struct Pipeline_Element : JSON::Element { PipelineModelObject_Element object_{v_}; }; +struct SlidingWindow_Element : JSON::Element { + explicit SlidingWindow_Element(std::optional& v) : v_{v} {} + + void OnValue(std::string_view name, JSON::Value value) override { + if (name == "window_size") { + v_->window_size = static_cast(JSON::Get(value)); + } else if (name == "pad_value") { + v_->pad_value = static_cast(JSON::Get(value)); + } else + throw JSON::unknown_value_error{}; + } + + private: + std::optional& v_; +}; + struct Decoder_Element : JSON::Element { explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {} @@ -301,6 +317,10 @@ struct Decoder_Element : JSON::Element { if (name == "outputs") { return outputs_; } + if (name == "sliding_window") { + v_.sliding_window = Config::Model::Decoder::SlidingWindow{}; + return sliding_window_; + } throw JSON::unknown_value_error{}; } @@ -316,6 +336,7 @@ struct Decoder_Element : JSON::Element { Inputs_Element inputs_{v_.inputs}; Outputs_Element outputs_{v_.outputs}; Pipeline_Element pipeline_{v_.pipeline}; + SlidingWindow_Element sliding_window_{v_.sliding_window}; }; struct VisionInputs_Element : JSON::Element { diff --git a/src/config.h b/src/config.h index cd7cffd8f..5ba21683c 100644 --- a/src/config.h +++ b/src/config.h @@ -107,6 +107,12 @@ struct Config { int num_hidden_layers{}; int head_size{}; + struct SlidingWindow { // Sliding window parameters for models that process input prompt in chunks + int window_size{}; // The size of the window to slide over the input prompt + int pad_value{}; // The key-value cache padding value to use for the sliding window for inactive tokens + }; + std::optional sliding_window; + struct Inputs { std::string input_ids{Defaults::InputIdsName}; std::string embeddings{"inputs_embeds"}; diff --git a/src/generators.cpp b/src/generators.cpp index cd4660121..9415cb1db 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -3,6 +3,7 @@ #include "generators.h" #include "sequences.h" +#include "models/env_utils.h" #include "models/model.h" #include "models/decoder_only.h" #include "search.h" @@ -45,8 +46,14 @@ void OnCudaError(cudaError_t error) { assert(false); } static bool _ = (Ort::InitApi(), false); +static OrtLoggingLevel GetDefaultOrtLoggingLevel() { + bool ort_verbose_logging = false; + GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging); + return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; +} + OrtGlobals::OrtGlobals() - : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} { + : env_{OrtEnv::Create(GetDefaultOrtLoggingLevel())} { auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1); Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()}; env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config); @@ -170,6 +177,8 @@ std::string to_string(DeviceType device_type) { return "DirectML"; case DeviceType::WEBGPU: return "WebGpu"; + case DeviceType::QNN: + return "QnnWithSharedMemory"; } throw std::runtime_error("Unknown device type"); } @@ -266,16 +275,24 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } } -DeviceSpan Generator::AllocateInputIdsOnDevice(const cpu_span input_ids) { - auto input_ids_device = state_->params_->p_device->Allocate(input_ids.size()); +DeviceSpan Generator::AllocateInputIdsOnDevice(cpu_span input_ids) { + size_t padded_input_ids_size = input_ids.size(); + if (model_->config_->model.decoder.sliding_window.has_value()) { + // If the model has a sliding window, pad the input_ids to the next multiple of the window size + // so that the input_ids can be divided into window size chunks. + const auto window_size = model_->config_->model.decoder.sliding_window->window_size; + padded_input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size; + } + auto input_ids_device = state_->params_->p_device->Allocate(padded_input_ids_size); auto cpu_span = input_ids_device.CpuSpan(); - std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin()); + std::fill_n(cpu_span.begin(), padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id); + std::copy_backward(input_ids.begin(), input_ids.end(), cpu_span.end()); input_ids_device.CopyCpuToDevice(); return input_ids_device; } // TODO(aciddelgado): Remove this function once SetInputs is moved to generator -void Generator::AuxAppendTokens(const cpu_span input_ids) { +void Generator::AuxAppendTokens(cpu_span input_ids) { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (input_ids.size() == 0) throw std::runtime_error("input_ids is empty"); @@ -288,7 +305,7 @@ void Generator::AuxAppendTokens(const cpu_span input_ids) { ComputeLogits(input_ids_device); } -void Generator::AppendTokens(const cpu_span input_ids) { +void Generator::AppendTokens(cpu_span input_ids) { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (input_ids.size() == 0) throw std::runtime_error("input_ids is empty"); @@ -297,6 +314,12 @@ void Generator::AppendTokens(const cpu_span input_ids) { if (search_->GetSequenceLength() != 0 && state_->params_->search.batch_size > 1) throw std::runtime_error("AppendTokens can only be called once for batch_size > 1. To call AppendTokens again, use RewindToLength(0)"); + constexpr std::array devices_supporting_continuous_decoding{DeviceType::CPU, DeviceType::CUDA}; + if (search_->GetSequenceLength() != 0 && + std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(), + [this](DeviceType device_type) { return device_type == state_->params_->device_type; })) + throw std::runtime_error("Continuous decoding is not supported on the selected device type: " + to_string(state_->params_->device_type)); + if (last_action_ == Action::generated) { ComputeLogits(search_->GetNextTokens()); } diff --git a/src/generators.h b/src/generators.h index 5cc68dad9..28a71e580 100644 --- a/src/generators.h +++ b/src/generators.h @@ -60,6 +60,7 @@ enum struct DeviceType { CUDA, DML, WEBGPU, + QNN, }; std::string to_string(DeviceType device_type); @@ -111,7 +112,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - void AppendTokens(const cpu_span input_ids); + void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length DeviceSpan GetLogits(); @@ -127,8 +128,8 @@ struct Generator : LeakChecked { bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio private: - DeviceSpan AllocateInputIdsOnDevice(const cpu_span input_ids); - void AuxAppendTokens(const cpu_span input_ids); + DeviceSpan AllocateInputIdsOnDevice(cpu_span input_ids); + void AuxAppendTokens(cpu_span input_ids); void ComputeLogits(DeviceSpan next_tokens); enum Action { standard, // Default, set in any other case generated, // Set after GenerateNextToken diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index 597ded8de..11056bfde 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -39,15 +39,18 @@ std::ostream& operator<<(std::ostream& stream, Ort::BFloat16_t v) { template void DumpSpan(std::ostream& stream, std::span values) { + // If type is uint8_t or int8_t cast to int so it displays as an int vs a char + using DisplayType = std::conditional_t || std::is_same_v, int, T>; + if (values.size() <= c_value_count) { for (auto v : values) - stream << v << ' '; + stream << static_cast(v) << ' '; } else { for (size_t i = 0; i < c_value_count / 2; i++) - stream << values[i] << ' '; + stream << static_cast(values[i]) << ' '; stream << "... "; for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) - stream << values[i] << ' '; + stream << static_cast(values[i]) << ' '; } } diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index a7836b170..27b006a0a 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -29,10 +29,10 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - InputIDs input_ids_{*this}; + DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; - KV_Cache kv_cache_{*this}; - PositionInputs position_inputs_; + DefaultKeyValueCache kv_cache_{*this}; + DefaultPositionInputs position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index d21f77600..b18fea179 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -58,7 +58,7 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const { } bool IntermediatePipelineState::SupportsPrimaryDevice() const { - if (model_.device_type_ == DeviceType::CPU) { + if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN) { return true; } else if (model_.device_type_ == DeviceType::CUDA) { if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) { @@ -91,13 +91,14 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode const GeneratorParams& params) : State{params, model}, model_{model}, - position_inputs_{model, *this, sequence_lengths} { - input_ids_.Add(); - position_inputs_.Add(); + input_ids_{CreateInputIDs(*this)}, + key_value_cache_{CreateKeyValueCache(*this)}, + position_inputs_{CreatePositionInputs(*this, sequence_lengths)} { + input_ids_->Add(); + position_inputs_->Add(); logits_.Add(); - if (KV_Cache::IsCacheNeeded(model)) { - kv_cache_ = std::make_unique(*this); - kv_cache_->Add(); + if (key_value_cache_) { + key_value_cache_->Add(); } extra_inputs_.Add(); @@ -106,10 +107,8 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode } } -DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, - DeviceSpan next_indices) { - UpdateInputsOutputs(next_tokens, next_indices, total_length); - +void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices) { for (auto& pipeline_state : pipeline_states_) { if (first_run_ && !model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_prompt) { continue; @@ -218,6 +217,28 @@ DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices) { + UpdateInputsOutputs(next_tokens, next_indices, total_length); + + size_t num_chunks{1}; + if (first_run_ && model_.config_->model.decoder.sliding_window.has_value()) { + int window_size = model_.config_->model.decoder.sliding_window->window_size; + num_chunks = (next_tokens.size() + window_size - 1) / window_size; + } + + for (size_t i = 0; i < num_chunks; ++i) { + RunPipeline(total_length, next_tokens, next_indices); + + if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) { + // Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask + input_ids_->Update(next_tokens); + key_value_cache_->Update(next_indices, total_length); + position_inputs_->Update(next_tokens, total_length, static_cast(input_ids_->GetShape()[1])); + } + } // Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier. if (!first_run_) { @@ -239,10 +260,10 @@ DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan beam_indices, int total_length) { - input_ids_.Update(next_tokens); - size_t new_length = input_ids_.GetShape()[1]; - position_inputs_.Update(next_tokens, total_length, static_cast(new_length)); - if (kv_cache_) kv_cache_->Update(beam_indices, total_length); + input_ids_->Update(next_tokens); + size_t new_length = input_ids_->GetShape()[1]; + position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); + if (key_value_cache_) key_value_cache_->Update(beam_indices, total_length); logits_.Update(next_tokens, new_length); } diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index b778ca92b..f9e9c4d29 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -58,6 +58,9 @@ struct DecoderOnlyPipelineState : State { OrtValue* GetOutput(const char* name) override; + void RunPipeline(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices); + private: void UpdateInputsOutputs(DeviceSpan& next_tokens, DeviceSpan next_indices, int total_length); @@ -68,10 +71,10 @@ struct DecoderOnlyPipelineState : State { // Stores all the outputs from the previous pipeline state(s) std::unordered_map> ortvalue_store_; - InputIDs input_ids_{*this}; + std::unique_ptr input_ids_; Logits logits_{*this}; - std::unique_ptr kv_cache_; - PositionInputs position_inputs_; + std::unique_ptr key_value_cache_; + std::unique_ptr position_inputs_; ExtraInputs extra_inputs_{*this}; }; diff --git a/src/models/gpt.h b/src/models/gpt.h index bbd250b24..f50e51406 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -27,10 +27,10 @@ struct Gpt_State : State { const Gpt_Model& model_; - InputIDs input_ids_{*this}; + DefaultInputIDs input_ids_{*this}; Logits logits_{*this}; - KV_Cache_Combined kv_cache_{*this}; - PositionInputs position_inputs_; + CombinedKeyValueCache kv_cache_{*this}; + DefaultPositionInputs position_inputs_; ExtraInputs extra_inputs_{*this}; }; } // namespace Generators diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 71f051bfc..f99907b59 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -5,7 +5,7 @@ namespace Generators { -InputIDs::InputIDs(State& state) +DefaultInputIDs::DefaultInputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); shape_ = {state_.params_->BatchBeamSize(), 0}; @@ -41,7 +41,7 @@ InputIDs::InputIDs(State& state) } } -void InputIDs::Add() { +void DefaultInputIDs::Add() { input_index_ = state_.inputs_.size(); state_.inputs_.push_back(value_.get()); @@ -55,7 +55,7 @@ void InputIDs::Add() { } } -void InputIDs::Update(DeviceSpan& new_tokens) { +void DefaultInputIDs::Update(DeviceSpan& new_tokens) { const auto get_unpadded_sequence_length = [](std::span input_ids, int32_t pad_token_id) { int32_t seq_length = 0; @@ -191,4 +191,71 @@ void InputIDs::Update(DeviceSpan& new_tokens) { is_prompt_ = false; } +WindowedInputIDs::WindowedInputIDs(State& state) : state_{state} { + name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); + + if (!model_.config_->model.decoder.sliding_window.has_value()) { + throw std::runtime_error("Sliding a window over input_ids requires sliding_window to be set in the genai_config.json."); + } + + if (state_.params_->BatchBeamSize() != 1) { + throw std::runtime_error("Batch beam size must be 1 for sliding a window over input_ids."); + } + + window_size_ = model_.config_->model.decoder.sliding_window->window_size; + shape_ = {1, model_.config_->model.decoder.sliding_window->window_size}; + type_ = model_.session_info_->GetInputDataType(name_); + + if (type_ != Ort::TypeToTensorType) { + throw std::runtime_error("WindowedInputIDs only supports int32_t input_ids."); + } +} + +void WindowedInputIDs::Add() { + input_index_ = state_.inputs_.size(); + + state_.inputs_.push_back(value_.get()); + state_.input_names_.push_back(name_); +} + +void WindowedInputIDs::Update(DeviceSpan& new_tokens) { + if (window_index_ == 0) { + num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_; + + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); + + // new_tokens will always be padded so that it's size is a multiple of window_size_ + // new_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 0, value_ -> [0, a, b] + std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData()); + } else if (window_index_ < num_windows_) { + // new_tokens -> [a, b, c, d, e] + // window_size = 3, num_windows = 2 + // window_index = 1, value_ -> [c, d, e] + std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData()); + } else { + // All prompt token chunks have been processed. Now we process the tokens generated by the model. + // new_tokens -> [f] + assert(new_tokens.size() == 1); + if (shape_[1] != 1) { + shape_[1] = 1; + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); + } + + value_->GetTensorMutableData()[0] = new_tokens.Span()[0]; + } + + state_.inputs_[input_index_] = value_.get(); + window_index_++; +} + +std::unique_ptr CreateInputIDs(State& state) { + if (state.model_.config_->model.decoder.sliding_window.has_value()) { + return std::make_unique(state); + } else { + return std::make_unique(state); + } +} + } // namespace Generators diff --git a/src/models/input_ids.h b/src/models/input_ids.h index dd364212f..d7a229911 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -5,18 +5,25 @@ namespace Generators { struct InputIDs { - InputIDs(State& state); - InputIDs(const InputIDs&) = delete; - InputIDs& operator=(const InputIDs&) = delete; + virtual ~InputIDs() = default; + virtual void Add() = 0; + virtual std::array GetShape() const = 0; + virtual void Update(DeviceSpan& next_tokens) = 0; +}; + +struct DefaultInputIDs : InputIDs { + DefaultInputIDs(State& state); + DefaultInputIDs(const DefaultInputIDs&) = delete; + DefaultInputIDs& operator=(const DefaultInputIDs&) = delete; // Register input_ids as ORT session input. // Called only once during initialization of state. - void Add(); + void Add() override; // Resize input_ids based on size of next_tokens. // Update value with next_tokens. - void Update(DeviceSpan& next_tokens); + void Update(DeviceSpan& next_tokens) override; - auto& GetShape() const { return shape_; } + std::array GetShape() const override { return shape_; } const char* name_; OrtValue* Get() { return value_.get(); } @@ -45,4 +52,36 @@ struct InputIDs { std::unique_ptr past_sequence_length_; }; +// Certain models can only process a fixed number of tokens at a time. +// For example, given a prompt with 120 tokens, and a model that can only process 20 tokens at a time, +// this class will split the prompt into 6 windows of 20 tokens each. +// At each update step, the next window of tokens is processed. +// This is done until all windows have been processed before switching to the model-generated tokens +// which are processed one token at a time. +// In contrast, DefaultInputIDs processes all prompt tokens at once. +struct WindowedInputIDs : public InputIDs { + WindowedInputIDs(State& state); + WindowedInputIDs(const WindowedInputIDs&) = delete; + WindowedInputIDs& operator=(const WindowedInputIDs&) = delete; + + void Add() override; + void Update(DeviceSpan& next_tokens) override; + std::array GetShape() const override { return shape_; } + + private: + State& state_; + const Model& model_{state_.model_}; + size_t input_index_{~0U}; + size_t window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; + const char* name_; + std::array shape_{}; + ONNXTensorElementDataType type_; + + std::unique_ptr value_; +}; + +std::unique_ptr CreateInputIDs(State& state); + } // namespace Generators diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 5cb66ade3..0a47355a9 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -1,6 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "../generators.h" #include "model.h" #include "kv_cache.h" +#include "threadpool.h" namespace Generators { @@ -17,9 +21,17 @@ std::string ComposeKeyValueName(const std::string& template_string, int index) { return std::string(key_value_name); } +int64_t ElementCountFromShape(const std::array& shape) { + return std::accumulate(shape.begin(), shape.end(), int64_t{1}, std::multiplies()); +} + +bool IsCacheNeeded(const Model& model) { + return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); +} + } // namespace -KV_Cache_Combined::KV_Cache_Combined(State& state) +CombinedKeyValueCache::CombinedKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, shape_{2, state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} { @@ -42,7 +54,7 @@ KV_Cache_Combined::KV_Cache_Combined(State& state) } } -void KV_Cache_Combined::Add() { +void CombinedKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -54,7 +66,7 @@ void KV_Cache_Combined::Add() { } } -void KV_Cache_Combined::Update(DeviceSpan beam_indices, int total_length) { +void CombinedKeyValueCache::Update(DeviceSpan beam_indices, int total_length) { assert(state_.params_->search.num_beams == 1 || !beam_indices.empty()); // We require beam_indices if we're a beam search if (!is_first_update_) { @@ -77,7 +89,7 @@ void KV_Cache_Combined::Update(DeviceSpan beam_indices, int total_lengt is_first_update_ = false; } -void KV_Cache_Combined::RewindTo(size_t index) { +void CombinedKeyValueCache::RewindTo(size_t index) { if (shape_[3] <= static_cast(index)) { throw std::runtime_error("Requested length of rewind is greater than the current length."); } @@ -96,7 +108,7 @@ void KV_Cache_Combined::RewindTo(size_t index) { } template -void KV_Cache_Combined::RewindPastTensorsTo(size_t index) { +void CombinedKeyValueCache::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[3] >= static_cast(index)); std::array new_shape = shape_; new_shape[3] = static_cast(index); @@ -131,7 +143,7 @@ void KV_Cache_Combined::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices_device, int index) { +void CombinedKeyValueCache::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.CopyDeviceToCpu(); auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4]; auto past_key_size = shape_[1] * block_size_per_beam; @@ -172,7 +184,7 @@ void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices_device, i pasts_[index] = std::move(past); } -void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices, int index) { +void CombinedKeyValueCache::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -180,11 +192,7 @@ void KV_Cache_Combined::PickPastState(DeviceSpan beam_indices, int inde } } -bool KV_Cache::IsCacheNeeded(const Model& model) { - return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0)); -} - -KV_Cache::KV_Cache(State& state) +DefaultKeyValueCache::DefaultKeyValueCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")}, @@ -249,7 +257,7 @@ KV_Cache::KV_Cache(State& state) } } -void KV_Cache::AddEncoder() { +void DefaultKeyValueCache::AddEncoder() { // We don't set the input_index_ & output_index_ because the encoder step only runs once, there's no update for (int i = 0; i < layer_count_ * 2; ++i) { @@ -258,7 +266,7 @@ void KV_Cache::AddEncoder() { } } -void KV_Cache::Add() { +void DefaultKeyValueCache::Add() { input_index_ = state_.inputs_.size(); output_index_ = state_.outputs_.size(); @@ -277,7 +285,7 @@ void KV_Cache::Add() { } } -void KV_Cache::Update(DeviceSpan beam_indices, int total_length) { +void DefaultKeyValueCache::Update(DeviceSpan beam_indices, int total_length) { // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; @@ -302,7 +310,7 @@ void KV_Cache::Update(DeviceSpan beam_indices, int total_length) { is_first_update_ = false; } -void KV_Cache::RewindTo(size_t index) { +void DefaultKeyValueCache::RewindTo(size_t index) { if (past_present_share_buffer_) { return; } else if (shape_[2] <= static_cast(index)) { @@ -323,7 +331,7 @@ void KV_Cache::RewindTo(size_t index) { } template -void KV_Cache::RewindPastTensorsTo(size_t index) { +void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[2] >= static_cast(index) && !past_present_share_buffer_); std::array new_shape = shape_; new_shape[2] = static_cast(index); @@ -358,7 +366,7 @@ void KV_Cache::RewindPastTensorsTo(size_t index) { // Copy present state to past state reordered by the beam_indices template -void KV_Cache::PickPastState(DeviceSpan beam_indices_device, int index) { +void DefaultKeyValueCache::PickPastState(DeviceSpan beam_indices_device, int index) { std::span beam_indices = beam_indices_device.Span(); auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3]; auto element_count = shape_[0] * block_size_per_beam; @@ -390,7 +398,7 @@ void KV_Cache::PickPastState(DeviceSpan beam_indices_device, int index) pasts_[index] = std::move(past_value); } -void KV_Cache::PickPastState(DeviceSpan beam_indices, int index) { +void DefaultKeyValueCache::PickPastState(DeviceSpan beam_indices, int index) { if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { @@ -398,7 +406,7 @@ void KV_Cache::PickPastState(DeviceSpan beam_indices, int index) { } } -Cross_Cache::Cross_Cache(State& state) +CrossCache::CrossCache(State& state) : state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 1500, model_.config_->model.decoder.head_size} { @@ -421,18 +429,259 @@ Cross_Cache::Cross_Cache(State& state) } } -void Cross_Cache::AddOutputs() { +void CrossCache::AddOutputs() { for (int i = 0; i < layer_count_ * 2; ++i) { state_.outputs_.push_back(values_[i].get()); state_.output_names_.push_back(output_name_strings_[i].c_str()); } } -void Cross_Cache::AddInputs() { +void CrossCache::AddInputs() { for (int i = 0; i < layer_count_ * 2; ++i) { state_.inputs_.push_back(values_[i].get()); state_.input_names_.push_back(input_name_strings_[i].c_str()); } } +WindowedKeyValueCache::WindowedKeyValueCache(State& state) + : state_{state}, + layer_count_{model_.config_->model.decoder.num_hidden_layers}, + window_size_{model_.config_->model.decoder.sliding_window->window_size}, + key_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, model_.config_->model.context_length - window_size_}, + key_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, window_size_}, + value_cache_shape_in_{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.context_length - window_size_, model_.config_->model.decoder.head_size}, + value_cache_shape_out_{model_.config_->model.decoder.num_key_value_heads, 1, + window_size_, model_.config_->model.decoder.head_size} { + if (layer_count_ == 0) { + throw std::runtime_error("Expected there to be at least 1 layer in the model. Actual: " + + std::to_string(layer_count_) + ". Please check the num_hidden_layers attribute in the model configuration."); + } + for (int i = 0; i < layer_count_; ++i) { + input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, i)); + input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, i)); + + output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_key_names, i)); + output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_value_names, i)); + } + + type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); + if (type_ != Ort::TypeToTensorType) { + throw std::runtime_error("Expected input data type to be uint8_t for WindowedKeyValueCache. Actual: " + + std::to_string(type_)); + } + + for (int i = 0; i < layer_count_; ++i) { + key_caches_in_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_)); + std::fill_n(key_caches_in_[i]->GetTensorMutableData(), + ElementCountFromShape(key_cache_shape_in_), + static_cast(model_.config_->model.decoder.sliding_window->pad_value)); + + value_caches_in_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_)); + std::fill_n(value_caches_in_[i]->GetTensorMutableData(), + ElementCountFromShape(value_cache_shape_in_), + static_cast(model_.config_->model.decoder.sliding_window->pad_value)); + + key_caches_out_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_)); + value_caches_out_.push_back( + OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_)); + } +} + +void WindowedKeyValueCache::Add() { + input_index_ = state_.inputs_.size(); + output_index_ = state_.outputs_.size(); + + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + state_.inputs_.push_back(key_caches_in_[layer_idx].get()); + state_.input_names_.push_back(input_name_strings_[2 * layer_idx].c_str()); + + state_.inputs_.push_back(value_caches_in_[layer_idx].get()); + state_.input_names_.push_back(input_name_strings_[2 * layer_idx + 1].c_str()); + + state_.outputs_.push_back(key_caches_out_[layer_idx].get()); + state_.output_names_.push_back(output_name_strings_[2 * layer_idx].c_str()); + + state_.outputs_.push_back(value_caches_out_[layer_idx].get()); + state_.output_names_.push_back(output_name_strings_[2 * layer_idx + 1].c_str()); + } +} + +void WindowedKeyValueCache::Slide() { + ThreadPool thread_pool{static_cast(layer_count_)}; + thread_pool.Compute([&](size_t layer_idx) { + uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[layer_idx]->GetTensorMutableData(); + + int64_t num_key_cache_chunks = key_cache_shape_in_[0] * key_cache_shape_in_[2]; + for (int64_t j = 0; j < num_key_cache_chunks; ++j) { + { + cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in_[3], + key_cache_shape_in_[3] - window_size_); + cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in_[3] + window_size_, + key_cache_shape_in_[3] - window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + { + cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in_[3] + key_cache_shape_in_[3] - window_size_, + window_size_); + cpu_span key_cache_src(key_cache_out_data + j * key_cache_shape_out_[3], + window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + } + + uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[layer_idx]->GetTensorMutableData(); + + for (int64_t j = 0; j < value_cache_shape_in_[0]; ++j) { + { + cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]), + (value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]); + cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + (window_size_ * value_cache_shape_in_[3]), + (value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + { + cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + ((value_cache_shape_in_[2] - window_size_) * value_cache_shape_in_[3]), + window_size_ * value_cache_shape_in_[3]); + cpu_span value_cache_src(value_cache_out_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]), + window_size_ * value_cache_shape_out_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + } + }); +} + +void WindowedKeyValueCache::Update(DeviceSpan beam_indices, int current_length) { + if (is_first_update_) { + num_windows_ = (current_length + window_size_ - 1) / window_size_; + is_first_update_ = false; + window_index_++; + return; + } else if (window_size_ == 1 || window_index_ < num_windows_) { + Slide(); + window_index_++; + return; + } + + // Transition from prompt processing to token generation. + // Concatenate the last window_size_ elements to the end of the cache + + // key_caches_in_ = Concat(key_caches_in_[:, :, :, 1:], key_caches_out_) + // [num_key_value_heads, 1, head_size, context_length-1] = [num_key_value_heads, 1, head_size, context_length - window_size_ - 1] + + // [num_key_value_heads, 1, head_size, window_size_] + // value_cache = Concat(value_caches_in_[:, :, 1:, :], value_caches_out_) + // [num_key_value_heads, 1, context_length - 1, head_size] = [num_key_value_heads, 1, context_length - window_size_ - 1, head_size] + + // [num_key_value_heads, 1, window_size_, head_size] + + int updated_window_size = 1; + auto updated_key_cache_shape_in = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, + model_.config_->model.context_length - updated_window_size}; + + auto updated_value_cache_shape_in = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.context_length - updated_window_size, + model_.config_->model.decoder.head_size}; + + auto updated_key_cache_shape_out = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + model_.config_->model.decoder.head_size, + updated_window_size}; + + auto updated_value_cache_shape_out = std::array{model_.config_->model.decoder.num_key_value_heads, 1, + updated_window_size, + model_.config_->model.decoder.head_size}; + + ThreadPool thread_pool{static_cast(layer_count_)}; + thread_pool.Compute([&](size_t layer_idx) { + std::unique_ptr key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_); + + uint8_t* key_cache_data = key_cache->GetTensorMutableData(); + uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* key_cache_out_data = key_caches_out_[layer_idx]->GetTensorMutableData(); + + int64_t num_key_cache_chunks = updated_key_cache_shape_in[0] * updated_key_cache_shape_in[2]; + for (int64_t j = 0; j < num_key_cache_chunks; ++j) { + { + cpu_span key_cache_dst(key_cache_data + j * updated_key_cache_shape_in[3], + updated_key_cache_shape_in[3] - updated_window_size); + cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in_[3] + updated_window_size, + key_cache_shape_in_[3] - updated_window_size); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + { + cpu_span key_cache_dst(key_cache_data + j * updated_key_cache_shape_in[3] + + key_cache_shape_in_[3] - updated_window_size, + window_size_); + cpu_span key_cache_src(key_cache_out_data + j * key_cache_shape_out_[3], + window_size_); + std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); + } + } + + key_caches_in_[layer_idx] = std::move(key_cache); + key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_); + + std::unique_ptr value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_); + + uint8_t* value_cache_data = value_cache->GetTensorMutableData(); + uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData(); + uint8_t* value_cache_out_data = value_caches_out_[layer_idx]->GetTensorMutableData(); + + for (int64_t j = 0; j < updated_value_cache_shape_in[0]; ++j) { + { + cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]), + (value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]); + cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_in_[2] * value_cache_shape_in_[3]) + + (updated_window_size * value_cache_shape_in_[3]), + (value_cache_shape_in_[2] - updated_window_size) * value_cache_shape_in_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + { + cpu_span value_cache_dst(value_cache_data + (j * updated_value_cache_shape_in[2] * updated_value_cache_shape_in[3]) + + ((value_cache_shape_in_[2] - updated_window_size) * updated_value_cache_shape_in[3]), + value_cache_shape_out_[2] * value_cache_shape_out_[3]); + cpu_span value_cache_src(value_cache_out_data + (j * value_cache_shape_out_[2] * value_cache_shape_out_[3]), + value_cache_shape_out_[2] * value_cache_shape_out_[3]); + std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); + } + } + + value_caches_in_[layer_idx] = std::move(value_cache); + value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_); + }); + + window_size_ = 1; + key_cache_shape_in_ = updated_key_cache_shape_in; + value_cache_shape_in_ = updated_value_cache_shape_in; + key_cache_shape_out_ = updated_key_cache_shape_out; + value_cache_shape_out_ = updated_value_cache_shape_out; + + for (size_t layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { + state_.inputs_[input_index_ + 2 * layer_idx] = key_caches_in_[layer_idx].get(); + state_.inputs_[input_index_ + 2 * layer_idx + 1] = value_caches_in_[layer_idx].get(); + state_.outputs_[output_index_ + 2 * layer_idx] = key_caches_out_[layer_idx].get(); + state_.outputs_[output_index_ + 2 * layer_idx + 1] = value_caches_out_[layer_idx].get(); + } +} + +std::unique_ptr CreateKeyValueCache(State& state) { + if (!IsCacheNeeded(state.model_)) { + return nullptr; + } + + if (state.model_.config_->model.decoder.sliding_window) { + return std::make_unique(state); + } else { + return std::make_unique(state); + } +} + } // namespace Generators diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 39cbf2e7c..0e871d938 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -4,18 +4,29 @@ namespace Generators { -struct KV_Cache_Combined { - KV_Cache_Combined(State& state); +struct KeyValueCache { + virtual ~KeyValueCache() = default; + virtual void Add() = 0; + virtual void AddEncoder() = 0; + virtual void Update(DeviceSpan beam_indices, int total_length) = 0; + virtual void RewindTo(size_t index) = 0; +}; + +struct CombinedKeyValueCache : KeyValueCache { + CombinedKeyValueCache(State& state); - void Add(); // Add to state inputs/outputs - void Update(DeviceSpan beam_indices, int total_length); - void RewindTo(size_t index); + void Add() override; // Add to state inputs/outputs + void AddEncoder() override { + throw std::runtime_error("CombinedKeyValueCache does not support AddEncoder."); + }; + void Update(DeviceSpan beam_indices, int total_length) override; + void RewindTo(size_t index) override; + private: template void PickPastState(DeviceSpan beam_indices, int index); void PickPastState(DeviceSpan beam_indices, int index); - private: template void RewindPastTensorsTo(size_t index); @@ -34,23 +45,22 @@ struct KV_Cache_Combined { std::vector input_name_strings_, output_name_strings_; }; -struct KV_Cache { - KV_Cache(State& state); +struct DefaultKeyValueCache : KeyValueCache { + DefaultKeyValueCache(State& state); - static bool IsCacheNeeded(const Model& model); - - void AddEncoder(); // If model has an initial encoder step, this is used + void AddEncoder() override; // If model has an initial encoder step, this is used // Register input_ids as ORT session input. // Called only once during initialization of state. - void Add(); + void Add() override; // Move present to past. Prepare present output for next generation iteration. - void Update(DeviceSpan beam_indices, int total_length); - void RewindTo(size_t index); + void Update(DeviceSpan beam_indices, int total_length) override; + void RewindTo(size_t index) override; + + private: template void PickPastState(DeviceSpan beam_indices, int index); void PickPastState(DeviceSpan beam_indices, int index); - private: template void RewindPastTensorsTo(size_t index); @@ -71,9 +81,9 @@ struct KV_Cache { std::vector sb_kv_caches_; }; -// Very similar to the KV_Cache, but is only created once at the encoder step, then used without modification for every decoder step -struct Cross_Cache { - Cross_Cache(State& state); +// Very similar to the DefaultKeyValueCache, but is only created once at the encoder step, then used without modification for every decoder step +struct CrossCache { + CrossCache(State& state); void AddOutputs(); void AddInputs(); @@ -89,4 +99,41 @@ struct Cross_Cache { std::vector> values_; std::vector input_name_strings_, output_name_strings_; }; + +struct WindowedKeyValueCache : KeyValueCache { + WindowedKeyValueCache(State& state); + + void Add() override; + void AddEncoder() override { + throw std::runtime_error("WindowedKeyValueCache does not support AddEncoder."); + }; + void Update(DeviceSpan beam_indices, int current_length) override; + void RewindTo(size_t index) override { + throw std::runtime_error("WindowedKeyValueCache does not support RewindTo."); + } + + private: + void Slide(); + + State& state_; + const Model& model_{state_.model_}; + int layer_count_{}; + int window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; + size_t input_index_{~0U}, output_index_{~0U}; + + std::array key_cache_shape_in_, key_cache_shape_out_; + std::array value_cache_shape_in_, value_cache_shape_out_; + ONNXTensorElementDataType type_; + + std::vector> key_caches_in_, value_caches_in_; + std::vector> key_caches_out_, value_caches_out_; + std::vector input_name_strings_, output_name_strings_; + + bool is_first_update_{true}; +}; + +std::unique_ptr CreateKeyValueCache(State& state); + } // namespace Generators diff --git a/src/models/model.cpp b/src/models/model.cpp index 4430fdb41..47f6dc4bd 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -288,29 +288,42 @@ Model::Model(std::unique_ptr config) : config_{std::move(config)} { Model::~Model() = default; -void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) { +void Model::InitDeviceAllocator(OrtSession& session) { allocator_device_ = &allocator_cpu_; + allocator_kvcache_ = &allocator_cpu_; #if USE_CUDA if (device_type_ == DeviceType::CUDA) { allocator_device_ = GetCudaAllocator(session); + allocator_kvcache_ = allocator_device_; } #endif + #if USE_DML if (device_type_ == DeviceType::DML) { memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - dml_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_device_ = dml_owned_allocator_.get(); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_device_ = owned_allocator_device_.get(); + allocator_kvcache_ = allocator_device_; } #endif - allocator_kvcache_ = allocator_device_; + #if USE_WEBGPU if (device_type_ == DeviceType::WEBGPU) { // for webgpu we only use device memory for kv_cache memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); - webgpu_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); - allocator_kvcache_ = webgpu_owned_allocator_.get(); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_kvcache_ = owned_allocator_device_.get(); } #endif + + if (device_type_ == DeviceType::QNN) { + memory_info_device_ = OrtMemoryInfo::Create("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, + OrtMemType::OrtMemTypeDefault); + owned_allocator_device_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_device_ = owned_allocator_device_.get(); + allocator_kvcache_ = allocator_device_; + } + session_info_ = std::make_unique(session); captured_graph_pool_ = std::make_shared(config_.get(), session_info_.get(), allocator_device_); } @@ -500,6 +513,15 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ for (auto& option : provider_options.options) { opts.emplace(option.first, option.second); } + + // TODO set device_type_ in a less hacky way. + // now, all QNN EP enable_htp_shared_memory_allocator option values had better be consistent... + // on the other hand, not sure if is_primary_session_options is the right thing to check here. + if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator"); + opt_it != opts.end() && opt_it->second == "1") { + device_type_ = DeviceType::QNN; + } + session_options.AppendExecutionProvider("QNN", opts); } else if (provider_options.name == "webgpu") { device_type_ = DeviceType::WEBGPU; @@ -685,6 +707,7 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, switch (device_type_) { case DeviceType::WEBGPU: case DeviceType::DML: + case DeviceType::QNN: // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs case DeviceType::CPU: for (int i = 0; i < batch_size; i++) { diff --git a/src/models/model.h b/src/models/model.h index bec481f03..9fdd7ac11 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -181,15 +181,11 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::unique_ptr dml_execution_context_; std::unique_ptr dml_readback_heap_; ComPtr dml_device_; - std::unique_ptr dml_owned_allocator_; -#endif -#if USE_WEBGPU - std::unique_ptr webgpu_owned_allocator_; - std::unique_ptr webgpu_io_binding_; -#endif -#if USE_DML || USE_WEBGPU - std::unique_ptr memory_info_device_; #endif + + std::unique_ptr owned_allocator_device_{}; // nullptr if n/a + std::unique_ptr memory_info_device_{}; // nullptr if n/a + std::shared_ptr captured_graph_pool_; std::map> pipeline_session_options_; }; diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index aa94da9d1..3cfa1bdfc 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,7 +42,7 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; int64_t num_image_tokens_; - InputIDs input_ids_{*this}; // Model input + DefaultInputIDs input_ids_{*this}; // Model input ImageFeatures image_features_{*this, ImageFeatures::Mode::Input, // Optional model input model_.config_->model.embedding.inputs.image_features, num_image_tokens_}; @@ -89,9 +89,9 @@ struct DecoderState : State { const CapturedGraphInfo* captured_graph_info_; Embeddings inputs_embeds_{*this, Embeddings::Mode::Input, // Model input model_.config_->model.decoder.inputs.embeddings}; - PositionInputs position_inputs_; // Model input - KV_Cache kv_cache_{*this}; // Model input - Logits logits_{*this}; // Model output + DefaultPositionInputs position_inputs_; // Model input + DefaultKeyValueCache kv_cache_{*this}; // Model input + Logits logits_{*this}; // Model output }; struct MultiModalPipelineState : State { diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index a6180e4ee..fde1ed7a9 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -9,7 +9,7 @@ namespace Generators { -PositionInputs::PositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk) +DefaultPositionInputs::DefaultPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk) : model_{model}, state_{state} { has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); @@ -59,7 +59,7 @@ PositionInputs::PositionInputs(const Model& model, State& state, DeviceSpan& next_tokens, int total_length, int new_length) { +void DefaultPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update if (is_first_update_) { @@ -96,7 +96,7 @@ void PositionInputs::Update(const DeviceSpan& next_tokens, int total_le is_first_update_ = false; } -void PositionInputs::RewindTo(size_t index) { +void DefaultPositionInputs::RewindTo(size_t index) { // Reset the state of the position inputs if (index == 0) { is_first_update_ = true; @@ -108,18 +108,18 @@ void PositionInputs::RewindTo(size_t index) { RewindMask(index); #endif } else - throw std::runtime_error("PositionInputs::RewindTo - Unsupported batch size"); + throw std::runtime_error("DefaultPositionInputs::RewindTo - Unsupported batch size"); } } -void PositionInputs::AddAttentionMask() { +void DefaultPositionInputs::AddAttentionMask() { mask_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(attention_mask_.get()); state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); } -void PositionInputs::AddPositionIDs() { +void DefaultPositionInputs::AddPositionIDs() { posid_input_index_ = state_.inputs_.size(); state_.inputs_.push_back(position_ids_.get()); @@ -127,7 +127,7 @@ void PositionInputs::AddPositionIDs() { } #if USE_CUDA || USE_DML -void PositionInputs::CopyNextPositionIDsToCurrent() { +void DefaultPositionInputs::CopyNextPositionIDsToCurrent() { #if USE_CUDA assert(model_.device_type_ == DeviceType::CUDA); cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), @@ -149,7 +149,7 @@ void PositionInputs::CopyNextPositionIDsToCurrent() { } #endif -void PositionInputs::CreateNextPositionIDsTensor() { +void DefaultPositionInputs::CreateNextPositionIDsTensor() { if (!sb_position_ids_) { if (position_ids_shape_[1] == 1 && position_ids_next_) { position_ids_ = std::move(position_ids_next_); @@ -167,11 +167,9 @@ void PositionInputs::CreateNextPositionIDsTensor() { } } -void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); - if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); // Reallocate position_ids when new_kv_length changes if (position_ids_shape_[1] != new_kv_length) { @@ -206,7 +204,7 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { } } -void PositionInputs::CreateNextAttentionMaskTensor(int total_length) { +void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) { if (!sb_attention_mask_) { attention_mask_shape_[1] = total_length; attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); @@ -232,18 +230,19 @@ void PositionInputs::CreateNextAttentionMaskTensor(int total_length) { } } -void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); if (DeviceType::DML == model_.device_type_ && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); + throw std::runtime_error("DefaultPositionInputs::UpdatePositionIDs - DML does not support continuous decoding."); CreateNextAttentionMaskTensor(total_length); state_.inputs_[mask_input_index_] = attention_mask_.get(); switch (model_.device_type_) { case DeviceType::WEBGPU: - case DeviceType::CPU: { + case DeviceType::CPU: + case DeviceType::QNN: { type_ == Ort::TypeToTensorType ? UpdateAttentionMaskImpl(total_length) : UpdateAttentionMaskImpl(total_length); break; @@ -280,7 +279,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } #endif default: - throw std::runtime_error("PositionInputs::Update - Unsupported device type"); + throw std::runtime_error("DefaultPositionInputs::Update - Unsupported device type"); } #if USE_DML if (model_.device_type_ != DeviceType::DML) { @@ -294,7 +293,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } template -void PositionInputs::CreateAndInitializePositionIDs(const DeviceSpan& next_tokens, std::array shape) { +void DefaultPositionInputs::CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -324,7 +323,7 @@ void PositionInputs::CreateAndInitializePositionIDs(const DeviceSpan& n } template -void PositionInputs::CreateAndInitializeAttentionMask(const DeviceSpan& next_tokens, std::array shape) { +void DefaultPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); @@ -348,14 +347,14 @@ void PositionInputs::CreateAndInitializeAttentionMask(const DeviceSpan& } template -void PositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { +void DefaultPositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { sequence_lengths_unk[i] = 0; } } template -void PositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) { +void DefaultPositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) { auto* data = position_ids_->GetTensorMutableData(); if (position_ids_shape_[0] == 1) { // For batch size == 1 we calculate position ids with total length and new kv length for continuous decoding @@ -369,7 +368,7 @@ void PositionInputs::UpdatePositionIDsImpl(int total_length, int new_kv_length) } #if USE_DML -void PositionInputs::UpdatePositionIDsImplDML() { +void DefaultPositionInputs::UpdatePositionIDsImplDML() { ComPtr target_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); @@ -388,7 +387,7 @@ void PositionInputs::UpdatePositionIDsImplDML() { #endif template -void PositionInputs::UpdateAttentionMaskImpl(int total_length) { +void DefaultPositionInputs::UpdateAttentionMaskImpl(int total_length) { auto* data = attention_mask_next_->GetTensorMutableData(); auto* old_data = attention_mask_->GetTensorData(); if (attention_mask_shape_[0] == 1) { @@ -407,7 +406,7 @@ void PositionInputs::UpdateAttentionMaskImpl(int total_length) { } #if USE_DML -void PositionInputs::UpdateAttentionMaskImplDML(int total_length) { +void DefaultPositionInputs::UpdateAttentionMaskImplDML(int total_length) { ComPtr attention_mask_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); ComPtr attention_mask_next_resource; @@ -442,7 +441,7 @@ void PositionInputs::UpdateAttentionMaskImplDML(int total_length) { #endif #if USE_CUDA -void PositionInputs::RewindMask(size_t index) { +void DefaultPositionInputs::RewindMask(size_t index) { if (sb_attention_mask_ && !is_first_mask_update_) { int past_length = static_cast(index); int max_length = static_cast(state_.params_->search.max_length); @@ -458,4 +457,149 @@ void PositionInputs::RewindMask(size_t index) { } #endif +WindowedPositionInputs::WindowedPositionInputs(State& state) + : state_{state} { + has_posid_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids); + has_mask_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.attention_mask); + + if (has_posid_input_ || has_mask_input_) { + if (!model_.config_->model.decoder.sliding_window.has_value()) { + throw std::runtime_error("Sliding a window over position_ids and attention_mask requires sliding_window to be set in the genai_config.json."); + } + window_size_ = model_.config_->model.decoder.sliding_window->window_size; + } + + if (has_posid_input_) { + position_ids_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.position_ids); + if (position_ids_type_ != Ort::TypeToTensorType) + throw std::runtime_error("WindowedPositionInputs only supports int32_t position_ids"); + + position_ids_shape_ = {1, model_.config_->model.decoder.sliding_window->window_size}; + } + + if (has_mask_input_) { + attention_mask_type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); + if (attention_mask_type_ != Ort::TypeToTensorType) + throw std::runtime_error("WindowedPositionInputs only supports int32_t attention_mask"); + + attention_mask_shape_ = {1, model_.config_->model.context_length}; + } +} + +void WindowedPositionInputs::Add() { + if (has_posid_input_) { + position_ids_index_ = state_.inputs_.size(); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); + state_.inputs_.push_back(position_ids_.get()); + } + + if (has_mask_input_) { + attention_mask_index_ = state_.inputs_.size(); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); + state_.inputs_.push_back(attention_mask_.get()); + } +} + +void WindowedPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { + if (window_index_ == 0) { + num_windows_ = (next_tokens.size() + window_size_ - 1) / window_size_; + if (has_posid_input_) { + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, position_ids_type_); + + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 0, position_ids_ -> [0, 0, 1] + auto* position_ids_data = position_ids_->GetTensorMutableData(); + for (int i = 0, j = 0; i < position_ids_shape_[1]; i++) { + if (next_tokens.Span()[i] == model_.config_->model.pad_token_id) { + position_ids_data[i] = 0; + } else { + position_ids_data[i] = j++; + } + } + } + + if (has_mask_input_) { + attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, attention_mask_shape_, attention_mask_type_); + + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 0, attention_mask_ -> ([0] * context_length - window_size_) + [0, 1, 1] + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill_n(attention_mask_data, attention_mask_shape_[1] - window_size_, 0); + for (size_t i = 0; i < window_size_; i++) { + attention_mask_data[attention_mask_shape_[1] - window_size_ + i] = next_tokens.CpuSpan()[i] == model_.config_->model.pad_token_id ? 0 : 1; + } + for (size_t i = 0; i < window_size_; i++) { + if (attention_mask_data[attention_mask_shape_[1] - window_size_ + i] == 1) { + attention_mask_backward_offset_ = attention_mask_shape_[1] - window_size_ + i - 1; + break; + } + } + } + } else if (window_index_ < num_windows_) { + if (has_posid_input_) { + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 1, position_ids_ -> [2, 3, 4] + + auto* position_ids_data = position_ids_->GetTensorMutableData(); + const auto last_position = position_ids_data[window_size_ - 1]; + std::iota(position_ids_data, position_ids_data + window_size_, last_position + 1); + } + + if (has_mask_input_) { + // next_tokens will always be padded so that it's size is a multiple of window_size_ + // next_tokens -> [0, a, b, c, d, e] + // window_size = 3, num_windows = 2, pad_token = 0 + // window_index = 1, attention_mask_ -> ([0] * context_length - (2 * window_size_)) + [0, 1, 1, 1, 1, 1] + auto* attention_mask_data = attention_mask_->GetTensorMutableData(); + std::fill_n(attention_mask_data + attention_mask_backward_offset_ - window_size_ + 1, window_size_, 1); + attention_mask_backward_offset_ -= window_size_; + } + } else { + // All prompt token chunks have been processed. Now we process the tokens generated by the model. + if (has_posid_input_) { + // next_tokens -> [f] + // position_ids_ -> [5] + const auto last_position = position_ids_->GetTensorData()[position_ids_shape_[1] - 1]; + if (position_ids_shape_[1] != 1) { + position_ids_shape_[1] = 1; + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, position_ids_type_); + } + position_ids_->GetTensorMutableData()[0] = last_position + 1; + } + + if (has_mask_input_) { + // next_tokens -> [f] + // attention_mask_ -> ([0] * context_length - (2 * window_size_) - 1) + [0, 1, 1, 1, 1, 1, 1] + attention_mask_->GetTensorMutableData()[attention_mask_backward_offset_] = 1; + if (attention_mask_backward_offset_ > 0) { + attention_mask_backward_offset_ -= 1; + } + } + } + + if (has_posid_input_) { + state_.inputs_[position_ids_index_] = position_ids_.get(); + } + + if (has_mask_input_) { + state_.inputs_[attention_mask_index_] = attention_mask_.get(); + } + + window_index_++; +} + +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths) { + if (state.model_.config_->model.decoder.sliding_window.has_value()) { + return std::make_unique(state); + } else { + return std::make_unique(state.model_, state, sequence_lengths); + } +} + } // namespace Generators diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 64a38779c..4365e0ee4 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -10,12 +10,19 @@ namespace Generators { struct PositionInputs { - PositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk); + virtual ~PositionInputs() = default; + virtual void Add() = 0; + virtual void Update(DeviceSpan next_tokens, int total_length, int new_length) = 0; + virtual void RewindTo(size_t index) = 0; +}; + +struct DefaultPositionInputs : PositionInputs { + DefaultPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk); - void Add(); - void Update(const DeviceSpan& next_tokens, int total_length, int new_length); + void Add() override; + void Update(DeviceSpan next_tokens, int total_length, int new_length) override; - void RewindTo(size_t index); + void RewindTo(size_t index) override; private: void AddAttentionMask(); @@ -30,9 +37,9 @@ struct PositionInputs { template void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); template - void CreateAndInitializePositionIDs(const DeviceSpan& next_tokens, std::array shape); + void CreateAndInitializePositionIDs(DeviceSpan next_tokens, std::array shape); template - void CreateAndInitializeAttentionMask(const DeviceSpan& next_tokens, std::array shape); + void CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape); template void UpdatePositionIDsImpl(int total_length, int new_kv_length); @@ -86,4 +93,49 @@ struct PositionInputs { #endif }; +// Certain models can only process a fixed number of tokens at a time. +// For example, given a prompt with 120 tokens, and a model that can only process 20 tokens at a time, +// this class will split the position ids into 6 windows of 20 tokens each. +// At each update step, the next window of position ids is prepared. +// This is done until all windows have been processed before switching to the model-generation phase +// where position ids are prepared one id at a time. +// This class will also prepare the attention mask for each iteration. The attention mask buffer is allocated just +// once and reused for each iteration by setting the mask to 1 for current window tokens and previously active window tokens +// In contrast, DefaultPositionInputs processes all position ids at once. +struct WindowedPositionInputs : PositionInputs { + WindowedPositionInputs(State& state); + WindowedPositionInputs(const WindowedPositionInputs&) = delete; + WindowedPositionInputs& operator=(const WindowedPositionInputs&) = delete; + + void Add() override; + void Update(DeviceSpan next_tokens, int total_length, int new_length) override; + void RewindTo(size_t index) override { + throw std::runtime_error("WindowedPositionInputs does not support RewindTo."); + }; + + private: + State& state_; + const Model& model_{state_.model_}; + + bool has_mask_input_{}; + bool has_posid_input_{}; + + std::array position_ids_shape_{}; + ONNXTensorElementDataType position_ids_type_{}; + std::unique_ptr position_ids_; + std::array attention_mask_shape_{}; + ONNXTensorElementDataType attention_mask_type_{}; + std::unique_ptr attention_mask_; + size_t attention_mask_backward_offset_{~0U}; + + size_t attention_mask_index_{~0U}; + size_t position_ids_index_{~0U}; + + size_t window_size_{}; + size_t num_windows_{}; + size_t window_index_{}; +}; + +std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths); + } // namespace Generators diff --git a/src/models/threadpool.cpp b/src/models/threadpool.cpp new file mode 100644 index 000000000..1ac56f08a --- /dev/null +++ b/src/models/threadpool.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "threadpool.h" + +namespace Generators { + +ThreadPool::ThreadPool(size_t num_threads) : num_threads_{num_threads} {} + +void ThreadPool::Compute(const std::function& func) { + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back([&, i] { func(i); }); + } + + for (auto& thread : threads_) { + thread.join(); + } + + threads_.clear(); +} + +} // namespace Generators \ No newline at end of file diff --git a/src/models/threadpool.h b/src/models/threadpool.h new file mode 100644 index 000000000..81fe98ae7 --- /dev/null +++ b/src/models/threadpool.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +namespace Generators { + +struct ThreadPool { + ThreadPool(size_t num_threads); + + void Compute(const std::function& func); + + private: + size_t num_threads_; + std::vector threads_; +}; + +} // namespace Generators diff --git a/src/models/whisper.h b/src/models/whisper.h index ab7e508d6..34eecd0ff 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -34,10 +34,10 @@ struct Whisper_State : State { Decoder, } run_state_{RunState::Encoder_Decoder_Init}; - InputIDs decoder_input_ids_{*this}; + DefaultInputIDs decoder_input_ids_{*this}; Logits logits_{*this}; - KV_Cache kv_cache_{*this}; - Cross_Cache cross_cache_{*this}; + DefaultKeyValueCache kv_cache_{*this}; + CrossCache cross_cache_{*this}; std::unique_ptr encoder_input_ids_; std::unique_ptr encoder_hidden_states_; diff --git a/src/ort_genai.h b/src/ort_genai.h index b9e2ab9d6..8b4a026d5 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -163,7 +163,7 @@ struct OgaSequences : OgaAbstract { std::span Get(size_t index) const { return {SequenceData(index), SequenceCount(index)}; } - void Append(const std::span& sequence) { + void Append(std::span sequence) { OgaCheckResult(OgaAppendTokenSequence(sequence.data(), sequence.size(), this)); } void Append(const std::vector& sequence) { @@ -280,7 +280,7 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_AppendTokenSequences(this, &sequences)); } - void AppendTokens(int32_t* input_ids, size_t input_ids_count) { + void AppendTokens(const int32_t* input_ids, size_t input_ids_count) { OgaCheckResult(OgaGenerator_AppendTokens(this, input_ids, input_ids_count)); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index af2ee4b11..3f2c9d750 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -297,10 +297,10 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerator* oga_gene OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count) { +OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); - generator.AppendTokens(Generators::cpu_span(input_ids, input_ids_count)); + generator.AppendTokens(Generators::cpu_span(input_ids, input_ids_count)); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 8e68d09ab..b970ecbaf 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -325,7 +325,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerato * \param[in] input_ids_count The number of input ids to add (batch_size * sequence_length). * \return OgaResult containing the error message if the setting of the input ids failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count); /** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator.