diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index dc7d20913b73d..d86d5bbefe603 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -95,10 +95,286 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } +static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { + bool is_hardware = false; + THROW_IF_FAILED(adapter->GetProperty( + DXCoreAdapterProperty::IsHardware, + &is_hardware)); + return is_hardware; +} + +static bool IsGPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); +} + +static bool IsNPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); +} + +enum class DeviceType { GPU, NPU, BadDevice }; + +static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) { + auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu; + if (IsGPU(adapter) && allow_gpus) { + return DeviceType::GPU; + } + + auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + if (IsNPU(adapter) && allow_npus) { + return DeviceType::NPU; + } + + return DeviceType::BadDevice; +} + +// Struct for holding each adapter +struct AdapterInfo { + ComPtr Adapter; + DeviceType Type; // GPU or NPU +}; + +static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) { + ComPtr adapter_list; + + // use_dxcore_workload_enumeration should be determined by QI + constexpr bool use_dxcore_workload_enumeration = false; + if (!use_dxcore_workload_enumeration) { + // Get a list of all the adapters that support compute + GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + ORT_THROW_IF_FAILED( + adapter_factory->CreateAdapterList(_countof(attributes), + attributes, + adapter_list.GetAddressOf())); + } + + return adapter_list; +} + +static void SortDXCoreAdaptersByPreference( + IDXCoreAdapterList* adapter_list, + OrtDmlPerformancePreference preference) { + if (adapter_list->GetAdapterCount() <= 1) { + return; + } + + // Move this to another static helper... add coments about the use case... + // DML prefers the HighPerformance adapter by default + std::array adapter_list_preferences = { + DXCoreAdapterPreference::HighPerformance + }; + + // If callers specify minimum power change the DXCore sort policy + // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists + if (preference == OrtDmlPerformancePreference::MinimumPower) { + adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; + } + + ORT_THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); +} + +static std::vector FilterDXCoreAdapters( + IDXCoreAdapterList* adapter_list, + OrtDmlDeviceFilter filter) { + auto adapter_infos = std::vector(); + const uint32_t count = adapter_list->GetAdapterCount(); + for (uint32_t i = 0; i < count; ++i) { + ComPtr candidate_adapter; + ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); + + // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) + auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); + if (adapter_type != DeviceType::BadDevice) { + adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); + } + } + + return adapter_infos; +} + +static void SortHeterogenousDXCoreAdapterList( + std::vector& adapter_infos, + OrtDmlDeviceFilter filter, + OrtDmlPerformancePreference preference) { + if (adapter_infos.size() <= 1) { + return; + } + + // When considering both GPUs and NPUs sort them by performance preference + // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) + auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + auto only_npus = filter == OrtDmlDeviceFilter::Npu; + if (!keep_npus || only_npus) { + return; + } + + struct SortingPolicy { + // default is false because GPUs are considered higher priority in + // a mixed adapter environment + bool npus_first_ = false; + + SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { } + + bool operator()(const AdapterInfo& a, const AdapterInfo& b) { + return npus_first_ ? a.Type < b.Type : a.Type > b.Type; + } + }; + + auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); + auto policy = SortingPolicy(npus_first); + std::sort(adapter_infos.begin(), adapter_infos.end(), policy); +} + std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { return Create(device_id, /*skip_software_device_check*/ false); } +std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( + OrtDmlDeviceOptions* device_options) { + OrtDmlPerformancePreference preference = device_options->Preference; + OrtDmlDeviceFilter filter = device_options->Filter; + + // Create DXCore Adapter Factory + ComPtr adapter_factory; + ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + // Get all DML compatible DXCore adapters + ComPtr adapter_list; + adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get()); + + if (adapter_list->GetAdapterCount() == 0) { + ORT_THROW("No GPUs or NPUs detected."); + } + + // Sort the adapter list to honor DXCore hardware ordering + SortDXCoreAdaptersByPreference(adapter_list.Get(), preference); + + std::vector adapter_infos; + constexpr bool use_dxcore_workload_enumeration = false; // should be determined by QI + if (!use_dxcore_workload_enumeration) { + // Filter all DXCore adapters to hardware type specified by the device filter + adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter); + if (adapter_infos.size() == 0) { + ORT_THROW("No devices detected that match the filter criteria."); + } + } + + // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them. + SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference); + + // Extract just the adapters + auto adapters = std::vector>(adapter_infos.size()); + std::transform( + adapter_infos.begin(), adapter_infos.end(), + adapters.begin(), + [](auto& a){ return a.Adapter; }); + + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); +} + +static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) +{ + static const std::string PerformancePreference = "PerformancePreference"; + static const std::string Default = "Default"; + static const std::string HighPerformance = "HighPerformance"; + static const std::string LowPower = "LowPower"; + + auto preference_it = provider_options.find(PerformancePreference); + if (preference_it != provider_options.end()) { + if (preference_it->second == "Default") { + return OrtDmlPerformancePreference::Default; + } + + if (preference_it->second == "HighPerformance") { + return OrtDmlPerformancePreference::HighPerformance; + } + + if (preference_it->second == "MinimumPower") { + return OrtDmlPerformancePreference::MinimumPower; + } + + ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseFilter(const ProviderOptions& provider_options) +{ + static const std::string Filter = "Filter"; + static const std::string Any = "Any"; + static const std::string Gpu = "Gpu"; + static const std::string Npu = "Npu"; + + auto preference_it = provider_options.find(Filter); + if (preference_it != provider_options.end()) { + if (preference_it->second == "Any") { + return OrtDmlDeviceFilter::Any; + } + + if (preference_it->second == "Gpu") { + return OrtDmlDeviceFilter::Gpu; + } + + if (preference_it->second == "Npu") { + return OrtDmlDeviceFilter::Npu; + } + + ORT_THROW("Invalid Filter provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseDeviceId(const ProviderOptions& provider_options) +{ + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) + { + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + } + + auto preference = ParsePerformancePreference(provider_options); + if (!preference.has_value()) + { + preference = OrtDmlPerformancePreference::Default; + } + + auto filter = ParseFilter(provider_options); + if (!filter.has_value()) + { + filter = OrtDmlDeviceFilter::Gpu; + } + + OrtDmlDeviceOptions device_options; + device_options.Preference = preference.value(); + device_options.Filter = filter.value(); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); +} + Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check) { #ifdef _GAMING_XBOX @@ -173,7 +449,7 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); } -std::shared_ptr DMLProviderFactoryCreator::CreateDXCore( +std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( std::vector>&& dxcore_devices) { // Choose the first device from the list since it's the highest priority auto dxcore_device = dxcore_devices[0]; @@ -229,136 +505,15 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } -static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { - bool is_hardware = false; - THROW_IF_FAILED(adapter->GetProperty( - DXCoreAdapterProperty::IsHardware, - &is_hardware)); - return is_hardware; -} - -static bool IsGPU(IDXCoreAdapter* compute_adapter) { - // Only considering hardware adapters - if (!IsHardwareAdapter(compute_adapter)) - return false; - return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); -} - -static bool IsNPU(IDXCoreAdapter* compute_adapter) { - // Only considering hardware adapters - if (!IsHardwareAdapter(compute_adapter)) - return false; - return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); -} - -enum class DeviceType { GPU, NPU, BadDevice }; - -static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) -{ - if ((filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu) { - return DeviceType::GPU; - } - if ((filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu) { - return DeviceType::NPU; - } - return DeviceType::BadDevice; -} - ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { API_IMPL_BEGIN - OrtDmlPerformancePreference preference = device_options->Preference; - OrtDmlDeviceFilter filter = device_options->Filter; - - // Create DXCore Adapter Factory - ComPtr adapter_factory; - ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); - - // Get a list of all the adapters that support compute - ComPtr adapter_list; - GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; - ORT_THROW_IF_FAILED( - adapter_factory->CreateAdapterList(_countof(attributes), - attributes, - adapter_list.GetAddressOf())); - - // Move this to another static helper... add coments about the use case... - // DML prefers the HighPerformance adapter by default - std::array adapter_list_preferences = { - DXCoreAdapterPreference::HighPerformance - }; - - // If callers specify minimum power change the DXCore sort policy - // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists - if (preference == OrtDmlPerformancePreference::MinimumPower) - { - adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; - } - - ORT_THROW_IF_FAILED(adapter_list->Sort( - static_cast(adapter_list_preferences.size()), - adapter_list_preferences.data())); - - // Struct for holding each adapter - struct AdapterInfo { - ComPtr Adapter; - DeviceType Type; // GPU or NPU - }; - - auto adapter_infos = std::vector(); - const uint32_t count = adapter_list->GetAdapterCount(); - for (uint32_t i = 0; i < count; ++i) - { - ComPtr candidate_adapter; - ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); - - // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) - auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); - if (adapter_type != DeviceType::BadDevice) - { - adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); - } - } - - // When considering both GPUs and NPUs sort them by performance preference - // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) - auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; - auto only_npus = filter == OrtDmlDeviceFilter::Npu; - if (keep_npus && !only_npus) - { - struct SortingPolicy - { - // default is false because GPUs are considered higher priority in - // a mixed adapter environment - bool npus_first_ = false; - - SortingPolicy(bool npus_first = false) : npus_first_(npus_first) - { - } - - bool operator()(const AdapterInfo& a, const AdapterInfo& b) - { - return npus_first_ ? a.Type < b.Type : a.Type > b.Type; - } - }; - - auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); - auto policy = SortingPolicy(npus_first); - std::sort(adapter_infos.begin(), adapter_infos.end(), policy); - } - - // Extract just the adapters - std::vector> adapters(adapter_infos.size()); - std::transform( - adapter_infos.begin(), adapter_infos.end(), - adapters.begin(), - [](auto& a){ return a.Adapter; }); - - // check if num adapters is 0 (no adapters?) - // return the create function for a dxcore device - options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::CreateDXCore( - std::move(adapters))); +#ifdef USE_DML + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + // return the create function for a dxcore device + options->provider_factories.push_back(factory); +#endif // USE_DML return nullptr; -API_IMPL_END + API_IMPL_END } ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 1fee916c0303c..4e13330a4cd71 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -7,6 +7,7 @@ #include #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" #include "core/providers/dml/dml_provider_factory.h" @@ -18,8 +19,14 @@ namespace onnxruntime { struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); - static std::shared_ptr CreateDXCore( + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromAdapterList( std::vector>&& dxcore_devices); + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 50c9f6681a0c8..09a619204e100 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -11,6 +11,7 @@ #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#include "core/providers/dml/dml_provider_factory_creator.h" using namespace onnxruntime; @@ -67,7 +68,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; - if (strcmp(provider_name, "QNN") == 0) { + if (strcmp(provider_name, "DirectML") == 0) { +#if defined(USE_DML) + options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "QNN") == 0) { #if defined(USE_QNN) options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 907ea0ec41e23..4fc164668a4b9 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -887,18 +887,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 57b2403e23a37..75e250771bf44 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -662,7 +662,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); + std::unordered_map dml_options; + dml_options["PerformancePreference"] = "HighPerformance"; + dml_options["DeviceFilter"] = "Gpu"; + session_options.AppendExecutionProvider("DirectML", dml_options); #else ORT_THROW("DirectML is not supported in this build\n"); #endif