diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp index 2862ac8d9..140f2a8cd 100644 --- a/src/models/captured_graph_pool.cpp +++ b/src/models/captured_graph_pool.cpp @@ -54,7 +54,9 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, new_captured_graph->sb_input_ids_ = std::make_unique(allocator_device_, max_beam_batch_size); #if USE_DML - new_captured_graph->sb_input_ids_int32_ = std::make_unique(allocator_device_, max_beam_batch_size); + if (model.device_type_ == DeviceType::DML) { + new_captured_graph->sb_input_ids_int32_ = std::make_unique(allocator_device_, max_beam_batch_size); + } #endif // Create the static buffers for the cache @@ -76,7 +78,9 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, #if USE_DML // DML currently needs an additional static buffer for the mask - new_captured_graph->sb_attention_mask_next_ = std::make_unique(allocator_device_, max_beam_batch_size); + if (model.device_type_ == DeviceType::DML) { + new_captured_graph->sb_attention_mask_next_ = std::make_unique(allocator_device_, max_beam_batch_size); + } #endif } diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 3cf44326c..6d281d247 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -32,7 +32,9 @@ InputIDs::InputIDs(const Model& model, State& state) sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); #if USE_DML - sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get(); + if (model_.device_type_ == DeviceType::DML) { + sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get(); + } #endif } } @@ -52,13 +54,17 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); #if USE_DML - value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + if (model_.device_type_ == DeviceType::DML) { + value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } #endif } else { value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); #if USE_DML - value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType::type); + if (model_.device_type_ == DeviceType::DML) { + value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType::type); + } #endif } diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 4c34f4b76..b85dfcff0 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -44,17 +44,18 @@ RoamingArray Logits::Get() { // Convert from float16 to float32 if necessary if (type_ == Ort::TypeToTensorType::type) { #if USE_DML - DmlHelpers::DmlCastInputToOutput( - model_.GetDmlExecutionContext(), - *model_.allocator_device_, - *value16_, - value32_, - model_.GetDmlDevice(), - model_.GetOrtDmlApi(), - logits_cast_command_list_state_); -#else - ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_); + if (model_.device_type_ == DeviceType::DML) { + DmlHelpers::DmlCastInputToOutput( + model_.GetDmlExecutionContext(), + *model_.allocator_device_, + *value16_, + value32_, + model_.GetDmlDevice(), + model_.GetOrtDmlApi(), + logits_cast_command_list_state_); + } else #endif + ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_); } // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor @@ -73,7 +74,9 @@ RoamingArray Logits::Get() { #if USE_DML // DML doesn't support on-device scoring yet, so we need to download some data to the CPU - value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); + if (model_.device_type_ == DeviceType::DML) { + value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); + } #endif size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process diff --git a/src/models/model.cpp b/src/models/model.cpp index 6558be2cf..9226e5c05 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -8,7 +8,7 @@ #include "decoder_only.h" #include "whisper.h" #include "kernels.h" -#ifdef USE_DML +#if USE_DML #include #include "dml_provider_factory.h" #include "../dml/dml_smart_container.h" @@ -347,7 +347,7 @@ void Model::CreateSessionOptions() { Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size())); ort_options.AppendExecutionProvider_ROCM(ort_provider_options); -#ifdef USE_DML +#if USE_DML } else if (provider_options.name == "dml") { dml_objects_ = DmlHelpers::CreateDmlObjects(); @@ -442,7 +442,7 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptrsb_attention_mask_.get(); #if USE_DML - sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get(); + if (model_.device_type_ == DeviceType::DML) { + sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get(); + } #endif } } @@ -301,7 +303,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { throw std::runtime_error("PositionIDs::Update - Unsupported device type"); } -#ifndef USE_DML +#if !USE_DML attention_mask_ = std::move(attention_mask_next_); #endif diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 4c714d78d..259f5c0c2 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -57,7 +57,7 @@ struct PositionInputs { bool is_first_posid_update_{true}; bool is_first_mask_update_{true}; -#ifdef USE_DML +#if USE_DML std::optional dml_update_mask_kernel_; StaticBuffer* sb_attention_mask_next_{}; std::optional dml_update_position_ids_kernel_; diff --git a/src/python/python.cpp b/src/python/python.cpp index 8ea4321b6..69d5caef5 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -247,18 +247,18 @@ PYBIND11_MODULE(onnxruntime_genai, m) { m.def("set_log_options", &SetLogOptions); m.def("is_cuda_available", []() { -#ifdef USE_CUDA +#if USE_CUDA return true; #else - return false; + return false; #endif }); m.def("is_dml_available", []() { -#ifdef USE_DML +#if USE_DML return true; #else - return false; + return false; #endif }); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 3984c148b..1e17fcc67 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -90,7 +90,7 @@ TEST(CAPITests, EndToEndPhiBatch) { } // DML doesn't support GPT attention -#ifndef USE_DML +#if !USE_DML TEST(CAPITests, GreedySearchGptFp32CAPI) { std::vector input_ids_shape{2, 4}; std::vector input_ids{0, 0, 0, 52, 0, 0, 195, 731}; diff --git a/test/model_tests.cpp b/test/model_tests.cpp index be7c2d7fd..edeeb4ea4 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -21,7 +21,7 @@ static const std::pair c_tiny_gpt2_model_paths[] = { }; // DML doesn't support GPT attention -#ifndef USE_DML +#if !USE_DML TEST(ModelTests, GreedySearchGptFp32) { std::vector input_ids_shape{2, 4}; std::vector input_ids{0, 0, 0, 52, 0, 0, 195, 731};