Skip to content

Commit

Permalink
Ignore software adapters from DML device enumeration and use DXGI enu…
Browse files Browse the repository at this point in the history
…meration instead of dxcore (#377)
  • Loading branch information
PatriceVignola authored May 1, 2024
1 parent a652001 commit d7d82a3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 21 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ if(USE_DML)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${WIL_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>/directx)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib dxcore.lib dxguid.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib dxcore.lib dxguid.lib)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib dxcore.lib dxguid.lib dxgi.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib dxcore.lib dxguid.lib dxgi.lib)

get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE)
set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12)
Expand Down
2 changes: 2 additions & 0 deletions src/dml/dml_adapter_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum class VendorID {
Undefined = 0,
Intel = 0x8086,
Microsoft = 0x1414,
};

// Retrieves information from a DXCore or DXGI adapter.
Expand All @@ -27,4 +28,5 @@ class AdapterInfo {
void Initialize(IDXCoreAdapter* adapter);

::VendorID vendor_id_;
uint32_t device_id_;
};
86 changes: 67 additions & 19 deletions src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,80 @@
#include <stdexcept>
#include <dxcore.h>
#include <dxcore_interface.h>
#include <dxgi1_6.h>
#include "dml_helpers.h"
#include "dml_adapter_info.h"

namespace DmlHelpers {

static ComPtr<IDXCoreAdapter> CreatePerformantAdapter() {
ComPtr<IDXCoreAdapterFactory> adapter_factory;
THROW_IF_FAILED(DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf()));

ComPtr<IDXCoreAdapterList> adapter_list;
THROW_IF_FAILED(adapter_factory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
adapter_list.GetAddressOf()));

// We prefer the hightest performance adapter
std::array<DXCoreAdapterPreference, 1> adapter_list_preferences = {DXCoreAdapterPreference::HighPerformance};

THROW_IF_FAILED(adapter_list->Sort(
static_cast<uint32_t>(adapter_list_preferences.size()),
adapter_list_preferences.data()));
static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
DXGI_ADAPTER_DESC1 desc = {};
THROW_IF_FAILED(adapter->GetDesc1(&desc));

// See here for documentation on filtering WARP adapter:
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
const bool is_basic_render_driver_vendor_id = desc.VendorId == static_cast<UINT>(VendorID::Microsoft);
const bool is_basic_render_driver_device_id = desc.DeviceId == 0x8c;
return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id);
};

static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
ComPtr<IDXGIFactory4> dxgi_factory;
THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));

std::vector<ComPtr<IDXGIAdapter1>> adapter_infos;

ComPtr<IDXGIFactory6> dxgi_factory6;
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) {
// Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0;
dxgi_factory6->EnumAdapterByGpuPreference(
adapter_index,
DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE,
IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND;
adapter_index++) {
// Since we enumerate by performance, we can ignore everything that comes after the first software adapter, which includes the IDD
// adapters. This is necessary for now because IDD (e.g. remote desktop) adapters don't have the DXGI_ADAPTER_FLAG_SOFTWARE flag,
// even though they run on software.
if (IsSoftwareAdapter(adapter.Get())) {
break;
}

// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
} else {
// Enumerate adapters without ordering.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) {
// We can't assume the ordering of hardware and software adapters, so keep looping. This path should only execute on Windows 10
// version 1709 or earlier; IDD (e.g. remote desktop) adapters do not exist when taking this code path.
if (IsSoftwareAdapter(adapter.Get())) {
continue;
}

// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
}

ComPtr<IDXCoreAdapter> performant_adapter;
THROW_IF_FAILED(adapter_list->GetAdapter(0, performant_adapter.GetAddressOf()));
return adapter_infos;
}

return performant_adapter;
static ComPtr<IDXGIAdapter1> CreatePerformantAdapter() {
auto filtered_adapters = EnumerateAdapters();
return filtered_adapters.front();
}

DmlObjects CreateDmlObjects() {
Expand Down

0 comments on commit d7d82a3

Please sign in to comment.