Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for NPU Device Selection in the DML EP #17612

Merged
merged 36 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4679d85
Add API
nums11 Sep 11, 2023
552a97b
Updates
nums11 Sep 15, 2023
3625fe3
Updates
nums11 Sep 19, 2023
4d1b6ae
Remove comments
nums11 Sep 19, 2023
d3cd274
Remove new lines
nums11 Sep 19, 2023
6595e84
Use vectors
nums11 Sep 19, 2023
d8d3d8c
Change device option names
nums11 Sep 19, 2023
a4498f6
Explicitly assign flag values
nums11 Sep 19, 2023
9e1fff5
Remove extra code
nums11 Sep 19, 2023
f103283
Fix function casing
nums11 Sep 19, 2023
1ecd83c
Add IsGPU and IsNPU
nums11 Sep 19, 2023
86c9ec3
Move hardware adapter check
nums11 Sep 19, 2023
3f67ae4
Use snake case
nums11 Sep 19, 2023
bd8200c
Pass by pointer
nums11 Sep 19, 2023
9a32189
Change implementation
nums11 Sep 20, 2023
ff87aa2
Change API Name
nums11 Sep 21, 2023
1acc7c8
Use Enum instead of Enum class
nums11 Sep 21, 2023
566f309
Use statics
nums11 Sep 21, 2023
7160188
Update sorting policy
nums11 Sep 21, 2023
de6aa17
Updates
nums11 Sep 21, 2023
ebdb6b0
Link DXCore
nums11 Sep 21, 2023
4f2349f
Combine Create functions
nums11 Sep 22, 2023
142150f
Merge branch 'main' into npu_with_dml_ep
Oct 6, 2023
ffa7d7e
refactor
Oct 9, 2023
1dc12ef
extra param unused
Oct 10, 2023
1b6653a
snake case
Oct 10, 2023
41a7e4f
nullcheck device options
Oct 10, 2023
c7cdb56
&
Oct 10, 2023
02808d6
add todo
Oct 10, 2023
16f857a
Merge branch 'main' into npu_with_dml_ep
Oct 10, 2023
df4f554
non-const
Oct 10, 2023
463f3db
lint
Oct 10, 2023
76b6cf1
ifdef dml and readd missing dxcore link
Oct 10, 2023
3d4267e
DirectML->DML
Oct 10, 2023
e5bf35f
lintrunner
Oct 11, 2023
b535f4b
lint
Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@
extern "C" {
#endif

enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
LowPower = 2
};

enum OrtDmlDeviceFilter {
Gpu = 1,
Npu = 2,
Both = 3
nums11 marked this conversation as resolved.
Show resolved Hide resolved
};

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference perf_pref;
OrtDmlDeviceFilter dev_filter;
};

/**
* [[deprecated]]
* This export is deprecated.
Expand Down Expand Up @@ -99,6 +116,13 @@
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);

/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);

Check warning on line 125 in include/onnxruntime/core/providers/dml/dml_provider_factory.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/dml/dml_provider_factory.h#L125

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/dml/dml_provider_factory.h:125:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};

#ifdef __cplusplus
Expand Down
139 changes: 139 additions & 0 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <dxcore.h>
#include <vector>

#include <DirectML.h>
#ifndef _GAMING_XBOX
#include <dxgi1_4.h>
Expand Down Expand Up @@ -172,6 +175,41 @@
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
}

std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateDXCore(ComPtr<IDXCoreAdapter> dxcore_device) {
// Create D3D12 Device from DXCore Adapter
ComPtr<ID3D12Device> d3d12_device;
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));

// Create DML Device from D3D12 device
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;

ComPtr<ID3D12CommandQueue> cmd_queue;
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));

DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;

// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
#if _DEBUG && !_GAMING_XBOX
ComPtr<ID3D12DebugDevice> debug_device;
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);

if (is_d3d12_debug_layer_enabled) {
flags |= DML_CREATE_DEVICE_FLAG_DEBUG;
}
#endif

ComPtr<IDMLDevice> dml_device;
ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(),
flags,
DML_FEATURE_LEVEL_5_0,
IID_PPV_ARGS(&dml_device)));

return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
nums11 marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace onnxruntime

// [[deprecated]]
Expand Down Expand Up @@ -216,6 +254,107 @@
API_IMPL_END
}

bool IsHardwareAdapter(IDXCoreAdapter* adapter) {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
bool is_hardware{ false };
nums11 marked this conversation as resolved.
Show resolved Hide resolved
THROW_IF_FAILED(adapter->GetProperty(
DXCoreAdapterProperty::IsHardware,
&is_hardware));
return is_hardware;
}

bool IsGPU(IDXCoreAdapter* compute_adapter) {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
// Only considering hardware adapters
if (!IsHardwareAdapter(compute_adapter))
return false;
return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS);
}

bool IsNPU(IDXCoreAdapter* compute_adapter) {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
// Only considering hardware adapters
if (!IsHardwareAdapter(compute_adapter))
return false;
return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS));
}
nums11 marked this conversation as resolved.
Show resolved Hide resolved

ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts) {
API_IMPL_BEGIN
OrtDmlPerformancePreference perf_pref = device_opts->perf_pref;
OrtDmlDeviceFilter dev_filter = device_opts->dev_filter;

// Create DXCore Adapter Factory
ComPtr<IDXCoreAdapterFactory> adapterFactory;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapterFactory.GetAddressOf()));

// Get a list of all the adapters that support compute
ComPtr<IDXCoreAdapterList> d3D12CoreComputeAdapters;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
ORT_THROW_IF_FAILED(
adapterFactory->CreateAdapterList(_countof(attributes),
attributes,
d3D12CoreComputeAdapters.GetAddressOf()));

const uint32_t count{ d3D12CoreComputeAdapters->GetAdapterCount() };
int compute_only_device_id = -1;

Check warning on line 298 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L298

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:298:  At least two spaces is best between code and comments  [whitespace/comments] [2]
// Struct for holding each adapter
enum class DeviceType { GPU, NPU, BadDevice};
nums11 marked this conversation as resolved.
Show resolved Hide resolved
struct AdapterInfo {
ComPtr<IDXCoreAdapter> Adapter;
DeviceType Type; // GPU or NPU

Check warning on line 303 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L303

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:303:  At least two spaces is best between code and comments  [whitespace/comments] [2]
};

// Returns an adapter type based on the user supplied device filter
std::function<DeviceType (IDXCoreAdapter*)> filterAdapterTypeQuery[4] = {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nullptr, // enum cannot take 0 value
// Only GPUs are considered so all other devices aren't needed
[](IDXCoreAdapter* adapter){ return IsGPU(adapter) ? DeviceType::GPU : DeviceType::BadDevice; },
// Only NPUs are considered so all other devices aren't needed
[](IDXCoreAdapter* adapter){ return IsNPU(adapter) ? DeviceType::NPU : DeviceType::BadDevice; },
// Both GPUs and NPUs are considered
[](IDXCoreAdapter* adapter){ return IsGPU(adapter) ? DeviceType::GPU :
IsNPU(adapter) ? DeviceType::NPU : DeviceType::BadDevice; }
};

auto selected_adapters = std::vector<AdapterInfo>();
// Iterate through all compute capable adapters
for (uint32_t i = 0; i < count; ++i)
{
ComPtr<IDXCoreAdapter> candidateAdapter;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
ORT_THROW_IF_FAILED(
d3D12CoreComputeAdapters->GetAdapter(i, candidateAdapter.GetAddressOf()));

// Add the adapters that are valid based on the device filter (GPU, NPU, or Both)
auto adapterType = filterAdapterTypeQuery[static_cast<int>(dev_filter)](candidateAdapter.Get());
smk2007 marked this conversation as resolved.
Show resolved Hide resolved
if (adapterType != DeviceType::BadDevice) {
selected_adapters.push_back(AdapterInfo{candidateAdapter, adapterType});
}

Check warning on line 330 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L330

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:330:  At least two spaces is best between code and comments  [whitespace/comments] [2]
nums11 marked this conversation as resolved.
Show resolved Hide resolved
}

// When considering both GPUs and NPUs sort them by performance preference
// of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first)
if (dev_filter == OrtDmlDeviceFilter::Both) // considering both NPUs and GPUs

Check warning on line 335 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L335

Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:335:  Constructors callable with one argument should be marked explicit.  [runtime/explicit] [5]
nums11 marked this conversation as resolved.
Show resolved Hide resolved
{
struct SortingPolicy
{
bool gpusFirst_ = true;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved
SortingPolicy(bool gpusFirst = true) : gpusFirst_(gpusFirst){}
bool operator()(const AdapterInfo& a, const AdapterInfo& b) {
return gpusFirst_ ? a.Type > b.Type : a.Type < b.Type;
}
};
auto sortingPolicy = SortingPolicy(perf_pref != OrtDmlPerformancePreference::LowPower);
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved
std::sort(selected_adapters.begin(), selected_adapters.end(), sortingPolicy);
}
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved

// check if num adapters is 0 (no adapters?)
ComPtr<IDXCoreAdapter> highest_pri_adapter = selected_adapters[0].Adapter;
// return the create function for a dxcore device
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::CreateDXCore(
nums11 marked this conversation as resolved.
Show resolved Hide resolved
highest_pri_adapter));
nums11 marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/dml/dml_provider_factory_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#include "core/providers/providers.h"
#include "core/providers/dml/dml_provider_factory.h"

#include <dxcore.h>

Check warning on line 13 in onnxruntime/core/providers/dml/dml_provider_factory_creator.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory_creator.h#L13

Found C system header after other header. Should be: dml_provider_factory_creator.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory_creator.h:13:  Found C system header after other header. Should be: dml_provider_factory_creator.h, c system, c++ system, other.  [build/include_order] [4]

namespace onnxruntime {

struct DMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
static std::shared_ptr<IExecutionProviderFactory> CreateDXCore(Microsoft::WRL::ComPtr<IDXCoreAdapter> dxcore_device);
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
};
} // namespace onnxruntime
Loading