From 73f842552a6d3f8d3c461b6066e4bb2ae52c7c11 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 29 Apr 2024 18:26:33 -0700 Subject: [PATCH] Use the most performant adapter for DML (#333) --- CMakeLists.txt | 4 ++-- src/dml/dml_helpers.cpp | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index acf8f22f6..064bef25c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,8 +150,8 @@ if(USE_DML) target_include_directories(onnxruntime-genai-static PUBLIC $) target_include_directories(onnxruntime-genai-static PUBLIC $/directx) target_include_directories(onnxruntime-genai-static PUBLIC $) - target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib) - target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib) + target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib dxcore.lib dxguid.lib) + target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib dxcore.lib dxguid.lib) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE) set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12) diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp index e7a0c2f08..2dcd0267f 100644 --- a/src/dml/dml_helpers.cpp +++ b/src/dml/dml_helpers.cpp @@ -9,6 +9,29 @@ namespace DmlHelpers { +static ComPtr CreatePerformantAdapter() { + ComPtr adapter_factory; + THROW_IF_FAILED(DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + ComPtr adapter_list; + THROW_IF_FAILED(adapter_factory->CreateAdapterList( + 1, + &DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, + adapter_list.GetAddressOf())); + + // We prefer the hightest performance adapter + std::array adapter_list_preferences = {DXCoreAdapterPreference::HighPerformance}; + + THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); + + ComPtr performant_adapter; + THROW_IF_FAILED(adapter_list->GetAdapter(0, performant_adapter.GetAddressOf())); + + return performant_adapter; +} + DmlObjects CreateDmlObjects() { D3D12_COMMAND_QUEUE_DESC command_queue_description = { D3D12_COMMAND_LIST_TYPE_COMPUTE, @@ -19,7 +42,8 @@ DmlObjects CreateDmlObjects() { DmlObjects dml_objects; - THROW_IF_FAILED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + auto adapter = CreatePerformantAdapter(); + THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue))); THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&dml_objects.command_allocator))); THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, dml_objects.command_allocator.Get(), nullptr, IID_PPV_ARGS(&dml_objects.command_list)));