From c3d96a7b35c975a4eb4ad5a4c94349797defbc78 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:06:26 -0800 Subject: [PATCH 01/25] Update DML version to 1.13.0 (#18978) Update DML nuget version to 1.13.0 --- .pipelines/nuget_config/x64/packages.config | 2 +- .pipelines/nuget_config/x86/packages.config | 2 +- cmake/external/dml.cmake | 2 +- packages.config | 2 +- tools/nuget/generate_nuspec_for_native_nuget.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2ac650b0e6dc9..2583e0d1b2ead 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index f80f96194a230..5ca659941c159 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 5d25b9529e030..d777306722cd6 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/packages.config b/packages.config index da61a10adfa74..b67219d6d6913 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 66248565a3e3a..56e50750ac153 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -219,7 +219,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("") From 9bbe425d7f805d4dfaad2e69c3edca40063fe673 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 27 Jul 2023 19:13:15 -0700 Subject: [PATCH 02/25] Register LPpool18 and AvgPool 19 (#16880) --- .../src/External/DirectMLHelpers/ApiTraits.h | 30 ++++++++++++++ .../External/DirectMLHelpers/DirectMLSchema.h | 40 +++++++++++++++++++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 39 ++++++++++++++++++ .../src/Operators/DmlOperatorPooling.cpp | 40 +++++++++++++++++-- .../src/Operators/OperatorRegistration.cpp | 2 + .../OperatorAuthorHelper/OperatorVersions.h | 6 +++ .../test/providers/cpu/nn/pool_op_test.cc | 25 ------------ 7 files changed, 153 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index c75b662af788d..94f2220fcc168 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -459,12 +459,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1; +}; + template <> struct OperatorDescTraits { static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1; +}; + template <> struct OperatorDescTraits { @@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1> +{ + using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING> { using DescType = DML_LP_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1> +{ + using DescType = DML_LP_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING> { @@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_AVERAGE_POOLING: return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_AVERAGE_POOLING1: + return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LP_POOLING: return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LP_POOLING1: + return std::invoke(std::forward(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING: return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING1: @@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; + case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; + case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 1ebd52d4ed427..9eae1c1fe8158 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA { DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_AVERAGE_POOLING1", + DML_OPERATOR_AVERAGE_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA { DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_LP_POOLING1", + DML_OPERATOR_LP_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 833871de0bbd9..ad4cceb85cfd2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,6 +425,21 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } + +inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc) { return { @@ -438,6 +453,20 @@ inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))), }; } +inline std::vector GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.P))), + }; +} inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc) { return { @@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA; case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA; case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA; @@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_AVERAGE_POOLING1: + return AbstractOperatorDesc( + &DML_AVERAGE_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_LP_POOLING: return AbstractOperatorDesc( &DML_LP_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LP_POOLING1: + return AbstractOperatorDesc( + &DML_LP_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_MAX_POOLING: return AbstractOperatorDesc( &DML_MAX_POOLING_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index e8d5b2746aa13..10ff1d8be8a29 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase kernelOutputIndices.emplace_back(1); } DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices); - + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1."); @@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_AVERAGE_POOLING1: + { + if (hasDilations) { + DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_AVERAGE_POOLING_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_LP_POOLING: { DML_LP_POOLING_OPERATOR_DESC desc = {}; @@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_LP_POOLING1: + { + if (hasDilations) { + DML_LP_POOLING1_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_LP_POOLING_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_MAX_POOLING: case DML_OPERATOR_MAX_POOLING1: case DML_OPERATOR_MAX_POOLING2: @@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { *isSupported = false; - + MLOperatorAttributes attributes(context); int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0); @@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* *isSupported = true; } -DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 28360f09bcba3..dbe9f5da4f569 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, @@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e18ba31def48a..3eb35faeba82f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -406,6 +406,12 @@ namespace OperatorHelper static const int sc_sinceVer_BitwiseNot = 18; static const int sc_sinceVer_Pad = 18; static const int sc_sinceVer_Split = 18; + static const int sc_sinceVer_LpPool = 18; + } + + namespace OnnxOperatorSet19 + { + static const int sc_sinceVer_AveragePool = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 10476ada2fa69..4b194ec18b31b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -777,11 +777,6 @@ TEST(PoolTest, GlobalMaxPool3D) { } TEST(PoolTest, AveragePool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -863,11 +858,6 @@ TEST(PoolTest, AveragePool) { } TEST(PoolTest, AveragePool_IncludePadPixel) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -911,11 +901,6 @@ TEST(PoolTest, AveragePool_DefaultStrides) { } TEST(PoolTest, AveragePool_10_ceil1_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 10); test.AddAttribute("auto_pad", ""); @@ -939,11 +924,6 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { } TEST(PoolTest, AveragePool_19_dilation_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 19); test.AddAttribute("auto_pad", ""); @@ -1070,11 +1050,6 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { } TEST(PoolTest, LpPool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("LpPool"); test.AddAttribute("auto_pad", ""); From 9ff5e3b7b0f5e9683422fa3609175fe4f82ccb77 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:34:35 -0700 Subject: [PATCH 03/25] Add QLinearConcat for DML EP (#16971) (#18268) ### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](https://github.com/microsoft/onnxruntime/pull/16971) ### Motivation and Context Co-authored-by: Xiang Zhang --- .../Operators/DmlOperatorQLinearConcat.cpp | 236 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 16 +- .../src/Operators/OperatorUtility.cpp | 4 +- .../src/Operators/OperatorUtility.h | 3 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 18 +- .../dml/OperatorAuthorHelper/OperatorHelper.h | 25 +- .../OperatorAuthorHelper/OperatorVersions.h | 1 + 7 files changed, 290 insertions(+), 13 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp new file mode 100644 index 0000000000000..67711fdc28b84 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +// QLinearConcat = Dequantize + Join + Quantize +class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper +{ + // This order matches the ONNX schema. + enum OnnxInputIndex + { + YScale, + YZeroPoint, + Count, + }; + +public: + DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + DmlOperator::Initialize(kernelCreationContext); + + auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); + + // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)} + uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount(); + ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs."); + ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!"); + + uint32_t inputCount = (inputDefinitionCount - 2) / 3; + + auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType; + auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType; + + // broadcast y_scale and y_zero_point to output shape + m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc( + yScaleDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc( + yZeroPointDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + // Validate input tensors + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + // Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2 + auto tupleStartIndex = 2 + inputIndex * 3; + auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType; + auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType; + ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale"); + ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point"); + + // broadcast x_scale and x_zero_point to shape of corresponding x + m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc( + xScaleDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc( + xZeroPointDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + } + + uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + // 1. output edges between Dequantize and Join node + // 2. input edge between Join and Quantize node + std::vector intermediateOutputTensorDescs(inputCount); + std::vector namedDequantizeOperatorDescs(inputCount); + std::vector dequantizeOperatorDescs(inputCount); + std::vector dmlOpDesc(inputCount); + std::vector opDescs; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + intermediateOutputTensorDescs[inputIndex] = TensorDesc( + MLOperatorTensorDataType::Float, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment) + ); + namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc(); + + dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex]; + dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; + dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; + dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; + opDescs.push_back(&dmlOpDesc[inputIndex]); + } + + TensorDesc joinOutputTensorDesc = TensorDesc( + MLOperatorTensorDataType::Float, + outputShape, + outputShape, + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc(); + + DML_JOIN_OPERATOR_DESC joinDesc = {}; + joinDesc.InputCount = gsl::narrow_cast(namedDequantizeOperatorDescs.size()); + joinDesc.InputTensors = namedDequantizeOperatorDescs.data(); + joinDesc.OutputTensor = &namedJoinOutputTensorDesc; + joinDesc.Axis = dmlAxis; + + const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc}; + opDescs.push_back(&opJoinDesc); + + DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {}; + quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor; + quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; + quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; + quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; + opDescs.push_back(&opQuantizeDesc); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = static_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; + uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; + + std::vector inputEdges; + // Input edges to Dequantize nodes + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + for (auto edge_index = 0; edge_index < 3; ++edge_index) + { + DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; + inputEdge.GraphInputIndex = tupleStartIndex + edge_index; + inputEdge.ToNodeIndex = inputIndex; + inputEdge.ToNodeInputIndex = edge_index; + inputEdges.push_back(inputEdge); + } + } + + // Input edge from y_scale to quantize node + DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {}; + yScaleInputEdge.GraphInputIndex = 0; // Y_scale + yScaleInputEdge.ToNodeIndex = quantizeNodeIndex; + yScaleInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(yScaleInputEdge); + + // Input edge from y_zero_point to quantize node + DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {}; + yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point + yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex; + yZeroPointInputEdge.ToNodeInputIndex = 2; + inputEdges.push_back(yZeroPointInputEdge); + + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + + // set intermediate edges + std::vector intermediateEdges; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {}; + dequantizeToJoinEdge.FromNodeIndex = inputIndex; + dequantizeToJoinEdge.FromNodeOutputIndex = 0; + dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join + dequantizeToJoinEdge.ToNodeInputIndex = inputIndex; + intermediateEdges.push_back(dequantizeToJoinEdge); + } + + DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {}; + joinToQuantizeEdge.FromNodeIndex = joinNodeIndex; + joinToQuantizeEdge.FromNodeOutputIndex = 0; + joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join + joinToQuantizeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(joinToQuantizeEdge); + + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + + // set the output edges + std::vector outputEdges; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.FromNodeIndex = quantizeNodeIndex; + outputEdge.FromNodeOutputIndex = 0; + outputEdge.GraphOutputIndex = 0; + outputEdges.push_back(outputEdge); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + }; +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index dbe9f5da4f569..fa2750a22425f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -496,6 +496,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND); DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv); DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); @@ -547,6 +548,7 @@ constexpr static std::array typeNameListEyeLike = { "T1", "T2" } constexpr static std::array typeNameShape = { "T", "T1" }; constexpr static std::array typeNameSize = { "T", "T1" }; constexpr static std::array typeNameListGroupNorm = {"T", "M"}; +constexpr static std::array typeNameListQLinearConcat= {"TF", "T8", "TV"}; constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; @@ -618,7 +620,18 @@ constexpr static std::array supportedTypeListQLinea constexpr static std::array supportedTypeListDynamicQuantizeLinear = { SupportedTensorDataTypes::Float32, - SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Ints8Bit +}; + +constexpr static std::array supportedTypeListDynamicQuantizeMatMul= { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::Ints8Bit, +}; + +constexpr static std::array supportedTypeListQLinearConcat= { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; template @@ -1012,6 +1025,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)}, {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index d8290bbdaee3e..2965fa32ce131 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -419,9 +419,9 @@ namespace Dml } // namespace FusionHelpers - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount) + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex) { - const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex); uint32_t onnxDimCount = gsl::narrow_cast(inputDimensions.size()); onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount); return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index f0fad6a05ffb0..8b2da6084242d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -64,8 +64,7 @@ namespace Dml } // namespace FusionHelpers // Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced. - // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case). - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount); + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0); uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 370f336ff5203..4d59964dcc664 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -1862,7 +1862,7 @@ namespace OperatorHelper return { std::move(outputShape) }; } - void ConcatHelper::Initialize( + void ConcatHelperBase::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions ) @@ -1872,13 +1872,13 @@ namespace OperatorHelper ML_CHECK_VALID_ARGUMENT(m_axis < static_cast(inputDimensions.size())); } - std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const { - auto outputShape = shapeInfo.GetInputTensorShape(0); + auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex); uint32_t inputCount = shapeInfo.GetInputCount(); - for (uint32_t i = 1; i < inputCount; ++i) + for (uint32_t i = firstInputIndex + step; i < inputCount; i += step) { auto inputShape = shapeInfo.GetInputTensorShape(i); for (size_t j = 0; j < outputShape.size(); ++j) @@ -1893,6 +1893,16 @@ namespace OperatorHelper return { EdgeShapes(outputShape) }; } + std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1); + } + + std::vector QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3); + } + void CropHelper::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index f7e545d9d99a9..55a01c59ee4b5 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -864,7 +864,7 @@ class RecurrentHelper int m_hiddenSize = 0; }; -class ConcatHelper +class ConcatHelperBase { public: void Initialize( @@ -875,17 +875,33 @@ class ConcatHelper // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template - ConcatHelper(const Info_t& info, const Shape_t& shape) + ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex) { - Initialize(info, shape.GetInputTensorShape(0)); + Initialize(info, shape.GetInputTensorShape(firstInputIndex)); } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const; protected: int m_axis; }; +class ConcatHelper: public ConcatHelperBase +{ +public: + template + ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + +class QLinearConcatHelper: public ConcatHelperBase +{ +public: + template + QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + class CropHelper { public: @@ -1512,6 +1528,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; using ShapeInferenceHelper_Split18 = VersionedOpsetHelper; using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; +using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper; using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change. diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 3eb35faeba82f..996ea1ddcb52c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -443,6 +443,7 @@ namespace OperatorHelper static const int sc_sinceVer_BiasAdd = 1; static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; + static const int sc_sinceVer_QLinearConcat = 1; static const int sc_sinceVer_RotaryEmbedding = 1; } // namespace MsftOperatorSet1 From cb7f28a16ab78601363c2e694679d2dada149dd1 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:43:49 -0700 Subject: [PATCH 04/25] Register Resize for INT8 and UINT8 (#18252) ### Description ### Motivation and Context Co-authored-by: Adrian Tsai --- .../DmlExecutionProvider/src/Operators/OperatorRegistration.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index fa2750a22425f..d7910a6c6849f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -589,7 +589,7 @@ constexpr static std::array supportedTypeListLogica constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32}; -constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; +constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; From dcfff10f57fbbdf81ea9ebcae008be712be42700 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:09:11 -0800 Subject: [PATCH 05/25] Enable QLinearAveragePooling DML EP (#17384) (#18240) [Cherry Pick Reviewed] DML EP Implementation for [QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) ``` Note: Google Test filter = *QLinear*Pool* [==========] Running 72 tests from 2 test suites. [----------] Global test environment set-up. [----------] 36 tests from QLinearGlobalAveragePool [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms) [----------] 36 tests from QLinearGlobalAveragePool (5968 ms total) [----------] 36 tests from QLinearPoolTest [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage [ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global [ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms) [----------] 36 tests from QLinearPoolTest (12914 ms total) [----------] Global test environment tear-down [==========] 72 tests from 2 test suites ran. (18885 ms total) [ PASSED ] 72 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Description ### Motivation and Context --- .../src/External/DirectMLHelpers/ApiTraits.h | 20 ++- .../External/DirectMLHelpers/DirectMLSchema.h | 25 +++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 26 +++ .../DmlOperatorQLinearAveragePooling.cpp | 150 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 8 + .../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++++ .../dml/DmlExecutionProvider/src/TensorDesc.h | 3 + .../dml/OperatorAuthorHelper/Attributes.h | 2 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++++- .../dml/OperatorAuthorHelper/OperatorHelper.h | 28 +++- .../OperatorAuthorHelper/OperatorVersions.h | 2 + .../qlinear_global_average_pool_test.cc | 3 + 12 files changed, 339 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 94f2220fcc168..a5415ba85f3d3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,7 +24,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 160; + static constexpr auto ValueCount = 161; static constexpr size_t ActivationFunctionCount = 24; }; @@ -495,6 +495,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + template <> struct OperatorDescTraits { @@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 9eae1c1fe8158..2a82c12872a72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index ad4cceb85cfd2..99218c135f058 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -502,6 +502,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return AbstractOperatorDesc( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp new file mode 100644 index 0000000000000..0fccedfe311c1 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase +{ + // For QLinear Avg Pool ORT and DML have same indexing order + enum OrtInputTensors : uint32_t + { + ortInput, + ortInputScale, + ortInputZeroPoint, + ortOutputScale, + ortOutputZeroPoint, + ortInputCount + }; + +public: + using Self = DmlOperatorQLinearAveragePooling; + + DmlOperatorQLinearAveragePooling( + const MLOperatorKernelCreationContext& kernelInfo, + bool useGlobalPooling + ) + : DmlOperator(kernelInfo), + PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling) + { + DmlOperator::Initialize(kernelInfo); + + bool isNhwc = m_kernel.channelsLast; + std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount(); + ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2); + + // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling + uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; + if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) + { + size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount; + + for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--) + { + m_kernel.windowSize[i + shift] = m_kernel.windowSize[i]; + m_kernel.windowSize[i] = 1; + + m_kernel.strides[i + shift] = m_kernel.strides[i]; + m_kernel.strides[i] = 1; + + m_kernel.startPadding[i + shift] = m_kernel.startPadding[i]; + m_kernel.startPadding[i] = 0; + + m_kernel.endPadding[i + shift] = m_kernel.endPadding[i]; + m_kernel.endPadding[i] = 0; + + m_kernel.dilations[i + shift] = m_kernel.dilations[i]; + m_kernel.dilations[i] = 1; + } + + m_kernel.spatialDimensionCount = expectedSpatialDimCount; + } + + // Initialize dimensionMapping for NCHW or NHWC layout + std::vector dimensionMapping = {0u, dmlDimSize - 1u}; + dimensionMapping.resize(dmlDimSize); + if (isNhwc) + { + // Form a remapping for dimensions so C is moved before the spatial dimensions. + // e.g. NWC -> {0,2,1} -> NCW + // NHWC -> {0,3,1,2} -> NCHW + // NDHWC -> {0,4,1,2,3} -> NCDHW + std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u); + } + else + { + // Use NCHW {0,1,2,3} format with increasing order of indexs + std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u); + } + m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Reshape the Output Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {}; + + qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; + qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; + qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; + qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; + qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; + qLinearAvgPooldesc.Strides = m_kernel.strides; + qLinearAvgPooldesc.StartPadding = m_kernel.startPadding; + qLinearAvgPooldesc.EndPadding = m_kernel.endPadding; + qLinearAvgPooldesc.Dilations = m_kernel.dilations; + qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +template +class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling +{ +public: + DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling) + { + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index d7910a6c6849f..0234bb6b7ec1e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool); DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); @@ -634,6 +636,10 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; +constexpr static std::array supportedTypeListQLinearAveragePool = { + SupportedTensorDataTypes::Ints8Bit +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 067a320dd8000..a2183aab52eed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm } m_bufferTensorDesc.DimensionCount = newDimensionCount; } + +// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout +void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment) +{ + EnsureStridesExist(); + SetDimensionCount(static_cast(dimensionMapping.size()), alignment); + + // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping + std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount}; + std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount}; + + for (size_t i = 0; i < dimensionMapping.size(); i++) + { + m_sizes[i] = tempSizes[dimensionMapping[i]]; + m_strides[i] = tempStrides[dimensionMapping[i]]; + } + + m_bufferTensorDesc.Sizes = m_sizes; + m_bufferTensorDesc.Strides = m_strides; +} + +void TensorDesc::EnsureStridesExist() +{ + if (m_bufferTensorDesc.Strides != nullptr) + { + // Strides are populated + return; + } + + uint32_t stride = 1; + for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;) + { + m_strides[i] = stride; + stride *= m_sizes[i]; + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index ff70dec5b8871..909e2084d0163 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -44,6 +44,7 @@ namespace Dml gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } gsl::span GetStrides() const; void SetStrides(gsl::span strides); + void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment); inline uint64_t GetBufferSizeInBytes() const { @@ -90,6 +91,8 @@ namespace Dml uint32_t m_sizes[MaximumDimensionCount] = {}; uint32_t m_strides[MaximumDimensionCount] = {}; DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; + + void EnsureStridesExist(); }; class TensorDescBuilder diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e9591cfce6870..85333aa77b686 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -23,8 +23,8 @@ namespace AttrName static constexpr const char* BlockSize = "blocksize"; static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; - static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* CeilMode = "ceil_mode"; + static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* Clip = "clip"; static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode"; static constexpr const char* CountIncludePad = "count_include_pad"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 4d59964dcc664..1fcd3b04300f4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -365,13 +365,20 @@ namespace OperatorHelper } // Creates a kernel that spans the entire spatial dimensions of the input. - KernelArgs InitializeGlobalKernel(gsl::span inputDimensions) + KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions) { ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor) uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount; ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor). KernelArgs args(spatialDimensionCount); + args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); + // For Global Pooling, kernel size equal to the spatial dimension of input tensor + // NHWC layout need to offset by one dim to acount for channel placed at the end + int dimOffset = args.channelsLast ? 1 : 0; for (size_t dim = 0; dim < spatialDimensionCount; ++dim) { @@ -379,7 +386,7 @@ namespace OperatorHelper args.dilations[dim] = 1; args.startPadding[dim] = 0; args.endPadding[dim] = 0; - args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]); + args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]); } return args; @@ -495,6 +502,7 @@ namespace OperatorHelper } args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); return args; } @@ -2012,7 +2020,37 @@ namespace OperatorHelper } return outputShapes; } + + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + + std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 55a01c59ee4b5..d8d09efd8d6e8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -160,6 +160,7 @@ struct KernelArgs bool autoPad = false; bool autoPadSameUpper = false; bool useCeilingOutputShape = false; + bool channelsLast = false; uint32_t spatialDimensionCount = 0; KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) @@ -188,6 +189,7 @@ struct KernelArgs KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), autoPadSameUpper(kernelArgs.autoPadSameUpper), + channelsLast(kernelArgs.channelsLast), spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); @@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose( gsl::span inputDimensions, const KernelArgs& args); -KernelArgs InitializeGlobalKernel(gsl::span inputDimensions); +KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions); KernelArgs InitializeKernel( const MLOperatorAttributes& kernelInfo, @@ -1059,7 +1063,7 @@ class PoolingHelperBase bool useGlobalPooling ) : m_kernel(useGlobalPooling - ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) + ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0)) : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { if (!useGlobalPooling) @@ -1161,6 +1165,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; +class QLinearAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + +class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + class SqueezeHelper { public: @@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper; using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; +using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper; +using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 996ea1ddcb52c..e9d88adf3e221 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -445,6 +445,8 @@ namespace OperatorHelper static const int sc_sinceVer_GroupNorm = 1; static const int sc_sinceVer_QLinearConcat = 1; static const int sc_sinceVer_RotaryEmbedding = 1; + static const int sc_sinceVer_QLinearAveragePool = 1; + static const int sc_sinceVer_QLinearGlobalAveragePool = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc index 8fb245819fd26..71b6f27b5391f 100644 --- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc @@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool( test.AddInput("y_scale", {}, {y_scale}); test.AddInput("y_zero_point", {}, {y_zero_point}); test.AddOutput("Y", y_dims, y_data); + if (channels_last) { + test.AddAttribute("channels_last", (int64_t)1LL); + } auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) { const OrtValue& ort_value = fetches[0]; From d5f3aae3fd30a79d9bbeed26a70907618ac84835 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:43:09 -0800 Subject: [PATCH 06/25] Utilize DML constant input graph node (#18267) ### Description This PR also includes, 8b0a55e7cc DML constant pow operator 7520974970 Enable custom heaps based on query- ### Motivation and Context --------- Co-authored-by: Jeff Bloomfield --- .../src/DmlGraphFusionHelper.cpp | 27 ++++- .../src/ExecutionProvider.cpp | 31 ++++++ .../src/ExecutionProvider.h | 2 + .../src/GraphDescBuilder.cpp | 104 ++++++++++++++---- .../src/GraphDescBuilder.h | 5 +- .../src/IExecutionProvider.h | 1 + .../src/MLOperatorAuthorImpl.cpp | 29 ++++- .../src/MLOperatorAuthorImpl.h | 7 ++ .../src/Operators/DmlOperatorElementWise.cpp | 38 +++++-- .../MLOperatorAuthorHelper.h | 13 +++ .../MLOperatorAuthorPrivate.h | 10 ++ 11 files changed, 229 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 4f7ec188140b5..18cdc5d1bf86e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper { ComPtr initializeInputBuffer; - // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps - if (providerImpl->IsMcdmDevice()) + if (!providerImpl->CustomHeapsSupported()) { initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize); } @@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper const uint32_t inputCount, const uint32_t outputCount, _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlConstantGraphNodes, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + + if (std::holds_alternative>(nodeInfo.nodeDef)) + { + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + else + { + auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); + dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ + nodeDefinitionData.data(), + nodeDefinitionData.size(), + nodeInfo.name.data() + }; + + // TODO: Change as new header is ingested + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + } } for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) @@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); + std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + std::vector dmlGraphNodes(graphDesc.nodes.size()); std::vector dmlInputEdges(graphDesc.inputEdges.size()); std::vector dmlOutputEdges(graphDesc.outputEdges.size()); @@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper fusedNodeInputCount, fusedNodeOutputCount, dmlOperatorGraphNodes, + dmlConstantGraphNodes, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 8644b8d56a426..49a64c4810252 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -182,6 +182,32 @@ namespace Dml } m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE); + m_areCustomHeapsSupported = !m_isMcdmDevice; + + if (m_isMcdmDevice) + { + + // TODO: Ingest updated header file + typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19 + { + BOOL MismatchingOutputDimensionsSupported; + UINT SupportedSampleCountsWithNoOutputs; + BOOL PointSamplingAddressesNeverRoundUp; + BOOL RasterizerDesc2Supported; + BOOL NarrowQuadrilateralLinesSupported; + BOOL AnisoFilterWithPointMipSupported; + UINT MaxSamplerDescriptorHeapSize; + UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers; + UINT MaxViewDescriptorHeapSize; + _Out_ BOOL ComputeOnlyCustomHeapSupported; + } D3D12_FEATURE_DATA_D3D12_OPTIONS19; + + D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; + + // The call may fail in which case the default value is false + d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; + } m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue); @@ -1089,6 +1115,11 @@ namespace Dml return m_isMcdmDevice; } + bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept + { + return m_areCustomHeapsSupported; + } + bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept { return m_areMetacommandsEnabled; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 3aaa11cdee479..ab932fb8a4367 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -150,6 +150,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; bool DynamicGraphFusionEnabled() const noexcept; @@ -186,6 +187,7 @@ namespace Dml ComPtr m_d3d12Device; ComPtr m_dmlDevice; bool m_isMcdmDevice = false; + bool m_areCustomHeapsSupported = false; bool m_areMetacommandsEnabled = true; bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 3fc8f415e5a58..ba022533a1e94 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, @@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (subgraphNodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) { reuseCommandList = true; } @@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder { ComPtr tensor = nullptr; - // Check whether this specific node requested support for constant CPU inputs - if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end()) + auto inputDefs = node.InputDefs(); + + if (inputIndex < inputDefs.size()) { - auto inputDefs = node.InputDefs(); - if (inputIndex < inputDefs.size()) + const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; + tensor = constantCpuGraphInputGetter(arg->Name()); + + if (tensor == nullptr) { - const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; - tensor = constantCpuGraphInputGetter(arg->Name()); + bool inputRequiredAsConstant = std::find( + requiredConstantCpuInputs.begin(), + requiredConstantCpuInputs.end(), + inputIndex) != requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } } @@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; + size_t firstOpDescGraphNodeIndex = graphNodes.size(); if (isNodeAsOpDesc) { @@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); } + + graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); } else { @@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); + nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); graphNodes.push_back(std::move(nodeInfo)); } } @@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder const uint32_t dmlFusedNodeInputIndex = iter->second; - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex - graphInputEdges.push_back(edge); - // If this is a constant input, set the appropriate flags on the desc if (isNodeAsOpDesc && dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + + // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming + // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point + // values that have been de-duplicated with conversion to scalars in kernels. + uint32_t c_maxConstNodeDataSize = 1024 * 1024; + + ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); + + if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + { + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + constantInput->GetTensorByteSize()); + + NodeInfo nodeInfo = {}; + nodeInfo.nodeDef = std::move(tensorData); + graphNodes.push_back(std::move(nodeInfo)); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + edge.FromNodeIndex = static_cast(graphNodes.size() - 1); + edge.FromNodeOutputIndex = 0; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphIntermediateEdges.push_back(edge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); + + auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + } + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); } } else @@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder if (isNodeAsOpDesc) { - for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc) + for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) { + auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + dmlDesc.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + dmlDesc.Type = (DML_OPERATOR_TYPE) 170; + ComPtr op; ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); allocator.Reset(); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(op); + nodeInfo.nodeDef = std::move(op); nodeInfo.name = node.Name(); - graphNodes.push_back(std::move(nodeInfo)); + graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 0039678c00e59..c95e89b45541b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -4,6 +4,7 @@ #pragma once #include "MLOperatorAuthorImpl.h" +#include "ExecutionProvider.h" namespace Dml { @@ -27,7 +28,7 @@ namespace Dml struct NodeInfo { - Microsoft::WRL::ComPtr op; + std::variant, std::vector> nodeDef; std::string name; }; @@ -47,7 +48,7 @@ namespace Dml const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index a8a6d6745e908..17fd7c18ba4a1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -76,6 +76,7 @@ namespace Dml STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0; STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 4deec620fe5fb..dbd06abf82f72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept + { + ORT_TRY + { + auto constantInput = m_constantInputGetter(inputIndex); + ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative>(constantInput)); + + auto tensorWrapper = std::get>(constantInput); + if (tensorWrapper == nullptr) + { + bool inputRequiredAsConstant = std::find( + m_requiredConstantCpuInputs.begin(), + m_requiredConstantCpuInputs.end(), + inputIndex) != m_requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); + } + + *tensor = tensorWrapper.Detach(); + + return S_OK; + } + ORT_CATCH_RETURN + } + HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept { ORT_TRY diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 913997ff4ad49..6530d89d895e7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -204,6 +204,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable _Outptr_ IMLOperatorTensor** tensor ) const noexcept; + HRESULT STDMETHODCALLTYPE TryGetConstantInputTensor( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept; + protected: // Lifetime is managed by the caller and guaranteed to outlive this class const onnxruntime::OpNodeProtoHelper* m_impl = nullptr; @@ -299,6 +304,8 @@ class OnnxTensorWrapper : public WRL::Base, public Closable const onnxruntime::Tensor* GetInterface() const { return nullptr; } onnxruntime::Tensor* GetInterface() { return nullptr; } + size_t GetTensorByteSize() const { return m_tensorByteSize; } + private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 43d34657098ef..f0a16da3a3c06 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,17 +479,37 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) + { + std::vector> kernelInputIndices = {0}; - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); - DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; - opDesc.InputTensor = &inputDescs[0]; - opDesc.ExponentTensor = &inputDescs[1]; - opDesc.OutputTensor = &outputDescs[0]; + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.OutputTensor = &outputDescs[0]; + opDesc.Exponent = static_cast(ReadScalarTensorCastToFloat64(*constExpTensor)); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); + } + else + { + Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.ExponentTensor = &inputDescs[1]; + opDesc.OutputTensor = &outputDescs[0]; + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + } } }; @@ -565,7 +585,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator opDesc.ScaleTensor = &inputDescs[1]; opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; - + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index f94270cfadb8b..59a1719d08ee6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -6,6 +6,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" #include "core/common/gsl.h" +#include #ifdef ORT_NO_EXCEPTIONS #define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x)) @@ -604,6 +605,18 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } + std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr tensor; + ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); + if (tensor) + { + return MLOperatorTensor(tensor.Get()); + } + + return std::nullopt; + } + uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const { auto shapeDesc = GetTensorShapeDescription(); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index d1a705e151ddf..3bec8d3864cba 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -41,6 +41,11 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + //! Gets the number of dimensions of a tensor output of the operator. STDMETHOD(GetSequenceInputInfo)( uint32_t inputIndex, @@ -73,6 +78,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE; STDMETHOD(SetDmlOperator)( From 531e875fb550f7a866cd13639fb2e332440e6b96 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:43:47 -0800 Subject: [PATCH 07/25] Avoid command list reset in common case of re-used command list execution (#18370) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/DmlCommandRecorder.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 530c26d212083..98345f37b68d4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -251,6 +251,24 @@ void DmlCommandRecorder::ExecuteCommandList( _Out_ uint64_t* completionValue ) { + if (!m_operationsRecordedInCurrentCommandList) + { + // The caller can re-use relevant resources after the next set of work to be + // flushed has completed. Its command list hasn't been executed yet, just batched. + GpuEvent gpuEvent = m_queue->GetNextCompletionEvent(); + gpuEvent.fence.CopyTo(fence); + *completionValue = gpuEvent.fenceValue; + + m_queue->ExecuteCommandLists( + gsl::span(reinterpret_cast(&commandList), 1)); + + // Fail early if something horrifying happens + ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); + ORT_THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason()); + + return; + } + ORT_THROW_IF_FAILED(m_currentCommandList->Close()); if (m_operationsRecordedInCurrentCommandList) From a1000a0a3c2552b78045bb5452aea09ae3490e6a Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:45:16 -0800 Subject: [PATCH 08/25] Enable GEMM activation fusions on MCDM (#18372) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/Operators/OperatorUtility.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index 2965fa32ce131..fb86648d595b3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -154,13 +154,13 @@ namespace Dml OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MeanVarianceNormalization }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, From 5c283340c384d5f9e402ba762c3c33af14e6763e Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 18:03:55 -0800 Subject: [PATCH 09/25] Filter activation fusions on MCDM (#18371) ### Description ### Motivation and Context --------- Co-authored-by: Jeff Bloomfield --- .../src/GraphTransformer.cpp | 15 ++++++-- .../src/GraphTransformer.h | 12 +++--- .../src/Operators/OperatorUtility.cpp | 37 ++++++++++++++----- .../src/Operators/OperatorUtility.h | 3 +- onnxruntime/core/session/inference_session.cc | 3 +- 5 files changed, 50 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 09922310b56c1..2e04da843696e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -17,6 +17,14 @@ namespace Dml { + GraphTransformer::GraphTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) : onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + onnxruntime::common::Status GraphTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -27,7 +35,7 @@ namespace Dml // Perform fusion { bool transformModifiedGraph = false; - PerformOperatorFusion(&graph, &transformModifiedGraph); + PerformOperatorFusion(&graph, m_providerImpl->IsMcdmDevice(), &transformModifiedGraph); modified |= transformModifiedGraph; if (modified) @@ -50,7 +58,7 @@ namespace Dml return ss.str(); } - void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const + void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const { struct NodeToAdd { @@ -112,7 +120,8 @@ namespace Dml gsl::narrow_cast(node.InputDefs().size()), outputNode.OpType(), outputNode.Domain(), - outputNode.Op()->SinceVersion()); + outputNode.Op()->SinceVersion(), + isMcdmDevice); if (!fusedOpProperties) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h index a7f8186fb3b64..337c0df7ff521 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h @@ -10,6 +10,7 @@ namespace Dml { + class ExecutionProviderImpl; // Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider // on the graph nodes which DML supports. @@ -17,16 +18,17 @@ namespace Dml { public: GraphTransformer( - const std::string& name - ) : onnxruntime::GraphTransformer(name) - { - } + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); private: onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final; private: - void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const; + void PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const; + + const ExecutionProviderImpl* m_providerImpl = nullptr; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index fb86648d595b3..46442fe942539 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -132,6 +132,8 @@ namespace Dml std::string_view domain; int sinceVersion; std::vector activationFilter; + bool enableOnMcdm; + std::vector extraMcdmActivationFilter; std::optional inputCountFilter; }; @@ -142,10 +144,10 @@ namespace Dml static const OperatorInfo c_fusableOps[] = { - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv }, - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_BatchNormalization }, @@ -163,11 +165,11 @@ namespace Dml OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, }; // Not all activations can be fused - only simple elementwise activations (i.e. activation functions which @@ -205,7 +207,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion) + int activationOpSinceVersion, + bool isMcdmDevice) { auto opIt = std::find( std::begin(c_fusableOps), @@ -233,6 +236,20 @@ namespace Dml return std::nullopt; } + if (isMcdmDevice) + { + if (!opIt->enableOnMcdm) + { + return std::nullopt; + } + + if (!opIt->extraMcdmActivationFilter.empty() && + std::find(opIt->extraMcdmActivationFilter.begin(), opIt->extraMcdmActivationFilter.end(), activationOpType) == opIt->extraMcdmActivationFilter.end()) + { + return std::nullopt; + } + } + if (opIt->inputCountFilter && *opIt->inputCountFilter != static_cast(candidateOpInputCount)) { return std::nullopt; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index 8b2da6084242d..d3483cb5e8de2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -40,7 +40,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion); + int activationOpSinceVersion, + bool isMcdmDevice); // Returns true if the given activation operator type supports being fused with a fusable operator, false // otherwise. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 575529a06fb7a..cef160489ac46 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1696,7 +1696,8 @@ common::Status InferenceSession::Initialize() { // This transformer applies DML-specific fusions that go beyond what ORT offers by default bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2; if (dml_operator_fusion_enabled) { - std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer"); + std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer", + execution_providers_.Get(kDmlExecutionProvider)); if (dmlOperatorFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr"); } From 613fdce12e528d6921d0c2c7674e0fbbe2c18e6a Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:16:47 -0800 Subject: [PATCH 10/25] Create ring buffer for re-used command lists (#18368) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/FusedGraphKernel.cpp | 118 ++++++++++-------- 1 file changed, 68 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 67c3f110e5a50..430ccec3cda10 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -13,6 +13,26 @@ namespace Dml { class FusedGraphKernel : public onnxruntime::OpKernel { + private: + struct ReusedCommandListState + { + // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. + ComPtr graphicsCommandList; + ComPtr commandAllocator; + ComPtr heap; + ComPtr bindingTable; + + // Bindings from previous executions of a re-used command list + mutable std::vector inputBindingAllocIds; + mutable std::vector outputBindingAllocIds; + mutable uint64_t tempBindingAllocId = 0; + + // Fence tracking the status of the command list's last execution, and whether its descriptor heap + // can safely be updated. + mutable ComPtr fence; + mutable uint64_t completionValue = 0; + }; + public: FusedGraphKernel() = delete; @@ -89,7 +109,7 @@ namespace Dml if (reuseCommandList) { - BuildReusableCommandList(); + m_reusedCommandLists.push_back(BuildReusableCommandList()); } } @@ -97,8 +117,7 @@ namespace Dml { // Only re-use the cached command list if its prior execution is complete on the GPU. // This requirement can be avoided by mantaining ring buffers. - if (!m_graphicsCommandList || - (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) + if (m_reusedCommandLists.empty()) { // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator OpKernelContextWrapper contextWrapper( @@ -147,7 +166,15 @@ namespace Dml } else { - ExecuteReusableCommandList(kernelContext); + if (m_reusedCommandLists.front()->fence && + m_reusedCommandLists.front()->fence->GetCompletedValue() < m_reusedCommandLists.front()->completionValue) + { + m_reusedCommandLists.push_front(BuildReusableCommandList()); + } + + ExecuteReusableCommandList(kernelContext, *m_reusedCommandLists.front()); + m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); + m_reusedCommandLists.pop_front(); } return onnxruntime::Status::OK(); @@ -217,8 +244,10 @@ namespace Dml } private: - void BuildReusableCommandList() + std::unique_ptr BuildReusableCommandList() const { + auto commandListState = std::make_unique(); + ComPtr device; ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf())); @@ -232,47 +261,49 @@ namespace Dml ComPtr d3dDevice; ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); - ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf()))); + ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(commandListState->heap.ReleaseAndGetAddressOf()))); // Create a binding table for execution. DML_BINDING_TABLE_DESC bindingTableDesc = {}; bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get(); - bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart(); - bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.CPUDescriptorHandle = commandListState->heap->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = commandListState->heap->GetGPUDescriptorHandleForHeapStart(); bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount; - ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable))); + ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&commandListState->bindingTable))); ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator( m_provider->GetCommandListTypeForQueue(), - IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf()))); + IID_GRAPHICS_PPV_ARGS(commandListState->commandAllocator.ReleaseAndGetAddressOf()))); ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList( 0, m_provider->GetCommandListTypeForQueue(), - m_commandAllocator.Get(), + commandListState->commandAllocator.Get(), nullptr, - IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); + IID_GRAPHICS_PPV_ARGS(commandListState->graphicsCommandList.ReleaseAndGetAddressOf()))); if (m_persistentResource) { DML_BINDING_DESC persistentResourceBindingDesc = { DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr }; - m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc); + commandListState->bindingTable->BindPersistentResource(&persistentResourceBindingDesc); } - ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() }; - m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); + ID3D12DescriptorHeap* descriptorHeaps[] = { commandListState->heap.Get() }; + commandListState->graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); ComPtr recorder; ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); - recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get()); + recorder->RecordDispatch(commandListState->graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), commandListState->bindingTable.Get()); + + ORT_THROW_IF_FAILED(commandListState->graphicsCommandList->Close()); - ORT_THROW_IF_FAILED(m_graphicsCommandList->Close()); + return commandListState; } - void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const + void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext, ReusedCommandListState& commandListState) const { DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); @@ -287,7 +318,7 @@ namespace Dml // Populate input bindings, excluding those which were specified as owned by DML and provided // at initialization instead. - m_inputBindingAllocIds.resize(inputBindings.size()); + commandListState.inputBindingAllocIds.resize(inputBindings.size()); bool inputBindingsChanged = false; for (uint32_t i = 0; i < inputBindings.size(); ++i) @@ -307,25 +338,25 @@ namespace Dml uint64_t allocId; DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); - inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); + inputBindingsChanged = inputBindingsChanged || (!allocId || commandListState.inputBindingAllocIds[i] != allocId); inputBindings[i].Buffer->Release(); // Avoid holding an additional reference inputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; - m_inputBindingAllocIds[i] = allocId; + commandListState.inputBindingAllocIds[i] = allocId; } } } if (inputBindingsChanged) { - m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); + commandListState.bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); } // Populate Output bindings std::vector outputBindings(kernelContext->OutputCount()); std::vector outputBindingDescs(kernelContext->OutputCount()); - m_outputBindingAllocIds.resize(outputBindings.size()); + commandListState.outputBindingAllocIds.resize(outputBindings.size()); bool outputBindingsChanged = false; for (uint32_t i = 0; i < outputBindings.size(); ++i) @@ -344,16 +375,16 @@ namespace Dml uint64_t allocId; DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); - outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); + outputBindingsChanged = outputBindingsChanged || (!allocId || commandListState.outputBindingAllocIds[i] != allocId); outputBindings[i].Buffer->Release(); // Avoid holding an additional reference outputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; - m_outputBindingAllocIds[i] = allocId; + commandListState.outputBindingAllocIds[i] = allocId; } if (outputBindingsChanged) { - m_bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); + commandListState.bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); } if (execBindingProps.TemporaryResourceSize > 0) @@ -373,19 +404,19 @@ namespace Dml DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; - if (!tempAllocId || m_tempBindingAllocId != tempAllocId) + if (!tempAllocId || commandListState.tempBindingAllocId != tempAllocId) { - m_bindingTable->BindTemporaryResource(&tempBindingDesc); + commandListState.bindingTable->BindTemporaryResource(&tempBindingDesc); } - m_tempBindingAllocId = tempAllocId; + commandListState.tempBindingAllocId = tempAllocId; } // Execute the command list and if it succeeds, update the fence value at which this command may be // re-used. ComPtr fence; uint64_t completionValue; - HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); + HRESULT hr = m_provider->ExecuteCommandList(commandListState.graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); if (hr == DXGI_ERROR_DEVICE_REMOVED) { @@ -395,13 +426,13 @@ namespace Dml } ORT_THROW_IF_FAILED(hr); - m_fence = fence; - m_completionValue = completionValue; + commandListState.fence = fence; + commandListState.completionValue = completionValue; // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get()); - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get()); - m_winmlProvider->QueueReference(m_bindingTable.Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.graphicsCommandList).Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.heap).Get()); + m_winmlProvider->QueueReference(commandListState.bindingTable.Get()); m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); } @@ -412,25 +443,12 @@ namespace Dml ComPtr m_provider; Windows::AI::MachineLearning::Adapter::EdgeShapes& m_outputShapes; - // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. - ComPtr m_graphicsCommandList; - ComPtr m_commandAllocator; - ComPtr m_heap; - ComPtr m_bindingTable; + mutable std::deque> m_reusedCommandLists; + std::optional m_persistentResourceBinding; ComPtr m_persistentResource; ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator - // Bindings from previous executions of a re-used command list - mutable std::vector m_inputBindingAllocIds; - mutable std::vector m_outputBindingAllocIds; - mutable uint64_t m_tempBindingAllocId = 0; - - // Fence tracking the status of the command list's last execution, and whether its descriptor heap - // can safely be updated. - mutable ComPtr m_fence; - mutable uint64_t m_completionValue = 0; - std::vector m_isInputsUploadedByDmlEP; std::vector> m_nonOwnedGraphInputsFromInitializers; }; From 7f9e6c42c2e7800c8a9d38cd12e0eac881c81dd5 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:12:47 -0800 Subject: [PATCH 11/25] readd npu enumeration (#18437) (#18518) [Cherry pick Reviewed] Re-add changes which were merged out... --------- ### Description ### Motivation and Context Co-authored-by: Sheil Kumar Co-authored-by: Sheil Kumar --- .../providers/dml/dml_provider_factory.cc | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 33f1f59e07f3f..cd4eb20c856c0 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -475,12 +475,39 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } +static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) { + D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {}; + + D3D_FEATURE_LEVEL feature_levels_list[] = { + D3D_FEATURE_LEVEL_1_0_CORE, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_12_0, + D3D_FEATURE_LEVEL_12_1 + }; + + feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list); + feature_levels.pFeatureLevelsRequested = feature_levels_list; + ORT_THROW_IF_FAILED(d3d12_device->CheckFeatureSupport( + D3D12_FEATURE_FEATURE_LEVELS, + &feature_levels, + sizeof(feature_levels) + )); + + auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE); + if (is_feature_level_1_0_core) { + return D3D12_COMMAND_LIST_TYPE_COMPUTE; + } + + return D3D12_COMMAND_LIST_TYPE_DIRECT; +} + std::shared_ptr CreateDMLDeviceAndProviderFactory( - ID3D12Device* d3d12_device, - bool disable_metacommands, - bool enable_dynamic_graph_fusion) { + ID3D12Device* d3d12_device, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Type = CalculateCommandListType(d3d12_device); cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; ComPtr cmd_queue; @@ -500,16 +527,20 @@ std::shared_ptr DMLProviderFactoryCreator::Create( } std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( - std::vector>&& dxcore_devices, + std::vector>&& adapters, bool disable_metacommands, bool enable_dynamic_graph_fusion) { // Choose the first device from the list since it's the highest priority - auto dxcore_device = dxcore_devices[0]; + auto adapter = adapters[0]; + + auto feature_level = D3D_FEATURE_LEVEL_11_0; + if (IsNPU(adapter.Get())) { + feature_level = D3D_FEATURE_LEVEL_1_0_CORE; + } // Create D3D12 Device from DXCore Adapter ComPtr d3d12_device; - ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); - + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), feature_level, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } From e8209ce2b0f966739dfbb6ed4486ce78ca3f2e39 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 22 Nov 2023 14:39:36 -0800 Subject: [PATCH 12/25] CP 7fd1ce95a4e4f3c2b6152dfc2b1807a983ef45e5 (#18560) CP 7fd1ce95a4e4f3c2b6152dfc2b1807a983ef45e5 for onnxruntime_perf_test changes. Co-authored-by: Sheil Kumar --- .../DmlExecutionProvider/src/CommandQueue.cpp | 15 ++- .../src/ExecutionContext.cpp | 18 ++-- .../src/ExecutionContext.h | 10 +- .../src/ExecutionProvider.cpp | 2 +- .../src/ExecutionProvider.h | 2 +- .../providers/dml/dml_provider_factory.cc | 5 +- .../test/perftest/command_args_parser.cc | 4 + onnxruntime/test/perftest/ort_test_session.cc | 99 +++++++++++++++---- onnxruntime/test/perftest/ort_test_session.h | 1 + onnxruntime/test/util/default_providers.cc | 3 +- 10 files changed, 113 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp index 5516fc62cdda0..2b4f3bb96537b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp @@ -46,12 +46,12 @@ namespace Dml return GpuEvent{ m_lastFenceValue + 1, m_fence }; } - void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork) + void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork) { - // If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK - // to queue additional references at this time, since those references would be leaked. This - // affects any objects in m_queuedReferences whose destructors indirectly call QueueReference; - // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference + // If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK + // to queue additional references at this time, since those references would be leaked. This + // affects any objects in m_queuedReferences whose destructors indirectly call QueueReference; + // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference // to its underlying D3D resource when freed. Furthermore, these references are unnecessary // since Close() already blocks for scheduled GPU work before clearing m_queuedReferences. if (!m_closing) @@ -68,7 +68,7 @@ namespace Dml m_queuedReferences.push_back(queuedReference); } } - + void CommandQueue::Close() { // Wait for flushed work: @@ -79,7 +79,7 @@ namespace Dml m_queuedReferences.clear(); m_closing = false; } - + void CommandQueue::ReleaseCompletedReferences() { uint64_t completedValue = GetFence()->GetCompletedValue(); @@ -89,5 +89,4 @@ namespace Dml } } - } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index a894d0660d6ff..bc82c7ab1c44a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -15,7 +15,7 @@ namespace Dml : m_queue(std::make_shared(queue)) , m_dmlRecorder(d3d12Device, dmlDevice, m_queue) { - ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); + ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } void ExecutionContext::SetAllocator(std::weak_ptr allocator) @@ -78,14 +78,14 @@ namespace Dml ID3D12GraphicsCommandList* commandList, _Outptr_ ID3D12Fence** fence, _Out_ uint64_t* completionValue - ) + ) { assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue); } - + void ExecutionContext::InitializeOperator( IDMLCompiledOperator* op, const DML_BINDING_DESC& persistentResourceBinding, @@ -110,7 +110,7 @@ namespace Dml } void ExecutionContext::AddUAVBarrier() - { + { assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); @@ -173,9 +173,9 @@ namespace Dml m_currentRecorder = nullptr; SetCommandRecorder(&m_dmlRecorder); } - - void ExecutionContext::QueueReference(IUnknown* object) - { + + void ExecutionContext::QueueReference(IUnknown* object) + { assert(!m_closed); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence // value is the one to signal completion. @@ -186,14 +186,14 @@ namespace Dml void ExecutionContext::Close() { assert(!m_closed); - + // Discard unflushed work and clear queued references. This prevents the circular reference: // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel m_queue->Close(); m_currentRecorder = nullptr; m_closed = true; } - + GpuEvent ExecutionContext::GetCurrentCompletionEvent() { assert(!m_closed); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index b06f11a5efd0a..ac8d3ff875786 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -20,13 +20,13 @@ namespace Dml public: // Constructs an ExecutionContext that executes on the supplied queue. ExecutionContext( - ID3D12Device* d3d12Device, - IDMLDevice* dmlDevice, + ID3D12Device* d3d12Device, + IDMLDevice* dmlDevice, ID3D12CommandQueue* queue); void SetAllocator(std::weak_ptr allocator); - // Waits for flushed work, discards unflushed work, and discards associated references to + // Waits for flushed work, discards unflushed work, and discards associated references to // prevent circular references. Must be the last call on the object before destruction. void Close(); @@ -75,12 +75,12 @@ namespace Dml // Returns an event which will become signaled when everything submitted to the execution context thus far has // completed execution on the GPU, including work that has yet to be flushed to the queue. GpuEvent GetCurrentCompletionEvent(); - + // Adds a reference which will be released when queued GPU work is completed void QueueReference(IUnknown* object); // Release any accumulated references who corresponding GPU fence values have - // been reached. + // been reached. void ReleaseCompletedReferences(); D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 49a64c4810252..8a32d06534dda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -205,7 +205,7 @@ namespace Dml D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; // The call may fail in which case the default value is false - d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index ab932fb8a4367..5617bc7bdcac6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -150,7 +150,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; - STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; bool DynamicGraphFusionEnabled() const noexcept; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index cd4eb20c856c0..73a068f3e1de2 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -118,7 +118,6 @@ static bool IsGPU(IDXCoreAdapter* compute_adapter) { return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); } -#ifdef ENABLE_NPU_ADAPTER_ENUMERATION static bool IsNPU(IDXCoreAdapter* compute_adapter) { // Only considering hardware adapters if (!IsHardwareAdapter(compute_adapter)) { @@ -126,7 +125,6 @@ static bool IsNPU(IDXCoreAdapter* compute_adapter) { } return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); } -#endif enum class DeviceType { GPU, NPU, BadDevice }; @@ -327,7 +325,8 @@ static std::optional ParsePerformancePreference(con } static std::optional ParseFilter(const ProviderOptions& provider_options) { - static const std::string Filter = "filter"; + static const std::string Filter = "device_filter"; + static const std::string Any = "any"; static const std::string Gpu = "gpu"; #ifdef ENABLE_NPU_ADAPTER_ENUMERATION static const std::string Any = "any"; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 6e3252aaeb4b8..f8d6296d2d785 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -58,6 +58,10 @@ namespace perftest { "\t-q [CUDA only] use separate stream for copy. \n" "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" + "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" + "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" + "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 04c9ae1f23108..ac25c98b15758 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -16,6 +16,10 @@ #include "providers.h" #include "TestCase.h" +#ifdef USE_DML +#include "core/providers/dml/dml_provider_factory.h" +#endif + #ifdef _WIN32 #define strdup _strdup #endif @@ -42,8 +46,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const TestModelInfo& m) : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { Ort::SessionOptions session_options; - const std::string& provider_name = performance_test_config.machine_config.provider_type_name; - if (provider_name == onnxruntime::kDnnlExecutionProvider) { + provider_name_ = performance_test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options OrtDnnlProviderOptions dnnl_options; @@ -96,7 +100,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("DNNL is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kCudaExecutionProvider) { + } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options; @@ -161,7 +165,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("CUDA is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { + } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT const auto& api = Ort::GetApi(); OrtTensorRTProviderOptionsV2* tensorrt_options; @@ -215,7 +219,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("TensorRT is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kOpenVINOExecutionProvider) { + } else if (provider_name_ == onnxruntime::kOpenVINOExecutionProvider) { #ifdef USE_OPENVINO #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -251,7 +255,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } else { ORT_THROW( - "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " + "[ERROR] [OpenVINO] You have selected a wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); @@ -305,7 +309,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("OpenVINO is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kQnnExecutionProvider) { + } else if (provider_name_ == onnxruntime::kQnnExecutionProvider) { #ifdef USE_QNN #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -378,7 +382,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("QNN is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kSnpeExecutionProvider) { + } else if (provider_name_ == onnxruntime::kSnpeExecutionProvider) { #ifdef USE_SNPE #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -430,7 +434,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("SNPE is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kNnapiExecutionProvider) { + } else if (provider_name_ == onnxruntime::kNnapiExecutionProvider) { #ifdef USE_NNAPI uint32_t nnapi_flags = 0; #ifdef _MSC_VER @@ -458,22 +462,81 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("NNAPI is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kCoreMLExecutionProvider) { + } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef USE_COREML Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0)); #else ORT_THROW("COREML is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kDmlExecutionProvider) { + } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML std::unordered_map dml_options; dml_options["performance_preference"] = "high_performance"; dml_options["device_filter"] = "gpu"; + dml_options["disable_metacommands"] = "false"; + dml_options["enable_dynamic_graph_fusion"] = "false"; +#ifdef _MSC_VER + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif + std::istringstream ss(ov_string); + std::string token; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "device_filter") { + std::set ov_supported_device_types = {"gpu", "npu"}; + if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. " + "Select from 'gpu', or 'npu' \n"); + } + } else if (key == "performance_preference") { + std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " + "Select from 'default', 'high_performance' or 'minimal_power' \n"); + } + } else if (key == "disable_metacommands") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'disable_metacommands'. " + "Select from 'true' or 'false' \n"); + } + } else if (key == "enable_dynamic_graph_fusion") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. " + "Select from 'true' or 'false' \n"); + } + } + } session_options.AppendExecutionProvider("DML", dml_options); #else ORT_THROW("DML is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kAclExecutionProvider) { + } else if (provider_name_ == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL Ort::ThrowOnError( OrtSessionOptionsAppendExecutionProvider_ACL(session_options, @@ -481,14 +544,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("Acl is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kArmNNExecutionProvider) { + } else if (provider_name_ == onnxruntime::kArmNNExecutionProvider) { #ifdef USE_ARMNN Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_options, performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); #else ORT_THROW("ArmNN is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kRocmExecutionProvider) { + } else if (provider_name_ == onnxruntime::kRocmExecutionProvider) { #ifdef USE_ROCM OrtROCMProviderOptions rocm_options; rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo; @@ -498,7 +561,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("ROCM is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) { + } else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0)); OrtROCMProviderOptions rocm_options; @@ -508,7 +571,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("MIGraphX is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kXnnpackExecutionProvider) { + } else if (provider_name_ == onnxruntime::kXnnpackExecutionProvider) { #ifdef USE_XNNPACK session_options.AddConfigEntry(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0"); session_options.AppendExecutionProvider( @@ -516,7 +579,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("Xnnpack is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kVitisAIExecutionProvider) { + } else if (provider_name_ == onnxruntime::kVitisAIExecutionProvider) { #ifdef USE_VITISAI #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -544,7 +607,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("VitisAI is not supported in this build\n"); #endif - } else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) { + } else if (!provider_name_.empty() && provider_name_ != onnxruntime::kCpuExecutionProvider) { ORT_THROW("This backend is not included in perf test runner.\n"); } diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index 208e3de53b1d2..f1a4220ab325e 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -45,6 +45,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_; std::vector input_names_str_; const int input_length_; + std::string provider_name_; }; } // namespace perftest diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 4468a64d18258..a94f7b5b707c7 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -274,8 +274,9 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false)) + if (auto factory = DMLProviderFactoryCreator::CreateFromOptions(nullptr, false, false)) { return factory->CreateProvider(); + } #endif return nullptr; } From 623d957607b571d8dadbba6d57021546c38658a4 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Fri, 1 Dec 2023 10:27:20 -0800 Subject: [PATCH 13/25] register resize with uint8/int8 support (#18647) ### Description 1. Expand input datatype support for Resize with uint8/int8. 2. Update the logic to compute output shape of Resize Op, roiRange is got rid of to align with how tests compute the output shape to go around the size asserting in MLOperatorAuthorImpl.cpp `m_inputDimensions[i] * roiRange * scale` -> `m_inputDimensions[i] * scale` 3. disable 4 tests because of the result mismatch. The results of DML with float32 and uint8/int8 match each other, so it should be problem of resize implementation, which is out the scope of this PR. `ResizeOpTest.NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_uint8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_uint8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8` --- .../dml/OperatorAuthorHelper/OperatorHelper.cpp | 3 +-- .../test/providers/cpu/tensor/resize_op_test.cc | 13 +++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 1fcd3b04300f4..7c2320160dd3b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -2450,8 +2450,7 @@ namespace OperatorHelper { float scale = m_scales[i]; ML_CHECK_VALID_ARGUMENT(scale > FLT_EPSILON, "Scale values should be positive."); - float roiRange = m_regionOfInterest.empty() ? 1.0f : m_regionOfInterest[i + rank] - m_regionOfInterest[i]; - m_outputDimensions.push_back(gsl::narrow_cast(floor(m_inputDimensions[i] * roiRange * scale))); + m_outputDimensions.push_back(gsl::narrow_cast(floor(m_inputDimensions[i] * scale))); } } else diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 3ea7295aef5a2..f473c98ca713e 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -187,7 +187,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -214,7 +215,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e 0, 0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { @@ -530,7 +532,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -558,7 +561,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {0, 2, -9}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch + // TensorRT: results mismatch + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { From c1ec3c3f9369d6a01b727d4d1ef772c9fe6a5213 Mon Sep 17 00:00:00 2001 From: Christian Larson Date: Fri, 8 Dec 2023 17:08:29 -0800 Subject: [PATCH 14/25] User/chrila/fix dml dx12 warning (#18746) Update resource creation flag to avoid D3D12 WARNING ### Description Update the DML DX12 allocator to use D3D12_RESOUCE_STATE_COMMON to avoid DX12 Warning messages. ### Motivation and Context When directML is created with debug layer there are warnings when resources are created by ORT. --------- Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com> --- .../src/DmlCommittedResourceAllocator.cpp | 2 +- .../DmlExecutionProvider/src/DmlExternalBufferAllocator.h | 2 +- .../dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp index b696aefecf664..54393e9bf1539 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp @@ -16,7 +16,7 @@ namespace Dml unmove_ptr(CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT)), D3D12_HEAP_FLAG_NONE, &buffer, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index 9514a24b4e781..22fd3be42c416 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -39,7 +39,7 @@ namespace Dml &props, D3D12_HEAP_FLAG_NONE, &buffer, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 18cdc5d1bf86e..642d9aa03eeef 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -36,7 +36,7 @@ namespace DmlGraphFusionHelper &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); @@ -74,7 +74,7 @@ namespace DmlGraphFusionHelper &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); @@ -302,7 +302,7 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - + if (std::holds_alternative>(nodeInfo.nodeDef)) { dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; From 107d7492b9c0f82dc61974065c08c91550d0dea4 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 8 Dec 2023 19:38:44 -0800 Subject: [PATCH 15/25] [DirectML EP] Add DML EP registration for Col2Im (#17786) ### Description [DirectML EP] Add DML EP registration for Col2Im operator ### Motivation and Context Add Col2Im support for opset 18. This operator is implemented as the DirectML Fold operator. --------- Co-authored-by: Sheil Kumar Co-authored-by: Dwayne Robinson --- cmake/external/dml.cmake | 4 +- .../src/Operators/DmlOperatorCol2Im.cpp | 59 +++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 2 + .../OperatorAuthorHelper/OperatorHelper.cpp | 71 +++++++++++++++++-- .../dml/OperatorAuthorHelper/OperatorHelper.h | 29 ++++++++ .../OperatorAuthorHelper/OperatorVersions.h | 1 + 6 files changed, 158 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index d777306722cd6..dfd9ad120eb98 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -72,12 +72,11 @@ else() if (dml_EXTERNAL_PROJECT) set(dml_preset_config $,debug,release>) set(dml_preset_name ${onnxruntime_target_platform}-win-redist-${dml_preset_config}) - target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) include(ExternalProject) ExternalProject_Add( directml_repo GIT_REPOSITORY https://dev.azure.com/microsoft/WindowsAI/_git/DirectML - GIT_TAG d460f0f46967bea878786f1bed69487692c779bf + GIT_TAG a5312f72c51864b4d705ac62d25d08bcd88c4fb1 GIT_SHALLOW OFF # not allowed when GIT_TAG is a commit SHA, which is preferred (it's stable, unlike branches) GIT_PROGRESS ON BUILD_IN_SOURCE ON @@ -94,6 +93,7 @@ else() target_link_libraries(DirectML INTERFACE ${directml_install_path}/lib/DirectML.lib) add_dependencies(DirectML directml_repo-install) include_directories(BEFORE ${directml_install_path}/include) + target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) else() include_directories(BEFORE ${dml_INCLUDE_DIR}) set(DML_PACKAGE_DIR ${dml_INCLUDE_DIR}/..) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp new file mode 100644 index 0000000000000..13a51f2be7560 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "./precomp.h" + +namespace Dml +{ + +class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper +{ +public: + explicit DmlOperatorCol2Im(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + Col2ImHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "Col2Im expects 3 inputs."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Col2Im expects 1 output."); + + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector inputTensorShape = tensorShapeDescription.GetInputTensorShape(0); + std::vector outputTensorShape = tensorShapeDescription.GetOutputTensorShape(0); + + ML_CHECK_VALID_ARGUMENT(outputTensorShape == m_outputShape); + + std::vector> inputIndices = { 0 }; + gsl::span inputShapes[1] = { m_inputShape }; + gsl::span outputShapes[1] = { m_outputShape }; + DmlOperator::InitializeWithShapes( + kernelCreationContext, + inputIndices, + std::nullopt, + inputShapes, + outputShapes, + 3 + ); + // Prepare DML_FOLD_OPERATOR_DESC + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + assert(inputDescs.size() == 1); + assert(outputDescs.size() == 1); + + DML_FOLD_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.DimensionCount = gsl::narrow_cast(m_blockShape.size()); + operatorDesc.WindowSizes = m_blockShape.data(); + operatorDesc.Dilations = m_dilations.data(); + operatorDesc.StartPadding = m_pads.data(); + operatorDesc.EndPadding = m_pads.data(); + operatorDesc.Strides = m_strides.data(); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FOLD, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(Col2Im, DmlOperatorCol2Im); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 0234bb6b7ec1e..2ab73afb8f1e1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); +DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); @@ -770,6 +771,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)}, {REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 18, Col2Im, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2))}, // Data reorganization that merely changes the dimensions while keeping the data identical. {REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 7c2320160dd3b..83c6748fadd35 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -257,14 +257,15 @@ namespace OperatorHelper } } - void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) + template + void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) { outputDimensions.reserve(inputDimensions.size()); outputDimensions.clear(); - for (int64_t dim : inputDimensions) + for (T dim : inputDimensions) { - outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); + outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); } } @@ -1870,6 +1871,64 @@ namespace OperatorHelper return { std::move(outputShape) }; } + void Col2ImHelper::Initialize( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation) + { + std::vector shapeData; + ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(1), /*out*/ shapeData); + m_imageShape.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_imageShape); + ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(2), /*out*/ shapeData); + m_blockShape.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_blockShape); + + const uint32_t dimCount = gsl::narrow_cast(m_blockShape.size()); + m_dilations = {dimCount, 1}; + m_pads = {dimCount * 2, 0}; + m_strides = {dimCount, 1}; + + if (kernelInformation.HasAttribute(AttrName::Dilations, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Dilations); + m_dilations.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_dilations); + ML_CHECK_VALID_ARGUMENT(m_dilations.size() == dimCount); + } + + if (kernelInformation.HasAttribute(AttrName::Pads, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Pads); + m_pads.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_pads); + ML_CHECK_VALID_ARGUMENT(m_pads.size() == dimCount * 2); + } + + if (kernelInformation.HasAttribute(AttrName::Strides, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Strides); + m_strides.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_strides); + ML_CHECK_VALID_ARGUMENT(m_strides.size() == dimCount); + } + + m_inputShape = shapeInformation.GetInputTensorShape(0); + + auto blockShapeProduct = ComputeElementCountFromDimensions(m_blockShape); + m_outputShape.resize(2 + m_imageShape.size()); + m_outputShape[0] = m_inputShape[0]; // N + m_outputShape[1] = m_inputShape[1] / blockShapeProduct; // C + for (int i = 2; i < m_outputShape.size(); i++) + { + m_outputShape[i] = m_imageShape[i - 2]; + }; + } + + std::vector Col2ImHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return { EdgeShapes(m_outputShape) }; + } + void ConcatHelperBase::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions @@ -2020,7 +2079,7 @@ namespace OperatorHelper } return outputShapes; } - + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto inputShape = shapeInfo.GetInputTensorShape(0); @@ -2050,7 +2109,7 @@ namespace OperatorHelper } return outputShapes; } - + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); @@ -2113,7 +2172,7 @@ namespace OperatorHelper { std::vector outputDimensions64bit = shapeInfo.GetAttributeVector(AttrName::OutputShape); ML_CHECK_VALID_ARGUMENT(outputDimensions64bit.size() == m_inputShape.size(), "Input dimensions and output_shape must have same rank."); - DowncastDimensions(outputDimensions64bit, /*out*/ outputDimensions); + DowncastDimensions(gsl::span(outputDimensions64bit), /*out*/ outputDimensions); } else { diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index d8d09efd8d6e8..0e0e6bb1eaf5c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1206,6 +1206,34 @@ class SqueezeHelper std::vector m_axes; }; +class Col2ImHelper +{ +public: + void Initialize( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation); + + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + Col2ImHelper(const Info_t& info, const Shape_t& shape) + { + Initialize(KernelInformationAdapter(info), ShapeInformationAdapter(shape)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +protected: + std::vector m_dilations; + std::vector m_pads; + std::vector m_strides; + std::vector m_imageShape; + std::vector m_blockShape; + std::vector m_inputShape; + std::vector m_outputShape; +}; + + class UnsqueezeHelper { public: @@ -1572,6 +1600,7 @@ using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper; using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Col2Im = Col2ImHelper; using ShapeInferenceHelper_Expand = ExpandHelper; using ShapeInferenceHelper_Reshape7 = ReshapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e9d88adf3e221..8438bc620712c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -407,6 +407,7 @@ namespace OperatorHelper static const int sc_sinceVer_Pad = 18; static const int sc_sinceVer_Split = 18; static const int sc_sinceVer_LpPool = 18; + static const int sc_sinceVer_Col2Im = 18; } namespace OnnxOperatorSet19 From d2f7a5b1286e34ea55dceccff8f17a45f2f799aa Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Mon, 11 Dec 2023 17:41:16 -0800 Subject: [PATCH 16/25] Cherry pick fix constant pow (#18785) ### Description Cherry pick https://github.com/microsoft/onnxruntime/pull/18784 --- .../src/Operators/DmlOperatorElementWise.cpp | 2 +- .../dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index f0a16da3a3c06..ec94772238cc9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,7 +479,7 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(1); if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) { std::vector> kernelInputIndices = {0}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 59a1719d08ee6..c40f82a8c31c6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -605,11 +605,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } - std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + std::optional TryGetConstantCpuInputTensor(uint32_t inputIndex) const { Microsoft::WRL::ComPtr tensor; ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); - if (tensor) + if (tensor && tensor->IsCpuData()) { return MLOperatorTensor(tensor.Get()); } From b2f81c8725a0dd1ba343c91ff45aa082e9331f5d Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 14 Dec 2023 14:20:55 -0800 Subject: [PATCH 17/25] Hide Col2Im registration behind DML_TARGET_VERSION 6300 (#18829) Hide Col2Im registration behind DML_TARGET_VERSION 6300 Co-authored-by: Sheil Kumar --- .../src/Operators/DmlOperatorCol2Im.cpp | 4 ++++ .../src/Operators/OperatorRegistration.cpp | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp index 13a51f2be7560..f80b4f98236bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp @@ -6,6 +6,8 @@ namespace Dml { +#if DML_TARGET_VERSION >= 0x6300 + class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper { public: @@ -56,4 +58,6 @@ class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper DML_OP_DEFINE_CREATION_FUNCTION(Col2Im, DmlOperatorCol2Im); +#endif // DML_TARGET_VERSION >= 0x6300 + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 2ab73afb8f1e1..15a8051953c79 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,7 +503,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); + +#if DML_TARGET_VERSION >= 0x6300 DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); +#endif + DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); @@ -771,7 +775,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)}, {REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + +#if DML_TARGET_VERSION >= 0x6300 {REG_INFO( 18, Col2Im, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2))}, +#endif // Data reorganization that merely changes the dimensions while keeping the data identical. {REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, From bdaeebd6ff5d4667f288ebb9b8072f66b45de1c8 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:47:57 -0800 Subject: [PATCH 18/25] Fix bug in DML EP ExecuteCommandList fast path and simplify design (#18866) ### Description This addresses a bug in a fast path that was added for submission of re-used command lists of fused graph kernels in the DML EP, addressing a D3D debug layer error. ### Motivation and Context The fast path in DmlCommandRecorder::ExecuteCommandList enabled a current non-reused command list, if empty, to be used for commands following submission of the fused command list. The fix ensures the associated command allocator is only re-used after the next fence value is completed, which is higher due to submission of the other command list. The command recorder design was intended to support batching of provided command list execution, however it submits command lists immedately as an implementation detail to maximize CPU/GPU parallelism. If that heuristic was removed, it would expose additional issues in this same fast path. Because of this and complexity and inefficiency of the old batching mechanism, I also removed this. --- .../src/CommandAllocatorRing.h | 8 ++ .../src/DmlCommandRecorder.cpp | 89 +++++++------------ .../src/DmlCommandRecorder.h | 16 ++-- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h index 570f62aac8105..2eee9c9a9e5a3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h @@ -47,6 +47,14 @@ namespace Dml return m_commandAllocators[m_currentCommandAllocator].Get(); } + // Updates the completion event of the current allocator to a different value. This is used when the caller + // decides to issue an unrelated call to the queue such as ExecuteCommandLists which updates its fence between calling + // GetNextAllocator and executing the work which it recorded using the allocator it received. + void UpdateCurrentAllocatorCompletionEvent(GpuEvent nextCompletionEvent) + { + m_commandAllocators[m_currentCommandAllocator].completionEvent = nextCompletionEvent; + } + private: struct CommandAllocatorInfo { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 98345f37b68d4..5254b23f56376 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -262,6 +262,9 @@ void DmlCommandRecorder::ExecuteCommandList( m_queue->ExecuteCommandLists( gsl::span(reinterpret_cast(&commandList), 1)); + // The fence value at which the current command allocator may be re-used will now be higher + m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent()); + // Fail early if something horrifying happens ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); ORT_THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason()); @@ -269,42 +272,20 @@ void DmlCommandRecorder::ExecuteCommandList( return; } - ORT_THROW_IF_FAILED(m_currentCommandList->Close()); - - if (m_operationsRecordedInCurrentCommandList) - { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); - } - else - { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); - } - - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - m_pendingCommandLists.push_back(commandList); - m_pendingCommandListsCacheable.push_back(false); - - // Remember the descriptor heap and apply it to the next command list + // Remember the descriptor heap and apply it to the next command list. This avoids unnecessarily setting it onto + // the D3D object lazily at a point when the operation may not be parallelized with GPU work. auto heap = m_currentDescriptorHeap; - m_currentDescriptorHeap = nullptr; - Open(); - - // The caller can re-use relevant resources after the next set of work to be - // flushed has completed. Its command list hasn't been executed yet, just batched. - GpuEvent gpuEvent = m_queue->GetNextCompletionEvent(); - gpuEvent.fence.CopyTo(fence); - *completionValue = gpuEvent.fenceValue; - // Trigger a flush of the command list, with the assumption that it contains enough GPU work that this - // will help parallelize GPU work with subsequent CPU work. This policy is related to the choice of - // minNodeCountToReuseCommandList within FusedGraphKernel, so both should be tuned together. - CloseAndExecute(); + // Execute work in the current command list plus provided command list while closing the recorder. + CloseAndExecute(commandList); Open(); + // Reset the descriptor heap opportunistically per above comment SetDescriptorHeap(heap); + + GpuEvent gpuEvent = m_queue->GetCurrentCompletionEvent(); + gpuEvent.fence.CopyTo(fence); + *completionValue = gpuEvent.fenceValue; } ComPtr DmlCommandRecorder::GetCommandList() @@ -334,7 +315,7 @@ void DmlCommandRecorder::Open() ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent()); - if (m_cachedCommandLists.empty()) + if (!m_cachedCommandList) { ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList( 0, @@ -345,47 +326,43 @@ void DmlCommandRecorder::Open() } else { - m_currentCommandList = m_cachedCommandLists.front(); - m_cachedCommandLists.pop_front(); + m_currentCommandList = m_cachedCommandList; + m_cachedCommandList = nullptr; ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr)); } } void DmlCommandRecorder::CloseAndExecute() { + CloseAndExecute(nullptr); +} + +void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList) +{ ORT_THROW_IF_FAILED(m_currentCommandList->Close()); + ID3D12GraphicsCommandList* commandListsToExecute[2] = {}; + uint32_t commandListsToExecuteCount = 0; + if (m_operationsRecordedInCurrentCommandList) { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); + commandListsToExecute[commandListsToExecuteCount++] = m_currentCommandList.Get(); } - else + + if (commandList) { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); + commandListsToExecute[commandListsToExecuteCount++] = commandList; } - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - if (!m_pendingCommandLists.empty()) + if (commandListsToExecuteCount > 0) { - // Close and execute the command list m_queue->ExecuteCommandLists( - gsl::span(reinterpret_cast(m_pendingCommandLists.data()), m_pendingCommandLists.size())); - - assert(m_pendingCommandLists.size() == m_pendingCommandListsCacheable.size()); - for (size_t i = 0; i < m_pendingCommandLists.size(); ++i) - { - if (m_pendingCommandListsCacheable[i]) - { - m_cachedCommandLists.push_back(m_pendingCommandLists[i]); - } - } - - m_pendingCommandLists.clear(); - m_pendingCommandListsCacheable.clear(); + gsl::span(reinterpret_cast(commandListsToExecute), commandListsToExecuteCount)); } + + m_cachedCommandList = m_currentCommandList; + m_currentCommandList = nullptr; + m_operationsRecordedInCurrentCommandList = false; // The descriptor heap must be set on the command list the next time it's opened. m_currentDescriptorHeap = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h index 7ad7032317d77..83051c8ca4ff9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h @@ -58,7 +58,7 @@ namespace Dml bool HasUnsubmittedWork() override { - return m_operationsRecordedInCurrentCommandList || !m_pendingCommandLists.empty(); + return m_operationsRecordedInCurrentCommandList; } // Forces the descriptor heap to be reset to D3D before executing future operations @@ -68,7 +68,8 @@ namespace Dml } private: - + void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList); + std::shared_ptr m_queue; ComPtr m_d3dDevice; ComPtr m_dmlDevice; @@ -89,15 +90,8 @@ namespace Dml ComPtr m_currentCommandList; bool m_operationsRecordedInCurrentCommandList = false; - // Command lists which have been batched up for execution. The values in - // m_pendingCommandListsCacheable indicate whether they can be moved into this - // class's cache after execution, versus if they belong to the caller and were - // passed to ExecuteCommandList. - std::vector> m_pendingCommandLists; - std::vector m_pendingCommandListsCacheable; - - // A pool of cached command lists which may be re-used. - std::deque> m_cachedCommandLists; + // A cached command list which may be re-used. + ComPtr m_cachedCommandList; void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap); }; From 70d3f682a7a7553f337bd0422c3b8fc555af52a1 Mon Sep 17 00:00:00 2001 From: tbqh <111796392+tbqh@users.noreply.github.com> Date: Tue, 2 Jan 2024 13:22:30 -0600 Subject: [PATCH 19/25] De-duplicate 1D scale and zero point tensors to scalars in DML kernels (#18862) ### Description Cleanup and rebase from [this PR](https://github.com/microsoft/onnxruntime/pull/18629) ### Motivation and Context --------- Co-authored-by: Christian Larson Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com> Co-authored-by: Jeff Bloomfield Co-authored-by: Anagha Rao --- .../inc/IWinmlExecutionProvider.h | 3 + .../src/Operators/DmlOperator.cpp | 83 +++++++++++++++++++ .../src/Operators/DmlOperator.h | 6 ++ .../DmlOperatorBatchNormalization.cpp | 2 + .../src/Operators/DmlOperatorElementWise.cpp | 3 + .../src/Operators/DmlOperatorQLinearAdd.cpp | 23 +++-- .../DmlOperatorQLinearAveragePooling.cpp | 10 ++- .../Operators/DmlOperatorQLinearConcat.cpp | 7 ++ .../src/Operators/DmlOperatorQLinearConv.cpp | 9 ++ .../Operators/DmlOperatorQLinearMatMul.cpp | 9 ++ .../Operators/DmlOperatorQLinearSigmoid.cpp | 7 ++ 11 files changed, 153 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 074f13b309181..f29cc3afc3cda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -85,7 +85,10 @@ namespace Windows::AI::MachineLearning::Adapter { uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; + + // TODO (jeffbloo): Remove this std::vector> nodesAsIDMLOperator; + std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 25c7be42d6425..c3bb1a52210f5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -6,6 +6,9 @@ namespace Dml { + + /*static*/ const uint32_t DmlOperator::zeroArray[8] = {}; + DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo) { ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider)); @@ -824,4 +827,84 @@ namespace Dml graphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } + /*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex) + { + if (!tensor) + { + return; + } + + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(kernelInputIndex); + if (!constExpTensor) + { + return; + } + else if (!constExpTensor->IsCpuData()) + { + return; + } + + uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount(); + if (totalKernelInputElementCount <= 1) + { + return; + } + + uint32_t elementSize = 0; + + switch (constExpTensor->GetTensorDataType()) + { + case MLOperatorTensorDataType::UInt8: + case MLOperatorTensorDataType::Int8: + elementSize = 1; + break; + + case MLOperatorTensorDataType::Float16: + case MLOperatorTensorDataType::UInt16: + case MLOperatorTensorDataType::Int16: + elementSize = 2; + break; + + case MLOperatorTensorDataType::/*Float32*/Float: + case MLOperatorTensorDataType::UInt32: + case MLOperatorTensorDataType::Int32: + elementSize = 4; + break; + + case MLOperatorTensorDataType::/*Float64*/Double: + case MLOperatorTensorDataType::UInt64: + case MLOperatorTensorDataType::Int64: + elementSize = 8; + break; + + default: + return; + } + + const std::uint8_t* byteData = static_cast(constExpTensor->GetByteData()); + + assert(tensor->Type == DML_TENSOR_TYPE_BUFFER); + auto *bufferTensorDesc = const_cast(static_cast(tensor->Desc)); + + for (size_t i = 1; i < totalKernelInputElementCount; ++i) + { + if (memcmp(byteData, byteData + i * elementSize, elementSize)) + { + return; + } + } + + if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0])) + { + assert(false); + return; + } + + bufferTensorDesc->Strides = zeroArray; + bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3; + } + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index c1e8cf42a974c..fa54d4b041b5f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -149,6 +149,11 @@ namespace Dml uint32_t minDimensionCount = NchwDimensionCount ) const; + static void TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex); + private: // For each input or output of the DML kernel, the corresponding input or output of the original // kernel. Entries for unused DML inputs are nullopt. @@ -164,6 +169,7 @@ namespace Dml _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); + static const uint32_t zeroArray[8]; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 9f9cfad670919..ee497715dd73f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); + // TODO (jeffbloo): Port this to a graph description to enable DML graph optimization + dml::Graph graph(m_dmlDevice.Get()); dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X]; dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index ec94772238cc9..835d43037eaee 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -586,6 +586,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp index 7b50dfb9ff1ad..a19e37e15e768 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp @@ -8,15 +8,15 @@ namespace Dml class DmlOperatorQLinearAdd : public DmlOperator { - enum InputTensors { - IN_A, + enum InputTensors { + IN_A, IN_A_SCALE, - IN_A_ZERO_POINT, - IN_B, + IN_A_ZERO_POINT, + IN_B, IN_B_SCALE, IN_B_ZERO_POINT, - IN_C_SCALE, - IN_C_ZERO_POINT + IN_C_SCALE, + IN_C_ZERO_POINT }; public: @@ -56,9 +56,18 @@ class DmlOperatorQLinearAdd : public DmlOperator AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE]; AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr; AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE]; - AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; + AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; AddDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp index 0fccedfe311c1..605e5fffb6a76 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; - qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; - qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale]; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint]; qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; @@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.Dilations = m_kernel.dilations; qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint); + + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint); + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index 67711fdc28b84..c97b03dc36b62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -123,6 +123,9 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2); dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; opDescs.push_back(&dmlOpDesc[inputIndex]); @@ -154,6 +157,10 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint); + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; opDescs.push_back(&opQuantizeDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp index d45fdef3c8807..4e121a6502cba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp @@ -117,6 +117,15 @@ class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase convDesc.EndPadding = kernelArgs.endPadding; convDesc.GroupCount = m_groupCount; + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp index b746a0e81a5cf..b38acd8cbf978 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp @@ -104,6 +104,15 @@ class DmlOperatorQLinearMatMul : public DmlOperator matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr; matMulDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 1da4a5cab7623..35f926d62c92a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -88,6 +88,9 @@ class DmlOperatorQLinearSigmoid : public DmlOperator dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale]; dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point]; dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point); const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc}; @@ -101,6 +104,10 @@ class DmlOperatorQLinearSigmoid : public DmlOperator quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point); + const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; MLOperatorGraphDesc operatorGraphDesc = {}; From ee60e3af6c0e12ea3ab4a17499122c289081e92f Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:58:13 -0800 Subject: [PATCH 20/25] =?UTF-8?q?Limit=20size=20of=20constant=20nodes=20cr?= =?UTF-8?q?eates=20by=20DML=20EP=20following=20deduplicatio=E2=80=A6=20(#1?= =?UTF-8?q?8915)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This limits the size of constant data nodes which the DML EP creates in the DML graph following de-duplication of 1D quantization tensors. In the process it reduces a check for the maximum size of the constant node. This is merged from: https://github.com/microsoft/onnxruntime/pull/18494 ### Motivation and Context --- .../src/GraphDescBuilder.cpp | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index ba022533a1e94..adb4fd131119f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -347,19 +347,23 @@ namespace Dml::GraphDescBuilder // This is a highly inefficient approach to generating constant nodes. It duplicates constant data // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. - - // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming - // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point - // values that have been de-duplicated with conversion to scalars in kernels. - uint32_t c_maxConstNodeDataSize = 1024 * 1024; + uint32_t c_maxConstNodeDataSize = 8; ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); - if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + + if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), tensorDesc->totalTensorSizeInBytes); auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + constantInput->GetTensorByteSize()); - + std::vector tensorData(data, data + minimumConstantSize); + NodeInfo nodeInfo = {}; nodeInfo.nodeDef = std::move(tensorData); graphNodes.push_back(std::move(nodeInfo)); @@ -379,9 +383,6 @@ namespace Dml::GraphDescBuilder edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; graphInputEdges.push_back(edge); - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; } } @@ -445,7 +446,7 @@ namespace Dml::GraphDescBuilder // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - + // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) dmlDesc.Type = (DML_OPERATOR_TYPE) 170; From 56fcea94e3c6f158f3a16dbe42dc1b16caf35565 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:06:05 -0800 Subject: [PATCH 21/25] Enable QDQ quantization for DML EP (#18367) ### Description This enables QDQ transforms with the DML EP --- .../qdq_selector_action_transformer.cc | 60 +++++++++++++------ .../selectors_actions/qdq_selectors.cc | 7 +++ .../selectors_actions/qdq_selectors.h | 26 ++++---- onnxruntime/core/session/inference_session.cc | 20 ------- tools/ci_build/build.py | 3 + 5 files changed, 67 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 29178fe87f75c..29f7575b2a638 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -105,8 +105,8 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit. - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, {"LeakyRelu", {}}, @@ -123,20 +123,43 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 4 nodes. 2 x DQ for inputs, target, Q // Replace with internal QLinear version of operator. Delete all original nodes. - const std::string action_name{"2DQ"}; - std::unique_ptr action = std::make_unique(kMSDomain); + { + const std::string action_name{"2DQ"}; + std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"Add", {}}, - {"Mul", {}}}, - std::move(selector), - std::move(action)); + // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Add", {}}, + {"Mul", {}}}, + std::move(selector), + std::move(action)); #else - qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif + } + +#ifdef USE_DML + { + const std::string action_name{"2DQ_DML"}; + std::unique_ptr action = std::make_unique(kMSDomain); + +#if !defined(ORT_MINIMAL_BUILD) + std::vector providers = {kDmlExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); + + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Add", {}}}, + std::move(selector), + std::move(action)); + +#else +#error "ORT_MINIMAL_BUILD and USE_DML are not expected simultaneously. This would require RegisterAction to be called here." +#endif + } #endif } @@ -214,8 +237,8 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when QGemm supports 16-bit. - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Gemm", {}}}, std::move(selector), @@ -235,8 +258,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit. - std::unique_ptr selector = std::make_unique(); + + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Where", {}}}, std::move(selector), @@ -271,8 +295,8 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed), apply_context, - // this transformer is only compatible with the CPU EP - {kCpuExecutionProvider}} { + // this transformer is only compatible with the CPU and DML EP + {kCpuExecutionProvider, kDmlExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 15b501c667046..8535b8c9a944a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -91,6 +91,13 @@ std::optional NodeGroupSelector::GetQDQSelection(const GraphViewer& g } std::optional BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { + const std::string_view node_ep = node.GetExecutionProviderType(); + + if (!compatible_providers_.empty() && + std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { + return std::nullopt; + } + const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node); if (!qdq_group.has_value()) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d0d7fb2c2af17..5ce2a97026516 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -257,12 +257,15 @@ class BaseSelector : public NodeSelector { // We std::move SelectorActionRegistry into the SelectorActionTransformer so this class needs to have a move ctor BaseSelector(BaseSelector&& rhs) noexcept - : node_group_selector_{std::move(rhs.node_group_selector_)} { + : node_group_selector_{std::move(rhs.node_group_selector_)}, + compatible_providers_{std::move(rhs.compatible_providers_)} { } protected: - BaseSelector(std::unique_ptr node_group_selector) - : node_group_selector_{std::move(node_group_selector)} {} + BaseSelector(std::unique_ptr node_group_selector, gsl::span compatible_providers = {}) + : node_group_selector_{std::move(node_group_selector)}, + compatible_providers_(compatible_providers.begin(), compatible_providers.end()) { + } // override if you need to adjust the values in NodesToOptimize. // e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs @@ -271,6 +274,7 @@ class BaseSelector : public NodeSelector { private: std::unique_ptr node_group_selector_; + std::vector compatible_providers_; }; class DropQDQNodesSelector : public BaseSelector { @@ -287,14 +291,14 @@ class DropDQNodesSelector : public BaseSelector { class UnarySelector : public BaseSelector { public: - explicit UnarySelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; class BinarySelector : public BaseSelector { public: - explicit BinarySelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; // Variadic DQ nodes -> node -> Q @@ -326,8 +330,8 @@ class ConvSelector : public BaseSelector { class WhereSelector : public BaseSelector { public: - explicit WhereSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not @@ -342,8 +346,8 @@ class MatMulSelector : public BaseSelector { // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - explicit GemmSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cef160489ac46..665cdbc36a963 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -633,26 +633,6 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr session_options_.enable_mem_pattern = false; } - // Default this option to true when the DML EP is registered. - // This should be removed if QDQ is supported for DML through QDQSelectorActionTransformer and the DML EP does not - // rely on the constant folding pass for DequantizeLinear. - optional disable_quant_qdq = session_options_.config_options.GetConfigEntry(kOrtSessionOptionsDisableQuantQDQ); - - if (disable_quant_qdq == std::nullopt) { - LOGS(*session_logger_, INFO) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "So disabling it for this session since it uses the DML Execution Provider."; - - auto st = session_options_.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"); - if (!st.IsOK()) { - return st; - } - } else if (*disable_quant_qdq != "1") { - LOGS(*session_logger_, WARNING) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "It is enabled within session options which may result in lower performance."; - } - // Parallel execution mode does not support DML EP if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { LOGS(*session_logger_, INFO) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 5cc537c4596e8..c655100fbf475 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1641,6 +1641,9 @@ def setup_dml_build(args, cmake_path, build_dir, configs): ] run_subprocess(cmd_args) + if args.minimal_build is not None: + raise BuildError("use_dml and minimal_build may not both be set") + def setup_rocm_build(args): rocm_home = None From 70a6f816af1a482b65015616d56b73d21a43d7fe Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Wed, 3 Jan 2024 16:22:54 -0800 Subject: [PATCH 22/25] Port attention query fix from b2768bbf2347b4ea564f2a937f9f48987620ddf0 --- .../src/Operators/DmlOperatorAttention.cpp | 6 ++++++ .../core/providers/dml/OperatorAuthorHelper/Attributes.h | 1 + 2 files changed, 7 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index bbebb4a333baf..c8ca6806e75f7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -571,6 +571,12 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o return; } + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + *isSupported = true; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 85333aa77b686..e3df1d00b3e8a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -107,6 +107,7 @@ namespace AttrName static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes"; static constexpr const char* Unidirectional = "unidirectional"; static constexpr const char* NumHeads = "num_heads"; + static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer"; static constexpr const char* FusedActivation = "fused_activation"; static constexpr const char* FusedActivationDomain = "fused_activation_domain"; From f4ad940ff3c69e139c3baf91517dfd458e166fb1 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Wed, 3 Jan 2024 18:37:14 -0800 Subject: [PATCH 23/25] Disable MatMul QDQ selector on DML EP until MatMulIntegerToFloat is re-enabled --- .../selectors_actions/qdq_selector_action_transformer.cc | 3 ++- .../qdq_transformer/selectors_actions/qdq_selectors.h | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 29f7575b2a638..244bd8b153077 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -216,7 +216,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #if !defined(ORT_MINIMAL_BUILD) // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. - std::unique_ptr selector = std::make_unique(is_int8_allowed); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers, is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5ce2a97026516..deee6e7f25f1a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -337,9 +337,10 @@ class WhereSelector : public BaseSelector { // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed, bool allow_16bit = false) + MatMulSelector(gsl::span compatible_providers, bool int8_allowed, bool allow_16bit = false) : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, - allow_16bit)) {} + allow_16bit), + compatible_providers) {} }; // Input: DQ nodes for A, B and optional C From 8ea3e68192967314f80ef4e7abd21a59b7cfd722 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 4 Jan 2024 10:10:46 -0800 Subject: [PATCH 24/25] Update ContribOperators.md --- docs/ContribOperators.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 131db5d8d9b37..38fceef67de25 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+
rotary_embedding_dim : int
+
Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
From 7401b6661d29094c1cb08d117fcaf8af7038b16f Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 4 Jan 2024 11:27:03 -0800 Subject: [PATCH 25/25] Update OperatorKernels.md --- docs/OperatorKernels.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ce9b3254d91f..e401baae2d803 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -903,7 +903,8 @@ Do not modify directly.* |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Atanh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|AveragePool|*in* X:**T**
*out* Y:**T**|19+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| |BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(float), tensor(float16)| @@ -951,7 +952,7 @@ Do not modify directly.* |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|7+|**T** = tensor(float), tensor(float16)| -|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| @@ -1030,7 +1031,8 @@ Do not modify directly.* |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| |LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||2+|**T** = tensor(float), tensor(float16)| |MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||9+|**T** = tensor(float), tensor(float16)| @@ -1145,8 +1147,8 @@ Do not modify directly.* |Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| -|||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| @@ -1247,6 +1249,9 @@ Do not modify directly.* |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)| +|QLinearGlobalAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|