From 6f013cc78da6de2bad4e787ef633348203f6b1a0 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 11 Sep 2020 14:25:02 -0700 Subject: [PATCH] Fix adapter enumeration when using RDP (#23) --- .../common_runtime/dml/dml_adapter_impl.cc | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/common_runtime/dml/dml_adapter_impl.cc b/tensorflow/core/common_runtime/dml/dml_adapter_impl.cc index e05350c469d..9d5387ae6ac 100644 --- a/tensorflow/core/common_runtime/dml/dml_adapter_impl.cc +++ b/tensorflow/core/common_runtime/dml/dml_adapter_impl.cc @@ -168,7 +168,7 @@ uint64_t DmlAdapterImpl::QueryAvailableDedicatedMemory() const { } std::vector EnumerateAdapterImpls() { - ComPtr dxgi_factory; + ComPtr dxgi_factory; DML_CHECK_SUCCEEDED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory))); const D3D_FEATURE_LEVEL min_feature_level = D3D_FEATURE_LEVEL_11_0; @@ -176,17 +176,25 @@ std::vector EnumerateAdapterImpls() { uint32_t adapter_index = 0; ComPtr adapter; - while (SUCCEEDED(dxgi_factory->EnumAdapters1(adapter_index, &adapter))) { + while (dxgi_factory->EnumAdapterByGpuPreference( + adapter_index, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, + IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND) { DXGI_ADAPTER_DESC1 desc = {}; DML_CHECK_SUCCEEDED(adapter->GetDesc1(&desc)); - // Ignore software devices like WARP and msbda - if (!(desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE)) { - HRESULT hr = D3D12CreateDevice(adapter.Get(), min_feature_level, - __uuidof(ID3D12Device), nullptr); - if (SUCCEEDED(hr)) { - adapter_infos.emplace_back(adapter.Get()); - } + // Since we enumerate by performance, we can ignore everything that comes + // after the first software adapter, which includes the IDD adapters. This + // is necessary for now because IDD adapters don't have the + // DXGI_ADAPTER_FLAG_SOFTWARE flag, even though they run on software. + // TFDML #21433167 + if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) { + break; + } + + HRESULT hr = D3D12CreateDevice(adapter.Get(), min_feature_level, + __uuidof(ID3D12Device), nullptr); + if (SUCCEEDED(hr)) { + adapter_infos.emplace_back(adapter.Get()); } ++adapter_index; @@ -260,10 +268,10 @@ std::vector EnumerateAdapterImpls() { DML_CHECK_SUCCEEDED(adapter_factory->CreateAdapterList( 1, &dxcore_adapter, IID_PPV_ARGS(&adapter_list))); - // Sort the adapters so that hardware is selected first (for when driver_type - // == D3D_DRIVER_TYPE_UNKNOWN) + // Sort the adapters so that performant adapters are selected first DXCoreAdapterPreference sort_preferences[] = { - DXCoreAdapterPreference::Hardware}; + DXCoreAdapterPreference::HighPerformance, + }; DML_CHECK_SUCCEEDED(adapter_list->Sort( static_cast(ABSL_ARRAYSIZE(sort_preferences)), @@ -279,18 +287,25 @@ std::vector EnumerateAdapterImpls() { DML_CHECK_SUCCEEDED(adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &is_hardware_adapter)); - if (is_hardware_adapter) { - DmlAdapterImpl adapter_impl(adapter.Get()); + // Since we enumerate by performance, we can ignore everything that comes + // after the first software adapter, which includes the IDD adapters. This + // is necessary for now because IDD adapters are considered hardware + // adapters, even though they run on software. + // TFDML #21433167 + if (!is_hardware_adapter) { + break; + } - D3D_FEATURE_LEVEL feature_level = adapter_impl.IsComputeOnly() - ? D3D_FEATURE_LEVEL_1_0_CORE - : D3D_FEATURE_LEVEL_11_0; + DmlAdapterImpl adapter_impl(adapter.Get()); - HRESULT hr = D3D12CreateDevice(adapter.Get(), feature_level, - __uuidof(ID3D12Device), nullptr); - if (SUCCEEDED(hr)) { - adapter_infos.push_back(std::move(adapter_impl)); - } + D3D_FEATURE_LEVEL feature_level = adapter_impl.IsComputeOnly() + ? D3D_FEATURE_LEVEL_1_0_CORE + : D3D_FEATURE_LEVEL_11_0; + + HRESULT hr = D3D12CreateDevice(adapter.Get(), feature_level, + __uuidof(ID3D12Device), nullptr); + if (SUCCEEDED(hr)) { + adapter_infos.push_back(std::move(adapter_impl)); } }