From bdf41a02838461b4d6580790ccb60ab10becf096 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 16 Apr 2024 15:16:01 -0700 Subject: [PATCH 1/3] Add DML support --- src/generators.h | 1 + src/models/input_ids.cpp | 4 ++++ src/models/kv_cache.cpp | 2 +- src/models/logits.cpp | 18 +++++++++++++++++ src/models/model.cpp | 35 +++++++++++++++++++++++++++++++-- src/models/position_ids.cpp | 8 ++++++++ src/python/CMakeLists.txt | 2 ++ src/python/py/models/builder.py | 9 +++++---- 8 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/generators.h b/src/generators.h index 3fb9f5201..c8668ea2e 100644 --- a/src/generators.h +++ b/src/generators.h @@ -42,6 +42,7 @@ using TokenSequences = std::vector>; enum struct DeviceType { CPU, CUDA, + DML, }; struct GeneratorParams : std::enable_shared_from_this { diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 88d2514b5..1c692c276 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -40,7 +40,11 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { // Resize input_ids shape once if it doesn't match the decoder shape if (shape_[1] != 1) { shape_[1] = 1; +#ifdef USE_DML + value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); +#else value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); +#endif state_.inputs_[input_index_] = value_.get(); } diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 17515355f..fc53d062b 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -117,7 +117,7 @@ KV_Cache::KV_Cache(const Model& model, State& state) : model_{model}, state_{state}, layer_count_{model_.config_->model.decoder.num_hidden_layers}, - past_present_share_buffer_{state_.params_->search.past_present_share_buffer && state_.params_->search.num_beams == 1 && model_.device_type_ == DeviceType::CUDA}, + past_present_share_buffer_{state_.params_->search.past_present_share_buffer && state_.params_->search.num_beams == 1 && (model_.device_type_ == DeviceType::CUDA || model_.device_type_ == DeviceType::DML)}, shape_{state_.params_->BatchBeamSize(), model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} { pasts_.resize(layer_count_ * 2); presents_.reserve(layer_count_ * 2); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 721a0b0ea..8f665366c 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -9,7 +9,11 @@ Logits::Logits(const Model& model, State& state) state_{state}, shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { +#ifdef USE_DML + auto logits_tensor = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); +#else auto logits_tensor = OrtValue::CreateTensor(*model.allocator_device_, shape_, type_); +#endif if (type_ == Ort::TypeToTensorType::type) value32_ = std::move(logits_tensor); else @@ -21,7 +25,11 @@ RoamingArray Logits::Get() { // Convert from float16 to float32 if necessary if (type_ == Ort::TypeToTensorType::type) +#if USE_DML + ConvertFp16ToFp32(model_.allocator_cpu_, *value16_, value32_, model_.device_type_, model_.cuda_stream_); +#else ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_); +#endif // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor // We'll reuse this tensor for all future iterations @@ -32,7 +40,13 @@ RoamingArray Logits::Get() { const size_t num_beams = state_.params_->search.num_beams; shape_[1] = 1; + +#if USE_DML + auto value_next = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); +#else auto value_next = OrtValue::CreateTensor(*model_.allocator_device_, shape_); +#endif + auto logits_next = cpu_span{value_next->GetTensorMutableData(), element_count}; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process @@ -65,7 +79,11 @@ RoamingArray Logits::Get() { value32_ = std::move(value_next); if (type_ == Ort::TypeToTensorType::type) +#if USE_DML + value16_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); +#else value16_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); +#endif state_.outputs_[output_index_] = type_ == Ort::TypeToTensorType::type ? value32_.get() : value16_.get(); } diff --git a/src/models/model.cpp b/src/models/model.cpp index 68147923f..ded331341 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -194,6 +194,22 @@ Ort::Allocator* GetCudaAllocator(OrtSession& session) { } #endif +#if USE_DML +// Since Python/Others can and will hold onto a generator object past the model object's lifetime we need to ensure +// the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory +// has been destroyed. +Ort::Allocator* GetDmlAllocator(OrtSession& session) { + static std::unique_ptr memory_info_dml_; + static std::unique_ptr allocator_dml_; + + if (!allocator_dml_) { + memory_info_dml_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); + allocator_dml_ = Ort::Allocator::Create(session, *memory_info_dml_); + } + return allocator_dml_.get(); +} +#endif + SessionInfo::SessionInfo(OrtSession& session) { auto input_names = session.GetInputNames(); std::vector input_types(input_names.size()); @@ -244,7 +260,12 @@ void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) { if (device_type_ == DeviceType::CUDA) { allocator_device_ = GetCudaAllocator(session); } +#elif USE_DML + if (device_type_ == DeviceType::DML) { + allocator_device_ = GetDmlAllocator(session); + } #endif + session_info_ = std::make_unique(session); } @@ -323,6 +344,7 @@ void Model::CreateSessionOptions() { ort_options.AppendExecutionProvider_ROCM(ort_provider_options); #ifdef USE_DML } else if (provider_options.name == "dml") { + device_type_ = DeviceType::DML; // We use a DML allocator for input/output caches, but other tensors will use CPU tensors const OrtDmlApi* p_dml_api{}; Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&p_dml_api))); if (!p_dml_api) @@ -383,6 +405,8 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptrGetTensorMutableData(); switch (device_type) { + case DeviceType::DML: + // DML doesn't currently support on-device scoring, so we fall back to the CPU case DeviceType::CPU: for (int i = 0; i < count; i++) fp32[i] = Float16ToFloat32(fp16[i]); @@ -436,8 +460,9 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, // Input shape (batch_size, sequence_length). The input is required with data type T. // Output shape (batch_size * num_beams, sequence_length) - // If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later - if (num_beams == 1 && device_type_ == DeviceType::CPU) { + // If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later. + // DML doesn't currently support on-device scoring, so we go the same route as the CPU + if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML)) { return std::move(input); } @@ -450,13 +475,19 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, input_shape[0] *= num_beams; +#if USE_DML + auto expanded = OrtValue::CreateTensor(allocator_cpu_, input_shape, element_type); +#else auto expanded = OrtValue::CreateTensor(*allocator_device_, input_shape, element_type); +#endif const auto* input_data = reinterpret_cast(input->GetTensorRawData()); auto* expanded_data = reinterpret_cast(expanded->GetTensorMutableRawData()); auto* target = expanded_data; switch (device_type_) { + case DeviceType::DML: + // DML doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs case DeviceType::CPU: for (int i = 0; i < batch_size; i++) { for (int j = 0; j < num_beams; j++) { diff --git a/src/models/position_ids.cpp b/src/models/position_ids.cpp index ec6ebd579..8b544929e 100644 --- a/src/models/position_ids.cpp +++ b/src/models/position_ids.cpp @@ -53,6 +53,8 @@ void PositionIDs::Update(int current_length) { state_.inputs_[input_index_] = position_ids_.get(); } else { // Just incrementing existing position IDs switch (model_.device_type_) { + case DeviceType::DML: + // DML doesn't support on-device position ids update yet, so we fall back to the CPU case DeviceType::CPU: { if (type_ == Ort::TypeToTensorType::type) UpdatePositionIDs(); @@ -79,9 +81,15 @@ void PositionIDs::Update(int current_length) { assert(attention_mask_shape_[1] == current_length - 1); // We should always be growing by 1 attention_mask_shape_[1] = current_length; +#if USE_DML + std::unique_ptr next_attention_mask = OrtValue::CreateTensor(model_.allocator_cpu_, attention_mask_shape_, type_); +#else std::unique_ptr next_attention_mask = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); +#endif switch (model_.device_type_) { + case DeviceType::DML: + // DML doesn't support on-device mask updating yet, so we fallback to the CPU case DeviceType::CPU: { if (type_ == Ort::TypeToTensorType::type) UpdateAttentionMask(next_attention_mask->GetTensorMutableData(), attention_mask_->GetTensorData(), current_length); diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index fdbe33023..f1e5e8720 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -29,6 +29,8 @@ if(BUILD_WHEEL) message("Setting up wheel files in : ${WHEEL_FILES_DIR}") if(USE_CUDA) set(TARGET_NAME "onnxruntime-genai-cuda") + elif(USE_DML) + set(TARGET_NAME "onnxruntime-genai-dml") else() set(TARGET_NAME "onnxruntime-genai") endif() diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index e0c6d28aa..9a0580483 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -159,12 +159,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) } - if self.ep == "cuda" and self.io_dtype == TensorProto.FLOAT16: + if self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16: # Change model settings for GroupQueryAttention self.attention_attrs["op_type"] = "GroupQueryAttention" - print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 CUDA and FP16 CUDA.") + print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 and FP16 on the CUDA and DML execution providers.") - self.attention_attrs["use_packed_matmul"] = self.num_attn_heads == self.num_kv_heads + # DML doesn't support stacked Q/K/V for GQA yet + self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads # GQA + Rot.Emb. does not require `position ids` as input self.attention_attrs["use_rotemb_in_attn"] = True @@ -1751,7 +1752,7 @@ def get_args(): "-e", "--execution_provider", required=True, - choices=["cpu", "cuda"], + choices=["cpu", "cuda", "dml"], help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU)", ) From 63c3ddf17b0a1c3bb734c9cc7bad3ae0aef30f26 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 16 Apr 2024 16:17:58 -0700 Subject: [PATCH 2/3] elif -> elseif in cmake file --- src/python/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index ca86b77bd..bf203f50d 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -32,7 +32,7 @@ if(BUILD_WHEEL) message("Setting up wheel files in : ${WHEEL_FILES_DIR}") if(USE_CUDA) set(TARGET_NAME "onnxruntime-genai-cuda") - elif(USE_DML) + elseif(USE_DML) set(TARGET_NAME "onnxruntime-genai-dml") else() set(TARGET_NAME "onnxruntime-genai") From 02258986712ffe4f4499978a37c3172179a3f320 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 17 Apr 2024 17:32:40 -0700 Subject: [PATCH 3/3] Fix lint --- src/models/logits.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 763c55dfc..4c2d312c0 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -9,7 +9,6 @@ Logits::Logits(const Model& model, State& state) state_{state}, shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { - auto& allocator = model_.device_type_ == DeviceType::DML ? model_.allocator_cpu_ : *model_.allocator_device_; auto logits_tensor = OrtValue::CreateTensor(allocator, shape_, type_); if (type_ == Ort::TypeToTensorType::type)