Skip to content

Commit

Permalink
Cherry-pick b8f373b (#17895) (#17926)
Browse files Browse the repository at this point in the history
Merge NPU Device Selection changes
git cherry-pick b8f373b

Changes already present in DmlPrototype-2023_10_02.

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
smk2007 and Sheil Kumar authored Oct 12, 2023
1 parent cfab945 commit 6062096
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 40 deletions.
4 changes: 2 additions & 2 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1314,13 +1314,13 @@ if (onnxruntime_USE_DML)
if (GDK_PLATFORM STREQUAL Scarlett)
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
else()
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib)
endif()

target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)

if (NOT GDK_PLATFORM)
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199")
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199")
endif()

target_compile_definitions(onnxruntime_providers_dml
Expand Down
32 changes: 32 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice;
extern "C" {
#endif

enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
};

enum OrtDmlDeviceFilter : uint32_t {
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
};

inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); }
inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); }
inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); }
inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); }
inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference Preference;
OrtDmlDeviceFilter Filter;
};

/**
* [[deprecated]]
* This export is deprecated.
Expand Down Expand Up @@ -99,6 +124,13 @@ struct OrtDmlApi {
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);

/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};

#ifdef __cplusplus
Expand Down
Loading

0 comments on commit 6062096

Please sign in to comment.