Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Qualcomm AI Hub sliding window models #1138

Merged
merged 23 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3c9d563
Updates for decoder-pipeline to work with other split models
baijumeswani Oct 7, 2024
4212d19
Sync changes with main
baijumeswani Nov 11, 2024
5e4ab3d
Allow adjustments to the sliding window kv cache
baijumeswani Nov 13, 2024
0a0f98d
enable setting default ORT logging level to verbose with ORTGENAI_ORT…
edgchen1 Nov 6, 2024
7a331ef
hack to run with qnn shared memory allocator
edgchen1 Nov 15, 2024
aca6622
Make kv cache updates parallel
baijumeswani Nov 14, 2024
ad737df
Support num tokens > 128
baijumeswani Dec 2, 2024
74c4746
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
baijumeswani Dec 9, 2024
5ef9658
Documentation and create kv cache interface
baijumeswani Dec 10, 2024
a68827d
Always assign allocator_kv_cache_
baijumeswani Dec 10, 2024
2d2c3fb
Avoid using front()
baijumeswani Dec 10, 2024
3acbfc0
link against pthreads
baijumeswani Dec 10, 2024
c5ee9c0
Address pull-request review comments
baijumeswani Dec 11, 2024
0ccc668
Address pull-request review comments
baijumeswani Dec 12, 2024
11dbed2
Throw meaningful exception when user tries continous decoding
baijumeswani Dec 16, 2024
0130a51
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
baijumeswani Dec 16, 2024
43af9aa
Address pull request review comments
baijumeswani Dec 17, 2024
49f4345
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
baijumeswani Dec 17, 2024
d301fc5
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
baijumeswani Dec 17, 2024
82de85a
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
baijumeswani Dec 18, 2024
2a00a26
Rename InputIDsDefault with DefaultInputIDs
baijumeswani Dec 18, 2024
ae436ef
More merge conflicts
baijumeswani Dec 18, 2024
5a8fae9
Rename structs
baijumeswani Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/cxx_standard.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ else ()
message("Test is using C++20")
set(CMAKE_CXX_STANDARD 20)
endif ()

if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
endif ()
21 changes: 21 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ struct Pipeline_Element : JSON::Element {
PipelineModelObject_Element object_{v_};
};

struct SlidingWindow_Element : JSON::Element {
explicit SlidingWindow_Element(std::optional<Config::Model::Decoder::SlidingWindow>& v) : v_{v} {}

void OnNumber(std::string_view name, double value) override {
if (name == "window_size") {
v_->window_size = static_cast<int>(value);
} else if (name == "pad_value") {
v_->pad_value = static_cast<int>(value);
} else
throw JSON::unknown_value_error{};
}

private:
std::optional<Config::Model::Decoder::SlidingWindow>& v_;
};

struct Decoder_Element : JSON::Element {
explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {}

Expand Down Expand Up @@ -321,6 +337,10 @@ struct Decoder_Element : JSON::Element {
if (name == "outputs") {
return outputs_;
}
if (name == "sliding_window") {
v_.sliding_window = Config::Model::Decoder::SlidingWindow{};
return sliding_window_;
}
throw JSON::unknown_value_error{};
}

Expand All @@ -336,6 +356,7 @@ struct Decoder_Element : JSON::Element {
Inputs_Element inputs_{v_.inputs};
Outputs_Element outputs_{v_.outputs};
Pipeline_Element pipeline_{v_.pipeline};
SlidingWindow_Element sliding_window_{v_.sliding_window};
};

struct VisionInputs_Element : JSON::Element {
Expand Down
6 changes: 6 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ struct Config {
int num_hidden_layers{};
int head_size{};

struct SlidingWindow {
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
int window_size{128};
int pad_value{};
};
std::optional<SlidingWindow> sliding_window;

struct Inputs {
std::string input_ids{Defaults::InputIdsName};
std::string embeddings{"inputs_embeds"};
Expand Down
27 changes: 22 additions & 5 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "generators.h"
#include "sequences.h"
#include "models/env_utils.h"
#include "models/model.h"
#include "models/decoder_only.h"
#include "search.h"
Expand Down Expand Up @@ -45,8 +46,14 @@ void OnCudaError(cudaError_t error) { assert(false); }

static bool _ = (Ort::InitApi(), false);

static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
bool ort_verbose_logging = false;
GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging);
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;
}

OrtGlobals::OrtGlobals()
: env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {
: env_{OrtEnv::Create(GetDefaultOrtLoggingLevel())} {
auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1);
Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config);
Expand Down Expand Up @@ -170,6 +177,8 @@ std::string to_string(DeviceType device_type) {
return "DirectML";
case DeviceType::WEBGPU:
return "WebGpu";
case DeviceType::QNN_WITH_SHARED_MEMORY:
return "QnnWithSharedMemory";
}
throw std::runtime_error("Unknown device type");
}
Expand Down Expand Up @@ -266,15 +275,23 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
}
}

DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(const cpu_span<int32_t> input_ids) {
auto input_ids_device = state_->params_->p_device->Allocate<int32_t>(input_ids.size());
DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(cpu_span<const int32_t> input_ids) {
size_t input_ids_size = input_ids.size();
if (model_->config_->model.decoder.sliding_window.has_value()) {
// If the model has a sliding window, pad the input_ids to the next multiple of the window size
// so that the input_ids can be divided into window size chunks.
const auto window_size = model_->config_->model.decoder.sliding_window->window_size;
input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size;
}
auto input_ids_device = state_->params_->p_device->Allocate<int32_t>(input_ids_size);
auto cpu_span = input_ids_device.CpuSpan();
std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin());
std::fill_n(cpu_span.begin(), input_ids_size, model_->config_->model.pad_token_id);
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
std::copy_backward(input_ids.begin(), input_ids.end(), cpu_span.end());
input_ids_device.CopyCpuToDevice();
return input_ids_device;
}

void Generator::AppendTokens(const cpu_span<int32_t> input_ids) {
void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
throw std::runtime_error("input_ids is empty");
Expand Down
5 changes: 3 additions & 2 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ enum struct DeviceType {
CUDA,
DML,
WEBGPU,
QNN_WITH_SHARED_MEMORY,
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
};

std::string to_string(DeviceType device_type);
Expand Down Expand Up @@ -111,7 +112,7 @@ struct Generator : LeakChecked<Generator> {
Generator(const Model& model, const GeneratorParams& params);

bool IsDone() const;
void AppendTokens(const cpu_span<int32_t> input_ids);
void AppendTokens(cpu_span<const int32_t> input_ids);
void GenerateNextToken();
void RewindToLength(size_t new_length); // Rewind state to new_length
DeviceSpan<float> GetLogits();
Expand All @@ -127,7 +128,7 @@ struct Generator : LeakChecked<Generator> {
bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio

private:
DeviceSpan<int32_t> AllocateInputIdsOnDevice(const cpu_span<int32_t> input_ids);
DeviceSpan<int32_t> AllocateInputIdsOnDevice(cpu_span<const int32_t> input_ids);
void ComputeLogits(DeviceSpan<int32_t> next_tokens);
enum Action { standard, // Default, set in any other case
generated, // Set after GenerateNextToken
Expand Down
16 changes: 12 additions & 4 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,19 @@ void DumpSpan(std::ostream& stream, std::span<const T> values) {
for (auto v : values)
stream << v << ' ';
} else {
for (size_t i = 0; i < c_value_count / 2; i++)
stream << values[i] << ' ';
for (size_t i = 0; i < c_value_count / 2; i++) {
if constexpr (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value)
stream << static_cast<int>(values[i]) << ' ';
else
stream << values[i] << ' ';
}
stream << "... ";
for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++)
stream << values[i] << ' ';
for (size_t i = values.size() - c_value_count / 2; i < values.size(); i++) {
if constexpr (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value)
stream << static_cast<int>(values[i]) << ' ';
else
stream << values[i] << ' ';
}
}
}
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
51 changes: 36 additions & 15 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const {
}

bool IntermediatePipelineState::SupportsPrimaryDevice() const {
if (model_.device_type_ == DeviceType::CPU) {
if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN_WITH_SHARED_MEMORY) {
return true;
} else if (model_.device_type_ == DeviceType::CUDA) {
if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) {
Expand Down Expand Up @@ -91,13 +91,14 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
const GeneratorParams& params)
: State{params, model},
model_{model},
position_inputs_{model, *this, sequence_lengths} {
input_ids_.Add();
position_inputs_.Add();
input_ids_{CreateInputIDs(*this)},
key_value_cache_{CreateKeyValueCache(*this)},
position_inputs_{CreatePositionInputs(*this, sequence_lengths)} {
input_ids_->Add();
position_inputs_->Add();
logits_.Add();
if (KV_Cache::IsCacheNeeded(model)) {
kv_cache_ = std::make_unique<KV_Cache>(*this);
kv_cache_->Add();
if (key_value_cache_) {
key_value_cache_->Add();
}
extra_inputs_.Add();

Expand All @@ -106,10 +107,8 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
}
}

DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
UpdateInputsOutputs(next_tokens, next_indices, total_length);

void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
for (auto& pipeline_state : pipeline_states_) {
if (first_run_ && !model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_prompt) {
continue;
Expand Down Expand Up @@ -218,6 +217,28 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
}
}
}
}

DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
UpdateInputsOutputs(next_tokens, next_indices, total_length);

size_t num_chunks{1};
if (first_run_ && model_.config_->model.decoder.sliding_window.has_value()) {
int window_size = model_.config_->model.decoder.sliding_window->window_size;
num_chunks = (next_tokens.size() + window_size - 1) / window_size;
}

for (size_t i = 0; i < num_chunks; ++i) {
RunPipeline(total_length, next_tokens, next_indices);

if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) {
// Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask
input_ids_->Update(next_tokens);
key_value_cache_->Update(next_indices, total_length);
position_inputs_->Update(next_tokens, total_length, static_cast<int>(input_ids_->GetShape()[1]));
}
}

// Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier.
if (!first_run_) {
Expand All @@ -239,10 +260,10 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int

void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> beam_indices, int total_length) {
input_ids_.Update(next_tokens);
size_t new_length = input_ids_.GetShape()[1];
position_inputs_.Update(next_tokens, total_length, static_cast<int>(new_length));
if (kv_cache_) kv_cache_->Update(beam_indices, total_length);
input_ids_->Update(next_tokens);
size_t new_length = input_ids_->GetShape()[1];
position_inputs_->Update(next_tokens, total_length, static_cast<int>(new_length));
if (key_value_cache_) key_value_cache_->Update(beam_indices, total_length);
logits_.Update(next_tokens, new_length);
}

Expand Down
9 changes: 6 additions & 3 deletions src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ struct DecoderOnlyPipelineState : State {

OrtValue* GetOutput(const char* name) override;

void RunPipeline(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices);

private:
void UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices,
int total_length);
Expand All @@ -68,10 +71,10 @@ struct DecoderOnlyPipelineState : State {
// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;

InputIDs input_ids_{*this};
std::unique_ptr<InputIDsInterface> input_ids_;
Logits logits_{*this};
std::unique_ptr<KV_Cache> kv_cache_;
PositionInputs position_inputs_;
std::unique_ptr<KeyValueCacheInterface> key_value_cache_;
std::unique_ptr<PositionInputsInterface> position_inputs_;
ExtraInputs extra_inputs_{*this};
};

Expand Down
67 changes: 67 additions & 0 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,71 @@ void InputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
is_prompt_ = false;
}

SlidingWindowInputIDs::SlidingWindowInputIDs(State& state) : state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();

if (!model_.config_->model.decoder.sliding_window.has_value()) {
throw std::runtime_error("Sliding a window over input_ids requires sliding_window to be set in the genai_config.json.");
}

if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch beam size must be 1 for sliding a window over input_ids.");
}

window_size_ = model_.config_->model.decoder.sliding_window->window_size;
shape_ = {1, model_.config_->model.decoder.sliding_window->window_size};
type_ = model_.session_info_->GetInputDataType(name_);

if (type_ != Ort::TypeToTensorType<int32_t>) {
throw std::runtime_error("SlidingWindowInputIDs only supports int32_t input_ids.");
}
}

void SlidingWindowInputIDs::Add() {
input_index_ = state_.inputs_.size();

state_.inputs_.push_back(value_.get());
state_.input_names_.push_back(name_);
}

void SlidingWindowInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
if (window_index_ == 0) {
num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_;

value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);

// next_tokens will always be padded so that it's size is a multiple of window_size_
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
// next_tokens -> [0, a, b, c, d, e]
// window_size = 3, num_windows = 2, pad_token = 0
// window_index = 0, value_ -> [0, a, b]
std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData<int32_t>());
} else if (window_index_ < num_windows_) {
// next_tokens -> [a, b, c, d, e]
// window_size = 3, num_windows = 2
// window_index = 1, value_ -> [c, d, e]
std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData<int32_t>());
} else {
// All prompt token chunks have been processed. Now we process the tokens generated by the model.
// next_tokens -> [f]
assert(new_tokens.size() == 1);
if (shape_[1] != 1) {
shape_[1] = 1;
value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
}

value_->GetTensorMutableData<int32_t>()[0] = new_tokens.Span()[0];
}

state_.inputs_[input_index_] = value_.get();
window_index_++;
}

std::unique_ptr<InputIDsInterface> CreateInputIDs(State& state) {
if (state.model_.config_->model.decoder.sliding_window.has_value()) {
return std::make_unique<SlidingWindowInputIDs>(state);
} else {
return std::make_unique<InputIDs>(state);
}
}

} // namespace Generators
Loading
Loading