From 1ad199ca9606320a55322f178c52d0418ba80be8 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:12:47 -0800 Subject: [PATCH] readd npu enumeration (#18437) (#18518) [Cherry pick Reviewed] Re-add changes which were merged out... --------- ### Description ### Motivation and Context Co-authored-by: Sheil Kumar Co-authored-by: Sheil Kumar --- .../providers/dml/dml_provider_factory.cc | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index d587424fe01f8..1696763c74791 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -466,12 +466,39 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } +static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) { + D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {}; + + D3D_FEATURE_LEVEL feature_levels_list[] = { + D3D_FEATURE_LEVEL_1_0_CORE, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_12_0, + D3D_FEATURE_LEVEL_12_1 + }; + + feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list); + feature_levels.pFeatureLevelsRequested = feature_levels_list; + ORT_THROW_IF_FAILED(d3d12_device->CheckFeatureSupport( + D3D12_FEATURE_FEATURE_LEVELS, + &feature_levels, + sizeof(feature_levels) + )); + + auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE); + if (is_feature_level_1_0_core) { + return D3D12_COMMAND_LIST_TYPE_COMPUTE; + } + + return D3D12_COMMAND_LIST_TYPE_DIRECT; +} + std::shared_ptr CreateDMLDeviceAndProviderFactory( - ID3D12Device* d3d12_device, - bool disable_metacommands, - bool enable_dynamic_graph_fusion) { + ID3D12Device* d3d12_device, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Type = CalculateCommandListType(d3d12_device); cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; ComPtr cmd_queue; @@ -491,16 +518,20 @@ std::shared_ptr DMLProviderFactoryCreator::Create( } std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( - std::vector>&& dxcore_devices, + std::vector>&& adapters, bool disable_metacommands, bool enable_dynamic_graph_fusion) { // Choose the first device from the list since it's the highest priority - auto dxcore_device = dxcore_devices[0]; + auto adapter = adapters[0]; + + auto feature_level = D3D_FEATURE_LEVEL_11_0; + if (IsNPU(adapter.Get())) { + feature_level = D3D_FEATURE_LEVEL_1_0_CORE; + } // Create D3D12 Device from DXCore Adapter ComPtr d3d12_device; - ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); - + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), feature_level, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); }