From 7177e32edb20639d28e73d9e1f619cb82ec78da4 Mon Sep 17 00:00:00 2001 From: Numfor Tiapo Date: Wed, 11 Oct 2023 14:53:00 -0700 Subject: [PATCH] Add API for NPU Device Selection in the DML EP (#17612) Co-authored-by: Sheil Kumar --- cmake/onnxruntime_providers.cmake | 2 +- cmake/onnxruntime_providers_dml.cmake | 6 +- .../core/providers/dml/dml_provider_factory.h | 32 ++ .../providers/dml/dml_provider_factory.cc | 330 +++++++++++++++++- .../dml/dml_provider_factory_creator.h | 12 + .../core/session/provider_registration.cc | 12 +- .../python/onnxruntime_pybind_state.cc | 16 +- onnxruntime/test/perftest/ort_test_session.cc | 7 +- 8 files changed, 390 insertions(+), 27 deletions(-) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 980a1dfbb404a..8d3ea403fb74b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -202,4 +202,4 @@ endif() if (onnxruntime_USE_AZURE) include(onnxruntime_providers_azure.cmake) -endif() \ No newline at end of file +endif() diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake index 4339c42cd4254..01b0bda9fea6b 100644 --- a/cmake/onnxruntime_providers_dml.cmake +++ b/cmake/onnxruntime_providers_dml.cmake @@ -56,13 +56,13 @@ if (GDK_PLATFORM STREQUAL Scarlett) target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) else() - target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib) + target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib) endif() target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199") + set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml @@ -88,4 +88,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 0782d2d9ed760..dd4ffb835d51c 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice; extern "C" { #endif +enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +}; + +enum OrtDmlDeviceFilter : uint32_t { + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +}; + +inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } +inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); } +inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); } +inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); } +inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); } +inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } +inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } + +struct OrtDmlDeviceOptions { + OrtDmlPerformancePreference Preference; + OrtDmlDeviceFilter Filter; +}; + /** * [[deprecated]] * This export is deprecated. @@ -99,6 +124,13 @@ struct OrtDmlApi { * This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP. */ ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource); + + /** + * SessionOptionsAppendExecutionProvider_DML2 + * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference + * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index fde61e73c2124..df94326e88045 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include #ifndef _GAMING_XBOX #include @@ -92,12 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } +static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { + bool is_hardware = false; + THROW_IF_FAILED(adapter->GetProperty( + DXCoreAdapterProperty::IsHardware, + &is_hardware)); + return is_hardware; +} + +static bool IsGPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); +} + +static bool IsNPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); +} + +enum class DeviceType { GPU, NPU, BadDevice }; + +static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) { + auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu; + if (IsGPU(adapter) && allow_gpus) { + return DeviceType::GPU; + } + + auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + if (IsNPU(adapter) && allow_npus) { + return DeviceType::NPU; + } + + return DeviceType::BadDevice; +} + +// Struct for holding each adapter +struct AdapterInfo { + ComPtr Adapter; + DeviceType Type; // GPU or NPU +}; + +static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) { + ComPtr adapter_list; + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + if (!use_dxcore_workload_enumeration) { + // Get a list of all the adapters that support compute + GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + ORT_THROW_IF_FAILED( + adapter_factory->CreateAdapterList(_countof(attributes), + attributes, + adapter_list.GetAddressOf())); + } + + return adapter_list; +} + +static void SortDXCoreAdaptersByPreference( + IDXCoreAdapterList* adapter_list, + OrtDmlPerformancePreference preference) { + if (adapter_list->GetAdapterCount() <= 1) { + return; + } + + // DML prefers the HighPerformance adapter by default + std::array adapter_list_preferences = { + DXCoreAdapterPreference::HighPerformance + }; + + // If callers specify minimum power change the DXCore sort policy + // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists + if (preference == OrtDmlPerformancePreference::MinimumPower) { + adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; + } + + ORT_THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); +} + +static std::vector FilterDXCoreAdapters( + IDXCoreAdapterList* adapter_list, + OrtDmlDeviceFilter filter) { + auto adapter_infos = std::vector(); + const uint32_t count = adapter_list->GetAdapterCount(); + for (uint32_t i = 0; i < count; ++i) { + ComPtr candidate_adapter; + ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); + + // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) + auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); + if (adapter_type != DeviceType::BadDevice) { + adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); + } + } + + return adapter_infos; +} + +static void SortHeterogenousDXCoreAdapterList( + std::vector& adapter_infos, + OrtDmlDeviceFilter filter, + OrtDmlPerformancePreference preference) { + if (adapter_infos.size() <= 1) { + return; + } + + // When considering both GPUs and NPUs sort them by performance preference + // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) + auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + auto only_npus = filter == OrtDmlDeviceFilter::Npu; + if (!keep_npus || only_npus) { + return; + } + + struct SortingPolicy { + // default is false because GPUs are considered higher priority in + // a mixed adapter environment + bool npus_first_ = false; + + SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { } + + bool operator()(const AdapterInfo& a, const AdapterInfo& b) { + return npus_first_ ? a.Type < b.Type : a.Type > b.Type; + } + }; + + auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); + auto policy = SortingPolicy(npus_first); + std::sort(adapter_infos.begin(), adapter_infos.end(), policy); +} + std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { return Create(device_id, /*skip_software_device_check*/ false); } -Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check) -{ +std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( + OrtDmlDeviceOptions* device_options) { + auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; + if (device_options == nullptr) { + device_options = &default_device_options; + } + + OrtDmlPerformancePreference preference = device_options->Preference; + OrtDmlDeviceFilter filter = device_options->Filter; + + // Create DXCore Adapter Factory + ComPtr adapter_factory; + ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + // Get all DML compatible DXCore adapters + ComPtr adapter_list; + adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get()); + + if (adapter_list->GetAdapterCount() == 0) { + ORT_THROW("No GPUs or NPUs detected."); + } + + // Sort the adapter list to honor DXCore hardware ordering + SortDXCoreAdaptersByPreference(adapter_list.Get(), preference); + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + + std::vector adapter_infos; + if (!use_dxcore_workload_enumeration) { + // Filter all DXCore adapters to hardware type specified by the device filter + adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter); + if (adapter_infos.size() == 0) { + ORT_THROW("No devices detected that match the filter criteria."); + } + } + + // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them. + SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference); + + // Extract just the adapters + auto adapters = std::vector>(adapter_infos.size()); + std::transform( + adapter_infos.begin(), adapter_infos.end(), + adapters.begin(), + [](auto& a){ return a.Adapter; }); + + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); +} + +static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { + static const std::string PerformancePreference = "performance_preference"; + static const std::string Default = "default"; + static const std::string HighPerformance = "high_performance"; + static const std::string MinimumPower = "minimum_power"; + + auto preference_it = provider_options.find(PerformancePreference); + if (preference_it != provider_options.end()) { + if (preference_it->second == Default) { + return OrtDmlPerformancePreference::Default; + } + + if (preference_it->second == HighPerformance) { + return OrtDmlPerformancePreference::HighPerformance; + } + + if (preference_it->second == MinimumPower) { + return OrtDmlPerformancePreference::MinimumPower; + } + + ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseFilter(const ProviderOptions& provider_options) { + static const std::string Filter = "filter"; + static const std::string Any = "any"; + static const std::string Gpu = "gpu"; + static const std::string Npu = "npu"; + + auto preference_it = provider_options.find(Filter); + if (preference_it != provider_options.end()) { + if (preference_it->second == Any) { + return OrtDmlDeviceFilter::Any; + } + + if (preference_it->second == Gpu) { + return OrtDmlDeviceFilter::Gpu; + } + + if (preference_it->second == Npu) { + return OrtDmlDeviceFilter::Npu; + } + + ORT_THROW("Invalid Filter provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseDeviceId(const ProviderOptions& provider_options) { + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) + { + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + } + + auto preference = ParsePerformancePreference(provider_options); + auto filter = ParseFilter(provider_options); + + // If no preference/filters are specified then create with default preference/filters. + if (!preference.has_value() && !filter.has_value()) { + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + } + + if (!preference.has_value()) { + preference = OrtDmlPerformancePreference::Default; + } + + if (!filter.has_value()) { + filter = OrtDmlDeviceFilter::Gpu; + } + + OrtDmlDeviceOptions device_options; + device_options.Preference = preference.value(); + device_options.Filter = filter.value(); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); +} + +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( + int device_id, + bool skip_software_device_check) { #ifdef _GAMING_XBOX ComPtr d3d12_device; D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {}; @@ -128,8 +417,7 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Devic return d3d12_device; } -Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) -{ +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) { DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled @@ -153,9 +441,7 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { - ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - +std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; @@ -163,10 +449,27 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int ComPtr cmd_queue; ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - auto dml_device = CreateDMLDevice(d3d12_device.Get()); + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } +std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( + std::vector>&& dxcore_devices) { + // Choose the first device from the list since it's the highest priority + auto dxcore_device = dxcore_devices[0]; + + // Create D3D12 Device from DXCore Adapter + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + } // namespace onnxruntime // [[deprecated]] @@ -211,6 +514,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { +API_IMPL_BEGIN +#ifdef USE_DML + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + // return the create function for a dxcore device + options->provider_factories.push_back(factory); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { API_IMPL_BEGIN #ifdef USE_DML diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 574f4410fe3e3..4e13330a4cd71 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -7,14 +7,26 @@ #include #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" #include "core/providers/dml/dml_provider_factory.h" +#include +#include + namespace onnxruntime { struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromAdapterList( + std::vector>&& dxcore_devices); + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 50c9f6681a0c8..4649ac35c3647 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -12,6 +12,10 @@ #include "core/session/ort_apis.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#if defined(USE_DML) +#include "core/providers/dml/dml_provider_factory_creator.h" +#endif + using namespace onnxruntime; namespace { @@ -67,7 +71,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; - if (strcmp(provider_name, "QNN") == 0) { + if (strcmp(provider_name, "DML") == 0) { +#if defined(USE_DML) + options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "QNN") == 0) { #if defined(USE_QNN) options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 95a8f59186ff5..35e03bf9eacd5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -879,18 +879,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 57b2403e23a37..1111a92a385fd 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -662,9 +662,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); + std::unordered_map dml_options; + dml_options["performance_preference"] = "high_performance"; + dml_options["device_filter"] = "gpu"; + session_options.AppendExecutionProvider("DML", dml_options); #else - ORT_THROW("DirectML is not supported in this build\n"); + ORT_THROW("DML is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL