diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 8eef0b5bac62b..04c49f40ecbcf 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 81f97948f1817..ad0ea2196a4d6 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index f446c2be9847c..7cab60514302b 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.0) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.Preview.1.13.0-deveb7a0e89e82dcf90ae58761b35fb3aebc2275ef5) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 7fb03487a255e..047b7c1ca98d2 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1314,13 +1314,13 @@ if (onnxruntime_USE_DML) if (GDK_PLATFORM STREQUAL Scarlett) target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs}) else() - target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib) + target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib) endif() target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199") + set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 0782d2d9ed760..dd4ffb835d51c 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice; extern "C" { #endif +enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +}; + +enum OrtDmlDeviceFilter : uint32_t { + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +}; + +inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } +inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); } +inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); } +inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); } +inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); } +inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } +inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } + +struct OrtDmlDeviceOptions { + OrtDmlPerformancePreference Preference; + OrtDmlDeviceFilter Filter; +}; + /** * [[deprecated]] * This export is deprecated. @@ -99,6 +124,13 @@ struct OrtDmlApi { * This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP. */ ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource); + + /** + * SessionOptionsAppendExecutionProvider_DML2 + * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference + * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 232a022d869f4..034b05e36aaaa 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -82,7 +82,10 @@ namespace Windows::AI::MachineLearning::Adapter { uint32_t nodeCount; std::vector> nodesAsOperatorDesc; + + // TODO: Remove this std::vector> nodesAsIDMLOperator; + std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 570a0f82b62ff..8558e33aaa8e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -3,38 +3,6 @@ #pragma once -struct DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC -{ - const DML_TENSOR_DESC* ATensor; - const DML_TENSOR_DESC* AScaleTensor; - _Maybenull_ const DML_TENSOR_DESC* AZeroPointTensor; - const DML_TENSOR_DESC* BTensor; - const DML_TENSOR_DESC* BScaleTensor; - _Maybenull_ const DML_TENSOR_DESC* BZeroPointTensor; - _Maybenull_ const DML_TENSOR_DESC* BiasTensor; - const DML_TENSOR_DESC* OutputTensor; -}; -const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011; - -struct DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC -{ - const DML_TENSOR_DESC* InputTensor; - const DML_TENSOR_DESC* InputScaleTensor; - _Maybenull_ const DML_TENSOR_DESC* InputZeroPointTensor; - const DML_TENSOR_DESC* OutputScaleTensor; - _Maybenull_ const DML_TENSOR_DESC* OutputZeroPointTensor; - const DML_TENSOR_DESC* OutputTensor; - UINT DimensionCount; - _Field_size_(DimensionCount) const UINT* Strides; - _Field_size_(DimensionCount) const UINT* WindowSize; - _Field_size_(DimensionCount) const UINT* StartPadding; - _Field_size_(DimensionCount) const UINT* EndPadding; - _Field_size_(DimensionCount) const UINT* Dilations; - BOOL IncludePadding; -}; -const int DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING = 0x8000000B; - - namespace ApiTraits { template @@ -2711,11 +2679,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; -#pragma warning(push) -#pragma warning(disable: 4063) case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; -#pragma warning(pop) default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 93a53e8d2e05d..30f510ca7ec78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -314,7 +314,11 @@ namespace Dml::GraphDescBuilder // This is a highly inefficient approach to generating constant nodes. It duplicates constant data // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. - uint32_t c_maxConstNodeDataSize = 64; + + // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming + // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point + // values that have been de-duplicated with conversion to scalars in kernels. + uint32_t c_maxConstNodeDataSize = 1024 * 1024; ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); @@ -405,6 +409,15 @@ namespace Dml::GraphDescBuilder auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + dmlDesc.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + dmlDesc.Type = (DML_OPERATOR_TYPE) 170; + ComPtr op; ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); allocator.Reset(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 25c7be42d6425..8343cd1b2a465 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -6,6 +6,9 @@ namespace Dml { + + /*static*/ const uint32_t DmlOperator::zeroArray[8] = {}; + DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo) { ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider)); @@ -38,10 +41,6 @@ namespace Dml } } - // Create and compile the operator. - ComPtr dmlOperator; - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator))); - ComPtr contextPrivate; ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf())); @@ -87,6 +86,20 @@ namespace Dml } else { + auto operatorDescCopy = operatorDesc; + + // TODO: Change as new header is ingested + if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + operatorDescCopy.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + operatorDescCopy.Type = (DML_OPERATOR_TYPE) 170; + + // Create and compile the operator. + ComPtr dmlOperator; + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDescCopy, IID_PPV_ARGS(&dmlOperator))); + DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags(); ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator))); @@ -792,8 +805,18 @@ namespace Dml graphDesc.NodeCount = operatorGraphDesc.nodeCount; for (size_t i = 0; i < graphDesc.NodeCount; ++i) { + DML_OPERATOR_DESC opDesc = *operatorGraphDesc.nodesAsOpDesc[i]; + + // TODO: Change as new header is ingested + if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + opDesc.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + opDesc.Type = (DML_OPERATOR_TYPE) 170; + // Create the operator. - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i]))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperators[i]))); dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } @@ -824,4 +847,80 @@ namespace Dml graphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } + /*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex) + { + if (!tensor) + { + return; + } + + auto constExpTensor = kernelInfo.TryGetConstantInputTensor(kernelInputIndex); + if (!constExpTensor) + { + return; + } + + uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount(); + if (totalKernelInputElementCount <= 1) + { + return; + } + + uint32_t elementSize = 0; + + switch (constExpTensor->GetTensorDataType()) + { + case MLOperatorTensorDataType::UInt8: + case MLOperatorTensorDataType::Int8: + elementSize = 1; + break; + + case MLOperatorTensorDataType::Float16: + case MLOperatorTensorDataType::UInt16: + case MLOperatorTensorDataType::Int16: + elementSize = 2; + break; + + case MLOperatorTensorDataType::/*Float32*/Float: + case MLOperatorTensorDataType::UInt32: + case MLOperatorTensorDataType::Int32: + elementSize = 4; + break; + + case MLOperatorTensorDataType::/*Float64*/Double: + case MLOperatorTensorDataType::UInt64: + case MLOperatorTensorDataType::Int64: + elementSize = 8; + break; + + default: + return; + } + + const std::uint8_t* byteData = static_cast(constExpTensor->GetByteData()); + + assert(tensor->Type == DML_TENSOR_TYPE_BUFFER); + auto *bufferTensorDesc = const_cast(static_cast(tensor->Desc)); + + for (size_t i = 1; i < totalKernelInputElementCount; ++i) + { + if (memcmp(byteData, byteData + i * elementSize, elementSize)) + { + return; + } + } + + if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0])) + { + assert(false); + return; + } + + bufferTensorDesc->Strides = zeroArray; + bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3; + } + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index c1e8cf42a974c..3995c3309bb92 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -149,6 +149,11 @@ namespace Dml uint32_t minDimensionCount = NchwDimensionCount ) const; + static void TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex); + private: // For each input or output of the DML kernel, the corresponding input or output of the original // kernel. Entries for unused DML inputs are nullopt. @@ -164,6 +169,7 @@ namespace Dml _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); + static const uint32_t zeroArray[8]; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 60b235880e23f..20163869a2aaf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); + // TODO: Port this to a graph description to enable DML graph optimization + dml::Graph graph(m_dmlDevice.Get()); dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X]; dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 0eb14cea8dc9f..7336f2c7fab5d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -585,6 +585,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator opDesc.ScaleTensor = &inputDescs[1]; opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp index 7b50dfb9ff1ad..789ce5b5c56f1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp @@ -58,6 +58,15 @@ class DmlOperatorQLinearAdd : public DmlOperator AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE]; AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; AddDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT); DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp index 0fccedfe311c1..605e5fffb6a76 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; - qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; - qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale]; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint]; qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; @@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.Dilations = m_kernel.dilations; qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint); + + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint); + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index 67711fdc28b84..c97b03dc36b62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -123,6 +123,9 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2); dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; opDescs.push_back(&dmlOpDesc[inputIndex]); @@ -154,6 +157,10 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint); + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; opDescs.push_back(&opQuantizeDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp index d45fdef3c8807..4e121a6502cba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp @@ -117,6 +117,15 @@ class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase convDesc.EndPadding = kernelArgs.endPadding; convDesc.GroupCount = m_groupCount; + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp index b746a0e81a5cf..b38acd8cbf978 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp @@ -104,6 +104,15 @@ class DmlOperatorQLinearMatMul : public DmlOperator matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr; matMulDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 14326a46b2a64..84d86c23f71f4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -88,6 +88,9 @@ class DmlOperatorQLinearSigmoid : public DmlOperator dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale]; dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point]; dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point); const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc}; @@ -101,6 +104,10 @@ class DmlOperatorQLinearSigmoid : public DmlOperator quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point); + const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; MLOperatorGraphDesc operatorGraphDesc = {}; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 6a2740e4369e9..424c0a43da974 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include #ifndef _GAMING_XBOX #include @@ -92,11 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } +static bool IsHardwareAdapter(IDXCoreAdapter* adapter) { + bool is_hardware = false; + THROW_IF_FAILED(adapter->GetProperty( + DXCoreAdapterProperty::IsHardware, + &is_hardware)); + return is_hardware; +} + +static bool IsGPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); +} + +static bool IsNPU(IDXCoreAdapter* compute_adapter) { + // Only considering hardware adapters + if (!IsHardwareAdapter(compute_adapter)) { + return false; + } + return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); +} + +enum class DeviceType { GPU, NPU, BadDevice }; + +static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) { + auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu; + if (IsGPU(adapter) && allow_gpus) { + return DeviceType::GPU; + } + + auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + if (IsNPU(adapter) && allow_npus) { + return DeviceType::NPU; + } + + return DeviceType::BadDevice; +} + +// Struct for holding each adapter +struct AdapterInfo { + ComPtr Adapter; + DeviceType Type; // GPU or NPU +}; + +static ComPtr EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) { + ComPtr adapter_list; + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + if (!use_dxcore_workload_enumeration) { + // Get a list of all the adapters that support compute + GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; + ORT_THROW_IF_FAILED( + adapter_factory->CreateAdapterList(_countof(attributes), + attributes, + adapter_list.GetAddressOf())); + } + + return adapter_list; +} + +static void SortDXCoreAdaptersByPreference( + IDXCoreAdapterList* adapter_list, + OrtDmlPerformancePreference preference) { + if (adapter_list->GetAdapterCount() <= 1) { + return; + } + + // DML prefers the HighPerformance adapter by default + std::array adapter_list_preferences = { + DXCoreAdapterPreference::HighPerformance + }; + + // If callers specify minimum power change the DXCore sort policy + // NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists + if (preference == OrtDmlPerformancePreference::MinimumPower) { + adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower; + } + + ORT_THROW_IF_FAILED(adapter_list->Sort( + static_cast(adapter_list_preferences.size()), + adapter_list_preferences.data())); +} + +static std::vector FilterDXCoreAdapters( + IDXCoreAdapterList* adapter_list, + OrtDmlDeviceFilter filter) { + auto adapter_infos = std::vector(); + const uint32_t count = adapter_list->GetAdapterCount(); + for (uint32_t i = 0; i < count; ++i) { + ComPtr candidate_adapter; + ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf())); + + // Add the adapters that are valid based on the device filter (GPU, NPU, or Both) + auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter); + if (adapter_type != DeviceType::BadDevice) { + adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type}); + } + } + + return adapter_infos; +} + +static void SortHeterogenousDXCoreAdapterList( + std::vector& adapter_infos, + OrtDmlDeviceFilter filter, + OrtDmlPerformancePreference preference) { + if (adapter_infos.size() <= 1) { + return; + } + + // When considering both GPUs and NPUs sort them by performance preference + // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) + auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; + auto only_npus = filter == OrtDmlDeviceFilter::Npu; + if (!keep_npus || only_npus) { + return; + } + + struct SortingPolicy { + // default is false because GPUs are considered higher priority in + // a mixed adapter environment + bool npus_first_ = false; + + SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { } + + bool operator()(const AdapterInfo& a, const AdapterInfo& b) { + return npus_first_ ? a.Type < b.Type : a.Type > b.Type; + } + }; + + auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower); + auto policy = SortingPolicy(npus_first); + std::sort(adapter_infos.begin(), adapter_infos.end(), policy); +} + std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { return Create(device_id, /*skip_software_device_check*/ false); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { +std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( + OrtDmlDeviceOptions* device_options) { + auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; + if (device_options == nullptr) { + device_options = &default_device_options; + } + + OrtDmlPerformancePreference preference = device_options->Preference; + OrtDmlDeviceFilter filter = device_options->Filter; + + // Create DXCore Adapter Factory + ComPtr adapter_factory; + ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf())); + + // Get all DML compatible DXCore adapters + ComPtr adapter_list; + adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get()); + + if (adapter_list->GetAdapterCount() == 0) { + ORT_THROW("No GPUs or NPUs detected."); + } + + // Sort the adapter list to honor DXCore hardware ordering + SortDXCoreAdaptersByPreference(adapter_list.Get(), preference); + + // TODO: use_dxcore_workload_enumeration should be determined by QI + // When DXCore APIs are available QI for relevant enumeration interfaces + constexpr bool use_dxcore_workload_enumeration = false; + + std::vector adapter_infos; + if (!use_dxcore_workload_enumeration) { + // Filter all DXCore adapters to hardware type specified by the device filter + adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter); + if (adapter_infos.size() == 0) { + ORT_THROW("No devices detected that match the filter criteria."); + } + } + + // DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them. + SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference); + + // Extract just the adapters + auto adapters = std::vector>(adapter_infos.size()); + std::transform( + adapter_infos.begin(), adapter_infos.end(), + adapters.begin(), + [](auto& a){ return a.Adapter; }); + + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); +} + +static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { + static const std::string PerformancePreference = "performance_preference"; + static const std::string Default = "default"; + static const std::string HighPerformance = "high_performance"; + static const std::string MinimumPower = "minimum_power"; + + auto preference_it = provider_options.find(PerformancePreference); + if (preference_it != provider_options.end()) { + if (preference_it->second == Default) { + return OrtDmlPerformancePreference::Default; + } + + if (preference_it->second == HighPerformance) { + return OrtDmlPerformancePreference::HighPerformance; + } + + if (preference_it->second == MinimumPower) { + return OrtDmlPerformancePreference::MinimumPower; + } + + ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseFilter(const ProviderOptions& provider_options) { + static const std::string Filter = "filter"; + static const std::string Any = "any"; + static const std::string Gpu = "gpu"; + static const std::string Npu = "npu"; + + auto preference_it = provider_options.find(Filter); + if (preference_it != provider_options.end()) { + if (preference_it->second == Any) { + return OrtDmlDeviceFilter::Any; + } + + if (preference_it->second == Gpu) { + return OrtDmlDeviceFilter::Gpu; + } + + if (preference_it->second == Npu) { + return OrtDmlDeviceFilter::Npu; + } + + ORT_THROW("Invalid Filter provided for DirectML EP device selection."); + } + + return {}; +} + +static std::optional ParseDeviceId(const ProviderOptions& provider_options) { + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) + { + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + } + + auto preference = ParsePerformancePreference(provider_options); + auto filter = ParseFilter(provider_options); + + // If no preference/filters are specified then create with default preference/filters. + if (!preference.has_value() && !filter.has_value()) { + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + } + + if (!preference.has_value()) { + preference = OrtDmlPerformancePreference::Default; + } + + if (!filter.has_value()) { + filter = OrtDmlDeviceFilter::Gpu; + } + + OrtDmlDeviceOptions device_options; + device_options.Preference = preference.value(); + device_options.Filter = filter.value(); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); +} + +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( + int device_id, + bool skip_software_device_check) { #ifdef _GAMING_XBOX ComPtr d3d12_device; D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {}; @@ -124,35 +414,65 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); #endif - D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; - - ComPtr cmd_queue; - ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + return d3d12_device; +} +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) { DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG && !_GAMING_XBOX - ComPtr debug_device; + Microsoft::WRL::ComPtr debug_device; (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); if (is_d3d12_debug_layer_enabled) { flags |= DML_CREATE_DEVICE_FLAG_DEBUG; } + + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); #endif - ComPtr dml_device; - ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(), - flags, - DML_FEATURE_LEVEL_5_0, - IID_PPV_ARGS(&dml_device))); + Microsoft::WRL::ComPtr dml_device; + ORT_THROW_IF_FAILED(DMLCreateDevice1( + d3d12_device, + flags, + DML_FEATURE_LEVEL_5_0, + IID_PPV_ARGS(&dml_device))); + + return dml_device; +} + +std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + + ComPtr cmd_queue; + ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } +std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( + std::vector>&& dxcore_devices) { + // Choose the first device from the list since it's the highest priority + auto dxcore_device = dxcore_devices[0]; + + // Create D3D12 Device from DXCore Adapter + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); +} + } // namespace onnxruntime // [[deprecated]] @@ -197,6 +517,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { +API_IMPL_BEGIN +#ifdef USE_DML + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + // return the create function for a dxcore device + options->provider_factories.push_back(factory); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { API_IMPL_BEGIN #ifdef USE_DML diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 3c3f04561d383..330055d64fd30 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -5,11 +5,30 @@ #include +#include +#include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include +#include + +interface IDMLDevice; +struct OrtDmlDeviceOptions; + namespace onnxruntime { struct DMLProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(int device_id, bool skip_software_device_check); + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromAdapterList( + std::vector>&& dxcore_devices); + + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); + static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8b32ec05719b6..2618a6a28127d 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -11,6 +11,10 @@ #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#if defined(USE_DML) +#include "core/providers/dml/dml_provider_factory_creator.h" +#endif + using namespace onnxruntime; namespace { @@ -66,7 +70,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; - if (strcmp(provider_name, "QNN") == 0) { + if (strcmp(provider_name, "DML") == 0) { +#if defined(USE_DML) + options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "QNN") == 0) { #if defined(USE_QNN) options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 12b020f32b22f..ae31e7da9c23f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -797,18 +797,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index d283d9df62d6a..e3031ccc02a50 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -677,9 +677,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0)); + std::unordered_map dml_options; + dml_options["performance_preference"] = "high_performance"; + dml_options["device_filter"] = "gpu"; + session_options.AppendExecutionProvider("DML", dml_options); #else - ORT_THROW("DirectML is not supported in this build\n"); + ORT_THROW("DML is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL diff --git a/packages.config b/packages.config index b2c918c414ccc..54c8c14872fc1 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index a4e00b92823cd..f1f62738843a3 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -192,7 +192,7 @@ def generate_repo_url(line_list, repo_url, commit_id): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("")