From 6062096ab67682c3719c962f758ef2092a7ee6a4 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 12 Oct 2023 16:56:40 -0700 Subject: [PATCH] Cherry-pick b8f373b0aee086a36ea357bc2ffb2944246be15a (#17895) (#17926) Merge NPU Device Selection changes git cherry-pick b8f373b0aee086a36ea357bc2ffb2944246be15a Changes already present in DmlPrototype-2023_10_02. Co-authored-by: Sheil Kumar --- cmake/onnxruntime_providers.cmake | 4 +- .../core/providers/dml/dml_provider_factory.h | 32 ++ .../providers/dml/dml_provider_factory.cc | 407 +++++++++++++++++- .../dml/dml_provider_factory_creator.h | 19 + .../core/session/provider_registration.cc | 12 +- .../python/onnxruntime_pybind_state.cc | 16 +- onnxruntime/test/perftest/ort_test_session.cc | 7 +- 7 files changed, 457 insertions(+), 40 deletions(-) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 7fb03487a255e..047b7c1ca98d2 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1314,13 +1314,13 @@ if (onnxruntime_USE_DML) if (GDK_PLATFORM STREQUAL Scarlett) target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) else() - target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib) + target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib) endif() target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199") + set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 0782d2d9ed760..dd4ffb835d51c 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice; extern "C" { #endif +enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +}; + +enum OrtDmlDeviceFilter : uint32_t { + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +}; + +inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } +inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); } +inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); } +inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); } +inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); } +inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } +inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } + +struct OrtDmlDeviceOptions { + OrtDmlPerformancePreference Preference; + OrtDmlDeviceFilter Filter; +}; + /** * [[deprecated]] * This export is deprecated. @@ -99,6 +124,13 @@ struct OrtDmlApi { * This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP. */ ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource); + + /** + * SessionOptionsAppendExecutionProvider_DML2 + * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference + * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 6a2740e4369e9..44850712d3f3f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include #ifndef _GAMING_XBOX #include @@ -92,11 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } +static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { + bool is_hardware = false; + THROW_IF_FAILED(adapter->GetProperty( + DXCoreAdapterProperty::IsHardware, + &is_hardware)); + return is_hardware; +} + +static bool IsGPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); +} + +static bool IsNPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); +} + +enum class DeviceType { GPU, NPU, BadDevice }; + +static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) { + auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu; + if (IsGPU(adapter) && allow_gpus) { + return DeviceType::GPU; + } + + auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + if (IsNPU(adapter) && allow_npus) { + return DeviceType::NPU; + } + + return DeviceType::BadDevice; +} + +// Struct for holding each adapter +struct AdapterInfo { + ComPtr Adapter; + DeviceType Type; // GPU or NPU +}; + +static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) { + ComPtr adapter_list; + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + if (!use_dxcore_workload_enumeration) { + // Get a list of all the adapters that support compute + GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + ORT_THROW_IF_FAILED( + adapter_factory->CreateAdapterList(_countof(attributes), + attributes, + adapter_list.GetAddressOf())); + } + + return adapter_list; +} + +static void SortDXCoreAdaptersByPreference( + IDXCoreAdapterList* adapter_list, + OrtDmlPerformancePreference preference) { + if (adapter_list->GetAdapterCount() <= 1) { + return; + } + + // DML prefers the HighPerformance adapter by default + std::array adapter_list_preferences = { + DXCoreAdapterPreference::HighPerformance + }; + + // If callers specify minimum power change the DXCore sort policy + // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists + if (preference == OrtDmlPerformancePreference::MinimumPower) { + adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; + } + + ORT_THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); +} + +static std::vector FilterDXCoreAdapters( + IDXCoreAdapterList* adapter_list, + OrtDmlDeviceFilter filter) { + auto adapter_infos = std::vector(); + const uint32_t count = adapter_list->GetAdapterCount(); + for (uint32_t i = 0; i < count; ++i) { + ComPtr candidate_adapter; + ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); + + // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) + auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); + if (adapter_type != DeviceType::BadDevice) { + adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); + } + } + + return adapter_infos; +} + +static void SortHeterogenousDXCoreAdapterList( + std::vector& adapter_infos, + OrtDmlDeviceFilter filter, + OrtDmlPerformancePreference preference) { + if (adapter_infos.size() <= 1) { + return; + } + + // When considering both GPUs and NPUs sort them by performance preference + // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) + auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + auto only_npus = filter == OrtDmlDeviceFilter::Npu; + if (!keep_npus || only_npus) { + return; + } + + struct SortingPolicy { + // default is false because GPUs are considered higher priority in + // a mixed adapter environment + bool npus_first_ = false; + + SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { } + + bool operator()(const AdapterInfo& a, const AdapterInfo& b) { + return npus_first_ ? a.Type < b.Type : a.Type > b.Type; + } + }; + + auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); + auto policy = SortingPolicy(npus_first); + std::sort(adapter_infos.begin(), adapter_infos.end(), policy); +} + std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { return Create(device_id, /*skip_software_device_check*/ false); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { +std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( + OrtDmlDeviceOptions* device_options) { + auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; + if (device_options == nullptr) { + device_options = &default_device_options; + } + + OrtDmlPerformancePreference preference = device_options->Preference; + OrtDmlDeviceFilter filter = device_options->Filter; + + // Create DXCore Adapter Factory + ComPtr adapter_factory; + ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + // Get all DML compatible DXCore adapters + ComPtr adapter_list; + adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get()); + + if (adapter_list->GetAdapterCount() == 0) { + ORT_THROW("No GPUs or NPUs detected."); + } + + // Sort the adapter list to honor DXCore hardware ordering + SortDXCoreAdaptersByPreference(adapter_list.Get(), preference); + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + + std::vector adapter_infos; + if (!use_dxcore_workload_enumeration) { + // Filter all DXCore adapters to hardware type specified by the device filter + adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter); + if (adapter_infos.size() == 0) { + ORT_THROW("No devices detected that match the filter criteria."); + } + } + + // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them. + SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference); + + // Extract just the adapters + auto adapters = std::vector>(adapter_infos.size()); + std::transform( + adapter_infos.begin(), adapter_infos.end(), + adapters.begin(), + [](auto& a){ return a.Adapter; }); + + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); +} + +static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { + static const std::string PerformancePreference = "performance_preference"; + static const std::string Default = "default"; + static const std::string HighPerformance = "high_performance"; + static const std::string MinimumPower = "minimum_power"; + + auto preference_it = provider_options.find(PerformancePreference); + if (preference_it != provider_options.end()) { + if (preference_it->second == Default) { + return OrtDmlPerformancePreference::Default; + } + + if (preference_it->second == HighPerformance) { + return OrtDmlPerformancePreference::HighPerformance; + } + + if (preference_it->second == MinimumPower) { + return OrtDmlPerformancePreference::MinimumPower; + } + + ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseFilter(const ProviderOptions& provider_options) { + static const std::string Filter = "filter"; + static const std::string Any = "any"; + static const std::string Gpu = "gpu"; + static const std::string Npu = "npu"; + + auto preference_it = provider_options.find(Filter); + if (preference_it != provider_options.end()) { + if (preference_it->second == Any) { + return OrtDmlDeviceFilter::Any; + } + + if (preference_it->second == Gpu) { + return OrtDmlDeviceFilter::Gpu; + } + + if (preference_it->second == Npu) { + return OrtDmlDeviceFilter::Npu; + } + + ORT_THROW("Invalid Filter provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseDeviceId(const ProviderOptions& provider_options) { + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) + { + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + } + + auto preference = ParsePerformancePreference(provider_options); + auto filter = ParseFilter(provider_options); + + // If no preference/filters are specified then create with default preference/filters. + if (!preference.has_value() && !filter.has_value()) { + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + } + + if (!preference.has_value()) { + preference = OrtDmlPerformancePreference::Default; + } + + if (!filter.has_value()) { + filter = OrtDmlDeviceFilter::Gpu; + } + + OrtDmlDeviceOptions device_options; + device_options.Preference = preference.value(); + device_options.Filter = filter.value(); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); +} + +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( + int device_id, + bool skip_software_device_check) { #ifdef _GAMING_XBOX ComPtr d3d12_device; D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {}; @@ -112,30 +402,27 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int ComPtr adapter; ORT_THROW_IF_FAILED(dxgi_factory->EnumAdapters1(device_id, &adapter)); - // Disallow using DML with the software adapter (Microsoft Basic Display Adapter) because CPU evaluations are much - // faster. Some scenarios though call for EP initialization without this check (as execution will not actually occur - // anyway) such as operation kernel registry enumeration for documentation purposes. - if (!skip_software_device_check) - { - ORT_THROW_HR_IF(ERROR_GRAPHICS_INVALID_DISPLAY_ADAPTER, IsSoftwareAdapter(adapter.Get())); - } + // Disallow using DML with the software adapter (Microsoft Basic Display Adapter) because CPU evaluations are much + // faster. Some scenarios though call for EP initialization without this check (as execution will not actually occur + // anyway) such as operation kernel registry enumeration for documentation purposes. + if (!skip_software_device_check) + { + ORT_THROW_HR_IF(ERROR_GRAPHICS_INVALID_DISPLAY_ADAPTER, IsSoftwareAdapter(adapter.Get())); + } - ComPtr d3d12_device; - ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); #endif - D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; - - ComPtr cmd_queue; - ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + return d3d12_device; +} +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) { DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG && !_GAMING_XBOX - ComPtr debug_device; + Microsoft::WRL::ComPtr debug_device; (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); @@ -144,15 +431,77 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int } #endif - ComPtr dml_device; - ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(), - flags, - DML_FEATURE_LEVEL_5_0, - IID_PPV_ARGS(&dml_device))); + Microsoft::WRL::ComPtr dml_device; + ORT_THROW_IF_FAILED(DMLCreateDevice1( + d3d12_device, + flags, + DML_FEATURE_LEVEL_5_0, + IID_PPV_ARGS(&dml_device))); + + return dml_device; +} + +static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) { + D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {}; + + D3D_FEATURE_LEVEL feature_levels_list[] = { + D3D_FEATURE_LEVEL_1_0_CORE, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_12_0, + D3D_FEATURE_LEVEL_12_1 + }; + + feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list); + feature_levels.pFeatureLevelsRequested = feature_levels_list; + ORT_THROW_IF_FAILED(d3d12_device->CheckFeatureSupport( + D3D12_FEATURE_FEATURE_LEVELS, + &feature_levels, + sizeof(feature_levels) + )); + + auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE); + if (is_feature_level_1_0_core) { + return D3D12_COMMAND_LIST_TYPE_COMPUTE; + } + + return D3D12_COMMAND_LIST_TYPE_DIRECT; +} + + +std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = CalculateCommandListType(d3d12_device); + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + ComPtr cmd_queue; + ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } +std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( + std::vector>&& adapters) { + // Choose the first device from the list since it's the highest priority + auto adapter = adapters[0]; + + auto feature_level = D3D_FEATURE_LEVEL_11_0; + if (IsNPU(adapter.Get())) { + feature_level = D3D_FEATURE_LEVEL_1_0_CORE; + } + + // Create D3D12 Device from DXCore Adapter + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), feature_level, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + } // namespace onnxruntime // [[deprecated]] @@ -197,6 +546,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { +API_IMPL_BEGIN +#ifdef USE_DML + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + // return the create function for a dxcore device + options->provider_factories.push_back(factory); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { API_IMPL_BEGIN #ifdef USE_DML @@ -219,7 +579,8 @@ static constexpr OrtDmlApi ort_dml_api_10_to_x = { &OrtSessionOptionsAppendExecutionProviderEx_DML, &CreateGPUAllocationFromD3DResource, &FreeGPUAllocation, - &GetD3D12ResourceFromAllocation + &GetD3D12ResourceFromAllocation, + &OrtSessionOptionsAppendExecutionProvider_DML2, }; const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION { diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 3c3f04561d383..330055d64fd30 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -5,11 +5,30 @@ #include +#include +#include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include +#include + +interface IDMLDevice; +struct OrtDmlDeviceOptions; + namespace onnxruntime { struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromAdapterList( + std::vector>&& dxcore_devices); + + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); + static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8b32ec05719b6..2618a6a28127d 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -11,6 +11,10 @@ #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#if defined(USE_DML) +#include "core/providers/dml/dml_provider_factory_creator.h" +#endif + using namespace onnxruntime; namespace { @@ -66,7 +70,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; - if (strcmp(provider_name, "QNN") == 0) { + if (strcmp(provider_name, "DML") == 0) { +#if defined(USE_DML) + options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "QNN") == 0) { #if defined(USE_QNN) options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 12b020f32b22f..ae31e7da9c23f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -797,18 +797,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index d283d9df62d6a..e3031ccc02a50 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -677,9 +677,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); + std::unordered_map dml_options; + dml_options["performance_preference"] = "high_performance"; + dml_options["device_filter"] = "gpu"; + session_options.AppendExecutionProvider("DML", dml_options); #else - ORT_THROW("DirectML is not supported in this build\n"); + ORT_THROW("DML is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL