Skip to content

Commit

Permalink
readd npu enumeration (#18437)
Browse files Browse the repository at this point in the history
Re-add changes which were merged out...

---------

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
2 people authored and raoanag committed Nov 20, 2023
1 parent 30f55fa commit a722d76
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,39 @@ Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID
return dml_device;
}

static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) {
D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {};

D3D_FEATURE_LEVEL feature_levels_list[] = {
D3D_FEATURE_LEVEL_1_0_CORE,
D3D_FEATURE_LEVEL_11_0,
D3D_FEATURE_LEVEL_11_1,
D3D_FEATURE_LEVEL_12_0,
D3D_FEATURE_LEVEL_12_1
};

feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list);
feature_levels.pFeatureLevelsRequested = feature_levels_list;
ORT_THROW_IF_FAILED(d3d12_device->CheckFeatureSupport(
D3D12_FEATURE_FEATURE_LEVELS,
&feature_levels,
sizeof(feature_levels)
));

auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE);
if (is_feature_level_1_0_core) {
return D3D12_COMMAND_LIST_TYPE_COMPUTE;
}

return D3D12_COMMAND_LIST_TYPE_DIRECT;
}

std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(
ID3D12Device* d3d12_device,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
ID3D12Device* d3d12_device,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
cmd_queue_desc.Type = CalculateCommandListType(d3d12_device);
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;

ComPtr<ID3D12CommandQueue> cmd_queue;
Expand All @@ -491,16 +518,20 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(
}

std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromAdapterList(
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices,
std::vector<ComPtr<IDXCoreAdapter>>&& adapters,
bool disable_metacommands,
bool enable_dynamic_graph_fusion) {
// Choose the first device from the list since it's the highest priority
auto dxcore_device = dxcore_devices[0];
auto adapter = adapters[0];

auto feature_level = D3D_FEATURE_LEVEL_11_0;
if (IsNPU(adapter.Get())) {
feature_level = D3D_FEATURE_LEVEL_1_0_CORE;
}

// Create D3D12 Device from DXCore Adapter
ComPtr<ID3D12Device> d3d12_device;
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));

ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), feature_level, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion);
}

Expand Down

0 comments on commit a722d76

Please sign in to comment.