Skip to content

Commit

Permalink
Use the most performant adapter for DML (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Apr 30, 2024
1 parent 2f619f9 commit 639d176
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ if(USE_DML)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${WIL_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>/directx)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib dxcore.lib dxguid.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib dxcore.lib dxguid.lib)

get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE)
set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12)
Expand Down
26 changes: 25 additions & 1 deletion src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@

namespace DmlHelpers {

static ComPtr<IDXCoreAdapter> CreatePerformantAdapter() {
ComPtr<IDXCoreAdapterFactory> adapter_factory;
THROW_IF_FAILED(DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf()));

ComPtr<IDXCoreAdapterList> adapter_list;
THROW_IF_FAILED(adapter_factory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
adapter_list.GetAddressOf()));

// We prefer the hightest performance adapter
std::array<DXCoreAdapterPreference, 1> adapter_list_preferences = {DXCoreAdapterPreference::HighPerformance};

THROW_IF_FAILED(adapter_list->Sort(
static_cast<uint32_t>(adapter_list_preferences.size()),
adapter_list_preferences.data()));

ComPtr<IDXCoreAdapter> performant_adapter;
THROW_IF_FAILED(adapter_list->GetAdapter(0, performant_adapter.GetAddressOf()));

return performant_adapter;
}

DmlObjects CreateDmlObjects() {
D3D12_COMMAND_QUEUE_DESC command_queue_description = {
D3D12_COMMAND_LIST_TYPE_COMPUTE,
Expand All @@ -19,7 +42,8 @@ DmlObjects CreateDmlObjects() {

DmlObjects dml_objects;

THROW_IF_FAILED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
auto adapter = CreatePerformantAdapter();
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&dml_objects.command_allocator)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, dml_objects.command_allocator.Get(), nullptr, IID_PPV_ARGS(&dml_objects.command_list)));
Expand Down

0 comments on commit 639d176

Please sign in to comment.