Skip to content

Commit

Permalink
Fix adapter enumeration when using RDP (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Sep 11, 2020
1 parent bb434f5 commit 6f013cc
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions tensorflow/core/common_runtime/dml/dml_adapter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,25 +168,33 @@ uint64_t DmlAdapterImpl::QueryAvailableDedicatedMemory() const {
}

std::vector<DmlAdapterImpl> EnumerateAdapterImpls() {
ComPtr<IDXGIFactory4> dxgi_factory;
ComPtr<IDXGIFactory6> dxgi_factory;
DML_CHECK_SUCCEEDED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));

const D3D_FEATURE_LEVEL min_feature_level = D3D_FEATURE_LEVEL_11_0;
std::vector<DmlAdapterImpl> adapter_infos;

uint32_t adapter_index = 0;
ComPtr<IDXGIAdapter1> adapter;
while (SUCCEEDED(dxgi_factory->EnumAdapters1(adapter_index, &adapter))) {
while (dxgi_factory->EnumAdapterByGpuPreference(
adapter_index, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE,
IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND) {
DXGI_ADAPTER_DESC1 desc = {};
DML_CHECK_SUCCEEDED(adapter->GetDesc1(&desc));

// Ignore software devices like WARP and msbda
if (!(desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE)) {
HRESULT hr = D3D12CreateDevice(adapter.Get(), min_feature_level,
__uuidof(ID3D12Device), nullptr);
if (SUCCEEDED(hr)) {
adapter_infos.emplace_back(adapter.Get());
}
// Since we enumerate by performance, we can ignore everything that comes
// after the first software adapter, which includes the IDD adapters. This
// is necessary for now because IDD adapters don't have the
// DXGI_ADAPTER_FLAG_SOFTWARE flag, even though they run on software.
// TFDML #21433167
if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) {
break;
}

HRESULT hr = D3D12CreateDevice(adapter.Get(), min_feature_level,
__uuidof(ID3D12Device), nullptr);
if (SUCCEEDED(hr)) {
adapter_infos.emplace_back(adapter.Get());
}

++adapter_index;
Expand Down Expand Up @@ -260,10 +268,10 @@ std::vector<DmlAdapterImpl> EnumerateAdapterImpls() {
DML_CHECK_SUCCEEDED(adapter_factory->CreateAdapterList(
1, &dxcore_adapter, IID_PPV_ARGS(&adapter_list)));

// Sort the adapters so that hardware is selected first (for when driver_type
// == D3D_DRIVER_TYPE_UNKNOWN)
// Sort the adapters so that performant adapters are selected first
DXCoreAdapterPreference sort_preferences[] = {
DXCoreAdapterPreference::Hardware};
DXCoreAdapterPreference::HighPerformance,
};

DML_CHECK_SUCCEEDED(adapter_list->Sort(
static_cast<uint32_t>(ABSL_ARRAYSIZE(sort_preferences)),
Expand All @@ -279,18 +287,25 @@ std::vector<DmlAdapterImpl> EnumerateAdapterImpls() {
DML_CHECK_SUCCEEDED(adapter->GetProperty(DXCoreAdapterProperty::IsHardware,
&is_hardware_adapter));

if (is_hardware_adapter) {
DmlAdapterImpl adapter_impl(adapter.Get());
// Since we enumerate by performance, we can ignore everything that comes
// after the first software adapter, which includes the IDD adapters. This
// is necessary for now because IDD adapters are considered hardware
// adapters, even though they run on software.
// TFDML #21433167
if (!is_hardware_adapter) {
break;
}

D3D_FEATURE_LEVEL feature_level = adapter_impl.IsComputeOnly()
? D3D_FEATURE_LEVEL_1_0_CORE
: D3D_FEATURE_LEVEL_11_0;
DmlAdapterImpl adapter_impl(adapter.Get());

HRESULT hr = D3D12CreateDevice(adapter.Get(), feature_level,
__uuidof(ID3D12Device), nullptr);
if (SUCCEEDED(hr)) {
adapter_infos.push_back(std::move(adapter_impl));
}
D3D_FEATURE_LEVEL feature_level = adapter_impl.IsComputeOnly()
? D3D_FEATURE_LEVEL_1_0_CORE
: D3D_FEATURE_LEVEL_11_0;

HRESULT hr = D3D12CreateDevice(adapter.Get(), feature_level,
__uuidof(ID3D12Device), nullptr);
if (SUCCEEDED(hr)) {
adapter_infos.push_back(std::move(adapter_impl));
}
}

Expand Down

0 comments on commit 6f013cc

Please sign in to comment.