diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config
index 8eef0b5bac62b..04c49f40ecbcf 100644
--- a/.pipelines/nuget_config/x64/packages.config
+++ b/.pipelines/nuget_config/x64/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config
index 81f97948f1817..ad0ea2196a4d6 100644
--- a/.pipelines/nuget_config/x86/packages.config
+++ b/.pipelines/nuget_config/x86/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index f446c2be9847c..7cab60514302b 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
- set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.0)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.Preview.1.13.0-deveb7a0e89e82dcf90ae58761b35fb3aebc2275ef5)
# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 7fb03487a255e..047b7c1ca98d2 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -1314,13 +1314,13 @@ if (onnxruntime_USE_DML)
if (GDK_PLATFORM STREQUAL Scarlett)
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
else()
- target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
+ target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib)
endif()
target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)
if (NOT GDK_PLATFORM)
- set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199")
+ set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199")
endif()
target_compile_definitions(onnxruntime_providers_dml
diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h
index 0782d2d9ed760..dd4ffb835d51c 100644
--- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h
+++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h
@@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice;
extern "C" {
#endif
+enum OrtDmlPerformancePreference {
+ Default = 0,
+ HighPerformance = 1,
+ MinimumPower = 2
+};
+
+enum OrtDmlDeviceFilter : uint32_t {
+ Any = 0xffffffff,
+ Gpu = 1 << 0,
+ Npu = 1 << 1,
+};
+
+inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
+inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); }
+inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); }
+inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); }
+inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); }
+inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
+inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }
+
+struct OrtDmlDeviceOptions {
+ OrtDmlPerformancePreference Preference;
+ OrtDmlDeviceFilter Filter;
+};
+
/**
* [[deprecated]]
* This export is deprecated.
@@ -99,6 +124,13 @@ struct OrtDmlApi {
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);
+
+ /**
+ * SessionOptionsAppendExecutionProvider_DML2
+ * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
+ * (high power, low power, or defult) and a device filter (None, GPU, or NPU).
+ */
+ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};
#ifdef __cplusplus
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
index 232a022d869f4..034b05e36aaaa 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
@@ -82,7 +82,10 @@ namespace Windows::AI::MachineLearning::Adapter
{
uint32_t nodeCount;
std::vector> nodesAsOperatorDesc;
+
+ // TODO: Remove this
std::vector> nodesAsIDMLOperator;
+
std::vector inputEdges;
std::vector outputEdges;
std::vector intermediateEdges;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index 570a0f82b62ff..8558e33aaa8e5 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -3,38 +3,6 @@
#pragma once
-struct DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC
-{
- const DML_TENSOR_DESC* ATensor;
- const DML_TENSOR_DESC* AScaleTensor;
- _Maybenull_ const DML_TENSOR_DESC* AZeroPointTensor;
- const DML_TENSOR_DESC* BTensor;
- const DML_TENSOR_DESC* BScaleTensor;
- _Maybenull_ const DML_TENSOR_DESC* BZeroPointTensor;
- _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
- const DML_TENSOR_DESC* OutputTensor;
-};
-const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011;
-
-struct DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC
-{
- const DML_TENSOR_DESC* InputTensor;
- const DML_TENSOR_DESC* InputScaleTensor;
- _Maybenull_ const DML_TENSOR_DESC* InputZeroPointTensor;
- const DML_TENSOR_DESC* OutputScaleTensor;
- _Maybenull_ const DML_TENSOR_DESC* OutputZeroPointTensor;
- const DML_TENSOR_DESC* OutputTensor;
- UINT DimensionCount;
- _Field_size_(DimensionCount) const UINT* Strides;
- _Field_size_(DimensionCount) const UINT* WindowSize;
- _Field_size_(DimensionCount) const UINT* StartPadding;
- _Field_size_(DimensionCount) const UINT* EndPadding;
- _Field_size_(DimensionCount) const UINT* Dilations;
- BOOL IncludePadding;
-};
-const int DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING = 0x8000000B;
-
-
namespace ApiTraits
{
template
@@ -2711,11 +2679,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
-#pragma warning(push)
-#pragma warning(disable: 4063)
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT";
case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
-#pragma warning(pop)
default:
assert(false);
return "";
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 93a53e8d2e05d..30f510ca7ec78 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -314,7 +314,11 @@ namespace Dml::GraphDescBuilder
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
- uint32_t c_maxConstNodeDataSize = 64;
+
+ // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
+ // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
+ // values that have been de-duplicated with conversion to scalars in kernels.
+ uint32_t c_maxConstNodeDataSize = 1024 * 1024;
ComPtr constantInput = constantCpuGraphInputGetter(arg->Name());
@@ -405,6 +409,15 @@ namespace Dml::GraphDescBuilder
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
+
+ // TODO: Change as new header is ingested
+ if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
+ dmlDesc.Type = (DML_OPERATOR_TYPE) 169;
+
+ // TODO: Change as new header is ingested
+ if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
+ dmlDesc.Type = (DML_OPERATOR_TYPE) 170;
+
ComPtr op;
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
index 25c7be42d6425..8343cd1b2a465 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
@@ -6,6 +6,9 @@
namespace Dml
{
+
+ /*static*/ const uint32_t DmlOperator::zeroArray[8] = {};
+
DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
@@ -38,10 +41,6 @@ namespace Dml
}
}
- // Create and compile the operator.
- ComPtr dmlOperator;
- ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator)));
-
ComPtr contextPrivate;
ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf()));
@@ -87,6 +86,20 @@ namespace Dml
}
else
{
+ auto operatorDescCopy = operatorDesc;
+
+ // TODO: Change as new header is ingested
+ if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
+ operatorDescCopy.Type = (DML_OPERATOR_TYPE) 169;
+
+ // TODO: Change as new header is ingested
+ if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
+ operatorDescCopy.Type = (DML_OPERATOR_TYPE) 170;
+
+ // Create and compile the operator.
+ ComPtr dmlOperator;
+ ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDescCopy, IID_PPV_ARGS(&dmlOperator)));
+
DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator)));
@@ -792,8 +805,18 @@ namespace Dml
graphDesc.NodeCount = operatorGraphDesc.nodeCount;
for (size_t i = 0; i < graphDesc.NodeCount; ++i)
{
+ DML_OPERATOR_DESC opDesc = *operatorGraphDesc.nodesAsOpDesc[i];
+
+ // TODO: Change as new header is ingested
+ if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
+ opDesc.Type = (DML_OPERATOR_TYPE) 169;
+
+ // TODO: Change as new header is ingested
+ if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
+ opDesc.Type = (DML_OPERATOR_TYPE) 170;
+
// Create the operator.
- ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i])));
+ ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperators[i])));
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}
@@ -824,4 +847,80 @@ namespace Dml
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
+ /*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
+ const MLOperatorKernelCreationContext& kernelInfo,
+ const DML_TENSOR_DESC* tensor,
+ uint32_t kernelInputIndex)
+ {
+ if (!tensor)
+ {
+ return;
+ }
+
+ auto constExpTensor = kernelInfo.TryGetConstantInputTensor(kernelInputIndex);
+ if (!constExpTensor)
+ {
+ return;
+ }
+
+ uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
+ if (totalKernelInputElementCount <= 1)
+ {
+ return;
+ }
+
+ uint32_t elementSize = 0;
+
+ switch (constExpTensor->GetTensorDataType())
+ {
+ case MLOperatorTensorDataType::UInt8:
+ case MLOperatorTensorDataType::Int8:
+ elementSize = 1;
+ break;
+
+ case MLOperatorTensorDataType::Float16:
+ case MLOperatorTensorDataType::UInt16:
+ case MLOperatorTensorDataType::Int16:
+ elementSize = 2;
+ break;
+
+ case MLOperatorTensorDataType::/*Float32*/Float:
+ case MLOperatorTensorDataType::UInt32:
+ case MLOperatorTensorDataType::Int32:
+ elementSize = 4;
+ break;
+
+ case MLOperatorTensorDataType::/*Float64*/Double:
+ case MLOperatorTensorDataType::UInt64:
+ case MLOperatorTensorDataType::Int64:
+ elementSize = 8;
+ break;
+
+ default:
+ return;
+ }
+
+ const std::uint8_t* byteData = static_cast(constExpTensor->GetByteData());
+
+ assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
+ auto *bufferTensorDesc = const_cast(static_cast(tensor->Desc));
+
+ for (size_t i = 1; i < totalKernelInputElementCount; ++i)
+ {
+ if (memcmp(byteData, byteData + i * elementSize, elementSize))
+ {
+ return;
+ }
+ }
+
+ if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
+ {
+ assert(false);
+ return;
+ }
+
+ bufferTensorDesc->Strides = zeroArray;
+ bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
+ }
+
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
index c1e8cf42a974c..3995c3309bb92 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h
@@ -149,6 +149,11 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
) const;
+ static void TryConvertTensorToBroadcastScalar(
+ const MLOperatorKernelCreationContext& kernelInfo,
+ const DML_TENSOR_DESC* tensor,
+ uint32_t kernelInputIndex);
+
private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
@@ -164,6 +169,7 @@ namespace Dml
_Inout_ std::vector& dmlOutputEdges,
_Inout_ std::vector& dmlIntermediateEdges);
+ static const uint32_t zeroArray[8];
};
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp
index 60b235880e23f..20163869a2aaf 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp
@@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe
std::vector inputDescs = GetDmlInputDescs();
std::vector outputDescs = GetDmlOutputDescs();
+ // TODO: Port this to a graph description to enable DML graph optimization
+
dml::Graph graph(m_dmlDevice.Get());
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
index 0eb14cea8dc9f..7336f2c7fab5d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
@@ -585,6 +585,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
+ TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo);
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp
index 7b50dfb9ff1ad..789ce5b5c56f1 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp
@@ -58,6 +58,15 @@ class DmlOperatorQLinearAdd : public DmlOperator
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
index 0fccedfe311c1..605e5fffb6a76 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
@@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
- qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
- qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
+ qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
+ qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
@@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+ TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
+ TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
+ TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);
+
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
index 67711fdc28b84..c97b03dc36b62 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
@@ -123,6 +123,9 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
+
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1);
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2);
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
@@ -154,6 +157,10 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
+
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale);
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint);
+
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
index d45fdef3c8807..4e121a6502cba 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp
@@ -117,6 +117,15 @@ class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase
convDesc.EndPadding = kernelArgs.endPadding;
convDesc.GroupCount = m_groupCount;
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
+
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp
index b746a0e81a5cf..b38acd8cbf978 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp
@@ -104,6 +104,15 @@ class DmlOperatorQLinearMatMul : public DmlOperator
matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr;
matMulDesc.OutputTensor = &outputDescs[0];
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT);
+
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE);
+ TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT);
+
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
index 14326a46b2a64..84d86c23f71f4 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
@@ -88,6 +88,9 @@ class DmlOperatorQLinearSigmoid : public DmlOperator
dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale];
dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point];
dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc;
+
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale);
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point);
const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc};
@@ -101,6 +104,10 @@ class DmlOperatorQLinearSigmoid : public DmlOperator
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
+
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale);
+ TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point);
+
const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
MLOperatorGraphDesc operatorGraphDesc = {};
diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc
index 6a2740e4369e9..424c0a43da974 100644
--- a/onnxruntime/core/providers/dml/dml_provider_factory.cc
+++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include
+#include
+
#include
#ifndef _GAMING_XBOX
#include
@@ -92,11 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId);
}
+static bool IsHardwareAdapter(IDXCoreAdapter* adapter) {
+ bool is_hardware = false;
+ THROW_IF_FAILED(adapter->GetProperty(
+ DXCoreAdapterProperty::IsHardware,
+ &is_hardware));
+ return is_hardware;
+}
+
+static bool IsGPU(IDXCoreAdapter* compute_adapter) {
+ // Only considering hardware adapters
+ if (!IsHardwareAdapter(compute_adapter)) {
+ return false;
+ }
+ return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS);
+}
+
+static bool IsNPU(IDXCoreAdapter* compute_adapter) {
+ // Only considering hardware adapters
+ if (!IsHardwareAdapter(compute_adapter)) {
+ return false;
+ }
+ return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS));
+}
+
+enum class DeviceType { GPU, NPU, BadDevice };
+
+static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) {
+ auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu;
+ if (IsGPU(adapter) && allow_gpus) {
+ return DeviceType::GPU;
+ }
+
+ auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
+ if (IsNPU(adapter) && allow_npus) {
+ return DeviceType::NPU;
+ }
+
+ return DeviceType::BadDevice;
+}
+
+// Struct for holding each adapter
+struct AdapterInfo {
+ ComPtr Adapter;
+ DeviceType Type; // GPU or NPU
+};
+
+static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) {
+ ComPtr adapter_list;
+
+ // TODO: use_dxcore_workload_enumeration should be determined by QI
+ // When DXCore APIs are available QI for relevant enumeration interfaces
+ constexpr bool use_dxcore_workload_enumeration = false;
+ if (!use_dxcore_workload_enumeration) {
+ // Get a list of all the adapters that support compute
+ GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
+ ORT_THROW_IF_FAILED(
+ adapter_factory->CreateAdapterList(_countof(attributes),
+ attributes,
+ adapter_list.GetAddressOf()));
+ }
+
+ return adapter_list;
+}
+
+static void SortDXCoreAdaptersByPreference(
+ IDXCoreAdapterList* adapter_list,
+ OrtDmlPerformancePreference preference) {
+ if (adapter_list->GetAdapterCount() <= 1) {
+ return;
+ }
+
+ // DML prefers the HighPerformance adapter by default
+ std::array adapter_list_preferences = {
+ DXCoreAdapterPreference::HighPerformance
+ };
+
+ // If callers specify minimum power change the DXCore sort policy
+ // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists
+ if (preference == OrtDmlPerformancePreference::MinimumPower) {
+ adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower;
+ }
+
+ ORT_THROW_IF_FAILED(adapter_list->Sort(
+ static_cast(adapter_list_preferences.size()),
+ adapter_list_preferences.data()));
+}
+
+static std::vector FilterDXCoreAdapters(
+ IDXCoreAdapterList* adapter_list,
+ OrtDmlDeviceFilter filter) {
+ auto adapter_infos = std::vector();
+ const uint32_t count = adapter_list->GetAdapterCount();
+ for (uint32_t i = 0; i < count; ++i) {
+ ComPtr candidate_adapter;
+ ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf()));
+
+ // Add the adapters that are valid based on the device filter (GPU, NPU, or Both)
+ auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter);
+ if (adapter_type != DeviceType::BadDevice) {
+ adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type});
+ }
+ }
+
+ return adapter_infos;
+}
+
+static void SortHeterogenousDXCoreAdapterList(
+ std::vector& adapter_infos,
+ OrtDmlDeviceFilter filter,
+ OrtDmlPerformancePreference preference) {
+ if (adapter_infos.size() <= 1) {
+ return;
+ }
+
+ // When considering both GPUs and NPUs sort them by performance preference
+ // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first)
+ auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
+ auto only_npus = filter == OrtDmlDeviceFilter::Npu;
+ if (!keep_npus || only_npus) {
+ return;
+ }
+
+ struct SortingPolicy {
+ // default is false because GPUs are considered higher priority in
+ // a mixed adapter environment
+ bool npus_first_ = false;
+
+ SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { }
+
+ bool operator()(const AdapterInfo& a, const AdapterInfo& b) {
+ return npus_first_ ? a.Type < b.Type : a.Type > b.Type;
+ }
+ };
+
+ auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower);
+ auto policy = SortingPolicy(npus_first);
+ std::sort(adapter_infos.begin(), adapter_infos.end(), policy);
+}
+
std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) {
return Create(device_id, /*skip_software_device_check*/ false);
}
-std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
+std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions(
+ OrtDmlDeviceOptions* device_options) {
+ auto default_device_options = OrtDmlDeviceOptions { Default, Gpu };
+ if (device_options == nullptr) {
+ device_options = &default_device_options;
+ }
+
+ OrtDmlPerformancePreference preference = device_options->Preference;
+ OrtDmlDeviceFilter filter = device_options->Filter;
+
+ // Create DXCore Adapter Factory
+ ComPtr adapter_factory;
+ ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf()));
+
+ // Get all DML compatible DXCore adapters
+ ComPtr adapter_list;
+ adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get());
+
+ if (adapter_list->GetAdapterCount() == 0) {
+ ORT_THROW("No GPUs or NPUs detected.");
+ }
+
+ // Sort the adapter list to honor DXCore hardware ordering
+ SortDXCoreAdaptersByPreference(adapter_list.Get(), preference);
+
+ // TODO: use_dxcore_workload_enumeration should be determined by QI
+ // When DXCore APIs are available QI for relevant enumeration interfaces
+ constexpr bool use_dxcore_workload_enumeration = false;
+
+ std::vector adapter_infos;
+ if (!use_dxcore_workload_enumeration) {
+ // Filter all DXCore adapters to hardware type specified by the device filter
+ adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter);
+ if (adapter_infos.size() == 0) {
+ ORT_THROW("No devices detected that match the filter criteria.");
+ }
+ }
+
+ // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them.
+ SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference);
+
+ // Extract just the adapters
+ auto adapters = std::vector>(adapter_infos.size());
+ std::transform(
+ adapter_infos.begin(), adapter_infos.end(),
+ adapters.begin(),
+ [](auto& a){ return a.Adapter; });
+
+ return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters));
+}
+
+static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) {
+ static const std::string PerformancePreference = "performance_preference";
+ static const std::string Default = "default";
+ static const std::string HighPerformance = "high_performance";
+ static const std::string MinimumPower = "minimum_power";
+
+ auto preference_it = provider_options.find(PerformancePreference);
+ if (preference_it != provider_options.end()) {
+ if (preference_it->second == Default) {
+ return OrtDmlPerformancePreference::Default;
+ }
+
+ if (preference_it->second == HighPerformance) {
+ return OrtDmlPerformancePreference::HighPerformance;
+ }
+
+ if (preference_it->second == MinimumPower) {
+ return OrtDmlPerformancePreference::MinimumPower;
+ }
+
+ ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection.");
+ }
+
+ return {};
+}
+
+static std::optional ParseFilter(const ProviderOptions& provider_options) {
+ static const std::string Filter = "filter";
+ static const std::string Any = "any";
+ static const std::string Gpu = "gpu";
+ static const std::string Npu = "npu";
+
+ auto preference_it = provider_options.find(Filter);
+ if (preference_it != provider_options.end()) {
+ if (preference_it->second == Any) {
+ return OrtDmlDeviceFilter::Any;
+ }
+
+ if (preference_it->second == Gpu) {
+ return OrtDmlDeviceFilter::Gpu;
+ }
+
+ if (preference_it->second == Npu) {
+ return OrtDmlDeviceFilter::Npu;
+ }
+
+ ORT_THROW("Invalid Filter provided for DirectML EP device selection.");
+ }
+
+ return {};
+}
+
+static std::optional ParseDeviceId(const ProviderOptions& provider_options) {
+ static const std::string DeviceId = "device_id";
+
+ auto preference_it = provider_options.find(DeviceId);
+ if (preference_it != provider_options.end()) {
+ if (!preference_it->second.empty()) {
+ return std::stoi(preference_it->second);
+ }
+ }
+
+ return {};
+}
+
+std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions(
+ const ProviderOptions& provider_options) {
+ auto device_id = ParseDeviceId(provider_options);
+ if (device_id.has_value())
+ {
+ return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value());
+ }
+
+ auto preference = ParsePerformancePreference(provider_options);
+ auto filter = ParseFilter(provider_options);
+
+ // If no preference/filters are specified then create with default preference/filters.
+ if (!preference.has_value() && !filter.has_value()) {
+ return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr);
+ }
+
+ if (!preference.has_value()) {
+ preference = OrtDmlPerformancePreference::Default;
+ }
+
+ if (!filter.has_value()) {
+ filter = OrtDmlDeviceFilter::Gpu;
+ }
+
+ OrtDmlDeviceOptions device_options;
+ device_options.Preference = preference.value();
+ device_options.Filter = filter.value();
+ return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options);
+}
+
+Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(
+ int device_id,
+ bool skip_software_device_check) {
#ifdef _GAMING_XBOX
ComPtr d3d12_device;
D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {};
@@ -124,35 +414,65 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int
ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
#endif
- D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
- cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
- cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
-
- ComPtr cmd_queue;
- ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
+ return d3d12_device;
+}
+Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) {
DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;
// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
#if _DEBUG && !_GAMING_XBOX
- ComPtr debug_device;
+ Microsoft::WRL::ComPtr debug_device;
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);
if (is_d3d12_debug_layer_enabled) {
flags |= DML_CREATE_DEVICE_FLAG_DEBUG;
}
+
+ ComPtr d3d12_device;
+ ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
#endif
- ComPtr dml_device;
- ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(),
- flags,
- DML_FEATURE_LEVEL_5_0,
- IID_PPV_ARGS(&dml_device)));
+ Microsoft::WRL::ComPtr dml_device;
+ ORT_THROW_IF_FAILED(DMLCreateDevice1(
+ d3d12_device,
+ flags,
+ DML_FEATURE_LEVEL_5_0,
+ IID_PPV_ARGS(&dml_device)));
+
+ return dml_device;
+}
+
+std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) {
+ D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
+ cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
+ cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
+
+ ComPtr cmd_queue;
+ ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
+ auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device);
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
}
+std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
+ ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
+ return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
+}
+
+std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList(
+ std::vector>&& dxcore_devices) {
+ // Choose the first device from the list since it's the highest priority
+ auto dxcore_device = dxcore_devices[0];
+
+ // Create D3D12 Device from DXCore Adapter
+ ComPtr d3d12_device;
+ ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
+
+ return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
+}
+
} // namespace onnxruntime
// [[deprecated]]
@@ -197,6 +517,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
API_IMPL_END
}
+ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) {
+API_IMPL_BEGIN
+#ifdef USE_DML
+ auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options);
+ // return the create function for a dxcore device
+ options->provider_factories.push_back(factory);
+#endif // USE_DML
+ return nullptr;
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h
index 3c3f04561d383..330055d64fd30 100644
--- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h
+++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h
@@ -5,11 +5,30 @@
#include
+#include
+#include
+#include "core/framework/provider_options.h"
#include "core/providers/providers.h"
+#include
+#include
+
+interface IDMLDevice;
+struct OrtDmlDeviceOptions;
+
namespace onnxruntime {
struct DMLProviderFactoryCreator {
static std::shared_ptr Create(int device_id);
static std::shared_ptr Create(int device_id, bool skip_software_device_check);
+
+ static std::shared_ptr CreateFromProviderOptions(
+ const ProviderOptions& provider_options_map);
+ static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options);
+
+ static std::shared_ptr CreateFromAdapterList(
+ std::vector>&& dxcore_devices);
+
+ static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check);
+ static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc
index 8b32ec05719b6..2618a6a28127d 100644
--- a/onnxruntime/core/session/provider_registration.cc
+++ b/onnxruntime/core/session/provider_registration.cc
@@ -11,6 +11,10 @@
#include "core/session/onnxruntime_c_api.h"
#include "core/session/ort_apis.h"
+#if defined(USE_DML)
+#include "core/providers/dml/dml_provider_factory_creator.h"
+#endif
+
using namespace onnxruntime;
namespace {
@@ -66,7 +70,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
};
- if (strcmp(provider_name, "QNN") == 0) {
+ if (strcmp(provider_name, "DML") == 0) {
+#if defined(USE_DML)
+ options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options));
+#else
+ status = create_not_supported_status();
+#endif
+ } else if (strcmp(provider_name, "QNN") == 0) {
#if defined(USE_QNN)
options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value)));
#else
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 12b020f32b22f..ae31e7da9c23f 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -797,18 +797,10 @@ std::unique_ptr CreateExecutionProviderInstance(
#endif
} else if (type == kDmlExecutionProvider) {
#ifdef USE_DML
- int device_id = 0;
- auto it = provider_options_map.find(type);
- if (it != provider_options_map.end()) {
- for (auto option : it->second) {
- if (option.first == "device_id") {
- if (!option.second.empty()) {
- device_id = std::stoi(option.second);
- }
- }
- }
- }
- return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider();
+ auto cit = provider_options_map.find(type);
+ return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions(
+ cit == provider_options_map.end() ? ProviderOptions{} : cit->second)
+ ->CreateProvider();
#endif
} else if (type == kNnapiExecutionProvider) {
#if defined(USE_NNAPI)
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index d283d9df62d6a..e3031ccc02a50 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -677,9 +677,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
#endif
} else if (provider_name == onnxruntime::kDmlExecutionProvider) {
#ifdef USE_DML
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0));
+ std::unordered_map dml_options;
+ dml_options["performance_preference"] = "high_performance";
+ dml_options["device_filter"] = "gpu";
+ session_options.AppendExecutionProvider("DML", dml_options);
#else
- ORT_THROW("DirectML is not supported in this build\n");
+ ORT_THROW("DML is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kAclExecutionProvider) {
#ifdef USE_ACL
diff --git a/packages.config b/packages.config
index b2c918c414ccc..54c8c14872fc1 100644
--- a/packages.config
+++ b/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py
index a4e00b92823cd..f1f62738843a3 100644
--- a/tools/nuget/generate_nuspec_for_native_nuget.py
+++ b/tools/nuget/generate_nuspec_for_native_nuget.py
@@ -192,7 +192,7 @@ def generate_repo_url(line_list, repo_url, commit_id):
def generate_dependencies(xml_text, package_name, version):
- dml_dependency = ''
+ dml_dependency = ''
if package_name == "Microsoft.AI.MachineLearning":
xml_text.append("")