Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for NPU Device Selection in the DML EP #17612

Merged
merged 36 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4679d85
Add API
nums11 Sep 11, 2023
552a97b
Updates
nums11 Sep 15, 2023
3625fe3
Updates
nums11 Sep 19, 2023
4d1b6ae
Remove comments
nums11 Sep 19, 2023
d3cd274
Remove new lines
nums11 Sep 19, 2023
6595e84
Use vectors
nums11 Sep 19, 2023
d8d3d8c
Change device option names
nums11 Sep 19, 2023
a4498f6
Explicitly assign flag values
nums11 Sep 19, 2023
9e1fff5
Remove extra code
nums11 Sep 19, 2023
f103283
Fix function casing
nums11 Sep 19, 2023
1ecd83c
Add IsGPU and IsNPU
nums11 Sep 19, 2023
86c9ec3
Move hardware adapter check
nums11 Sep 19, 2023
3f67ae4
Use snake case
nums11 Sep 19, 2023
bd8200c
Pass by pointer
nums11 Sep 19, 2023
9a32189
Change implementation
nums11 Sep 20, 2023
ff87aa2
Change API Name
nums11 Sep 21, 2023
1acc7c8
Use Enum instead of Enum class
nums11 Sep 21, 2023
566f309
Use statics
nums11 Sep 21, 2023
7160188
Update sorting policy
nums11 Sep 21, 2023
de6aa17
Updates
nums11 Sep 21, 2023
ebdb6b0
Link DXCore
nums11 Sep 21, 2023
4f2349f
Combine Create functions
nums11 Sep 22, 2023
142150f
Merge branch 'main' into npu_with_dml_ep
Oct 6, 2023
ffa7d7e
refactor
Oct 9, 2023
1dc12ef
extra param unused
Oct 10, 2023
1b6653a
snake case
Oct 10, 2023
41a7e4f
nullcheck device options
Oct 10, 2023
c7cdb56
&
Oct 10, 2023
02808d6
add todo
Oct 10, 2023
16f857a
Merge branch 'main' into npu_with_dml_ep
Oct 10, 2023
df4f554
non-const
Oct 10, 2023
463f3db
lint
Oct 10, 2023
76b6cf1
ifdef dml and readd missing dxcore link
Oct 10, 2023
3d4267e
DirectML->DML
Oct 10, 2023
e5bf35f
lintrunner
Oct 11, 2023
b535f4b
lint
Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@
extern "C" {
#endif

enum class OrtDmlPerformancePreference {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
Default,
nums11 marked this conversation as resolved.
Show resolved Hide resolved
HighPerformance, // Default
LowPower
};

enum class OrtDmlDeviceFilter {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved
None,
Gpu, // Default
Npu
};

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference p;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
OrtDmlDeviceFilter f;
};

/**
* [[deprecated]]
* This export is deprecated.
Expand Down Expand Up @@ -82,6 +99,7 @@
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);


/**
nums11 marked this conversation as resolved.
Show resolved Hide resolved
* CreateGPUAllocationFromD3DResource
* This API creates a DML EP resource based on a user-specified D3D12 resource.
Expand All @@ -94,11 +112,37 @@
*/
ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource);


/**
nums11 marked this conversation as resolved.
Show resolved Hide resolved
* GetD3D12ResourceFromAllocation
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);


//ORT_API2_STATUS(FreeGPUAllocation_2, _In_ void* dml_resource);

Check warning on line 123 in include/onnxruntime/core/providers/dml/dml_provider_factory.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/dml/dml_provider_factory.h#L123

Should have a space between // and comment [whitespace/comments] [4]
Raw output
include/onnxruntime/core/providers/dml/dml_provider_factory.h:123:  Should have a space between // and comment  [whitespace/comments] [4]

//ORT_API2_STATUS(GetNPUDevice, _Out_ int* device_id);

//ORT_API2_STATUS(GetMLDevice, _In_ bool npu_first, _Out_ int* device_ids);

// enum class Ordering {
// SAME, // DXGI/DXCore as is today...
// NONE, // Take the first one
// GPUS_FIRST,
// NPUS_FIRST
// };

// Ordering foo[3][3] = {
// // Default, HighPerformance, LowPower
// /*None*/ { Ordering::GPUS_FIRST, Ordering::GPUS_FIRST, Ordering::NPUS_FIRST },
// /*Gpu*/ { Ordering::SAME, Ordering::SAME , Ordering::SAME },
// /*Npu*/ { Ordering::NONE, Ordering::NONE , Ordering::NONE },

// }
nums11 marked this conversation as resolved.
Show resolved Hide resolved

// null means default
ORT_API2_STATUS(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
nums11 marked this conversation as resolved.
Show resolved Hide resolved
};

#ifdef __cplusplus
Expand Down
174 changes: 173 additions & 1 deletion onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <dxcore.h>

#include <DirectML.h>
#ifndef _GAMING_XBOX
#include <dxgi1_4.h>
Expand Down Expand Up @@ -172,6 +174,41 @@
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
}

std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateDXCore(ComPtr<IDXCoreAdapter> dxcore_device) {
// Create D3D12 Device from DXCore Adapter
ComPtr<ID3D12Device> d3d12_device;
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));

// Create DML Device from D3D12 device
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;

ComPtr<ID3D12CommandQueue> cmd_queue;
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));

DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;

// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
#if _DEBUG && !_GAMING_XBOX
ComPtr<ID3D12DebugDevice> debug_device;
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);

if (is_d3d12_debug_layer_enabled) {
flags |= DML_CREATE_DEVICE_FLAG_DEBUG;
}
#endif

ComPtr<IDMLDevice> dml_device;
ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(),
flags,
DML_FEATURE_LEVEL_5_0,
IID_PPV_ARGS(&dml_device)));

return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
nums11 marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace onnxruntime

// [[deprecated]]
Expand All @@ -196,6 +233,18 @@
return nullptr;
}

bool isHardwareAdapter(ComPtr<IDXCoreAdapter> adapter) {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
bool isHardware{ false };
nums11 marked this conversation as resolved.
Show resolved Hide resolved
THROW_IF_FAILED(adapter->GetProperty(
DXCoreAdapterProperty::IsHardware,
&isHardware));
return isHardware;
}

bool supportsGraphics(ComPtr<IDXCoreAdapter> adapter) {
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved
return adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS);
}

ORT_API_STATUS_IMPL(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
Expand All @@ -216,6 +265,128 @@
API_IMPL_END
}

ORT_API_STATUS_IMPL(FreeGPUAllocation_2, _In_ void* ptr) {
API_IMPL_BEGIN
#ifdef USE_DML
Dml::FreeGPUAllocation(ptr);
#endif // USE_DML
return nullptr;
API_IMPL_END

Check warning on line 274 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L274

Should have a space between // and comment [whitespace/comments] [4]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:274:  Should have a space between // and comment  [whitespace/comments] [4]
}

Check warning on line 275 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L275

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:275:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
nums11 marked this conversation as resolved.
Show resolved Hide resolved

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts) {
API_IMPL_BEGIN
//options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id));

OrtDmlPerformancePreference p = device_opts->p;
OrtDmlDeviceFilter f = device_opts->f;

// Create DXCore Adapter Factory
ComPtr<IDXCoreAdapterFactory> adapterFactory;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapterFactory.GetAddressOf()));

// Get a list of all the adapters that support compute
ComPtr<IDXCoreAdapterList> d3D12CoreComputeAdapters;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
ORT_THROW_IF_FAILED(
adapterFactory->CreateAdapterList(_countof(attributes),
attributes,
d3D12CoreComputeAdapters.GetAddressOf()));

const uint32_t count{ d3D12CoreComputeAdapters->GetAdapterCount() };
int compute_only_device_id = -1;

// will hold all the adapters in the order specified by the device options
ComPtr<IDXCoreAdapter>* ordered_adapters =
(ComPtr<IDXCoreAdapter>*) malloc(count * sizeof(ComPtr<IDXCoreAdapter>));
nums11 marked this conversation as resolved.
Show resolved Hide resolved
int ordered_adapter_count = 0;

// Used in the case that both GPU and NPU adapters are considered
ComPtr<IDXCoreAdapter>* gpu_adapters =
(ComPtr<IDXCoreAdapter>*) malloc(count * sizeof(ComPtr<IDXCoreAdapter>));
int gpu_adapter_count = 0;
ComPtr<IDXCoreAdapter>* npu_adapters =
(ComPtr<IDXCoreAdapter>*) malloc(count * sizeof(ComPtr<IDXCoreAdapter>));
nums11 marked this conversation as resolved.
Show resolved Hide resolved
int npu_adapter_count = 0;

// Iterate through all compute capable adapters
for (uint32_t i = 0; i < count; ++i)
{
ComPtr<IDXCoreAdapter> candidateAdapter;
nums11 marked this conversation as resolved.
Show resolved Hide resolved
ORT_THROW_IF_FAILED(
d3D12CoreComputeAdapters->GetAdapter(i, candidateAdapter.GetAddressOf()));

Check warning on line 318 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L318

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:318:  At least two spaces is best between code and comments  [whitespace/comments] [2]
// Only considering hardware adapters
if (!isHardwareAdapter(candidateAdapter))
continue;

nums11 marked this conversation as resolved.
Show resolved Hide resolved
if (f == OrtDmlDeviceFilter::Gpu) // consider GPUs only
{
if (supportsGraphics(candidateAdapter)) {

Check warning on line 325 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L325

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:325:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ordered_adapters[ordered_adapter_count] = candidateAdapter;
ordered_adapter_count++;
}
}
else if(f == OrtDmlDeviceFilter::Npu) // consider NPUs only
{
if (!supportsGraphics(candidateAdapter)) {

Check warning on line 332 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L332

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:332:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ordered_adapters[ordered_adapter_count] = candidateAdapter;
ordered_adapter_count++;
}
}
else // consider both GPUs and NPUs
{
if (supportsGraphics(candidateAdapter)) {
gpu_adapters[gpu_adapter_count] = candidateAdapter;
gpu_adapter_count++;
} else {
npu_adapters[npu_adapter_count] = candidateAdapter;
npu_adapter_count++;

Check warning on line 344 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L344

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:344:  At least two spaces is best between code and comments  [whitespace/comments] [2]
}
}
nums11 marked this conversation as resolved.
Show resolved Hide resolved
}

if (f == OrtDmlDeviceFilter::None) // considering both NPUs and GPUs
{
// order the adapters based on the performance preference

Check warning on line 351 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L351

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:351:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ComPtr<IDXCoreAdapter>* high_pri_adapters;
ComPtr<IDXCoreAdapter>* low_pri_adapters;
int num_high_pri_adapters;
int num_low_pri_adapters;
if (p == OrtDmlPerformancePreference::LowPower) // NPUs first
{
high_pri_adapters = npu_adapters;

Check warning on line 358 in onnxruntime/core/providers/dml/dml_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory.cc#L358

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory.cc:358:  At least two spaces is best between code and comments  [whitespace/comments] [2]
num_high_pri_adapters = npu_adapter_count;
low_pri_adapters = gpu_adapters;
num_low_pri_adapters = gpu_adapter_count;
}
else // GPUs first
{
high_pri_adapters = gpu_adapters;
num_high_pri_adapters = gpu_adapter_count;
low_pri_adapters = npu_adapters;
num_low_pri_adapters = npu_adapter_count;
}

for (int i = 0; i < num_high_pri_adapters; i++) {
ordered_adapters[ordered_adapter_count] = high_pri_adapters[i];
ordered_adapter_count++;
}
for (int i = 0; i < num_low_pri_adapters; i++) {
ordered_adapters[ordered_adapter_count] = low_pri_adapters[i];
ordered_adapter_count++;
}
}
nums11 marked this conversation as resolved.
Show resolved Hide resolved
nums11 marked this conversation as resolved.
Show resolved Hide resolved

// check if ordered_adapter count is 0 (no adapters?)
ComPtr<IDXCoreAdapter> highest_pri_adapter = ordered_adapters[0];
// return the create function for a dxcore device
options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::CreateDXCore(
nums11 marked this conversation as resolved.
Show resolved Hide resolved
highest_pri_adapter));
nums11 marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
Expand All @@ -238,7 +409,8 @@
&OrtSessionOptionsAppendExecutionProviderEx_DML,
&CreateGPUAllocationFromD3DResource,
&FreeGPUAllocation,
&GetD3D12ResourceFromAllocation
&GetD3D12ResourceFromAllocation,
&OrtSessionOptionsAppendExecutionProvider_DML2
nums11 marked this conversation as resolved.
Show resolved Hide resolved
};

const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/dml/dml_provider_factory_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#include "core/providers/providers.h"
#include "core/providers/dml/dml_provider_factory.h"

#include <dxcore.h>

Check warning on line 13 in onnxruntime/core/providers/dml/dml_provider_factory_creator.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/dml_provider_factory_creator.h#L13

Found C system header after other header. Should be: dml_provider_factory_creator.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/core/providers/dml/dml_provider_factory_creator.h:13:  Found C system header after other header. Should be: dml_provider_factory_creator.h, c system, c++ system, other.  [build/include_order] [4]

namespace onnxruntime {

struct DMLProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
static std::shared_ptr<IExecutionProviderFactory> CreateDXCore(Microsoft::WRL::ComPtr<IDXCoreAdapter> dxcore_device);
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
};
} // namespace onnxruntime
Loading