From c3d96a7b35c975a4eb4ad5a4c94349797defbc78 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:06:26 -0800 Subject: [PATCH 01/45] Update DML version to 1.13.0 (#18978) Update DML nuget version to 1.13.0 --- .pipelines/nuget_config/x64/packages.config | 2 +- .pipelines/nuget_config/x86/packages.config | 2 +- cmake/external/dml.cmake | 2 +- packages.config | 2 +- tools/nuget/generate_nuspec_for_native_nuget.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2ac650b0e6dc9..2583e0d1b2ead 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index f80f96194a230..5ca659941c159 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 5d25b9529e030..d777306722cd6 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/packages.config b/packages.config index da61a10adfa74..b67219d6d6913 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 66248565a3e3a..56e50750ac153 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -219,7 +219,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("") From 9bbe425d7f805d4dfaad2e69c3edca40063fe673 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 27 Jul 2023 19:13:15 -0700 Subject: [PATCH 02/45] Register LPpool18 and AvgPool 19 (#16880) --- .../src/External/DirectMLHelpers/ApiTraits.h | 30 ++++++++++++++ .../External/DirectMLHelpers/DirectMLSchema.h | 40 +++++++++++++++++++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 39 ++++++++++++++++++ .../src/Operators/DmlOperatorPooling.cpp | 40 +++++++++++++++++-- .../src/Operators/OperatorRegistration.cpp | 2 + .../OperatorAuthorHelper/OperatorVersions.h | 6 +++ .../test/providers/cpu/nn/pool_op_test.cc | 25 ------------ 7 files changed, 153 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index c75b662af788d..94f2220fcc168 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -459,12 +459,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1; +}; + template <> struct OperatorDescTraits { static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1; +}; + template <> struct OperatorDescTraits { @@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1> +{ + using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING> { using DescType = DML_LP_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1> +{ + using DescType = DML_LP_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING> { @@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_AVERAGE_POOLING: return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_AVERAGE_POOLING1: + return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LP_POOLING: return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LP_POOLING1: + return std::invoke(std::forward(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING: return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING1: @@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; + case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; + case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 1ebd52d4ed427..9eae1c1fe8158 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA { DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_AVERAGE_POOLING1", + DML_OPERATOR_AVERAGE_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA { DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_LP_POOLING1", + DML_OPERATOR_LP_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 833871de0bbd9..ad4cceb85cfd2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,6 +425,21 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } + +inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc) { return { @@ -438,6 +453,20 @@ inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))), }; } +inline std::vector GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.P))), + }; +} inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc) { return { @@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA; case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA; case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA; @@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_AVERAGE_POOLING1: + return AbstractOperatorDesc( + &DML_AVERAGE_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_LP_POOLING: return AbstractOperatorDesc( &DML_LP_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LP_POOLING1: + return AbstractOperatorDesc( + &DML_LP_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_MAX_POOLING: return AbstractOperatorDesc( &DML_MAX_POOLING_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index e8d5b2746aa13..10ff1d8be8a29 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase kernelOutputIndices.emplace_back(1); } DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices); - + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1."); @@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_AVERAGE_POOLING1: + { + if (hasDilations) { + DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_AVERAGE_POOLING_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_LP_POOLING: { DML_LP_POOLING_OPERATOR_DESC desc = {}; @@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_LP_POOLING1: + { + if (hasDilations) { + DML_LP_POOLING1_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_LP_POOLING_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_MAX_POOLING: case DML_OPERATOR_MAX_POOLING1: case DML_OPERATOR_MAX_POOLING2: @@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { *isSupported = false; - + MLOperatorAttributes attributes(context); int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0); @@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* *isSupported = true; } -DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 28360f09bcba3..dbe9f5da4f569 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, @@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e18ba31def48a..3eb35faeba82f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -406,6 +406,12 @@ namespace OperatorHelper static const int sc_sinceVer_BitwiseNot = 18; static const int sc_sinceVer_Pad = 18; static const int sc_sinceVer_Split = 18; + static const int sc_sinceVer_LpPool = 18; + } + + namespace OnnxOperatorSet19 + { + static const int sc_sinceVer_AveragePool = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 10476ada2fa69..4b194ec18b31b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -777,11 +777,6 @@ TEST(PoolTest, GlobalMaxPool3D) { } TEST(PoolTest, AveragePool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -863,11 +858,6 @@ TEST(PoolTest, AveragePool) { } TEST(PoolTest, AveragePool_IncludePadPixel) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -911,11 +901,6 @@ TEST(PoolTest, AveragePool_DefaultStrides) { } TEST(PoolTest, AveragePool_10_ceil1_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 10); test.AddAttribute("auto_pad", ""); @@ -939,11 +924,6 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { } TEST(PoolTest, AveragePool_19_dilation_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 19); test.AddAttribute("auto_pad", ""); @@ -1070,11 +1050,6 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { } TEST(PoolTest, LpPool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("LpPool"); test.AddAttribute("auto_pad", ""); From 9ff5e3b7b0f5e9683422fa3609175fe4f82ccb77 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:34:35 -0700 Subject: [PATCH 03/45] Add QLinearConcat for DML EP (#16971) (#18268) ### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](https://github.com/microsoft/onnxruntime/pull/16971) ### Motivation and Context Co-authored-by: Xiang Zhang --- .../Operators/DmlOperatorQLinearConcat.cpp | 236 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 16 +- .../src/Operators/OperatorUtility.cpp | 4 +- .../src/Operators/OperatorUtility.h | 3 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 18 +- .../dml/OperatorAuthorHelper/OperatorHelper.h | 25 +- .../OperatorAuthorHelper/OperatorVersions.h | 1 + 7 files changed, 290 insertions(+), 13 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp new file mode 100644 index 0000000000000..67711fdc28b84 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +// QLinearConcat = Dequantize + Join + Quantize +class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper +{ + // This order matches the ONNX schema. + enum OnnxInputIndex + { + YScale, + YZeroPoint, + Count, + }; + +public: + DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + DmlOperator::Initialize(kernelCreationContext); + + auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); + + // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)} + uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount(); + ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs."); + ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!"); + + uint32_t inputCount = (inputDefinitionCount - 2) / 3; + + auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType; + auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType; + + // broadcast y_scale and y_zero_point to output shape + m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc( + yScaleDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc( + yZeroPointDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + // Validate input tensors + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + // Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2 + auto tupleStartIndex = 2 + inputIndex * 3; + auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType; + auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType; + ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale"); + ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point"); + + // broadcast x_scale and x_zero_point to shape of corresponding x + m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc( + xScaleDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc( + xZeroPointDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + } + + uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + // 1. output edges between Dequantize and Join node + // 2. input edge between Join and Quantize node + std::vector intermediateOutputTensorDescs(inputCount); + std::vector namedDequantizeOperatorDescs(inputCount); + std::vector dequantizeOperatorDescs(inputCount); + std::vector dmlOpDesc(inputCount); + std::vector opDescs; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + intermediateOutputTensorDescs[inputIndex] = TensorDesc( + MLOperatorTensorDataType::Float, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment) + ); + namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc(); + + dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex]; + dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; + dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; + dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; + opDescs.push_back(&dmlOpDesc[inputIndex]); + } + + TensorDesc joinOutputTensorDesc = TensorDesc( + MLOperatorTensorDataType::Float, + outputShape, + outputShape, + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc(); + + DML_JOIN_OPERATOR_DESC joinDesc = {}; + joinDesc.InputCount = gsl::narrow_cast(namedDequantizeOperatorDescs.size()); + joinDesc.InputTensors = namedDequantizeOperatorDescs.data(); + joinDesc.OutputTensor = &namedJoinOutputTensorDesc; + joinDesc.Axis = dmlAxis; + + const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc}; + opDescs.push_back(&opJoinDesc); + + DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {}; + quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor; + quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; + quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; + quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; + opDescs.push_back(&opQuantizeDesc); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = static_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; + uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; + + std::vector inputEdges; + // Input edges to Dequantize nodes + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + for (auto edge_index = 0; edge_index < 3; ++edge_index) + { + DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; + inputEdge.GraphInputIndex = tupleStartIndex + edge_index; + inputEdge.ToNodeIndex = inputIndex; + inputEdge.ToNodeInputIndex = edge_index; + inputEdges.push_back(inputEdge); + } + } + + // Input edge from y_scale to quantize node + DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {}; + yScaleInputEdge.GraphInputIndex = 0; // Y_scale + yScaleInputEdge.ToNodeIndex = quantizeNodeIndex; + yScaleInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(yScaleInputEdge); + + // Input edge from y_zero_point to quantize node + DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {}; + yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point + yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex; + yZeroPointInputEdge.ToNodeInputIndex = 2; + inputEdges.push_back(yZeroPointInputEdge); + + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + + // set intermediate edges + std::vector intermediateEdges; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {}; + dequantizeToJoinEdge.FromNodeIndex = inputIndex; + dequantizeToJoinEdge.FromNodeOutputIndex = 0; + dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join + dequantizeToJoinEdge.ToNodeInputIndex = inputIndex; + intermediateEdges.push_back(dequantizeToJoinEdge); + } + + DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {}; + joinToQuantizeEdge.FromNodeIndex = joinNodeIndex; + joinToQuantizeEdge.FromNodeOutputIndex = 0; + joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join + joinToQuantizeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(joinToQuantizeEdge); + + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + + // set the output edges + std::vector outputEdges; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.FromNodeIndex = quantizeNodeIndex; + outputEdge.FromNodeOutputIndex = 0; + outputEdge.GraphOutputIndex = 0; + outputEdges.push_back(outputEdge); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + }; +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index dbe9f5da4f569..fa2750a22425f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -496,6 +496,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND); DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv); DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); @@ -547,6 +548,7 @@ constexpr static std::array typeNameListEyeLike = { "T1", "T2" } constexpr static std::array typeNameShape = { "T", "T1" }; constexpr static std::array typeNameSize = { "T", "T1" }; constexpr static std::array typeNameListGroupNorm = {"T", "M"}; +constexpr static std::array typeNameListQLinearConcat= {"TF", "T8", "TV"}; constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; @@ -618,7 +620,18 @@ constexpr static std::array supportedTypeListQLinea constexpr static std::array supportedTypeListDynamicQuantizeLinear = { SupportedTensorDataTypes::Float32, - SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Ints8Bit +}; + +constexpr static std::array supportedTypeListDynamicQuantizeMatMul= { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::Ints8Bit, +}; + +constexpr static std::array supportedTypeListQLinearConcat= { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; template @@ -1012,6 +1025,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)}, {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index d8290bbdaee3e..2965fa32ce131 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -419,9 +419,9 @@ namespace Dml } // namespace FusionHelpers - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount) + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex) { - const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex); uint32_t onnxDimCount = gsl::narrow_cast(inputDimensions.size()); onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount); return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index f0fad6a05ffb0..8b2da6084242d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -64,8 +64,7 @@ namespace Dml } // namespace FusionHelpers // Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced. - // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case). - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount); + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0); uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 370f336ff5203..4d59964dcc664 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -1862,7 +1862,7 @@ namespace OperatorHelper return { std::move(outputShape) }; } - void ConcatHelper::Initialize( + void ConcatHelperBase::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions ) @@ -1872,13 +1872,13 @@ namespace OperatorHelper ML_CHECK_VALID_ARGUMENT(m_axis < static_cast(inputDimensions.size())); } - std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const { - auto outputShape = shapeInfo.GetInputTensorShape(0); + auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex); uint32_t inputCount = shapeInfo.GetInputCount(); - for (uint32_t i = 1; i < inputCount; ++i) + for (uint32_t i = firstInputIndex + step; i < inputCount; i += step) { auto inputShape = shapeInfo.GetInputTensorShape(i); for (size_t j = 0; j < outputShape.size(); ++j) @@ -1893,6 +1893,16 @@ namespace OperatorHelper return { EdgeShapes(outputShape) }; } + std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1); + } + + std::vector QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3); + } + void CropHelper::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index f7e545d9d99a9..55a01c59ee4b5 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -864,7 +864,7 @@ class RecurrentHelper int m_hiddenSize = 0; }; -class ConcatHelper +class ConcatHelperBase { public: void Initialize( @@ -875,17 +875,33 @@ class ConcatHelper // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template - ConcatHelper(const Info_t& info, const Shape_t& shape) + ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex) { - Initialize(info, shape.GetInputTensorShape(0)); + Initialize(info, shape.GetInputTensorShape(firstInputIndex)); } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const; protected: int m_axis; }; +class ConcatHelper: public ConcatHelperBase +{ +public: + template + ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + +class QLinearConcatHelper: public ConcatHelperBase +{ +public: + template + QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + class CropHelper { public: @@ -1512,6 +1528,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; using ShapeInferenceHelper_Split18 = VersionedOpsetHelper; using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; +using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper; using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change. diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 3eb35faeba82f..996ea1ddcb52c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -443,6 +443,7 @@ namespace OperatorHelper static const int sc_sinceVer_BiasAdd = 1; static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; + static const int sc_sinceVer_QLinearConcat = 1; static const int sc_sinceVer_RotaryEmbedding = 1; } // namespace MsftOperatorSet1 From cb7f28a16ab78601363c2e694679d2dada149dd1 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:43:49 -0700 Subject: [PATCH 04/45] Register Resize for INT8 and UINT8 (#18252) ### Description ### Motivation and Context Co-authored-by: Adrian Tsai --- .../DmlExecutionProvider/src/Operators/OperatorRegistration.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index fa2750a22425f..d7910a6c6849f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -589,7 +589,7 @@ constexpr static std::array supportedTypeListLogica constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32}; -constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; +constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; From dcfff10f57fbbdf81ea9ebcae008be712be42700 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:09:11 -0800 Subject: [PATCH 05/45] Enable QLinearAveragePooling DML EP (#17384) (#18240) [Cherry Pick Reviewed] DML EP Implementation for [QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) ``` Note: Google Test filter = *QLinear*Pool* [==========] Running 72 tests from 2 test suites. [----------] Global test environment set-up. [----------] 36 tests from QLinearGlobalAveragePool [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms) [----------] 36 tests from QLinearGlobalAveragePool (5968 ms total) [----------] 36 tests from QLinearPoolTest [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage [ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global [ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms) [----------] 36 tests from QLinearPoolTest (12914 ms total) [----------] Global test environment tear-down [==========] 72 tests from 2 test suites ran. (18885 ms total) [ PASSED ] 72 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Description ### Motivation and Context --- .../src/External/DirectMLHelpers/ApiTraits.h | 20 ++- .../External/DirectMLHelpers/DirectMLSchema.h | 25 +++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 26 +++ .../DmlOperatorQLinearAveragePooling.cpp | 150 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 8 + .../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++++ .../dml/DmlExecutionProvider/src/TensorDesc.h | 3 + .../dml/OperatorAuthorHelper/Attributes.h | 2 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++++- .../dml/OperatorAuthorHelper/OperatorHelper.h | 28 +++- .../OperatorAuthorHelper/OperatorVersions.h | 2 + .../qlinear_global_average_pool_test.cc | 3 + 12 files changed, 339 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 94f2220fcc168..a5415ba85f3d3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,7 +24,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 160; + static constexpr auto ValueCount = 161; static constexpr size_t ActivationFunctionCount = 24; }; @@ -495,6 +495,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + template <> struct OperatorDescTraits { @@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 9eae1c1fe8158..2a82c12872a72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index ad4cceb85cfd2..99218c135f058 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -502,6 +502,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return AbstractOperatorDesc( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp new file mode 100644 index 0000000000000..0fccedfe311c1 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase +{ + // For QLinear Avg Pool ORT and DML have same indexing order + enum OrtInputTensors : uint32_t + { + ortInput, + ortInputScale, + ortInputZeroPoint, + ortOutputScale, + ortOutputZeroPoint, + ortInputCount + }; + +public: + using Self = DmlOperatorQLinearAveragePooling; + + DmlOperatorQLinearAveragePooling( + const MLOperatorKernelCreationContext& kernelInfo, + bool useGlobalPooling + ) + : DmlOperator(kernelInfo), + PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling) + { + DmlOperator::Initialize(kernelInfo); + + bool isNhwc = m_kernel.channelsLast; + std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount(); + ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2); + + // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling + uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; + if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) + { + size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount; + + for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--) + { + m_kernel.windowSize[i + shift] = m_kernel.windowSize[i]; + m_kernel.windowSize[i] = 1; + + m_kernel.strides[i + shift] = m_kernel.strides[i]; + m_kernel.strides[i] = 1; + + m_kernel.startPadding[i + shift] = m_kernel.startPadding[i]; + m_kernel.startPadding[i] = 0; + + m_kernel.endPadding[i + shift] = m_kernel.endPadding[i]; + m_kernel.endPadding[i] = 0; + + m_kernel.dilations[i + shift] = m_kernel.dilations[i]; + m_kernel.dilations[i] = 1; + } + + m_kernel.spatialDimensionCount = expectedSpatialDimCount; + } + + // Initialize dimensionMapping for NCHW or NHWC layout + std::vector dimensionMapping = {0u, dmlDimSize - 1u}; + dimensionMapping.resize(dmlDimSize); + if (isNhwc) + { + // Form a remapping for dimensions so C is moved before the spatial dimensions. + // e.g. NWC -> {0,2,1} -> NCW + // NHWC -> {0,3,1,2} -> NCHW + // NDHWC -> {0,4,1,2,3} -> NCDHW + std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u); + } + else + { + // Use NCHW {0,1,2,3} format with increasing order of indexs + std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u); + } + m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Reshape the Output Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {}; + + qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; + qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; + qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; + qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; + qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; + qLinearAvgPooldesc.Strides = m_kernel.strides; + qLinearAvgPooldesc.StartPadding = m_kernel.startPadding; + qLinearAvgPooldesc.EndPadding = m_kernel.endPadding; + qLinearAvgPooldesc.Dilations = m_kernel.dilations; + qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +template +class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling +{ +public: + DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling) + { + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index d7910a6c6849f..0234bb6b7ec1e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool); DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); @@ -634,6 +636,10 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; +constexpr static std::array supportedTypeListQLinearAveragePool = { + SupportedTensorDataTypes::Ints8Bit +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 067a320dd8000..a2183aab52eed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm } m_bufferTensorDesc.DimensionCount = newDimensionCount; } + +// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout +void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment) +{ + EnsureStridesExist(); + SetDimensionCount(static_cast(dimensionMapping.size()), alignment); + + // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping + std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount}; + std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount}; + + for (size_t i = 0; i < dimensionMapping.size(); i++) + { + m_sizes[i] = tempSizes[dimensionMapping[i]]; + m_strides[i] = tempStrides[dimensionMapping[i]]; + } + + m_bufferTensorDesc.Sizes = m_sizes; + m_bufferTensorDesc.Strides = m_strides; +} + +void TensorDesc::EnsureStridesExist() +{ + if (m_bufferTensorDesc.Strides != nullptr) + { + // Strides are populated + return; + } + + uint32_t stride = 1; + for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;) + { + m_strides[i] = stride; + stride *= m_sizes[i]; + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index ff70dec5b8871..909e2084d0163 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -44,6 +44,7 @@ namespace Dml gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } gsl::span GetStrides() const; void SetStrides(gsl::span strides); + void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment); inline uint64_t GetBufferSizeInBytes() const { @@ -90,6 +91,8 @@ namespace Dml uint32_t m_sizes[MaximumDimensionCount] = {}; uint32_t m_strides[MaximumDimensionCount] = {}; DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; + + void EnsureStridesExist(); }; class TensorDescBuilder diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e9591cfce6870..85333aa77b686 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -23,8 +23,8 @@ namespace AttrName static constexpr const char* BlockSize = "blocksize"; static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; - static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* CeilMode = "ceil_mode"; + static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* Clip = "clip"; static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode"; static constexpr const char* CountIncludePad = "count_include_pad"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 4d59964dcc664..1fcd3b04300f4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -365,13 +365,20 @@ namespace OperatorHelper } // Creates a kernel that spans the entire spatial dimensions of the input. - KernelArgs InitializeGlobalKernel(gsl::span inputDimensions) + KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions) { ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor) uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount; ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor). KernelArgs args(spatialDimensionCount); + args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); + // For Global Pooling, kernel size equal to the spatial dimension of input tensor + // NHWC layout need to offset by one dim to acount for channel placed at the end + int dimOffset = args.channelsLast ? 1 : 0; for (size_t dim = 0; dim < spatialDimensionCount; ++dim) { @@ -379,7 +386,7 @@ namespace OperatorHelper args.dilations[dim] = 1; args.startPadding[dim] = 0; args.endPadding[dim] = 0; - args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]); + args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]); } return args; @@ -495,6 +502,7 @@ namespace OperatorHelper } args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); return args; } @@ -2012,7 +2020,37 @@ namespace OperatorHelper } return outputShapes; } + + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + + std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 55a01c59ee4b5..d8d09efd8d6e8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -160,6 +160,7 @@ struct KernelArgs bool autoPad = false; bool autoPadSameUpper = false; bool useCeilingOutputShape = false; + bool channelsLast = false; uint32_t spatialDimensionCount = 0; KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) @@ -188,6 +189,7 @@ struct KernelArgs KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), autoPadSameUpper(kernelArgs.autoPadSameUpper), + channelsLast(kernelArgs.channelsLast), spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); @@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose( gsl::span inputDimensions, const KernelArgs& args); -KernelArgs InitializeGlobalKernel(gsl::span inputDimensions); +KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions); KernelArgs InitializeKernel( const MLOperatorAttributes& kernelInfo, @@ -1059,7 +1063,7 @@ class PoolingHelperBase bool useGlobalPooling ) : m_kernel(useGlobalPooling - ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) + ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0)) : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { if (!useGlobalPooling) @@ -1161,6 +1165,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; +class QLinearAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + +class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + class SqueezeHelper { public: @@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper; using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; +using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper; +using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 996ea1ddcb52c..e9d88adf3e221 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -445,6 +445,8 @@ namespace OperatorHelper static const int sc_sinceVer_GroupNorm = 1; static const int sc_sinceVer_QLinearConcat = 1; static const int sc_sinceVer_RotaryEmbedding = 1; + static const int sc_sinceVer_QLinearAveragePool = 1; + static const int sc_sinceVer_QLinearGlobalAveragePool = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc index 8fb245819fd26..71b6f27b5391f 100644 --- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc @@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool( test.AddInput("y_scale", {}, {y_scale}); test.AddInput("y_zero_point", {}, {y_zero_point}); test.AddOutput("Y", y_dims, y_data); + if (channels_last) { + test.AddAttribute("channels_last", (int64_t)1LL); + } auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) { const OrtValue& ort_value = fetches[0]; From d5f3aae3fd30a79d9bbeed26a70907618ac84835 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:43:09 -0800 Subject: [PATCH 06/45] Utilize DML constant input graph node (#18267) ### Description This PR also includes, 8b0a55e7cc DML constant pow operator 7520974970 Enable custom heaps based on query- ### Motivation and Context --------- Co-authored-by: Jeff Bloomfield --- .../src/DmlGraphFusionHelper.cpp | 27 ++++- .../src/ExecutionProvider.cpp | 31 ++++++ .../src/ExecutionProvider.h | 2 + .../src/GraphDescBuilder.cpp | 104 ++++++++++++++---- .../src/GraphDescBuilder.h | 5 +- .../src/IExecutionProvider.h | 1 + .../src/MLOperatorAuthorImpl.cpp | 29 ++++- .../src/MLOperatorAuthorImpl.h | 7 ++ .../src/Operators/DmlOperatorElementWise.cpp | 38 +++++-- .../MLOperatorAuthorHelper.h | 13 +++ .../MLOperatorAuthorPrivate.h | 10 ++ 11 files changed, 229 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 4f7ec188140b5..18cdc5d1bf86e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper { ComPtr initializeInputBuffer; - // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps - if (providerImpl->IsMcdmDevice()) + if (!providerImpl->CustomHeapsSupported()) { initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize); } @@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper const uint32_t inputCount, const uint32_t outputCount, _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlConstantGraphNodes, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + + if (std::holds_alternative>(nodeInfo.nodeDef)) + { + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + else + { + auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); + dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ + nodeDefinitionData.data(), + nodeDefinitionData.size(), + nodeInfo.name.data() + }; + + // TODO: Change as new header is ingested + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + } } for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) @@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); + std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + std::vector dmlGraphNodes(graphDesc.nodes.size()); std::vector dmlInputEdges(graphDesc.inputEdges.size()); std::vector dmlOutputEdges(graphDesc.outputEdges.size()); @@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper fusedNodeInputCount, fusedNodeOutputCount, dmlOperatorGraphNodes, + dmlConstantGraphNodes, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 8644b8d56a426..49a64c4810252 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -182,6 +182,32 @@ namespace Dml } m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE); + m_areCustomHeapsSupported = !m_isMcdmDevice; + + if (m_isMcdmDevice) + { + + // TODO: Ingest updated header file + typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19 + { + BOOL MismatchingOutputDimensionsSupported; + UINT SupportedSampleCountsWithNoOutputs; + BOOL PointSamplingAddressesNeverRoundUp; + BOOL RasterizerDesc2Supported; + BOOL NarrowQuadrilateralLinesSupported; + BOOL AnisoFilterWithPointMipSupported; + UINT MaxSamplerDescriptorHeapSize; + UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers; + UINT MaxViewDescriptorHeapSize; + _Out_ BOOL ComputeOnlyCustomHeapSupported; + } D3D12_FEATURE_DATA_D3D12_OPTIONS19; + + D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; + + // The call may fail in which case the default value is false + d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; + } m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue); @@ -1089,6 +1115,11 @@ namespace Dml return m_isMcdmDevice; } + bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept + { + return m_areCustomHeapsSupported; + } + bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept { return m_areMetacommandsEnabled; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 3aaa11cdee479..ab932fb8a4367 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -150,6 +150,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; bool DynamicGraphFusionEnabled() const noexcept; @@ -186,6 +187,7 @@ namespace Dml ComPtr m_d3d12Device; ComPtr m_dmlDevice; bool m_isMcdmDevice = false; + bool m_areCustomHeapsSupported = false; bool m_areMetacommandsEnabled = true; bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 3fc8f415e5a58..ba022533a1e94 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, @@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (subgraphNodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) { reuseCommandList = true; } @@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder { ComPtr tensor = nullptr; - // Check whether this specific node requested support for constant CPU inputs - if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end()) + auto inputDefs = node.InputDefs(); + + if (inputIndex < inputDefs.size()) { - auto inputDefs = node.InputDefs(); - if (inputIndex < inputDefs.size()) + const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; + tensor = constantCpuGraphInputGetter(arg->Name()); + + if (tensor == nullptr) { - const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; - tensor = constantCpuGraphInputGetter(arg->Name()); + bool inputRequiredAsConstant = std::find( + requiredConstantCpuInputs.begin(), + requiredConstantCpuInputs.end(), + inputIndex) != requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } } @@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; + size_t firstOpDescGraphNodeIndex = graphNodes.size(); if (isNodeAsOpDesc) { @@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); } + + graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); } else { @@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); + nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); graphNodes.push_back(std::move(nodeInfo)); } } @@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder const uint32_t dmlFusedNodeInputIndex = iter->second; - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex - graphInputEdges.push_back(edge); - // If this is a constant input, set the appropriate flags on the desc if (isNodeAsOpDesc && dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + + // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming + // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point + // values that have been de-duplicated with conversion to scalars in kernels. + uint32_t c_maxConstNodeDataSize = 1024 * 1024; + + ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); + + if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + { + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + constantInput->GetTensorByteSize()); + + NodeInfo nodeInfo = {}; + nodeInfo.nodeDef = std::move(tensorData); + graphNodes.push_back(std::move(nodeInfo)); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + edge.FromNodeIndex = static_cast(graphNodes.size() - 1); + edge.FromNodeOutputIndex = 0; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphIntermediateEdges.push_back(edge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); + + auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + } + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); } } else @@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder if (isNodeAsOpDesc) { - for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc) + for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) { + auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + dmlDesc.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + dmlDesc.Type = (DML_OPERATOR_TYPE) 170; + ComPtr op; ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); allocator.Reset(); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(op); + nodeInfo.nodeDef = std::move(op); nodeInfo.name = node.Name(); - graphNodes.push_back(std::move(nodeInfo)); + graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 0039678c00e59..c95e89b45541b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -4,6 +4,7 @@ #pragma once #include "MLOperatorAuthorImpl.h" +#include "ExecutionProvider.h" namespace Dml { @@ -27,7 +28,7 @@ namespace Dml struct NodeInfo { - Microsoft::WRL::ComPtr op; + std::variant, std::vector> nodeDef; std::string name; }; @@ -47,7 +48,7 @@ namespace Dml const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index a8a6d6745e908..17fd7c18ba4a1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -76,6 +76,7 @@ namespace Dml STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0; STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 4deec620fe5fb..dbd06abf82f72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept + { + ORT_TRY + { + auto constantInput = m_constantInputGetter(inputIndex); + ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative>(constantInput)); + + auto tensorWrapper = std::get>(constantInput); + if (tensorWrapper == nullptr) + { + bool inputRequiredAsConstant = std::find( + m_requiredConstantCpuInputs.begin(), + m_requiredConstantCpuInputs.end(), + inputIndex) != m_requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); + } + + *tensor = tensorWrapper.Detach(); + + return S_OK; + } + ORT_CATCH_RETURN + } + HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept { ORT_TRY diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 913997ff4ad49..6530d89d895e7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -204,6 +204,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable _Outptr_ IMLOperatorTensor** tensor ) const noexcept; + HRESULT STDMETHODCALLTYPE TryGetConstantInputTensor( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept; + protected: // Lifetime is managed by the caller and guaranteed to outlive this class const onnxruntime::OpNodeProtoHelper* m_impl = nullptr; @@ -299,6 +304,8 @@ class OnnxTensorWrapper : public WRL::Base, public Closable const onnxruntime::Tensor* GetInterface() const { return nullptr; } onnxruntime::Tensor* GetInterface() { return nullptr; } + size_t GetTensorByteSize() const { return m_tensorByteSize; } + private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 43d34657098ef..f0a16da3a3c06 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,17 +479,37 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) + { + std::vector> kernelInputIndices = {0}; - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); - DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; - opDesc.InputTensor = &inputDescs[0]; - opDesc.ExponentTensor = &inputDescs[1]; - opDesc.OutputTensor = &outputDescs[0]; + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.OutputTensor = &outputDescs[0]; + opDesc.Exponent = static_cast(ReadScalarTensorCastToFloat64(*constExpTensor)); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); + } + else + { + Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.ExponentTensor = &inputDescs[1]; + opDesc.OutputTensor = &outputDescs[0]; + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + } } }; @@ -565,7 +585,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator opDesc.ScaleTensor = &inputDescs[1]; opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; - + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index f94270cfadb8b..59a1719d08ee6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -6,6 +6,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" #include "core/common/gsl.h" +#include #ifdef ORT_NO_EXCEPTIONS #define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x)) @@ -604,6 +605,18 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } + std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr tensor; + ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); + if (tensor) + { + return MLOperatorTensor(tensor.Get()); + } + + return std::nullopt; + } + uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const { auto shapeDesc = GetTensorShapeDescription(); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index d1a705e151ddf..3bec8d3864cba 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -41,6 +41,11 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + //! Gets the number of dimensions of a tensor output of the operator. STDMETHOD(GetSequenceInputInfo)( uint32_t inputIndex, @@ -73,6 +78,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE; STDMETHOD(SetDmlOperator)( From 531e875fb550f7a866cd13639fb2e332440e6b96 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:43:47 -0800 Subject: [PATCH 07/45] Avoid command list reset in common case of re-used command list execution (#18370) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/DmlCommandRecorder.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 530c26d212083..98345f37b68d4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -251,6 +251,24 @@ void DmlCommandRecorder::ExecuteCommandList( _Out_ uint64_t* completionValue ) { + if (!m_operationsRecordedInCurrentCommandList) + { + // The caller can re-use relevant resources after the next set of work to be + // flushed has completed. Its command list hasn't been executed yet, just batched. + GpuEvent gpuEvent = m_queue->GetNextCompletionEvent(); + gpuEvent.fence.CopyTo(fence); + *completionValue = gpuEvent.fenceValue; + + m_queue->ExecuteCommandLists( + gsl::span(reinterpret_cast(&commandList), 1)); + + // Fail early if something horrifying happens + ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); + ORT_THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason()); + + return; + } + ORT_THROW_IF_FAILED(m_currentCommandList->Close()); if (m_operationsRecordedInCurrentCommandList) From a1000a0a3c2552b78045bb5452aea09ae3490e6a Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:45:16 -0800 Subject: [PATCH 08/45] Enable GEMM activation fusions on MCDM (#18372) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/Operators/OperatorUtility.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index 2965fa32ce131..fb86648d595b3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -154,13 +154,13 @@ namespace Dml OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MeanVarianceNormalization }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm }, - OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul }, - OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, From 5c283340c384d5f9e402ba762c3c33af14e6763e Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 18:03:55 -0800 Subject: [PATCH 09/45] Filter activation fusions on MCDM (#18371) ### Description ### Motivation and Context --------- Co-authored-by: Jeff Bloomfield --- .../src/GraphTransformer.cpp | 15 ++++++-- .../src/GraphTransformer.h | 12 +++--- .../src/Operators/OperatorUtility.cpp | 37 ++++++++++++++----- .../src/Operators/OperatorUtility.h | 3 +- onnxruntime/core/session/inference_session.cc | 3 +- 5 files changed, 50 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 09922310b56c1..2e04da843696e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -17,6 +17,14 @@ namespace Dml { + GraphTransformer::GraphTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) : onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + onnxruntime::common::Status GraphTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -27,7 +35,7 @@ namespace Dml // Perform fusion { bool transformModifiedGraph = false; - PerformOperatorFusion(&graph, &transformModifiedGraph); + PerformOperatorFusion(&graph, m_providerImpl->IsMcdmDevice(), &transformModifiedGraph); modified |= transformModifiedGraph; if (modified) @@ -50,7 +58,7 @@ namespace Dml return ss.str(); } - void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const + void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const { struct NodeToAdd { @@ -112,7 +120,8 @@ namespace Dml gsl::narrow_cast(node.InputDefs().size()), outputNode.OpType(), outputNode.Domain(), - outputNode.Op()->SinceVersion()); + outputNode.Op()->SinceVersion(), + isMcdmDevice); if (!fusedOpProperties) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h index a7f8186fb3b64..337c0df7ff521 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h @@ -10,6 +10,7 @@ namespace Dml { + class ExecutionProviderImpl; // Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider // on the graph nodes which DML supports. @@ -17,16 +18,17 @@ namespace Dml { public: GraphTransformer( - const std::string& name - ) : onnxruntime::GraphTransformer(name) - { - } + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); private: onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final; private: - void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const; + void PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const; + + const ExecutionProviderImpl* m_providerImpl = nullptr; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index fb86648d595b3..46442fe942539 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -132,6 +132,8 @@ namespace Dml std::string_view domain; int sinceVersion; std::vector activationFilter; + bool enableOnMcdm; + std::vector extraMcdmActivationFilter; std::optional inputCountFilter; }; @@ -142,10 +144,10 @@ namespace Dml static const OperatorInfo c_fusableOps[] = { - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv }, - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_BatchNormalization }, @@ -163,11 +165,11 @@ namespace Dml OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul, {}, true, {"Relu", "LeakyRelu"} }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, }; // Not all activations can be fused - only simple elementwise activations (i.e. activation functions which @@ -205,7 +207,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion) + int activationOpSinceVersion, + bool isMcdmDevice) { auto opIt = std::find( std::begin(c_fusableOps), @@ -233,6 +236,20 @@ namespace Dml return std::nullopt; } + if (isMcdmDevice) + { + if (!opIt->enableOnMcdm) + { + return std::nullopt; + } + + if (!opIt->extraMcdmActivationFilter.empty() && + std::find(opIt->extraMcdmActivationFilter.begin(), opIt->extraMcdmActivationFilter.end(), activationOpType) == opIt->extraMcdmActivationFilter.end()) + { + return std::nullopt; + } + } + if (opIt->inputCountFilter && *opIt->inputCountFilter != static_cast(candidateOpInputCount)) { return std::nullopt; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index 8b2da6084242d..d3483cb5e8de2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -40,7 +40,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion); + int activationOpSinceVersion, + bool isMcdmDevice); // Returns true if the given activation operator type supports being fused with a fusable operator, false // otherwise. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 575529a06fb7a..cef160489ac46 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1696,7 +1696,8 @@ common::Status InferenceSession::Initialize() { // This transformer applies DML-specific fusions that go beyond what ORT offers by default bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2; if (dml_operator_fusion_enabled) { - std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer"); + std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer", + execution_providers_.Get(kDmlExecutionProvider)); if (dmlOperatorFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr"); } From 613fdce12e528d6921d0c2c7674e0fbbe2c18e6a Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:16:47 -0800 Subject: [PATCH 10/45] Create ring buffer for re-used command lists (#18368) ### Description ### Motivation and Context Co-authored-by: Jeff Bloomfield --- .../src/FusedGraphKernel.cpp | 118 ++++++++++-------- 1 file changed, 68 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 67c3f110e5a50..430ccec3cda10 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -13,6 +13,26 @@ namespace Dml { class FusedGraphKernel : public onnxruntime::OpKernel { + private: + struct ReusedCommandListState + { + // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. + ComPtr graphicsCommandList; + ComPtr commandAllocator; + ComPtr heap; + ComPtr bindingTable; + + // Bindings from previous executions of a re-used command list + mutable std::vector inputBindingAllocIds; + mutable std::vector outputBindingAllocIds; + mutable uint64_t tempBindingAllocId = 0; + + // Fence tracking the status of the command list's last execution, and whether its descriptor heap + // can safely be updated. + mutable ComPtr fence; + mutable uint64_t completionValue = 0; + }; + public: FusedGraphKernel() = delete; @@ -89,7 +109,7 @@ namespace Dml if (reuseCommandList) { - BuildReusableCommandList(); + m_reusedCommandLists.push_back(BuildReusableCommandList()); } } @@ -97,8 +117,7 @@ namespace Dml { // Only re-use the cached command list if its prior execution is complete on the GPU. // This requirement can be avoided by mantaining ring buffers. - if (!m_graphicsCommandList || - (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) + if (m_reusedCommandLists.empty()) { // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator OpKernelContextWrapper contextWrapper( @@ -147,7 +166,15 @@ namespace Dml } else { - ExecuteReusableCommandList(kernelContext); + if (m_reusedCommandLists.front()->fence && + m_reusedCommandLists.front()->fence->GetCompletedValue() < m_reusedCommandLists.front()->completionValue) + { + m_reusedCommandLists.push_front(BuildReusableCommandList()); + } + + ExecuteReusableCommandList(kernelContext, *m_reusedCommandLists.front()); + m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); + m_reusedCommandLists.pop_front(); } return onnxruntime::Status::OK(); @@ -217,8 +244,10 @@ namespace Dml } private: - void BuildReusableCommandList() + std::unique_ptr BuildReusableCommandList() const { + auto commandListState = std::make_unique(); + ComPtr device; ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf())); @@ -232,47 +261,49 @@ namespace Dml ComPtr d3dDevice; ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); - ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf()))); + ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(commandListState->heap.ReleaseAndGetAddressOf()))); // Create a binding table for execution. DML_BINDING_TABLE_DESC bindingTableDesc = {}; bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get(); - bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart(); - bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.CPUDescriptorHandle = commandListState->heap->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = commandListState->heap->GetGPUDescriptorHandleForHeapStart(); bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount; - ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable))); + ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&commandListState->bindingTable))); ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator( m_provider->GetCommandListTypeForQueue(), - IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf()))); + IID_GRAPHICS_PPV_ARGS(commandListState->commandAllocator.ReleaseAndGetAddressOf()))); ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList( 0, m_provider->GetCommandListTypeForQueue(), - m_commandAllocator.Get(), + commandListState->commandAllocator.Get(), nullptr, - IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); + IID_GRAPHICS_PPV_ARGS(commandListState->graphicsCommandList.ReleaseAndGetAddressOf()))); if (m_persistentResource) { DML_BINDING_DESC persistentResourceBindingDesc = { DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr }; - m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc); + commandListState->bindingTable->BindPersistentResource(&persistentResourceBindingDesc); } - ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() }; - m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); + ID3D12DescriptorHeap* descriptorHeaps[] = { commandListState->heap.Get() }; + commandListState->graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); ComPtr recorder; ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); - recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get()); + recorder->RecordDispatch(commandListState->graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), commandListState->bindingTable.Get()); + + ORT_THROW_IF_FAILED(commandListState->graphicsCommandList->Close()); - ORT_THROW_IF_FAILED(m_graphicsCommandList->Close()); + return commandListState; } - void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const + void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext, ReusedCommandListState& commandListState) const { DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); @@ -287,7 +318,7 @@ namespace Dml // Populate input bindings, excluding those which were specified as owned by DML and provided // at initialization instead. - m_inputBindingAllocIds.resize(inputBindings.size()); + commandListState.inputBindingAllocIds.resize(inputBindings.size()); bool inputBindingsChanged = false; for (uint32_t i = 0; i < inputBindings.size(); ++i) @@ -307,25 +338,25 @@ namespace Dml uint64_t allocId; DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); - inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); + inputBindingsChanged = inputBindingsChanged || (!allocId || commandListState.inputBindingAllocIds[i] != allocId); inputBindings[i].Buffer->Release(); // Avoid holding an additional reference inputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; - m_inputBindingAllocIds[i] = allocId; + commandListState.inputBindingAllocIds[i] = allocId; } } } if (inputBindingsChanged) { - m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); + commandListState.bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); } // Populate Output bindings std::vector outputBindings(kernelContext->OutputCount()); std::vector outputBindingDescs(kernelContext->OutputCount()); - m_outputBindingAllocIds.resize(outputBindings.size()); + commandListState.outputBindingAllocIds.resize(outputBindings.size()); bool outputBindingsChanged = false; for (uint32_t i = 0; i < outputBindings.size(); ++i) @@ -344,16 +375,16 @@ namespace Dml uint64_t allocId; DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); - outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); + outputBindingsChanged = outputBindingsChanged || (!allocId || commandListState.outputBindingAllocIds[i] != allocId); outputBindings[i].Buffer->Release(); // Avoid holding an additional reference outputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; - m_outputBindingAllocIds[i] = allocId; + commandListState.outputBindingAllocIds[i] = allocId; } if (outputBindingsChanged) { - m_bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); + commandListState.bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); } if (execBindingProps.TemporaryResourceSize > 0) @@ -373,19 +404,19 @@ namespace Dml DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; - if (!tempAllocId || m_tempBindingAllocId != tempAllocId) + if (!tempAllocId || commandListState.tempBindingAllocId != tempAllocId) { - m_bindingTable->BindTemporaryResource(&tempBindingDesc); + commandListState.bindingTable->BindTemporaryResource(&tempBindingDesc); } - m_tempBindingAllocId = tempAllocId; + commandListState.tempBindingAllocId = tempAllocId; } // Execute the command list and if it succeeds, update the fence value at which this command may be // re-used. ComPtr fence; uint64_t completionValue; - HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); + HRESULT hr = m_provider->ExecuteCommandList(commandListState.graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); if (hr == DXGI_ERROR_DEVICE_REMOVED) { @@ -395,13 +426,13 @@ namespace Dml } ORT_THROW_IF_FAILED(hr); - m_fence = fence; - m_completionValue = completionValue; + commandListState.fence = fence; + commandListState.completionValue = completionValue; // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get()); - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get()); - m_winmlProvider->QueueReference(m_bindingTable.Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.graphicsCommandList).Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.heap).Get()); + m_winmlProvider->QueueReference(commandListState.bindingTable.Get()); m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); } @@ -412,25 +443,12 @@ namespace Dml ComPtr m_provider; Windows::AI::MachineLearning::Adapter::EdgeShapes& m_outputShapes; - // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. - ComPtr m_graphicsCommandList; - ComPtr m_commandAllocator; - ComPtr m_heap; - ComPtr m_bindingTable; + mutable std::deque> m_reusedCommandLists; + std::optional m_persistentResourceBinding; ComPtr m_persistentResource; ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator - // Bindings from previous executions of a re-used command list - mutable std::vector m_inputBindingAllocIds; - mutable std::vector m_outputBindingAllocIds; - mutable uint64_t m_tempBindingAllocId = 0; - - // Fence tracking the status of the command list's last execution, and whether its descriptor heap - // can safely be updated. - mutable ComPtr m_fence; - mutable uint64_t m_completionValue = 0; - std::vector m_isInputsUploadedByDmlEP; std::vector> m_nonOwnedGraphInputsFromInitializers; }; From 7f9e6c42c2e7800c8a9d38cd12e0eac881c81dd5 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:12:47 -0800 Subject: [PATCH 11/45] readd npu enumeration (#18437) (#18518) [Cherry pick Reviewed] Re-add changes which were merged out... --------- ### Description ### Motivation and Context Co-authored-by: Sheil Kumar Co-authored-by: Sheil Kumar --- .../providers/dml/dml_provider_factory.cc | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 33f1f59e07f3f..cd4eb20c856c0 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -475,12 +475,39 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } +static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) { + D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {}; + + D3D_FEATURE_LEVEL feature_levels_list[] = { + D3D_FEATURE_LEVEL_1_0_CORE, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_12_0, + D3D_FEATURE_LEVEL_12_1 + }; + + feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list); + feature_levels.pFeatureLevelsRequested = feature_levels_list; + ORT_THROW_IF_FAILED(d3d12_device->CheckFeatureSupport( + D3D12_FEATURE_FEATURE_LEVELS, + &feature_levels, + sizeof(feature_levels) + )); + + auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE); + if (is_feature_level_1_0_core) { + return D3D12_COMMAND_LIST_TYPE_COMPUTE; + } + + return D3D12_COMMAND_LIST_TYPE_DIRECT; +} + std::shared_ptr CreateDMLDeviceAndProviderFactory( - ID3D12Device* d3d12_device, - bool disable_metacommands, - bool enable_dynamic_graph_fusion) { + ID3D12Device* d3d12_device, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Type = CalculateCommandListType(d3d12_device); cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; ComPtr cmd_queue; @@ -500,16 +527,20 @@ std::shared_ptr DMLProviderFactoryCreator::Create( } std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( - std::vector>&& dxcore_devices, + std::vector>&& adapters, bool disable_metacommands, bool enable_dynamic_graph_fusion) { // Choose the first device from the list since it's the highest priority - auto dxcore_device = dxcore_devices[0]; + auto adapter = adapters[0]; + + auto feature_level = D3D_FEATURE_LEVEL_11_0; + if (IsNPU(adapter.Get())) { + feature_level = D3D_FEATURE_LEVEL_1_0_CORE; + } // Create D3D12 Device from DXCore Adapter ComPtr d3d12_device; - ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); - + ORT_THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), feature_level, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } From e8209ce2b0f966739dfbb6ed4486ce78ca3f2e39 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 22 Nov 2023 14:39:36 -0800 Subject: [PATCH 12/45] CP 7fd1ce95a4e4f3c2b6152dfc2b1807a983ef45e5 (#18560) CP 7fd1ce95a4e4f3c2b6152dfc2b1807a983ef45e5 for onnxruntime_perf_test changes. Co-authored-by: Sheil Kumar --- .../DmlExecutionProvider/src/CommandQueue.cpp | 15 ++- .../src/ExecutionContext.cpp | 18 ++-- .../src/ExecutionContext.h | 10 +- .../src/ExecutionProvider.cpp | 2 +- .../src/ExecutionProvider.h | 2 +- .../providers/dml/dml_provider_factory.cc | 5 +- .../test/perftest/command_args_parser.cc | 4 + onnxruntime/test/perftest/ort_test_session.cc | 99 +++++++++++++++---- onnxruntime/test/perftest/ort_test_session.h | 1 + onnxruntime/test/util/default_providers.cc | 3 +- 10 files changed, 113 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp index 5516fc62cdda0..2b4f3bb96537b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp @@ -46,12 +46,12 @@ namespace Dml return GpuEvent{ m_lastFenceValue + 1, m_fence }; } - void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork) + void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork) { - // If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK - // to queue additional references at this time, since those references would be leaked. This - // affects any objects in m_queuedReferences whose destructors indirectly call QueueReference; - // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference + // If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK + // to queue additional references at this time, since those references would be leaked. This + // affects any objects in m_queuedReferences whose destructors indirectly call QueueReference; + // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference // to its underlying D3D resource when freed. Furthermore, these references are unnecessary // since Close() already blocks for scheduled GPU work before clearing m_queuedReferences. if (!m_closing) @@ -68,7 +68,7 @@ namespace Dml m_queuedReferences.push_back(queuedReference); } } - + void CommandQueue::Close() { // Wait for flushed work: @@ -79,7 +79,7 @@ namespace Dml m_queuedReferences.clear(); m_closing = false; } - + void CommandQueue::ReleaseCompletedReferences() { uint64_t completedValue = GetFence()->GetCompletedValue(); @@ -89,5 +89,4 @@ namespace Dml } } - } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index a894d0660d6ff..bc82c7ab1c44a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -15,7 +15,7 @@ namespace Dml : m_queue(std::make_shared(queue)) , m_dmlRecorder(d3d12Device, dmlDevice, m_queue) { - ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); + ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } void ExecutionContext::SetAllocator(std::weak_ptr allocator) @@ -78,14 +78,14 @@ namespace Dml ID3D12GraphicsCommandList* commandList, _Outptr_ ID3D12Fence** fence, _Out_ uint64_t* completionValue - ) + ) { assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue); } - + void ExecutionContext::InitializeOperator( IDMLCompiledOperator* op, const DML_BINDING_DESC& persistentResourceBinding, @@ -110,7 +110,7 @@ namespace Dml } void ExecutionContext::AddUAVBarrier() - { + { assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); @@ -173,9 +173,9 @@ namespace Dml m_currentRecorder = nullptr; SetCommandRecorder(&m_dmlRecorder); } - - void ExecutionContext::QueueReference(IUnknown* object) - { + + void ExecutionContext::QueueReference(IUnknown* object) + { assert(!m_closed); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence // value is the one to signal completion. @@ -186,14 +186,14 @@ namespace Dml void ExecutionContext::Close() { assert(!m_closed); - + // Discard unflushed work and clear queued references. This prevents the circular reference: // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel m_queue->Close(); m_currentRecorder = nullptr; m_closed = true; } - + GpuEvent ExecutionContext::GetCurrentCompletionEvent() { assert(!m_closed); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index b06f11a5efd0a..ac8d3ff875786 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -20,13 +20,13 @@ namespace Dml public: // Constructs an ExecutionContext that executes on the supplied queue. ExecutionContext( - ID3D12Device* d3d12Device, - IDMLDevice* dmlDevice, + ID3D12Device* d3d12Device, + IDMLDevice* dmlDevice, ID3D12CommandQueue* queue); void SetAllocator(std::weak_ptr allocator); - // Waits for flushed work, discards unflushed work, and discards associated references to + // Waits for flushed work, discards unflushed work, and discards associated references to // prevent circular references. Must be the last call on the object before destruction. void Close(); @@ -75,12 +75,12 @@ namespace Dml // Returns an event which will become signaled when everything submitted to the execution context thus far has // completed execution on the GPU, including work that has yet to be flushed to the queue. GpuEvent GetCurrentCompletionEvent(); - + // Adds a reference which will be released when queued GPU work is completed void QueueReference(IUnknown* object); // Release any accumulated references who corresponding GPU fence values have - // been reached. + // been reached. void ReleaseCompletedReferences(); D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 49a64c4810252..8a32d06534dda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -205,7 +205,7 @@ namespace Dml D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; // The call may fail in which case the default value is false - d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index ab932fb8a4367..5617bc7bdcac6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -150,7 +150,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; - STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; bool DynamicGraphFusionEnabled() const noexcept; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index cd4eb20c856c0..73a068f3e1de2 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -118,7 +118,6 @@ static bool IsGPU(IDXCoreAdapter* compute_adapter) { return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); } -#ifdef ENABLE_NPU_ADAPTER_ENUMERATION static bool IsNPU(IDXCoreAdapter* compute_adapter) { // Only considering hardware adapters if (!IsHardwareAdapter(compute_adapter)) { @@ -126,7 +125,6 @@ static bool IsNPU(IDXCoreAdapter* compute_adapter) { } return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); } -#endif enum class DeviceType { GPU, NPU, BadDevice }; @@ -327,7 +325,8 @@ static std::optional ParsePerformancePreference(con } static std::optional ParseFilter(const ProviderOptions& provider_options) { - static const std::string Filter = "filter"; + static const std::string Filter = "device_filter"; + static const std::string Any = "any"; static const std::string Gpu = "gpu"; #ifdef ENABLE_NPU_ADAPTER_ENUMERATION static const std::string Any = "any"; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 6e3252aaeb4b8..f8d6296d2d785 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -58,6 +58,10 @@ namespace perftest { "\t-q [CUDA only] use separate stream for copy. \n" "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" + "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" + "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" + "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 04c9ae1f23108..ac25c98b15758 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -16,6 +16,10 @@ #include "providers.h" #include "TestCase.h" +#ifdef USE_DML +#include "core/providers/dml/dml_provider_factory.h" +#endif + #ifdef _WIN32 #define strdup _strdup #endif @@ -42,8 +46,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const TestModelInfo& m) : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { Ort::SessionOptions session_options; - const std::string& provider_name = performance_test_config.machine_config.provider_type_name; - if (provider_name == onnxruntime::kDnnlExecutionProvider) { + provider_name_ = performance_test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options OrtDnnlProviderOptions dnnl_options; @@ -96,7 +100,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("DNNL is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kCudaExecutionProvider) { + } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options; @@ -161,7 +165,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("CUDA is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { + } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT const auto& api = Ort::GetApi(); OrtTensorRTProviderOptionsV2* tensorrt_options; @@ -215,7 +219,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("TensorRT is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kOpenVINOExecutionProvider) { + } else if (provider_name_ == onnxruntime::kOpenVINOExecutionProvider) { #ifdef USE_OPENVINO #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -251,7 +255,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } else { ORT_THROW( - "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " + "[ERROR] [OpenVINO] You have selected a wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); @@ -305,7 +309,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("OpenVINO is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kQnnExecutionProvider) { + } else if (provider_name_ == onnxruntime::kQnnExecutionProvider) { #ifdef USE_QNN #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -378,7 +382,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else ORT_THROW("QNN is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kSnpeExecutionProvider) { + } else if (provider_name_ == onnxruntime::kSnpeExecutionProvider) { #ifdef USE_SNPE #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -430,7 +434,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("SNPE is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kNnapiExecutionProvider) { + } else if (provider_name_ == onnxruntime::kNnapiExecutionProvider) { #ifdef USE_NNAPI uint32_t nnapi_flags = 0; #ifdef _MSC_VER @@ -458,22 +462,81 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("NNAPI is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kCoreMLExecutionProvider) { + } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef USE_COREML Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0)); #else ORT_THROW("COREML is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kDmlExecutionProvider) { + } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML std::unordered_map dml_options; dml_options["performance_preference"] = "high_performance"; dml_options["device_filter"] = "gpu"; + dml_options["disable_metacommands"] = "false"; + dml_options["enable_dynamic_graph_fusion"] = "false"; +#ifdef _MSC_VER + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif + std::istringstream ss(ov_string); + std::string token; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "device_filter") { + std::set ov_supported_device_types = {"gpu", "npu"}; + if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. " + "Select from 'gpu', or 'npu' \n"); + } + } else if (key == "performance_preference") { + std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " + "Select from 'default', 'high_performance' or 'minimal_power' \n"); + } + } else if (key == "disable_metacommands") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'disable_metacommands'. " + "Select from 'true' or 'false' \n"); + } + } else if (key == "enable_dynamic_graph_fusion") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + dml_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. " + "Select from 'true' or 'false' \n"); + } + } + } session_options.AppendExecutionProvider("DML", dml_options); #else ORT_THROW("DML is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kAclExecutionProvider) { + } else if (provider_name_ == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL Ort::ThrowOnError( OrtSessionOptionsAppendExecutionProvider_ACL(session_options, @@ -481,14 +544,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("Acl is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kArmNNExecutionProvider) { + } else if (provider_name_ == onnxruntime::kArmNNExecutionProvider) { #ifdef USE_ARMNN Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_options, performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); #else ORT_THROW("ArmNN is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kRocmExecutionProvider) { + } else if (provider_name_ == onnxruntime::kRocmExecutionProvider) { #ifdef USE_ROCM OrtROCMProviderOptions rocm_options; rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo; @@ -498,7 +561,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("ROCM is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) { + } else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0)); OrtROCMProviderOptions rocm_options; @@ -508,7 +571,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("MIGraphX is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kXnnpackExecutionProvider) { + } else if (provider_name_ == onnxruntime::kXnnpackExecutionProvider) { #ifdef USE_XNNPACK session_options.AddConfigEntry(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0"); session_options.AppendExecutionProvider( @@ -516,7 +579,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("Xnnpack is not supported in this build\n"); #endif - } else if (provider_name == onnxruntime::kVitisAIExecutionProvider) { + } else if (provider_name_ == onnxruntime::kVitisAIExecutionProvider) { #ifdef USE_VITISAI #ifdef _MSC_VER std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -544,7 +607,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("VitisAI is not supported in this build\n"); #endif - } else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) { + } else if (!provider_name_.empty() && provider_name_ != onnxruntime::kCpuExecutionProvider) { ORT_THROW("This backend is not included in perf test runner.\n"); } diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index 208e3de53b1d2..f1a4220ab325e 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -45,6 +45,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_; std::vector input_names_str_; const int input_length_; + std::string provider_name_; }; } // namespace perftest diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 4468a64d18258..a94f7b5b707c7 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -274,8 +274,9 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false)) + if (auto factory = DMLProviderFactoryCreator::CreateFromOptions(nullptr, false, false)) { return factory->CreateProvider(); + } #endif return nullptr; } From 623d957607b571d8dadbba6d57021546c38658a4 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Fri, 1 Dec 2023 10:27:20 -0800 Subject: [PATCH 13/45] register resize with uint8/int8 support (#18647) ### Description 1. Expand input datatype support for Resize with uint8/int8. 2. Update the logic to compute output shape of Resize Op, roiRange is got rid of to align with how tests compute the output shape to go around the size asserting in MLOperatorAuthorImpl.cpp `m_inputDimensions[i] * roiRange * scale` -> `m_inputDimensions[i] * scale` 3. disable 4 tests because of the result mismatch. The results of DML with float32 and uint8/int8 match each other, so it should be problem of resize implementation, which is out the scope of this PR. `ResizeOpTest.NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_uint8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_uint8 ResizeOpTest.NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8` --- .../dml/OperatorAuthorHelper/OperatorHelper.cpp | 3 +-- .../test/providers/cpu/tensor/resize_op_test.cc | 13 +++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 1fcd3b04300f4..7c2320160dd3b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -2450,8 +2450,7 @@ namespace OperatorHelper { float scale = m_scales[i]; ML_CHECK_VALID_ARGUMENT(scale > FLT_EPSILON, "Scale values should be positive."); - float roiRange = m_regionOfInterest.empty() ? 1.0f : m_regionOfInterest[i + rank] - m_regionOfInterest[i]; - m_outputDimensions.push_back(gsl::narrow_cast(floor(m_inputDimensions[i] * roiRange * scale))); + m_outputDimensions.push_back(gsl::narrow_cast(floor(m_inputDimensions[i] * scale))); } } else diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 3ea7295aef5a2..f473c98ca713e 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -187,7 +187,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -214,7 +215,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e 0, 0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { @@ -530,7 +532,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -558,7 +561,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {0, 2, -9}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch + // TensorRT: results mismatch + // DML: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { From c1ec3c3f9369d6a01b727d4d1ef772c9fe6a5213 Mon Sep 17 00:00:00 2001 From: Christian Larson Date: Fri, 8 Dec 2023 17:08:29 -0800 Subject: [PATCH 14/45] User/chrila/fix dml dx12 warning (#18746) Update resource creation flag to avoid D3D12 WARNING ### Description Update the DML DX12 allocator to use D3D12_RESOUCE_STATE_COMMON to avoid DX12 Warning messages. ### Motivation and Context When directML is created with debug layer there are warnings when resources are created by ORT. --------- Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com> --- .../src/DmlCommittedResourceAllocator.cpp | 2 +- .../DmlExecutionProvider/src/DmlExternalBufferAllocator.h | 2 +- .../dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp index b696aefecf664..54393e9bf1539 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp @@ -16,7 +16,7 @@ namespace Dml unmove_ptr(CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT)), D3D12_HEAP_FLAG_NONE, &buffer, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index 9514a24b4e781..22fd3be42c416 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -39,7 +39,7 @@ namespace Dml &props, D3D12_HEAP_FLAG_NONE, &buffer, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 18cdc5d1bf86e..642d9aa03eeef 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -36,7 +36,7 @@ namespace DmlGraphFusionHelper &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); @@ -74,7 +74,7 @@ namespace DmlGraphFusionHelper &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON, nullptr, IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); @@ -302,7 +302,7 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - + if (std::holds_alternative>(nodeInfo.nodeDef)) { dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; From 107d7492b9c0f82dc61974065c08c91550d0dea4 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 8 Dec 2023 19:38:44 -0800 Subject: [PATCH 15/45] [DirectML EP] Add DML EP registration for Col2Im (#17786) ### Description [DirectML EP] Add DML EP registration for Col2Im operator ### Motivation and Context Add Col2Im support for opset 18. This operator is implemented as the DirectML Fold operator. --------- Co-authored-by: Sheil Kumar Co-authored-by: Dwayne Robinson --- cmake/external/dml.cmake | 4 +- .../src/Operators/DmlOperatorCol2Im.cpp | 59 +++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 2 + .../OperatorAuthorHelper/OperatorHelper.cpp | 71 +++++++++++++++++-- .../dml/OperatorAuthorHelper/OperatorHelper.h | 29 ++++++++ .../OperatorAuthorHelper/OperatorVersions.h | 1 + 6 files changed, 158 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index d777306722cd6..dfd9ad120eb98 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -72,12 +72,11 @@ else() if (dml_EXTERNAL_PROJECT) set(dml_preset_config $,debug,release>) set(dml_preset_name ${onnxruntime_target_platform}-win-redist-${dml_preset_config}) - target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) include(ExternalProject) ExternalProject_Add( directml_repo GIT_REPOSITORY https://dev.azure.com/microsoft/WindowsAI/_git/DirectML - GIT_TAG d460f0f46967bea878786f1bed69487692c779bf + GIT_TAG a5312f72c51864b4d705ac62d25d08bcd88c4fb1 GIT_SHALLOW OFF # not allowed when GIT_TAG is a commit SHA, which is preferred (it's stable, unlike branches) GIT_PROGRESS ON BUILD_IN_SOURCE ON @@ -94,6 +93,7 @@ else() target_link_libraries(DirectML INTERFACE ${directml_install_path}/lib/DirectML.lib) add_dependencies(DirectML directml_repo-install) include_directories(BEFORE ${directml_install_path}/include) + target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1) else() include_directories(BEFORE ${dml_INCLUDE_DIR}) set(DML_PACKAGE_DIR ${dml_INCLUDE_DIR}/..) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp new file mode 100644 index 0000000000000..13a51f2be7560 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "./precomp.h" + +namespace Dml +{ + +class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper +{ +public: + explicit DmlOperatorCol2Im(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + Col2ImHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "Col2Im expects 3 inputs."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Col2Im expects 1 output."); + + auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription(); + std::vector inputTensorShape = tensorShapeDescription.GetInputTensorShape(0); + std::vector outputTensorShape = tensorShapeDescription.GetOutputTensorShape(0); + + ML_CHECK_VALID_ARGUMENT(outputTensorShape == m_outputShape); + + std::vector> inputIndices = { 0 }; + gsl::span inputShapes[1] = { m_inputShape }; + gsl::span outputShapes[1] = { m_outputShape }; + DmlOperator::InitializeWithShapes( + kernelCreationContext, + inputIndices, + std::nullopt, + inputShapes, + outputShapes, + 3 + ); + // Prepare DML_FOLD_OPERATOR_DESC + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + assert(inputDescs.size() == 1); + assert(outputDescs.size() == 1); + + DML_FOLD_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.DimensionCount = gsl::narrow_cast(m_blockShape.size()); + operatorDesc.WindowSizes = m_blockShape.data(); + operatorDesc.Dilations = m_dilations.data(); + operatorDesc.StartPadding = m_pads.data(); + operatorDesc.EndPadding = m_pads.data(); + operatorDesc.Strides = m_strides.data(); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FOLD, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(Col2Im, DmlOperatorCol2Im); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 0234bb6b7ec1e..2ab73afb8f1e1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); +DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); @@ -770,6 +771,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)}, {REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 18, Col2Im, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2))}, // Data reorganization that merely changes the dimensions while keeping the data identical. {REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 7c2320160dd3b..83c6748fadd35 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -257,14 +257,15 @@ namespace OperatorHelper } } - void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) + template + void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) { outputDimensions.reserve(inputDimensions.size()); outputDimensions.clear(); - for (int64_t dim : inputDimensions) + for (T dim : inputDimensions) { - outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); + outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); } } @@ -1870,6 +1871,64 @@ namespace OperatorHelper return { std::move(outputShape) }; } + void Col2ImHelper::Initialize( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation) + { + std::vector shapeData; + ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(1), /*out*/ shapeData); + m_imageShape.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_imageShape); + ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(2), /*out*/ shapeData); + m_blockShape.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_blockShape); + + const uint32_t dimCount = gsl::narrow_cast(m_blockShape.size()); + m_dilations = {dimCount, 1}; + m_pads = {dimCount * 2, 0}; + m_strides = {dimCount, 1}; + + if (kernelInformation.HasAttribute(AttrName::Dilations, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Dilations); + m_dilations.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_dilations); + ML_CHECK_VALID_ARGUMENT(m_dilations.size() == dimCount); + } + + if (kernelInformation.HasAttribute(AttrName::Pads, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Pads); + m_pads.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_pads); + ML_CHECK_VALID_ARGUMENT(m_pads.size() == dimCount * 2); + } + + if (kernelInformation.HasAttribute(AttrName::Strides, MLOperatorAttributeType::IntArray)) + { + shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Strides); + m_strides.resize(shapeData.size()); + DowncastDimensions(gsl::span(shapeData), /*out*/ m_strides); + ML_CHECK_VALID_ARGUMENT(m_strides.size() == dimCount); + } + + m_inputShape = shapeInformation.GetInputTensorShape(0); + + auto blockShapeProduct = ComputeElementCountFromDimensions(m_blockShape); + m_outputShape.resize(2 + m_imageShape.size()); + m_outputShape[0] = m_inputShape[0]; // N + m_outputShape[1] = m_inputShape[1] / blockShapeProduct; // C + for (int i = 2; i < m_outputShape.size(); i++) + { + m_outputShape[i] = m_imageShape[i - 2]; + }; + } + + std::vector Col2ImHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return { EdgeShapes(m_outputShape) }; + } + void ConcatHelperBase::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions @@ -2020,7 +2079,7 @@ namespace OperatorHelper } return outputShapes; } - + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto inputShape = shapeInfo.GetInputTensorShape(0); @@ -2050,7 +2109,7 @@ namespace OperatorHelper } return outputShapes; } - + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); @@ -2113,7 +2172,7 @@ namespace OperatorHelper { std::vector outputDimensions64bit = shapeInfo.GetAttributeVector(AttrName::OutputShape); ML_CHECK_VALID_ARGUMENT(outputDimensions64bit.size() == m_inputShape.size(), "Input dimensions and output_shape must have same rank."); - DowncastDimensions(outputDimensions64bit, /*out*/ outputDimensions); + DowncastDimensions(gsl::span(outputDimensions64bit), /*out*/ outputDimensions); } else { diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index d8d09efd8d6e8..0e0e6bb1eaf5c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1206,6 +1206,34 @@ class SqueezeHelper std::vector m_axes; }; +class Col2ImHelper +{ +public: + void Initialize( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation); + + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + Col2ImHelper(const Info_t& info, const Shape_t& shape) + { + Initialize(KernelInformationAdapter(info), ShapeInformationAdapter(shape)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +protected: + std::vector m_dilations; + std::vector m_pads; + std::vector m_strides; + std::vector m_imageShape; + std::vector m_blockShape; + std::vector m_inputShape; + std::vector m_outputShape; +}; + + class UnsqueezeHelper { public: @@ -1572,6 +1600,7 @@ using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper; using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Col2Im = Col2ImHelper; using ShapeInferenceHelper_Expand = ExpandHelper; using ShapeInferenceHelper_Reshape7 = ReshapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e9d88adf3e221..8438bc620712c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -407,6 +407,7 @@ namespace OperatorHelper static const int sc_sinceVer_Pad = 18; static const int sc_sinceVer_Split = 18; static const int sc_sinceVer_LpPool = 18; + static const int sc_sinceVer_Col2Im = 18; } namespace OnnxOperatorSet19 From d2f7a5b1286e34ea55dceccff8f17a45f2f799aa Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Mon, 11 Dec 2023 17:41:16 -0800 Subject: [PATCH 16/45] Cherry pick fix constant pow (#18785) ### Description Cherry pick https://github.com/microsoft/onnxruntime/pull/18784 --- .../src/Operators/DmlOperatorElementWise.cpp | 2 +- .../dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index f0a16da3a3c06..ec94772238cc9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,7 +479,7 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(1); if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) { std::vector> kernelInputIndices = {0}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 59a1719d08ee6..c40f82a8c31c6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -605,11 +605,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } - std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + std::optional TryGetConstantCpuInputTensor(uint32_t inputIndex) const { Microsoft::WRL::ComPtr tensor; ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); - if (tensor) + if (tensor && tensor->IsCpuData()) { return MLOperatorTensor(tensor.Get()); } From b2f81c8725a0dd1ba343c91ff45aa082e9331f5d Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 14 Dec 2023 14:20:55 -0800 Subject: [PATCH 17/45] Hide Col2Im registration behind DML_TARGET_VERSION 6300 (#18829) Hide Col2Im registration behind DML_TARGET_VERSION 6300 Co-authored-by: Sheil Kumar --- .../src/Operators/DmlOperatorCol2Im.cpp | 4 ++++ .../src/Operators/OperatorRegistration.cpp | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp index 13a51f2be7560..f80b4f98236bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCol2Im.cpp @@ -6,6 +6,8 @@ namespace Dml { +#if DML_TARGET_VERSION >= 0x6300 + class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper { public: @@ -56,4 +58,6 @@ class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper DML_OP_DEFINE_CREATION_FUNCTION(Col2Im, DmlOperatorCol2Im); +#endif // DML_TARGET_VERSION >= 0x6300 + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 2ab73afb8f1e1..15a8051953c79 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,7 +503,11 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); + +#if DML_TARGET_VERSION >= 0x6300 DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); +#endif + DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); @@ -771,7 +775,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)}, {REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + +#if DML_TARGET_VERSION >= 0x6300 {REG_INFO( 18, Col2Im, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2))}, +#endif // Data reorganization that merely changes the dimensions while keeping the data identical. {REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, From bdaeebd6ff5d4667f288ebb9b8072f66b45de1c8 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:47:57 -0800 Subject: [PATCH 18/45] Fix bug in DML EP ExecuteCommandList fast path and simplify design (#18866) ### Description This addresses a bug in a fast path that was added for submission of re-used command lists of fused graph kernels in the DML EP, addressing a D3D debug layer error. ### Motivation and Context The fast path in DmlCommandRecorder::ExecuteCommandList enabled a current non-reused command list, if empty, to be used for commands following submission of the fused command list. The fix ensures the associated command allocator is only re-used after the next fence value is completed, which is higher due to submission of the other command list. The command recorder design was intended to support batching of provided command list execution, however it submits command lists immedately as an implementation detail to maximize CPU/GPU parallelism. If that heuristic was removed, it would expose additional issues in this same fast path. Because of this and complexity and inefficiency of the old batching mechanism, I also removed this. --- .../src/CommandAllocatorRing.h | 8 ++ .../src/DmlCommandRecorder.cpp | 89 +++++++------------ .../src/DmlCommandRecorder.h | 16 ++-- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h index 570f62aac8105..2eee9c9a9e5a3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h @@ -47,6 +47,14 @@ namespace Dml return m_commandAllocators[m_currentCommandAllocator].Get(); } + // Updates the completion event of the current allocator to a different value. This is used when the caller + // decides to issue an unrelated call to the queue such as ExecuteCommandLists which updates its fence between calling + // GetNextAllocator and executing the work which it recorded using the allocator it received. + void UpdateCurrentAllocatorCompletionEvent(GpuEvent nextCompletionEvent) + { + m_commandAllocators[m_currentCommandAllocator].completionEvent = nextCompletionEvent; + } + private: struct CommandAllocatorInfo { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 98345f37b68d4..5254b23f56376 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -262,6 +262,9 @@ void DmlCommandRecorder::ExecuteCommandList( m_queue->ExecuteCommandLists( gsl::span(reinterpret_cast(&commandList), 1)); + // The fence value at which the current command allocator may be re-used will now be higher + m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent()); + // Fail early if something horrifying happens ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); ORT_THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason()); @@ -269,42 +272,20 @@ void DmlCommandRecorder::ExecuteCommandList( return; } - ORT_THROW_IF_FAILED(m_currentCommandList->Close()); - - if (m_operationsRecordedInCurrentCommandList) - { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); - } - else - { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); - } - - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - m_pendingCommandLists.push_back(commandList); - m_pendingCommandListsCacheable.push_back(false); - - // Remember the descriptor heap and apply it to the next command list + // Remember the descriptor heap and apply it to the next command list. This avoids unnecessarily setting it onto + // the D3D object lazily at a point when the operation may not be parallelized with GPU work. auto heap = m_currentDescriptorHeap; - m_currentDescriptorHeap = nullptr; - Open(); - - // The caller can re-use relevant resources after the next set of work to be - // flushed has completed. Its command list hasn't been executed yet, just batched. - GpuEvent gpuEvent = m_queue->GetNextCompletionEvent(); - gpuEvent.fence.CopyTo(fence); - *completionValue = gpuEvent.fenceValue; - // Trigger a flush of the command list, with the assumption that it contains enough GPU work that this - // will help parallelize GPU work with subsequent CPU work. This policy is related to the choice of - // minNodeCountToReuseCommandList within FusedGraphKernel, so both should be tuned together. - CloseAndExecute(); + // Execute work in the current command list plus provided command list while closing the recorder. + CloseAndExecute(commandList); Open(); + // Reset the descriptor heap opportunistically per above comment SetDescriptorHeap(heap); + + GpuEvent gpuEvent = m_queue->GetCurrentCompletionEvent(); + gpuEvent.fence.CopyTo(fence); + *completionValue = gpuEvent.fenceValue; } ComPtr DmlCommandRecorder::GetCommandList() @@ -334,7 +315,7 @@ void DmlCommandRecorder::Open() ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent()); - if (m_cachedCommandLists.empty()) + if (!m_cachedCommandList) { ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList( 0, @@ -345,47 +326,43 @@ void DmlCommandRecorder::Open() } else { - m_currentCommandList = m_cachedCommandLists.front(); - m_cachedCommandLists.pop_front(); + m_currentCommandList = m_cachedCommandList; + m_cachedCommandList = nullptr; ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr)); } } void DmlCommandRecorder::CloseAndExecute() { + CloseAndExecute(nullptr); +} + +void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList) +{ ORT_THROW_IF_FAILED(m_currentCommandList->Close()); + ID3D12GraphicsCommandList* commandListsToExecute[2] = {}; + uint32_t commandListsToExecuteCount = 0; + if (m_operationsRecordedInCurrentCommandList) { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); + commandListsToExecute[commandListsToExecuteCount++] = m_currentCommandList.Get(); } - else + + if (commandList) { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); + commandListsToExecute[commandListsToExecuteCount++] = commandList; } - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - if (!m_pendingCommandLists.empty()) + if (commandListsToExecuteCount > 0) { - // Close and execute the command list m_queue->ExecuteCommandLists( - gsl::span(reinterpret_cast(m_pendingCommandLists.data()), m_pendingCommandLists.size())); - - assert(m_pendingCommandLists.size() == m_pendingCommandListsCacheable.size()); - for (size_t i = 0; i < m_pendingCommandLists.size(); ++i) - { - if (m_pendingCommandListsCacheable[i]) - { - m_cachedCommandLists.push_back(m_pendingCommandLists[i]); - } - } - - m_pendingCommandLists.clear(); - m_pendingCommandListsCacheable.clear(); + gsl::span(reinterpret_cast(commandListsToExecute), commandListsToExecuteCount)); } + + m_cachedCommandList = m_currentCommandList; + m_currentCommandList = nullptr; + m_operationsRecordedInCurrentCommandList = false; // The descriptor heap must be set on the command list the next time it's opened. m_currentDescriptorHeap = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h index 7ad7032317d77..83051c8ca4ff9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h @@ -58,7 +58,7 @@ namespace Dml bool HasUnsubmittedWork() override { - return m_operationsRecordedInCurrentCommandList || !m_pendingCommandLists.empty(); + return m_operationsRecordedInCurrentCommandList; } // Forces the descriptor heap to be reset to D3D before executing future operations @@ -68,7 +68,8 @@ namespace Dml } private: - + void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList); + std::shared_ptr m_queue; ComPtr m_d3dDevice; ComPtr m_dmlDevice; @@ -89,15 +90,8 @@ namespace Dml ComPtr m_currentCommandList; bool m_operationsRecordedInCurrentCommandList = false; - // Command lists which have been batched up for execution. The values in - // m_pendingCommandListsCacheable indicate whether they can be moved into this - // class's cache after execution, versus if they belong to the caller and were - // passed to ExecuteCommandList. - std::vector> m_pendingCommandLists; - std::vector m_pendingCommandListsCacheable; - - // A pool of cached command lists which may be re-used. - std::deque> m_cachedCommandLists; + // A cached command list which may be re-used. + ComPtr m_cachedCommandList; void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap); }; From 70d3f682a7a7553f337bd0422c3b8fc555af52a1 Mon Sep 17 00:00:00 2001 From: tbqh <111796392+tbqh@users.noreply.github.com> Date: Tue, 2 Jan 2024 13:22:30 -0600 Subject: [PATCH 19/45] De-duplicate 1D scale and zero point tensors to scalars in DML kernels (#18862) ### Description Cleanup and rebase from [this PR](https://github.com/microsoft/onnxruntime/pull/18629) ### Motivation and Context --------- Co-authored-by: Christian Larson Co-authored-by: Christian Larson <28911437+chrilaMSFT@users.noreply.github.com> Co-authored-by: Jeff Bloomfield Co-authored-by: Anagha Rao --- .../inc/IWinmlExecutionProvider.h | 3 + .../src/Operators/DmlOperator.cpp | 83 +++++++++++++++++++ .../src/Operators/DmlOperator.h | 6 ++ .../DmlOperatorBatchNormalization.cpp | 2 + .../src/Operators/DmlOperatorElementWise.cpp | 3 + .../src/Operators/DmlOperatorQLinearAdd.cpp | 23 +++-- .../DmlOperatorQLinearAveragePooling.cpp | 10 ++- .../Operators/DmlOperatorQLinearConcat.cpp | 7 ++ .../src/Operators/DmlOperatorQLinearConv.cpp | 9 ++ .../Operators/DmlOperatorQLinearMatMul.cpp | 9 ++ .../Operators/DmlOperatorQLinearSigmoid.cpp | 7 ++ 11 files changed, 153 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 074f13b309181..f29cc3afc3cda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -85,7 +85,10 @@ namespace Windows::AI::MachineLearning::Adapter { uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; + + // TODO (jeffbloo): Remove this std::vector> nodesAsIDMLOperator; + std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 25c7be42d6425..c3bb1a52210f5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -6,6 +6,9 @@ namespace Dml { + + /*static*/ const uint32_t DmlOperator::zeroArray[8] = {}; + DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo) { ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider)); @@ -824,4 +827,84 @@ namespace Dml graphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } + /*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex) + { + if (!tensor) + { + return; + } + + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(kernelInputIndex); + if (!constExpTensor) + { + return; + } + else if (!constExpTensor->IsCpuData()) + { + return; + } + + uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount(); + if (totalKernelInputElementCount <= 1) + { + return; + } + + uint32_t elementSize = 0; + + switch (constExpTensor->GetTensorDataType()) + { + case MLOperatorTensorDataType::UInt8: + case MLOperatorTensorDataType::Int8: + elementSize = 1; + break; + + case MLOperatorTensorDataType::Float16: + case MLOperatorTensorDataType::UInt16: + case MLOperatorTensorDataType::Int16: + elementSize = 2; + break; + + case MLOperatorTensorDataType::/*Float32*/Float: + case MLOperatorTensorDataType::UInt32: + case MLOperatorTensorDataType::Int32: + elementSize = 4; + break; + + case MLOperatorTensorDataType::/*Float64*/Double: + case MLOperatorTensorDataType::UInt64: + case MLOperatorTensorDataType::Int64: + elementSize = 8; + break; + + default: + return; + } + + const std::uint8_t* byteData = static_cast(constExpTensor->GetByteData()); + + assert(tensor->Type == DML_TENSOR_TYPE_BUFFER); + auto *bufferTensorDesc = const_cast(static_cast(tensor->Desc)); + + for (size_t i = 1; i < totalKernelInputElementCount; ++i) + { + if (memcmp(byteData, byteData + i * elementSize, elementSize)) + { + return; + } + } + + if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0])) + { + assert(false); + return; + } + + bufferTensorDesc->Strides = zeroArray; + bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3; + } + } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index c1e8cf42a974c..fa54d4b041b5f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -149,6 +149,11 @@ namespace Dml uint32_t minDimensionCount = NchwDimensionCount ) const; + static void TryConvertTensorToBroadcastScalar( + const MLOperatorKernelCreationContext& kernelInfo, + const DML_TENSOR_DESC* tensor, + uint32_t kernelInputIndex); + private: // For each input or output of the DML kernel, the corresponding input or output of the original // kernel. Entries for unused DML inputs are nullopt. @@ -164,6 +169,7 @@ namespace Dml _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); + static const uint32_t zeroArray[8]; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 9f9cfad670919..ee497715dd73f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); + // TODO (jeffbloo): Port this to a graph description to enable DML graph optimization + dml::Graph graph(m_dmlDevice.Get()); dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X]; dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index ec94772238cc9..835d43037eaee 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -586,6 +586,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); + TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp index 7b50dfb9ff1ad..a19e37e15e768 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp @@ -8,15 +8,15 @@ namespace Dml class DmlOperatorQLinearAdd : public DmlOperator { - enum InputTensors { - IN_A, + enum InputTensors { + IN_A, IN_A_SCALE, - IN_A_ZERO_POINT, - IN_B, + IN_A_ZERO_POINT, + IN_B, IN_B_SCALE, IN_B_ZERO_POINT, - IN_C_SCALE, - IN_C_ZERO_POINT + IN_C_SCALE, + IN_C_ZERO_POINT }; public: @@ -56,9 +56,18 @@ class DmlOperatorQLinearAdd : public DmlOperator AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE]; AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr; AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE]; - AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; + AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; AddDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp index 0fccedfe311c1..605e5fffb6a76 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; - qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; - qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale]; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint]; qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; @@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe qLinearAvgPooldesc.Dilations = m_kernel.dilations; qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint); + + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale); + TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint); + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index 67711fdc28b84..c97b03dc36b62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -123,6 +123,9 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ScaleTensor, tupleStartIndex + 1); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDescs[inputIndex].ZeroPointTensor, tupleStartIndex + 2); dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; opDescs.push_back(&dmlOpDesc[inputIndex]); @@ -154,6 +157,10 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::YScale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::YZeroPoint); + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; opDescs.push_back(&opQuantizeDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp index d45fdef3c8807..4e121a6502cba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConv.cpp @@ -117,6 +117,15 @@ class DmlOperatorQLinearConv : public DmlOperator, public ConvolutionHelperBase convDesc.EndPadding = kernelArgs.endPadding; convDesc.GroupCount = m_groupCount; + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputScaleTensor, IN_X_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.InputZeroPointTensor, IN_X_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterScaleTensor, IN_F_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.FilterZeroPointTensor, IN_F_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, convDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &convDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp index b746a0e81a5cf..b38acd8cbf978 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearMatMul.cpp @@ -104,6 +104,15 @@ class DmlOperatorQLinearMatMul : public DmlOperator matMulDesc.OutputZeroPointTensor = inputDescs[IN_Y_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_Y_ZERO_POINT] : nullptr; matMulDesc.OutputTensor = &outputDescs[0]; + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AScaleTensor, IN_A_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.AZeroPointTensor, IN_A_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BScaleTensor, IN_B_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.BZeroPointTensor, IN_B_ZERO_POINT); + + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputScaleTensor, IN_Y_SCALE); + TryConvertTensorToBroadcastScalar(kernelInfo, matMulDesc.OutputZeroPointTensor, IN_Y_ZERO_POINT); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, &matMulDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 1da4a5cab7623..35f926d62c92a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -88,6 +88,9 @@ class DmlOperatorQLinearSigmoid : public DmlOperator dequantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::X_scale]; dequantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::X_zero_point]; dequantizeOperatorDesc.OutputTensor = &namedIntermediateOutputTensorDesc; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ScaleTensor, OnnxInputIndex::X_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, dequantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::X_zero_point); const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDesc}; @@ -101,6 +104,10 @@ class DmlOperatorQLinearSigmoid : public DmlOperator quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::Y_scale]; quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::Y_zero_point]; quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ScaleTensor, OnnxInputIndex::Y_scale); + TryConvertTensorToBroadcastScalar(kernelCreationContext, quantizeOperatorDesc.ZeroPointTensor, OnnxInputIndex::Y_zero_point); + const DML_OPERATOR_DESC opDesc3{DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; MLOperatorGraphDesc operatorGraphDesc = {}; From ee60e3af6c0e12ea3ab4a17499122c289081e92f Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:58:13 -0800 Subject: [PATCH 20/45] =?UTF-8?q?Limit=20size=20of=20constant=20nodes=20cr?= =?UTF-8?q?eates=20by=20DML=20EP=20following=20deduplicatio=E2=80=A6=20(#1?= =?UTF-8?q?8915)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This limits the size of constant data nodes which the DML EP creates in the DML graph following de-duplication of 1D quantization tensors. In the process it reduces a check for the maximum size of the constant node. This is merged from: https://github.com/microsoft/onnxruntime/pull/18494 ### Motivation and Context --- .../src/GraphDescBuilder.cpp | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index ba022533a1e94..adb4fd131119f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -347,19 +347,23 @@ namespace Dml::GraphDescBuilder // This is a highly inefficient approach to generating constant nodes. It duplicates constant data // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. - - // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming - // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point - // values that have been de-duplicated with conversion to scalars in kernels. - uint32_t c_maxConstNodeDataSize = 1024 * 1024; + uint32_t c_maxConstNodeDataSize = 8; ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); - if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + + if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), tensorDesc->totalTensorSizeInBytes); auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + constantInput->GetTensorByteSize()); - + std::vector tensorData(data, data + minimumConstantSize); + NodeInfo nodeInfo = {}; nodeInfo.nodeDef = std::move(tensorData); graphNodes.push_back(std::move(nodeInfo)); @@ -379,9 +383,6 @@ namespace Dml::GraphDescBuilder edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; graphInputEdges.push_back(edge); - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; } } @@ -445,7 +446,7 @@ namespace Dml::GraphDescBuilder // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - + // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) dmlDesc.Type = (DML_OPERATOR_TYPE) 170; From 56fcea94e3c6f158f3a16dbe42dc1b16caf35565 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:06:05 -0800 Subject: [PATCH 21/45] Enable QDQ quantization for DML EP (#18367) ### Description This enables QDQ transforms with the DML EP --- .../qdq_selector_action_transformer.cc | 60 +++++++++++++------ .../selectors_actions/qdq_selectors.cc | 7 +++ .../selectors_actions/qdq_selectors.h | 26 ++++---- onnxruntime/core/session/inference_session.cc | 20 ------- tools/ci_build/build.py | 3 + 5 files changed, 67 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 29178fe87f75c..29f7575b2a638 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -105,8 +105,8 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit. - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, {"LeakyRelu", {}}, @@ -123,20 +123,43 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 4 nodes. 2 x DQ for inputs, target, Q // Replace with internal QLinear version of operator. Delete all original nodes. - const std::string action_name{"2DQ"}; - std::unique_ptr action = std::make_unique(kMSDomain); + { + const std::string action_name{"2DQ"}; + std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"Add", {}}, - {"Mul", {}}}, - std::move(selector), - std::move(action)); + // TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit. + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Add", {}}, + {"Mul", {}}}, + std::move(selector), + std::move(action)); #else - qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif + } + +#ifdef USE_DML + { + const std::string action_name{"2DQ_DML"}; + std::unique_ptr action = std::make_unique(kMSDomain); + +#if !defined(ORT_MINIMAL_BUILD) + std::vector providers = {kDmlExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); + + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Add", {}}}, + std::move(selector), + std::move(action)); + +#else +#error "ORT_MINIMAL_BUILD and USE_DML are not expected simultaneously. This would require RegisterAction to be called here." +#endif + } #endif } @@ -214,8 +237,8 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when QGemm supports 16-bit. - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Gemm", {}}}, std::move(selector), @@ -235,8 +258,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - // TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit. - std::unique_ptr selector = std::make_unique(); + + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Where", {}}}, std::move(selector), @@ -271,8 +295,8 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed), apply_context, - // this transformer is only compatible with the CPU EP - {kCpuExecutionProvider}} { + // this transformer is only compatible with the CPU and DML EP + {kCpuExecutionProvider, kDmlExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 15b501c667046..8535b8c9a944a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -91,6 +91,13 @@ std::optional NodeGroupSelector::GetQDQSelection(const GraphViewer& g } std::optional BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { + const std::string_view node_ep = node.GetExecutionProviderType(); + + if (!compatible_providers_.empty() && + std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { + return std::nullopt; + } + const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node); if (!qdq_group.has_value()) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d0d7fb2c2af17..5ce2a97026516 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -257,12 +257,15 @@ class BaseSelector : public NodeSelector { // We std::move SelectorActionRegistry into the SelectorActionTransformer so this class needs to have a move ctor BaseSelector(BaseSelector&& rhs) noexcept - : node_group_selector_{std::move(rhs.node_group_selector_)} { + : node_group_selector_{std::move(rhs.node_group_selector_)}, + compatible_providers_{std::move(rhs.compatible_providers_)} { } protected: - BaseSelector(std::unique_ptr node_group_selector) - : node_group_selector_{std::move(node_group_selector)} {} + BaseSelector(std::unique_ptr node_group_selector, gsl::span compatible_providers = {}) + : node_group_selector_{std::move(node_group_selector)}, + compatible_providers_(compatible_providers.begin(), compatible_providers.end()) { + } // override if you need to adjust the values in NodesToOptimize. // e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs @@ -271,6 +274,7 @@ class BaseSelector : public NodeSelector { private: std::unique_ptr node_group_selector_; + std::vector compatible_providers_; }; class DropQDQNodesSelector : public BaseSelector { @@ -287,14 +291,14 @@ class DropDQNodesSelector : public BaseSelector { class UnarySelector : public BaseSelector { public: - explicit UnarySelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit UnarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; class BinarySelector : public BaseSelector { public: - explicit BinarySelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit BinarySelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; // Variadic DQ nodes -> node -> Q @@ -326,8 +330,8 @@ class ConvSelector : public BaseSelector { class WhereSelector : public BaseSelector { public: - explicit WhereSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit WhereSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} }; // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not @@ -342,8 +346,8 @@ class MatMulSelector : public BaseSelector { // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - explicit GemmSelector(bool allow_16bit = false) - : BaseSelector(std::make_unique(allow_16bit)) {} + explicit GemmSelector(gsl::span compatible_providers = {}, bool allow_16bit = false) + : BaseSelector(std::make_unique(allow_16bit), compatible_providers) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cef160489ac46..665cdbc36a963 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -633,26 +633,6 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr session_options_.enable_mem_pattern = false; } - // Default this option to true when the DML EP is registered. - // This should be removed if QDQ is supported for DML through QDQSelectorActionTransformer and the DML EP does not - // rely on the constant folding pass for DequantizeLinear. - optional disable_quant_qdq = session_options_.config_options.GetConfigEntry(kOrtSessionOptionsDisableQuantQDQ); - - if (disable_quant_qdq == std::nullopt) { - LOGS(*session_logger_, INFO) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "So disabling it for this session since it uses the DML Execution Provider."; - - auto st = session_options_.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"); - if (!st.IsOK()) { - return st; - } - } else if (*disable_quant_qdq != "1") { - LOGS(*session_logger_, WARNING) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "It is enabled within session options which may result in lower performance."; - } - // Parallel execution mode does not support DML EP if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { LOGS(*session_logger_, INFO) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 5cc537c4596e8..c655100fbf475 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1641,6 +1641,9 @@ def setup_dml_build(args, cmake_path, build_dir, configs): ] run_subprocess(cmd_args) + if args.minimal_build is not None: + raise BuildError("use_dml and minimal_build may not both be set") + def setup_rocm_build(args): rocm_home = None From 70a6f816af1a482b65015616d56b73d21a43d7fe Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Wed, 3 Jan 2024 16:22:54 -0800 Subject: [PATCH 22/45] Port attention query fix from b2768bbf2347b4ea564f2a937f9f48987620ddf0 --- .../src/Operators/DmlOperatorAttention.cpp | 6 ++++++ .../core/providers/dml/OperatorAuthorHelper/Attributes.h | 1 + 2 files changed, 7 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index bbebb4a333baf..c8ca6806e75f7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -571,6 +571,12 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o return; } + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + *isSupported = true; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 85333aa77b686..e3df1d00b3e8a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -107,6 +107,7 @@ namespace AttrName static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes"; static constexpr const char* Unidirectional = "unidirectional"; static constexpr const char* NumHeads = "num_heads"; + static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer"; static constexpr const char* FusedActivation = "fused_activation"; static constexpr const char* FusedActivationDomain = "fused_activation_domain"; From f4ad940ff3c69e139c3baf91517dfd458e166fb1 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Wed, 3 Jan 2024 18:37:14 -0800 Subject: [PATCH 23/45] Disable MatMul QDQ selector on DML EP until MatMulIntegerToFloat is re-enabled --- .../selectors_actions/qdq_selector_action_transformer.cc | 3 ++- .../qdq_transformer/selectors_actions/qdq_selectors.h | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 29f7575b2a638..244bd8b153077 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -216,7 +216,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #if !defined(ORT_MINIMAL_BUILD) // TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit. - std::unique_ptr selector = std::make_unique(is_int8_allowed); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers, is_int8_allowed); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"MatMul", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 5ce2a97026516..deee6e7f25f1a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -337,9 +337,10 @@ class WhereSelector : public BaseSelector { // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { public: - MatMulSelector(bool int8_allowed, bool allow_16bit = false) + MatMulSelector(gsl::span compatible_providers, bool int8_allowed, bool allow_16bit = false) : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true, - allow_16bit)) {} + allow_16bit), + compatible_providers) {} }; // Input: DQ nodes for A, B and optional C From 8ea3e68192967314f80ef4e7abd21a59b7cfd722 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 4 Jan 2024 10:10:46 -0800 Subject: [PATCH 24/45] Update ContribOperators.md --- docs/ContribOperators.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 131db5d8d9b37..38fceef67de25 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+
rotary_embedding_dim : int
+
Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
From 7401b6661d29094c1cb08d117fcaf8af7038b16f Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 4 Jan 2024 11:27:03 -0800 Subject: [PATCH 25/45] Update OperatorKernels.md --- docs/OperatorKernels.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ce9b3254d91f..e401baae2d803 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -903,7 +903,8 @@ Do not modify directly.* |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Atanh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|AveragePool|*in* X:**T**
*out* Y:**T**|19+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| |BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(float), tensor(float16)| @@ -951,7 +952,7 @@ Do not modify directly.* |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|7+|**T** = tensor(float), tensor(float16)| -|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| @@ -1030,7 +1031,8 @@ Do not modify directly.* |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| |LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float), tensor(float16)| +|||11+|**T** = tensor(float), tensor(float16)| |||2+|**T** = tensor(float), tensor(float16)| |MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||9+|**T** = tensor(float), tensor(float16)| @@ -1145,8 +1147,8 @@ Do not modify directly.* |Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| -|||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| @@ -1247,6 +1249,9 @@ Do not modify directly.* |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)| +|QLinearGlobalAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| From 658e30eb33f157dc7e7cba0e6ac9bf37178722e1 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 4 Jan 2024 12:59:47 -0800 Subject: [PATCH 26/45] Remove DORT since it's in PyTorch main now (#18996) Main code are removed and tests are modified to use DORT directly from PyTorch. --- cmake/onnxruntime_python.cmake | 7 - .../python/training/torchdynamo/__init__.py | 4 - .../training/torchdynamo/ort_backend.py | 729 ------------------ .../training/torchdynamo/register_backend.py | 89 --- .../test/python/orttraining_test_dort.py | 47 +- .../orttraining_test_dort_custom_ops.py | 26 +- setup.py | 1 - 7 files changed, 42 insertions(+), 861 deletions(-) delete mode 100644 orttraining/orttraining/python/training/torchdynamo/__init__.py delete mode 100644 orttraining/orttraining/python/training/torchdynamo/ort_backend.py delete mode 100644 orttraining/orttraining/python/training/torchdynamo/register_backend.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 61922961588b2..2e3594f256f65 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -354,9 +354,6 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/optim/*.py" ) - file(GLOB onnxruntime_python_torchdynamo_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/training/torchdynamo/*.py" - ) file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py" ) @@ -746,7 +743,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental/gradient_graph COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/optim - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/torchdynamo COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config @@ -777,9 +773,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_optim_srcs} $/onnxruntime/training/optim/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_torchdynamo_srcs} - $/onnxruntime/training/torchdynamo/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_srcs} $/onnxruntime/training/ortmodule/ diff --git a/orttraining/orttraining/python/training/torchdynamo/__init__.py b/orttraining/orttraining/python/training/torchdynamo/__init__.py deleted file mode 100644 index 862c45ce31b25..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py deleted file mode 100644 index 9bafe39a5c211..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ /dev/null @@ -1,729 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import dataclasses -import logging -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union - -import numpy as np -import onnx -import torch -import torch._C -import torch._ops -import torch._prims.executor -import torch.fx -import torch.onnx - -# TODO(wschin,justinchuby): Since the internal APIs are not stable, please -# contact us if you hit errors. -import torch.onnx._internal -import torch.onnx._internal.diagnostics -import torch.onnx._internal.exporter -import torch.onnx._internal.fx.decomposition_table -import torch.onnx._internal.fx.passes -from torch._subclasses.fake_tensor import FakeTensor -from torch.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch.fx.passes.operator_support import OperatorSupport -from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.utils import _pytree - -import onnxruntime # type: ignore -from onnxruntime.capi import _pybind_state as ORTC - -_NP_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.longlong, - torch.bool: np.bool_, -} - -_ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE = { - 1: torch.float32, - 2: torch.uint8, - 3: torch.int8, - 5: torch.int16, - 6: torch.int32, - 7: torch.int64, - 9: torch.bool, - 10: torch.float16, -} - -_TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE = {value: key for key, value in _ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE.items()} - - -def _nvtx_range_push(name: str): - """If PyTorch is installed with CUDA support, this starts NVTX range. - - Check torch.cuda.nvtx.range_push's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_push(name) - - -def _nvtx_range_pop(): - """If PyTorch is installed with CUDA support, this terminates NVTX range. - - Check torch.cuda.nvtx.range_pop's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_pop() - - -def _get_ort_device_type(device_type: str): - if device_type == "cuda": - return ORTC.OrtDevice.cuda() # type: ignore - if device_type == "cpu": - return ORTC.OrtDevice.cpu() # type: ignore - # ort pytorch device is mapped to NPU OrtDevice type - if device_type == "ort": - return ORTC.OrtDevice.npu() # type: ignore - raise ValueError("Unsupported device type: " + device_type) - - -logger = logging.getLogger(__name__) -# Uncomment the following lines to print out development info. -# logging.basicConfig(level=logging.INFO) -# logger.setLevel(logging.INFO) - - -class OrtOperatorSupport(OperatorSupport): - """ - Operator support for ONNXRuntime backend. It has two-level of support decision. - One is via support_dict and the other one is via extra_support_dict. The logic - of using support_dict is implemented in OrtOperatorSupport and extra_support_dict - is used by OperatorSupport.is_node_supported. - """ - - def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): - # Use extra_support_dict[op_name] = None to indicate - # we support op_name with all input types. Otherwise, - # see support_dict (type: SupportDict) in operator_support.py - # for specifying supported types. - super().__init__(extra_support_dict) - self._support_dict = support_dict - - def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool: - # OperatorSupport.is_node_supported returns True for non-callable nodes. - # Since ORT can't execute them, we return False here to override the base - # behavior. - if node.op not in CALLABLE_NODE_OPS: - return False - # This is the and the only place to decide if aten op is supported. - if node.op == "call_function" and node.target in self._support_dict: - logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target)) - return True - logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target)) - # If node.target is not in support_dict, we still want to check if torch.jit.script - # can convert it to ONNX equivalence. Let's use base mechanism to do this. - # See extra_support_dict for supported ops. - if super().is_node_supported(submodules, node): - logger.info("extra_support_dict supports node.target: %s (type: %s)", node.target, type(node.target)) - return True - logger.info("extra_support_dict doesn't supports node.target: %s (type: %s)", node.target, type(node.target)) - return False - - -def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: - """ - In torch.fx.Graph, placehoder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -def _replace_to_copy_with_to(fx_module: torch.fx.GraphModule) -> None: - # aten._to_copy doesn't have exporter so we replace it with aten.to. - for node in fx_module.graph.nodes: - if ( - isinstance(node.target, torch._ops.OpOverload) - and node.target.overloadpacket == torch.ops.aten._to_copy # type: ignore - ): - is_default_layout = True - is_on_same_device = True - is_cast = True - are_kwargs_supported = True - if "layout" in node.kwargs and node.kwargs["layout"] != torch.strided: - is_default_layout = False - if "device" in node.kwargs and node.kwargs["device"] != node.args[0].meta["val"].device: - is_on_same_device = False - if "dtype" not in node.kwargs: - is_cast = False - for kwarg in node.kwargs: - if kwarg not in ["layout", "device", "dtype"]: - are_kwargs_supported = False - - if len(node.args) == 1 and is_default_layout and is_on_same_device and is_cast and are_kwargs_supported: - # This aten::_to_copy looks like ONNX Cast, so other kwargs are ignored. - # This change could lead to invalid FX graph but it doesn't matter, as long as the downstream backend, - # ONNXRuntime, can execute the exported ONNX graph. - node.kwargs = {"dtype": node.kwargs["dtype"]} - - node.target = torch.ops.aten.to.dtype # type: ignore - else: - raise RuntimeError( - f"aten._to_copy must be replaced with other ONNX-supported aten ops. \ - args={[arg.meta for arg in node.args]}, kwargs={node.kwargs}" - ) - fx_module.recompile() - - -def _create_onnx_model(onnx_proto): - return onnx.ModelProto.FromString(onnx_proto) - - -def _create_onnx_session(onnx_proto, eps: Tuple[str, ...], session_options): - # TODO(wechi): Add more EPs per PyTorch device types. - # TODO(wechi): enable external allocators. - return onnxruntime.InferenceSession(onnx_proto, providers=eps, sess_options=session_options) - - -def _infer_ep_from_device(*args) -> Tuple[str, ...]: - """Return the first valid device (i.e., GPU or CPU) in argument list.""" - eps = [] - for arg in args: - if hasattr(arg, "device"): - device = arg.device - if device.type == "cuda": - eps.append("CUDAExecutionProvider") - elif device.type == "cpu": - eps.append("CPUExecutionProvider") - return tuple(eps) - - -def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]: - placeholders = [] - for node in graph_module.graph.nodes: - if node.op == "placeholder": - if hasattr(node, "meta") and "val" in node.meta: - assert isinstance(node.meta["val"], torch.Tensor) - placeholders.append(node) - - -def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: - """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" - for node in graph_module.graph.nodes: - if node.op == "output": - # Output node is unique. Let's retrieve output values from - # this node's input list. And then just return. - return node.args[0] - raise ValueError("No output node found in this torch.fx.GraphModule.") - - -def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: - """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" - flattened_output_args, _ = _pytree.tree_flatten(_extract_graph_module_outputs(graph_module)) - # Output arguments with example value (type: torch.Tensor) in the `graph_module`. - selected_output_args = [ - output_arg.meta["val"] - for output_arg in flattened_output_args - # output_arg must have tensor for its device information. - # Otherwise, skip it. - if (hasattr(output_arg, "meta") and "val" in output_arg.meta) - ] - return _infer_ep_from_device(*selected_output_args) - - -def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: - """Sort execution providers in eps based on pre-set priority.""" - - def get_execution_provider_priority(ep: str) -> int: - if ep == "CPUExecutionProvider": - # Lowest priority. - return 2 - if ep == "CUDAExecutionProvider": - # Higher priority than CPU but lower than - # other specialized EPs. - return 1 - # Highest priority. - return 0 - - unique_eps = set(eps) - return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) - - -def _get_onnx_devices(values: Tuple[torch.Tensor, ...]) -> Tuple[ORTC.OrtDevice, ...]: # type: ignore - assert all(value.device == values[0].device for value in values), "All values must be on the same device." - - def _device_id_or_zero(device_id: int) -> int: - return device_id or 0 - - devices: Tuple[ORTC.OrtDevice, ...] = tuple( # type: ignore - ORTC.OrtDevice( # type: ignore - _get_ort_device_type(value.device.type), - ORTC.OrtDevice.default_memory(), # type: ignore - _device_id_or_zero(value.device.index), - ) - for value in values - ) - return devices - - -def _get_ortvalues_from_torch_tensors( - tensors: Tuple[torch.Tensor, ...], devices: Tuple[ORTC.OrtDevice, ...] -) -> Tuple[torch.Tensor, ...]: - ortvalues = ORTC.OrtValueVector() # type: ignore - ortvalues.reserve(len(tensors)) - dtypes = [] - shapes = [] - data_ptrs = [] - - for tensor in tensors: - dtypes.append(_NP_DTYPE[tensor.dtype]) - shapes.append(tensor.size()) - data_ptrs.append(tensor.data_ptr()) - ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) - return ortvalues - - -def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: - if tensor.is_sparse: - raise ValueError("sparse tensor is not yet supported.") - out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) - return out - - -def _run_onnx_session_with_ortvaluevector( - sess: onnxruntime.InferenceSession, - input_names: Tuple[str, ...], - inputs: Tuple[torch.Tensor, ...], - input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - output_names: Tuple[str, ...], - outputs: Tuple[torch.Tensor, ...], - output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - preallocate_output: bool, -) -> Tuple[torch.Tensor, ...]: - _nvtx_range_push("contiguous") - inputs = tuple(a.contiguous() for a in inputs) - _nvtx_range_pop() - - _nvtx_range_push("push_back_batch") - - ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) - - # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. - # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue - # to torch Tensor transferring the ownership. - if preallocate_output: - pth_outputs = tuple(map(lambda t: _to_real_tensor(t) if isinstance(t, FakeTensor) else t, outputs)) - ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) - else: - ort_outputs = ORTC.OrtValueVector() # type: ignore - _nvtx_range_pop() - - _nvtx_range_push("run_with_ortvaluevector") - run_options = onnxruntime.RunOptions() - run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") - sess.run_with_ortvaluevector(run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices) - _nvtx_range_pop() - - if preallocate_output: - return pth_outputs - else: - _nvtx_range_push("after run_with_ortvaluevector") - pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(ort_outputs) # type: ignore - _nvtx_range_pop() - return pth_outputs - - -def _assert_allclose_with_detailed_error_message( - actual: torch.Tensor, expected: torch.Tensor, rtol: float = 1e-03, atol: float = 1e-04 -): - diff = actual - expected - real_atol = torch.max(torch.abs(diff)) - max_value = torch.max(torch.abs(actual), torch.abs(expected)) - max_value[max_value == 0.0] = 1.0 - real_rtol = torch.max(diff / max_value) - allclose = bool(real_atol <= atol or real_rtol <= rtol) - if not allclose: - raise RuntimeError( - "ONNX output doesn't match baseline output with " - f"actual rtol={real_rtol} and actual atol={real_atol} " - f"but expected rtol={rtol} and expected atol={atol}." - ) - - -class OrtExecutionInfoPerSession: - """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" - - def __init__( - self, - session: onnxruntime.InferenceSession, - input_names: Tuple[str, ...], - input_value_infos: Tuple[onnx.ValueInfoProto, ...], - output_names: Tuple[str, ...], - output_value_infos: Tuple[onnx.ValueInfoProto, ...], - input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor], - ): - # Carrier of ONNX model and its executor. - self.session: onnxruntime.InferenceSession = session - # For the ONNX model stored in self.session, self.input_names[i] is the - # name of the i-th positional input. - self.input_names: Tuple[str, ...] = input_names - # self.input_name[i]'s type information is stored in self.input_value_infos[i]. - self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos - # Similar to self.input_names, but for outputs. - self.output_names: Tuple[str, ...] = output_names - # Similar to self.input_value_infos but for outputs. - self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos - # For the ONNX model stored in self.session, self.input_devices[i] is the - # i-th positional input's device. - self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices # type: ignore - # Similar to self.input_devices, but for outputs. - self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices # type: ignore - # This is the outputs of executing the original torch.fx.GraphModule with example inputs - # (i.e., args passed into OrtBackend._ort_acclerated_call). - self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = example_outputs - - def is_supported(self, *args): - # Compare the args and the input schema in ONNX model and - # return the first match. - if len(args) != len(self.input_value_infos): - return False - for arg, value_info in zip(args, self.input_value_infos): - if not isinstance(arg, torch.Tensor): - return False - onnx_dtype = _TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE[arg.dtype] - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): - if isinstance(dim, int) and (onnx_dim.dim_value == dim or onnx_dim.dim_param): - continue - elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: - continue - else: - return False - return True - - -@dataclasses.dataclass -class OrtExecutionInfoForAllGraphModules: - def __init__(self): - # All sessions (and their related information) created by exporting the same GraphModule - # with different inputs. - self.execution_info_per_graph_module: Dict[torch.fx.GraphModule, List[OrtExecutionInfoPerSession]] = {} - - def search_reusable_session_execution_info(self, graph_module: torch.fx.GraphModule, *args): - if graph_module not in self.execution_info_per_graph_module: - return None - # All execution information for ONNX models exported from the same `graph_module` - # with different inputs. - candidates = self.execution_info_per_graph_module[graph_module] - - for candidate in candidates: - if candidate.is_supported(*args): - # Returns the first session that accepts this input schema. - return candidate - # No reusable session found. - return None - - def cache_session_execution_info(self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession): - if graph_module not in self.execution_info_per_graph_module: - self.execution_info_per_graph_module[graph_module] = [info] - else: - self.execution_info_per_graph_module[graph_module].append(info) - - -class OrtBackend: - """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. - - The compiler entry point is OrtBackend.compile, which - 1. partitions the original graph into supported sub-graphs (type: torch.fx.GrpahModule) and unsupported - sub-graphs. - 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. - 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. - """ - - def __init__( - self, - ep: str = "CPUExecutionProvider", - preallocate_output: bool = False, - session_options=None, - onnx_exporter_options: Optional["torch.onnx.ExportOptions"] = None, - ): - # onnx_exporter_options contains information shared between exporter and DORT. - # For example, they should use the same decomposition table when - # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) - # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model - # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). - if onnx_exporter_options is None: - onnx_exporter_options = torch.onnx.ExportOptions() - # Convert user-facing option to internal option used by ONNX exporter - # to access required information. - # Some useful fields: - # - Decomposition table for decomposing FX operators in exporter is - # self.resolved_onnx_exporter_options.decomposition_table. - # - self.resolved_onnx_exporter_options.onnx_registry records what - # aten/prim ops are supported by exporter and their exporters (type: callable). - self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options) - - # TODO(wechi): This line must generate result identical to the call of - # _create_onnx_supports_op_overload_table(...) inside - # create_onnx_friendly_decomposition_table(...) in - # torch/onnx/_internal/fx/decomposition_table.py. - support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( - # This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry. - self.resolved_onnx_exporter_options.onnx_registry - ) # type: ignore - - extra_support_dict: Dict[str, Any] = { - "getattr": None, - "_operator.getitem": None, - } - - self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) - # TODO: this is a naive implementation of cache without proper guard - self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} - # Conceptually, this filed is a 2-layer dictionary - # GraphModule 0 - # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 1 - # ... - # GraphModule 1 - # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 3 - # ... - # ... - # , which caches all previous compilation result so that we can reuse them. - # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs - # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different - # graphs captured by Dynamo and sent to OrtBackend.compile. - self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() - - self._assert_allclose_to_baseline = False - - self.ep = ep - self.session_options = session_options - - # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession - # in order to avoid internal allocation of output buffers in InferenceSession. - # If output ortvalue returned from InferenceSession is allocated internally, - # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. - # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor - # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. - # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, - # and use the preallocated output buffers for InferenceSession not holding any ownership for them. - self.preallocate_output = preallocate_output - - def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): - cached_execution_info_per_session = self._all_ort_execution_info.search_reusable_session_execution_info( - graph_module, *args - ) - if cached_execution_info_per_session: - onnx_session = cached_execution_info_per_session.session - input_names = cached_execution_info_per_session.input_names - output_names = cached_execution_info_per_session.output_names - input_devices = cached_execution_info_per_session.input_devices - output_devices = cached_execution_info_per_session.output_devices - prim_outputs = cached_execution_info_per_session.example_outputs - else: - # It's first time seeing such as graph. Let's make a new session - # (type: onnxruntime.InferenceSession) for it. - - # TODO(wechi): this is a workaround for pytorch/pytorch#84311. - _move_placeholder_to_front(graph_module) - # Generate reference outputs. They are used to indicate output - # tensors' types and devices when calling ORT. - # - # WARNING: The downstream code should not change prim_outputs and - # this backend should always produces output with schema identical to prim_outputs'. - - if self.resolved_onnx_exporter_options.dynamic_shapes: - # No pre-allocation when dynamic shape is enabled. - self.preallocate_output = False - extracted_outputs = _extract_graph_module_outputs(graph_module) - - def maybe_map_to_meta_val(value): - if hasattr(value, "meta") and "val" in value.meta: - # Select outputs with "val" information. Without "val", - # it's not possible access output_arg.meta["val"].device. - return value.meta["val"] - else: - return value - - prim_outputs = _pytree.tree_map(maybe_map_to_meta_val, extracted_outputs) - else: - try: - prim_outputs = FakeTensorProp(graph_module).propagate(*args, **kwargs) - except Exception: - logger.info(f"FakeTensorProb failed for {graph_module}") - # When FakeTensorProp fails, it is not possible to preallocate output buffers - # because the output shapes are not inferred. - self.preallocate_output = False - - # rethrow FakeTensorProb failure because it is not yet currently handled. - raise - - graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( - self.resolved_onnx_exporter_options.diagnostic_context, graph_module - ).run() - - from torch.onnx._internal.fx import fx_onnx_interpreter - - # Create the object to iterate through the nodes in graph one-by-one - # and calls the corresponding ONNX exporter for each node. - fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( - diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context - ) - # Start the per-node exporting process. It's conceptually a for loop - # scanning through the nodes in the graph. - exported = fx_interpreter.run( - fx_graph_module=graph_module, - onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher, - op_level_debug=self.resolved_onnx_exporter_options.op_level_debug, - ) - # Convert the exported result to ONNX ModelProto. - onnx_proto = exported.to_model_proto( - opset_version=self.resolved_onnx_exporter_options.onnx_registry.opset_version - ).SerializeToString() - - # Initialize a ORT session to execute this ONNX model. - # Note that TorchDynamo assumes all inputs/outputs are on the - # same device, but it's subject to change (very likely with - # dynamic shape support), so we add execution providers - # based on the all inputs/outputs plus a default OrtBackend.ep. - eps_from_args = _infer_ep_from_device(args) - eps_from_graph_module = _infer_ep_from_graph_module(graph_module) - if eps_from_args: - # If user feeds CUDA tensor as input argument, - # we want to use CUDA EP. - # Thus, `eps_from_args` (deduced from input arguments) - # has highest priority. - selected_eps = _sort_eps((*eps_from_args, self.ep)) - elif eps_from_graph_module: - # If there is no EP in input arguments, we deduce EP from - # graph_module's outputs. Those outputs may come from - # FakeTensorProp or Dynamo's built-in symbolic shape inference. - selected_eps = _sort_eps((*eps_from_graph_module, self.ep)) - else: - # No EP found in inputs and outputs, let's use default. - selected_eps = (self.ep,) - - onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options) - # Cache ORT session. It's reused for the same "graph_module". - # Generate ONNX model and extract its input and output names. - onnx_model = _create_onnx_model(onnx_proto) - # TODO(wechi): ORT session should provide a API to extract - # input and output names from the underlying model. - input_names = tuple(input.name for input in onnx_model.graph.input) - output_names = tuple(output.name for output in onnx_model.graph.output) - input_devices = _get_onnx_devices(args) - # Cache devices for inputs and outputs. They are used to invoke - # ORT session. Output devices indicate where (e.g., GPU or CPU) - # to store outputs - if isinstance(prim_outputs, tuple): - output_devices = _get_onnx_devices(prim_outputs) - else: - output_devices = _get_onnx_devices((prim_outputs,)) - - execution_info_per_session = OrtExecutionInfoPerSession( - session=onnx_session, - input_names=input_names, - input_value_infos=tuple(input for input in onnx_model.graph.input), - output_names=output_names, - output_value_infos=tuple(output for output in onnx_model.graph.output), - input_devices=input_devices, - output_devices=output_devices, - example_outputs=prim_outputs, - ) - - self._all_ort_execution_info.cache_session_execution_info(graph_module, execution_info_per_session) - - if isinstance(prim_outputs, tuple): - assert all(isinstance(elem, torch.Tensor) for elem in prim_outputs) - # ORT always returns a tuple of outputs. If the original is a tuple, just returning - # ORT output is ok. - _nvtx_range_push("run_onnx_session_with_ortvaluevector") - onnx_outputs = _run_onnx_session_with_ortvaluevector( - onnx_session, - input_names, - args, - input_devices, - output_names, - prim_outputs, - output_devices, - self.preallocate_output, - ) - _nvtx_range_pop() - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten") - # Ensure every output tensor is close to the corresponding baseline. - for onnx_output, baseline_output in zip(onnx_outputs, baseline_outputs): - _assert_allclose_with_detailed_error_message(onnx_output, baseline_output) - return onnx_outputs - else: - assert isinstance(prim_outputs, torch.Tensor) - # ORT always returns a tuple of outputs. If the original output is a tensor, - # ORT output's first element must be extracted and returned. Otherwise, type - # mismatch may happen in downstream computation. - onnx_outputs = _run_onnx_session_with_ortvaluevector( - onnx_session, - input_names, - args, - input_devices, - output_names, - (prim_outputs,), - output_devices, - self.preallocate_output, - ) - assert len(onnx_outputs) == 1 - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten") - # Ensure output tensor is close to the corresponding baseline. - _assert_allclose_with_detailed_error_message(onnx_outputs[0], baseline_outputs) - return onnx_outputs[0] - - def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - # FX graph based partitioning based on ONNX supported ops. - if graph_module in self._partitioner_cache: - partitioned_prim_graph_module = self._partitioner_cache[graph_module] - else: - prim_graph_module = graph_module - # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to. - _replace_to_copy_with_to(prim_graph_module) - partitioner = CapabilityBasedPartitioner( - prim_graph_module, self._supported_ops, allows_single_node_partition=True - ) - partitioned_prim_graph_module = partitioner.partition_and_fuse() - self._partitioner_cache[graph_module] = partitioned_prim_graph_module - - # Overriding fused_module's __call__() function with ort_acclerated_call() - # This loop goes through all graph partitions (each of them is an ONNX-representable graph) - # and override their _wrappped_call function with _ort_accelerated_call. - # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. - for node in partitioned_prim_graph_module.graph.nodes: - # TODO: use a better way to identify fused submodule - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(partitioned_prim_graph_module, node.name) - # self.ort_acclerated_call is responsible for exporting graph to ONNX, - # creating ORT session, and running ORT session. - fused_module._wrapped_call = self._ort_acclerated_call - - return partitioned_prim_graph_module - - def __call__(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - return self.compile(graph_module, args) diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py deleted file mode 100644 index 3a49e85ab836d..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py +++ /dev/null @@ -1,89 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from functorch.compile import min_cut_rematerialization_partition -from torch._dynamo.backends.common import aot_autograd -from torch.onnx._internal.exporter import ExportOptions - -from .ort_backend import OrtBackend - - -def make_aot_ort(dynamic: bool = True): - """Wrap OrtBackend as PyTorch's AOT compiler. - - Example usages: - import torch - from onnxruntime.training.torchdynamo.register_backend import make_aot_ort - use_dynamic = True - local_aot_ort, _ = make_aot_ort(dynamic = use_dynamic) - - @torch._dynamo.optimize(local_aot_ort, dynamic=use_dynamic) - def foo(x: torch.Tensor): - return torch.sigmoid(x) - - x = torch.rand(2, 2, dtype=torch.float) - torch.testing.assert_close(torch.sigmoid(x), foo(x)) - """ - ort_backend = OrtBackend(onnx_exporter_options=ExportOptions(dynamic_shapes=dynamic)) - return ( - aot_autograd( - fw_compiler=ort_backend, - partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, - ), - ort_backend, - ) - - -# Wrap ORT as a compiler in Dynamo for training (i.e., when .backward is called). -# -# Under the hood, OrtBackend.compile is called inside functorch. See aot_function -# and aot_module in aot_autograd.py in PyTorch repo for more details. Basically, -# OrtBackend.compile is mapped to forward graph compiler, fw_compile, and backward -# graph compiler, bw_compile, in aot_autograd.py. -# -# Example usage: -# import torch -# from onnxruntime.training.torchdynamo.register_backend import aot_ort -# model = torch.nn.Linear(2, 2) -# compiled_model = torch._dynamo.optimize(aot_ort)(model) -# result = compiled_model(torch.rand(2, 2, dtype=torch.float) -# result.sum().backward() -# -# DEFAULT_BACKEND should be the underlying compiler for ALL graphs if -# the user uses ORT to accelerate PyTorch via Dynamo. -# By using a global compiler for all graphs, cached compilation -# results can be reused when encountering the identical graphs. -aot_ort, DEFAULT_BACKEND = make_aot_ort(dynamic=False) - -# Similar to aot_ort but should be used with -# torch._dynamo.optimize(dynamic_aot_ort, dynamic=True) -# to enable dynamic shapes in ONNX graph. -# -# Similar to DEFAULT_BACKEND but DEFAULT_DYNAMIC_BACKEND enables dynamic shapes -# when exporting FX graph to ONNX. -# Note that this backend must be used with -# torch._dynamo.optimize(DEFAULT_DYNAMIC_BACKEND, dynamic=True) -# Without `dynamic=True`, the FX graph only contains static shapes, and results ONNX graph -# with static shapes. -dynamic_aot_ort, DEFAULT_DYNAMIC_BACKEND = make_aot_ort(dynamic=True) - -# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called). -# -# ort is usually faster than aot_ort for inference because the graphs generated by aot_autograd -# mechanism are very different than the original graphs. Therefore, some ORT's graph transformers -# are not applicable. -# -# Example usage: -# import torch -# from onnxruntime.training.torchdynamo.register_backend import ort -# model = torch.nn.Linear(2, 2) -# compiled_model = torch._dynamo.optimize(ort)(model) -ort = DEFAULT_BACKEND - -# Similar to ort but should be used with -# torch._dynamo.optimize(dynamic_ort, dynamic=True) -# to enable dynamic shapes in ONNX graph. -dynamic_ort = DEFAULT_DYNAMIC_BACKEND diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 2a7012787be6e..f0b6b9c5fba28 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -8,9 +8,22 @@ import torch.onnx._internal.exporter from torch import nn from torch.nn import functional as F +from torch.onnx import ExportOptions +from torch.onnx import _OrtBackend as OrtBackend +from torch.onnx import _OrtBackendOptions as OrtBackendOptions from torch.utils import _pytree -from onnxruntime.training.torchdynamo.register_backend import aot_ort, dynamic_aot_ort, make_aot_ort, ort + +def make_local_backend(dynamic: bool = False, use_aot_autograd: bool = False): + ort_backend = OrtBackend( + options=OrtBackendOptions( + export_options=ExportOptions( + dynamic_shapes=dynamic, + ), + use_aot_autograd=use_aot_autograd, + ) + ) + return ort_backend class TestTorchDynamoOrt(unittest.TestCase): @@ -35,9 +48,7 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.sigmoid() return tensor_q - @torch._dynamo.optimize(aot_ort) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True) def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_() @@ -77,9 +88,7 @@ def elementwise_model(tensor_x: torch.Tensor): # With dynamic_shape=True, Dynamo sends FX graphs with dynamic # shapes (e.g., batch size is a symbol "batch" instead of a fixed # number) to OrtBackend.compile(...). - @torch._dynamo.optimize(dynamic_aot_ort, dynamic=True) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True) def run(fun, seed: torch.Tensor): tensor_x = seed.detach().clone().requires_grad_() @@ -125,8 +134,8 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.sigmoid() return (tensor_q, (tensor_y, tensor_z)) - local_aot_ort, ort_backend = make_aot_ort(dynamic=True) - cached = ort_backend._all_ort_execution_info.execution_info_per_graph_module + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) + cached = local_backend._all_ort_execution_info.execution_info_per_graph_module # Before compilation, no graph is generated. assert len(cached) == 0 @@ -135,7 +144,7 @@ def elementwise_model(tensor_x: torch.Tensor): # With dynamic_shape=True, Dynamo sends FX graphs with dynamic # shapes (e.g., batch size is a symbol "batch" instead of a fixed # number) to OrtBackend.compile(...). - @torch._dynamo.optimize(local_aot_ort, dynamic=True) + @torch._dynamo.optimize(local_backend, dynamic=True) def optimized_elementwise_model(tensor_x: torch.Tensor): return elementwise_model(tensor_x) @@ -207,9 +216,8 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - @torch._dynamo.optimize(ort) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_() @@ -237,9 +245,7 @@ def copy_copy_copy(tensor_x: torch.Tensor): ) return tensor_x1, tensor_x2, tensor_x3 - @torch._dynamo.optimize(aot_ort) - def optimized_copy_copy_copy(tensor_x: torch.Tensor): - return copy_copy_copy(tensor_x) + optimized_copy_copy_copy = torch.compile(copy_copy_copy, backend="onnxrt") def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32) @@ -265,7 +271,7 @@ def run_no_input_model(): def no_input_model(): return torch.ops.aten.full([2, 3], 1.5) - @torch._dynamo.optimize(aot_ort) + @torch._dynamo.optimize("onnxrt") def optimized_no_input_model(): return no_input_model() @@ -291,9 +297,7 @@ def run_no_input_model(): def no_input_model(): return torch.ops.aten.full([2, 3], 1.5, device="cpu") - @torch._dynamo.optimize(aot_ort) - def optimized_no_input_model(): - return no_input_model() + optimized_no_input_model = torch.compile(no_input_model, backend="onnxrt") def run(fun): tensor_x = fun() @@ -355,7 +359,8 @@ def run(model, tensor_x, tensor_y): # Baseline. loss, grads = run(model, tensor_x, tensor_y) # ORT result. - compiled_model = torch._dynamo.optimize(aot_ort)(model) + local_backend = make_local_backend(dynamic=False, use_aot_autograd=True) + compiled_model = torch.compile(model, backend=local_backend, dynamic=False) loss_new, grads_new = run(compiled_model, tensor_x, tensor_y) print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).") diff --git a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py index c2a6ed504a206..dfc62dba427e5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py @@ -11,9 +11,10 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch.library import Library +from torch.onnx import _OrtBackend as OrtBackend +from torch.onnx import _OrtBackendOptions as OrtBackendOptions import onnxruntime -from onnxruntime.training.torchdynamo.ort_backend import OrtBackend # Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne # in ONNX model executed by DORT. @@ -112,16 +113,18 @@ def test_export_aten_mul_as_onnx_custom_op_and_run_ort(self): # In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`. ort_backend = OrtBackend( - ep="CPUExecutionProvider", - session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), - onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + OrtBackendOptions( + preferred_execution_providers="CPUExecutionProvider", + ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), + export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + ) ) # Wrap ORT executor as a Dynamo backend. aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, + decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table, ) def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor): @@ -169,19 +172,22 @@ def bar_impl(self: torch.Tensor) -> torch.Tensor: # Create executor of ONNX model. ort_backend = OrtBackend( - ep="CPUExecutionProvider", - session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), - onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry), + OrtBackendOptions( + preferred_execution_providers="CPUExecutionProvider", + ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), + export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + ) ) + # Allow torch.ops.foo.bar.default to be sent to DORT. # _support_dict tells Dynamo which ops to sent to DORT. - ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default) + ort_backend._supported_ops._support_dict[torch.ops.foo.bar.default] = None # Wrap ORT executor as a Dynamo backend. aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, + decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table, ) def one_foo(tensor_x: torch.Tensor): diff --git a/setup.py b/setup.py index 0c2eb19e82c87..685f0612e3762 100644 --- a/setup.py +++ b/setup.py @@ -464,7 +464,6 @@ def finalize_options(self): "onnxruntime.training.experimental", "onnxruntime.training.experimental.gradient_graph", "onnxruntime.training.optim", - "onnxruntime.training.torchdynamo", "onnxruntime.training.ortmodule", "onnxruntime.training.ortmodule.experimental", "onnxruntime.training.ortmodule.experimental.json_config", From 02b1ff5fa2c41dc026022ca29c9249628f71f026 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 4 Jan 2024 13:32:48 -0800 Subject: [PATCH 27/45] [QNN EP] Support multithreaded inference of a single session (#18981) ### Description - Add mutex to protect QNN API calls for executing a graph and extracting the corresponding profile data. - Ensures QNN EP's execute function does not store unnecessary state (i.e., input and output buffer pointers do not need to be stored as class members.) ### Motivation and Context Allow calling `session.Run()` from multiple threads when using QNN EP. --- .../core/providers/qnn/builder/qnn_def.cc | 9 + .../core/providers/qnn/builder/qnn_def.h | 1 + .../core/providers/qnn/builder/qnn_model.cc | 107 ++++++---- .../core/providers/qnn/builder/qnn_model.h | 19 +- .../test/providers/qnn/qnn_basic_test.cc | 194 +++++++++++++++++- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 8 +- .../win-qnn-arm64-ci-pipeline.yml | 6 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 4 +- 8 files changed, 292 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index a77ac16cf624b..55e72670a6971 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -89,6 +89,15 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector } } +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.clientBuf.data = buf_data; + qnn_tensor.v1.clientBuf.dataSize = buf_size; + } else { + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); + } +} + void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { qnn_tensor.v1.clientBuf.dataSize = client_buf_size; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index f6a3b1bd360ec..c202f2bf79c57 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -100,6 +100,7 @@ void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dime void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); +void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index fd3a95b5f1f78..869d9326d9232 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -166,14 +166,14 @@ Status QnnModel::FinalizeGraphs() { Status QnnModel::SetupQnnInputOutput() { LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); - auto result = SetupTensors(qnn_inputs_, graph_info_->InputTensors()); + auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); } - result = SetupTensors(qnn_outputs_, graph_info_->OutputTensors(), false); + result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); @@ -186,8 +186,8 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); - ORT_RETURN_IF_NOT(qnn_inputs_.size() <= num_inputs, "Inconsistent input sizes"); - ORT_RETURN_IF_NOT(qnn_outputs_.size() == num_outputs, "Inconsistent output sizes"); + ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes"); + ORT_RETURN_IF_NOT(qnn_output_infos_.size() == num_outputs, "Inconsistent output sizes"); using namespace qnn::utils; auto TensorDataSize = [&](auto ort_tensor) -> size_t { @@ -198,49 +198,67 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { return element_size * length; }; - for (auto& qnn_input_tensor : qnn_inputs_) { - const std::string& model_input_name(GetQnnTensorName(qnn_input_tensor)); - auto index = GetOrtInputIndex(model_input_name); - LOGS(logger_, VERBOSE) << "model_input = " << model_input_name << " index = " << index; - auto ort_input_tensor = context.GetInput(index); - auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_input_tensor).dataSize; + std::vector qnn_inputs; + qnn_inputs.reserve(qnn_input_infos_.size()); + + for (const auto& qnn_input_info : qnn_input_infos_) { + LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; + auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_tensor_size == ort_tensor_size, + LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; + ORT_ENFORCE(qnn_input_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size."); - SetQnnTensorClientBufData(qnn_input_tensor, - const_cast(ort_input_tensor.GetTensorData())); + + qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); + SetQnnTensorClientBuf(qnn_inputs.back(), + const_cast(ort_input_tensor.GetTensorData()), qnn_input_info.tensor_byte_size); } - for (auto& qnn_output_tensor : qnn_outputs_) { - const std::string& model_output_name(GetQnnTensorName(qnn_output_tensor)); - auto index = GetOutputIndex(model_output_name); - LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << index; - const auto& output_info = GetOutputInfo(model_output_name); - const std::vector& output_shape = output_info->shape_; - auto output_tensor = context.GetOutput(index, output_shape.data(), output_shape.size()); - auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_output_tensor).dataSize; - auto ort_tensor_size = TensorDataSize(output_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_tensor_size == ort_tensor_size, + std::vector qnn_outputs; + qnn_outputs.reserve(qnn_output_infos_.size()); + + for (auto& qnn_output_info : qnn_output_infos_) { + const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); + LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + const auto& ort_output_info = GetOutputInfo(model_output_name); + const std::vector& output_shape = ort_output_info->shape_; + auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); + auto ort_tensor_size = TensorDataSize(ort_output_tensor); + LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; + ORT_ENFORCE(qnn_output_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size"); - SetQnnTensorClientBufData(qnn_output_tensor, - const_cast(output_tensor.GetTensorData())); + + qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); + SetQnnTensorClientBuf(qnn_outputs.back(), + const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); } LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; - execute_status = qnn_interface.graphExecute(graph_info_->Graph(), - qnn_inputs_.data(), - static_cast(qnn_inputs_.size()), - qnn_outputs_.data(), - static_cast(qnn_outputs_.size()), - profile_backend_handle, - nullptr); - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + { + // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run() + // from multiple threads. + std::lock_guard lock(graph_exec_mutex_); + execute_status = qnn_interface.graphExecute(graph_info_->Graph(), + qnn_inputs.data(), + static_cast(qnn_inputs.size()), + qnn_outputs.data(), + static_cast(qnn_outputs.size()), + profile_backend_handle, + nullptr); + + // NOTE: This function returns immediately when profiling is disabled. + // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes + // and not in production. We can improve synchronization for event profiling if it becomes an issue. + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + } + if (QNN_GRAPH_NO_ERROR != execute_status) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN graph execute error. Error code: ", execute_status); } @@ -262,14 +280,13 @@ Status QnnModel::GetQnnTensorDataLength(const std::vector& dims, return Status::OK(); } -// Setup details for Qnn_Tensor_t for execution -// based on information in QnnTensorWrapper -Status QnnModel::SetupTensors(std::vector& qnn_tensors, +// Setup information for Qnn inputs/outputs used during execution. +Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, const std::vector& tensor_wrappers, bool is_input) { size_t tensor_count = tensor_wrappers.size(); ORT_RETURN_IF(0 == tensor_count, "Zero tensor size!"); - qnn_tensors.resize(tensor_count); + qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { size_t length = 0; @@ -277,10 +294,14 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensors, ORT_RETURN_IF_ERROR(GetQnnTensorDataLength(tensor_wrapper.GetTensorDims(), tensor_wrapper.GetTensorDataType(), length)); - auto tensor_name = tensor_wrapper.GetName(); - auto index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); - qnn_tensors[index] = tensor_wrapper.GetQnnTensor(); - SetQnnTensorClientBufSize(qnn_tensors[index], static_cast(length)); + const auto& tensor_name = tensor_wrapper.GetName(); + auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); + auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index; + + QnnTensorInfo& qnn_tensor_info = qnn_tensor_infos[qnn_index]; + qnn_tensor_info.tensor_wrapper = &tensor_wrapper; + qnn_tensor_info.tensor_byte_size = static_cast(length); + qnn_tensor_info.ort_index = ort_index; } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index de4f872f73ccf..d0dd091cb1688 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -3,8 +3,11 @@ #pragma once +#include + #include "core/common/status.h" #include "core/graph/graph_viewer.h" +#include "core/platform/ort_mutex.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" @@ -14,6 +17,12 @@ namespace onnxruntime { namespace qnn { +struct QnnTensorInfo { + const QnnTensorWrapper* tensor_wrapper = nullptr; + uint32_t tensor_byte_size = 0; + size_t ort_index = 0; +}; + class QnnModel { public: QnnModel(const logging::Logger& logger, @@ -103,7 +112,8 @@ class QnnModel { Qnn_DataType_t data_type, size_t& data_length) const; - Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, bool is_input = true); + Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, + bool is_input = true); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } @@ -126,9 +136,12 @@ class QnnModel { std::vector output_names_; std::unordered_map inputs_info_; std::unordered_map outputs_info_; - std::vector qnn_inputs_; - std::vector qnn_outputs_; + std::vector qnn_input_infos_; + std::vector qnn_output_infos_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + + // Mutex acquired during graph execution to support multi-threaded inference of a single session. + OrtMutex graph_exec_mutex_; }; } // namespace qnn diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 391d7bebc9589..f9064cad3fe12 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include +#include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -287,8 +288,199 @@ TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); } +struct ModelAndBuilder { + ModelAndBuilder(Graph& graph) : builder(graph) {} + std::string model_data; + ModelTestBuilder builder; +}; + +// Creates a model in memory. Input feeds and output names can be accessed from result.builder. +static void CreateModelInMemory(std::unique_ptr& result, + const GetTestModelFn& model_build_fn, + const std::string& model_name, + int opset_version = 18) { + const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + + // Create float model and serialize it to a string. + onnxruntime::Model model(model_name, false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + result = std::make_unique(model.MainGraph()); + model_build_fn(result->builder); + result->builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + model.ToProto().SerializeToString(&result->model_data); +} + +// Runs a session and verifies the outputs. Can be run by individual threads. +static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, + const std::vector& output_names, + const std::vector>& output_shapes, + const std::vector>& expected_values) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } +} + +// Returns a function that builds a float32 model that adds 3 tensors. +static GetTestModelFn F32BuildAdd3Tensors(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input0 = MakeTestInput(builder, input0_def); + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input2 = MakeTestInput(builder, input1_def); + + auto* add0_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input0, input1}, {add0_out}); + + auto* output = builder.MakeOutput(); + builder.AddNode("Add", {add0_out, input2}, {output}); + }; +} + +// Tests running a single session in multiple threads on the CPU backend. +TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + F32BuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.f32"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnCpu.dll"; +#else + options["backend_path"] = "libQnnCpu.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values)); + } + + for (auto& th : threads) { + th.join(); + } +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Returns a function that builds a QDQ model that adds 3 tensors. Forces all scales and zero-points to be (1.0f, 0), +// so it is only accurate when using non-fractional positive inputs. +template +static GetTestModelFn QDQBuildAdd3Tensors(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input0 = MakeTestInput(builder, input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, 1.0f, 0); + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, 1.0f, 0); + NodeArg* input2 = MakeTestInput(builder, input1_def); + NodeArg* input2_after_qdq = AddQDQNodePair(builder, input2, 1.0f, 0); + + auto* add0_out = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_qdq, input1_after_qdq}, {add0_out}); + + auto* add0_out_dq = AddQDQNodePair(builder, add0_out, 1.0f, 0); + + auto* add1_out = builder.MakeIntermediate(); + builder.AddNode("Add", {add0_out_dq, input2_after_qdq}, {add1_out}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add1_out, 1.0f, 0); + }; +} + +// Tests running a single session in multiple threads on the HTP backend. +TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values)); + } + + for (auto& th : threads) { + th.join(); + } +} + // Test shape inference of QDQ NHWC Resize operator (opset 18) that uses // the sizes input. Use the QNN HTP backend. TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 07e69ff496720..d286c4f3a46fe 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -86,7 +86,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ cmake/external/onnx/onnx/backend/test/data/node - task: CmdLine@2 @@ -94,7 +94,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \ /data/float32_models - task: CmdLine@2 @@ -102,7 +102,7 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ + -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models - task: CmdLine@2 @@ -110,5 +110,5 @@ jobs: inputs: script: | ./build/Release/onnx_test_runner -e qnn \ - -v -f -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ + -v -f -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models/mobilenetv2-1.0_add_transpose_quant diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 5e35cbfed6692..6dc428d6606af 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -84,17 +84,17 @@ jobs: displayName: 'Run unit tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run ONNX Tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run float32 model tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run QDQ model tests' enabled: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 65b2924c8be60..fbec572fd346c 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -88,11 +88,11 @@ jobs: displayName: 'Run unit tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run ONNX Tests' - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models + .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' displayName: 'Run float32 model tests' From e10a8ae31feba949b682f2451268c0dc68589ba3 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Thu, 4 Jan 2024 17:41:01 -0800 Subject: [PATCH 28/45] reduce max/min 20 (#17805) ### Description reducemax/min have been updated in onnx(20). implement it in ort ### Motivation and Context this is for ort1.17.0 release --------- Signed-off-by: Liqun Fu --- docs/OperatorKernels.md | 6 +- .../providers/cpu/cpu_execution_provider.cc | 100 +++-- .../cpu/reduction/reduction_kernel_base.h | 40 ++ .../providers/cpu/reduction/reduction_ops.cc | 101 ++++- .../providers/cpu/reduction/reduction_ops.h | 175 +++++--- .../providers/cuda/reduction/reduction_ops.h | 2 +- onnxruntime/test/onnx/TestCase.cc | 2 +- .../cpu/reduction/reduction_ops_test.cc | 398 +++++++++++++++++- .../onnx_backend_test_series_filters.jsonc | 55 ++- 9 files changed, 737 insertions(+), 142 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e401baae2d803..f985cf10ded60 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -278,7 +278,8 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -287,7 +288,8 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 1390f60243174..f60c7ddac5c05 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -850,21 +850,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceLogSumExp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceProd); @@ -960,6 +960,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, bool, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, bool, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample); @@ -2263,36 +2277,36 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceLogSumExp)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h b/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h new file mode 100644 index 0000000000000..5725e85f8e1e4 --- /dev/null +++ b/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h @@ -0,0 +1,40 @@ +#ifndef CORE_PROVIDERS_CPU_REDUCTION_KERNEL_BASE_H +#define CORE_PROVIDERS_CPU_REDUCTION_KERNEL_BASE_H + +#ifndef SHARED_PROVIDER +#include "core/common/optional.h" +#include "core/framework/op_kernel.h" +#endif + +namespace onnxruntime { + +template +class ReduceKernelBase { + protected: + ReduceKernelBase(const OpKernelInfo& info, optional keepdims_override = {}) { + if (allow_multi_axes) { + axes_ = ToShapeVector(info.GetAttrsOrDefault("axes")); + } else { + auto v = info.GetAttrOrDefault("axis", 0); + axes_.push_back(v); + } + int64_t keepdims = 1; + if (keepdims_override.has_value()) { + keepdims = *keepdims_override; + } else { + ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK()); + } + keepdims_ = (keepdims == 1); + int64_t noop_with_empty_axes = info.GetAttrOrDefault("noop_with_empty_axes", 0); + noop_with_empty_axes_ = (noop_with_empty_axes == 1); + int64_t select_last_index = info.GetAttrOrDefault("select_last_index", 0); + select_last_index_ = (select_last_index != 0); + } + + TensorShapeVector axes_; + bool keepdims_; + bool noop_with_empty_axes_; + bool select_last_index_; +}; +} // namespace onnxruntime +#endif // !CORE_PROVIDERS_CPU_REDUCTION_KERNEL_BASE_H diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 3c83394fb0bf4..244da35427f49 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -114,6 +114,14 @@ namespace onnxruntime { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); +#define REGISTER_UNARY_ELEMENTWISE_KERNEL_BOOL_ONLY(x, sinceVersion) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + x, \ + sinceVersion, \ + bool, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + x); + REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12); @@ -173,11 +181,18 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMax, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ReduceMax, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ReduceMax, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMax, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMax, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ReduceMax, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMax, 18); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMax, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMax, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ReduceMax, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ReduceMax, 18, 19); + +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMax, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMax, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ReduceMax, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMax, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_BOOL_ONLY(ReduceMax, 20); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); @@ -207,11 +222,18 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMin, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ReduceMin, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ReduceMin, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMin, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMin, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ReduceMin, 18); -REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMin, 18); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMin, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMin, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ReduceMin, 18, 19); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ReduceMin, 18, 19); + +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMin, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMin, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ReduceMin, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMin, 20); +REGISTER_UNARY_ELEMENTWISE_KERNEL_BOOL_ONLY(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 1, 10); @@ -822,10 +844,57 @@ static void ValidateKeepDims(const Tensor* input, int64_t keepdims) { ValidateKeepDims(input->Shape(), keepdims); } +template +bool check_and_reduce_empty_set_input(OpKernelContext* ctx, const gsl::span axes, bool keepdims) { + const Tensor* input = ctx->Input(0); + const TensorShape& input_shape = input->Shape(); + if (input_shape.Size() != 0) { + return false; + } + + // input is an empty set + std::vector input_axes; + if (ctx->InputCount() == 2) { + ORT_ENFORCE(axes.empty(), "Axes input and attribute should not both be present for reduction."); + // second input holds the axes. + const Tensor* axes_tensor = ctx->Input(1); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + input_axes.insert(input_axes.begin(), data, data + nDims); + } else { + input_axes.resize(axes.size()); + std::copy(axes.begin(), axes.end(), input_axes.begin()); + } + + gsl::span shape_dims = input_shape.GetDims(); + const int64_t input_shape_size = narrow(shape_dims.size()); + TensorShapeVector output_shape_vector; + for (int64_t i = 0; i < input_shape_size; ++i) { + if (input_axes.empty() || std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) { + if (keepdims) { + output_shape_vector.push_back(1); + } + } else { + output_shape_vector.push_back(input_shape[onnxruntime::narrow(i)]); + } + } + + TensorShape output_shape(output_shape_vector); + Tensor* output = ctx->Output(0, output_shape); + if (output_shape.Size() != 0) { + AGG::fill_for_empty_set(*output); + } + return true; +} + template void CommonReduce1Loop(OpKernelContext* ctx, const gsl::span& axes_, int64_t keepdims_, bool noop_with_empty_axes) { + if (check_and_reduce_empty_set_input(ctx, axes_, keepdims_ != 0)) { + return; + } + FastReduceKind fast_kind; TensorShapeVector fast_shape; TensorShapeVector output_shape; @@ -838,8 +907,8 @@ void CommonReduce1Loop(OpKernelContext* ctx, const Tensor* input = ctx->Input(0); Tensor* output = ctx->Output(0, output_shape); if (fast_kind == FastReduceKind::kEmpty) { - const TensorShape& new_input_shape = input->Shape(); - if (new_input_shape.Size() == 1) { + const TensorShape& input_shape = input->Shape(); + if (input_shape.Size() == 1) { const typename AGG::input_type* from_data = input->Data(); typename AGG::value_type* to_data = output->MutableData(); AGG agg(1, *from_data); @@ -859,6 +928,10 @@ template void CommonReduce2Loops(OpKernelContext* ctx, const gsl::span& axes_, int64_t keepdims_, bool noop_with_empty_axes) { + if (check_and_reduce_empty_set_input(ctx, axes_, keepdims_ != 0)) { + return; + } + FastReduceKind fast_kind; TensorShapeVector fast_shape, output_shape, fast_axes; if (CommonFastReduce(ctx, axes_, keepdims_, noop_with_empty_axes, @@ -869,8 +942,8 @@ void CommonReduce2Loops(OpKernelContext* ctx, const Tensor* input = ctx->Input(0); Tensor* output = ctx->Output(0, output_shape); if (fast_kind == FastReduceKind::kEmpty) { - const TensorShape& new_input_shape = input->Shape(); - if (new_input_shape.Size() == 1) { + const TensorShape& input_shape = input->Shape(); + if (input_shape.Size() == 1) { const typename AGG::input_type* from_data = input->Data(); typename AGG::value_type* to_data = output->MutableData(); AGG agg(1, *from_data); diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h index 7105fd2ddad2e..4d205acaa015a 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h @@ -11,8 +11,10 @@ #include "core/providers/cpu/containers.h" #include "core/util/math.h" #endif +#include "core/framework/math.h" #include "core/util/math_cpuonly.h" #include "core/platform/threadpool.h" +#include "core/providers/cpu/reduction/reduction_kernel_base.h" #include "core/common/safeint.h" #include @@ -178,6 +180,7 @@ class ReduceAggregator : public ReduceAggregatorBase { inline void update0(const T&) {} inline TVAL aggall(const T*) {} inline TVAL get_value() { return accumulator_; } + static void fill_for_empty_set(Tensor&) { ORT_NOT_IMPLEMENTED(); } protected: static void CommonFastReduceRKR(const Tensor& input, const gsl::span& fast_shape, @@ -217,6 +220,10 @@ class ReduceAggregatorSum : public ReduceAggregator { return aggall(from_data, this->N_); } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = static_cast(0); + } + // Fast reduction static inline FastReduceKind WhichFastReduce() { return FastReduceKind::kKR | FastReduceKind::kRK | FastReduceKind::kKRK | FastReduceKind::kRKR; @@ -290,6 +297,9 @@ class ReduceAggregatorSumSquare : public ReduceAggregator { return Eigen::Map>(from_data, onnxruntime::narrow(this->N_)).squaredNorm(); } inline void update(const T& v) { this->accumulator_ += v * v; } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = static_cast(0); + } }; template @@ -363,7 +373,11 @@ class ReduceAggregatorMax : public ReduceAggregator { public: inline ReduceAggregatorMax(int64_t N, const T& init) : ReduceAggregator(N, init) {} static T aggall(const T* from_data, int64_t size) { - return Eigen::Map>(from_data, onnxruntime::narrow(size)).maxCoeff(); + if constexpr (std::is_same_v) { /* bool specific impl */ + return Eigen::Map>(from_data, onnxruntime::narrow(size)).cast().maxCoeff(); + } else { /* generic impl */ + return Eigen::Map>(from_data, onnxruntime::narrow(size)).maxCoeff(); + } } inline T aggall(const T* from_data) { return aggall(from_data, this->N_); @@ -383,10 +397,19 @@ class ReduceAggregatorMax : public ReduceAggregator { concurrency::ThreadPool::TryParallelFor( tp, onnxruntime::narrow(fast_shape[0]), ParallelReduceFastCost(1, stridei, sizeof(T), 6), [data, stridei, out](std::ptrdiff_t first, std::ptrdiff_t last) { - EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( - data + first * stridei, onnxruntime::narrow(stridei), last - first) - .colwise() - .maxCoeff(); + if constexpr (std::is_same_v) { /* bool specific impl */ + EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( + data + first * stridei, onnxruntime::narrow(stridei), last - first) + .cast() + .colwise() + .maxCoeff() + .cast(); + } else { + EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( + data + first * stridei, onnxruntime::narrow(stridei), last - first) + .colwise() + .maxCoeff(); + } }); } @@ -405,8 +428,12 @@ class ReduceAggregatorMax : public ReduceAggregator { for (int64_t row = 1; row < n_rows; ++row) { p = data + row * N; for (int64_t j = begin; j < end; ++j) { - if (out[j] < p[j]) - out[j] = p[j]; + if constexpr (std::is_same_v) { /* bool specific impl */ + out[j] = out[j] || p[j]; + } else { + if (out[j] < p[j]) + out[j] = p[j]; + } } } }); @@ -422,11 +449,21 @@ class ReduceAggregatorMax : public ReduceAggregator { tp, onnxruntime::narrow(fast_shape[0]), ParallelReduceFastCost(fast_shape[1], fast_shape[2], sizeof(T), 6), [data, fast_shape, stridei, strideo, out](ptrdiff_t begin, ptrdiff_t end) { for (ptrdiff_t j = begin; j < end; ++j) { - EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = - ConstEigenMatrixMap( - data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) - .rowwise() - .maxCoeff(); + if constexpr (std::is_same_v) { /* bool specific impl */ + EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = + ConstEigenMatrixMap( + data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) + .cast() + .rowwise() + .maxCoeff() + .cast(); + } else { + EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = + ConstEigenMatrixMap( + data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) + .rowwise() + .maxCoeff(); + } } }); } @@ -438,8 +475,12 @@ class ReduceAggregatorMax : public ReduceAggregator { [=](const T* p) -> T { return p[0]; }, [=](T& value, const T* p, int64_t size) { T v = aggall(p, size); - if (v > value) - value = v; + if constexpr (std::is_same_v) { /* bool specific impl */ + value = value || v; + } else { + if (v > value) + value = v; + } }); } }; @@ -545,6 +586,14 @@ class ReduceAggregatorMin : public ReduceAggregator { } inline void update(const T& v) { this->accumulator_ = v < this->accumulator_ ? v : this->accumulator_; } + static void fill_for_empty_set(Tensor& output) { + if constexpr (std::is_same_v) { /* bool specific impl */ + ORT_NOT_IMPLEMENTED(); + } else { + EigenMap(output).array() = std::numeric_limits::infinity(); + } + } + // Fast reduction static inline FastReduceKind WhichFastReduce() { return FastReduceKind::kKR | FastReduceKind::kRK | FastReduceKind::kKRK | FastReduceKind::kRKR; @@ -558,10 +607,19 @@ class ReduceAggregatorMin : public ReduceAggregator { concurrency::ThreadPool::TryParallelFor( tp, onnxruntime::narrow(fast_shape[0]), ParallelReduceFastCost(1, stridei, sizeof(T), 6), [data, stridei, out](std::ptrdiff_t first, std::ptrdiff_t last) { - EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( - data + first * stridei, onnxruntime::narrow(stridei), last - first) - .colwise() - .minCoeff(); + if constexpr (std::is_same_v) { /* bool specific impl */ + EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( + data + first * stridei, onnxruntime::narrow(stridei), last - first) + .cast() + .colwise() + .minCoeff() + .cast(); + } else { + EigenVectorMap(out + first, last - first) = ConstEigenMatrixMap( + data + first * stridei, onnxruntime::narrow(stridei), last - first) + .colwise() + .minCoeff(); + } }); } @@ -580,8 +638,12 @@ class ReduceAggregatorMin : public ReduceAggregator { for (int64_t row = 1; row < n_rows; ++row) { p = data + row * N; for (int64_t j = begin; j < end; ++j) { - if (out[j] > p[j]) - out[j] = p[j]; + if constexpr (std::is_same_v) { /* bool specific impl */ + out[j] = out[j] && p[j]; + } else { + if (out[j] > p[j]) + out[j] = p[j]; + } } } }); @@ -597,11 +659,21 @@ class ReduceAggregatorMin : public ReduceAggregator { tp, onnxruntime::narrow(fast_shape[0]), ParallelReduceFastCost(fast_shape[1], fast_shape[2], sizeof(T), 6), [data, fast_shape, stridei, strideo, out](ptrdiff_t begin, ptrdiff_t end) { for (ptrdiff_t j = begin; j < end; ++j) { - EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = - ConstEigenMatrixMap( - data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) - .rowwise() - .minCoeff(); + if constexpr (std::is_same_v) { /* bool specific impl */ + EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = + ConstEigenMatrixMap( + data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) + .cast() + .rowwise() + .minCoeff() + .cast(); + } else { + EigenVectorMap(out + j * strideo, onnxruntime::narrow(strideo)) = + ConstEigenMatrixMap( + data + j * stridei, onnxruntime::narrow(fast_shape[2]), onnxruntime::narrow(fast_shape[1])) + .rowwise() + .minCoeff(); + } } }); } @@ -613,8 +685,12 @@ class ReduceAggregatorMin : public ReduceAggregator { [=](const T* p) -> T { return p[0]; }, [=](T& value, const T* p, int64_t size) { T v = aggall(p, size); - if (v < value) - value = v; + if constexpr (std::is_same_v) { /* bool specific impl */ + value = value && v; + } else { + if (v < value) + value = v; + } }); } }; @@ -627,6 +703,9 @@ class ReduceAggregatorProd : public ReduceAggregator { return Eigen::Map>(from_data, onnxruntime::narrow(this->N_)).prod(); } inline void update(const T& v) { this->accumulator_ *= v; } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = static_cast(1); + } }; template @@ -637,6 +716,10 @@ class ReduceAggregatorL1 : public ReduceAggregator { return Eigen::Map>(from_data, onnxruntime::narrow(this->N_)).cwiseAbs().sum(); } inline void update(const T& v) { this->accumulator_ += v > 0 ? v : -v; } + + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = static_cast(0); + } }; template @@ -648,6 +731,9 @@ class ReduceAggregatorL2 : public ReduceAggregator { } inline void update(const T& v) { this->accumulator_ += v * v; } inline T get_value() { return reduce_sqrt(this->accumulator_); } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = static_cast(0); + } }; template @@ -659,6 +745,9 @@ class ReduceAggregatorLogSum : public ReduceAggregator { } inline void update(const T& v) { this->accumulator_ += v; } inline T get_value() { return reduce_log(this->accumulator_); } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = -std::numeric_limits::infinity(); + } }; template @@ -682,6 +771,9 @@ class ReduceAggregatorLogSumExp : public ReduceAggregator { } inline void update(const T& v) { this->accumulator_ += reduce_exp(v - max_); } inline T get_value() { return reduce_log(this->accumulator_) + max_; } + static void fill_for_empty_set(Tensor& output) { + EigenMap(output).array() = -std::numeric_limits::infinity(); + } }; void NoTransposePrepareForReduce(const TensorShape& new_input_shape, @@ -710,35 +802,6 @@ void CommonReduce2Loops(OpKernelContext* ctx, const gsl::span& axes_, int64_t keepdims_, bool noop_with_empty_axes = false); -template -class ReduceKernelBase { - protected: - ReduceKernelBase(const OpKernelInfo& info, optional keepdims_override = {}) { - if (allow_multi_axes) { - axes_ = ToShapeVector(info.GetAttrsOrDefault("axes")); - } else { - auto v = info.GetAttrOrDefault("axis", 0); - axes_.push_back(v); - } - int64_t keepdims = 1; - if (keepdims_override.has_value()) { - keepdims = *keepdims_override; - } else { - ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK()); - } - keepdims_ = (keepdims == 1); - int64_t noop_with_empty_axes = info.GetAttrOrDefault("noop_with_empty_axes", 0); - noop_with_empty_axes_ = (noop_with_empty_axes == 1); - int64_t select_last_index = info.GetAttrOrDefault("select_last_index", 0); - select_last_index_ = (select_last_index != 0); - } - - TensorShapeVector axes_; - bool keepdims_; - bool noop_with_empty_axes_; - bool select_last_index_; -}; - template class ReduceKernel : public OpKernel, public ReduceKernelBase { protected: diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h index ee8e13db2eb53..c22ff2d01a37d 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/optional.h" #include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cpu/reduction/reduction_ops.h" +#include "core/providers/cpu/reduction/reduction_kernel_base.h" #include "core/providers/cuda/reduction/reduction_functions.h" namespace onnxruntime { diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 6d07ddde5c442..57c2061883736 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -954,7 +954,6 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"reduce_log_sum_exp_empty_set_expanded", "unknown version", {}}, {"reduce_prod_empty_set", "unknown version", {}}, {"reduce_sum_empty_set", "unknown version", {}}, - {"reduce_sum_square_empty_set", "unknown version", {}}, {"reduce_sum_square_empty_set_expanded", "unknown version", {}}, #ifdef ENABLE_TRAINING_CORE {"adagrad", "not a registered function/op", {}}, // Op not registered. @@ -1352,6 +1351,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); broken_tests->insert({"spacetodepth", "result differs"}); + broken_tests->insert({"reduce_sum_square_empty_set_expanded", "unknown version"}); // Fails with QNN SDK 2.17.0: // expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ broken_tests->insert({"facedetection_op8_qdq", "result differs"}); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 79da8004a9edd..b0e0a0dd0d564 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -924,7 +924,280 @@ TEST(ReductionOpTest, ReduceMax_default_axes_do_not_keep_dims) { 55.0f, 1.0f, 60.0f, 2.0f}); test.AddOutput("reduced", {}, {60.0f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: full reduce without keepDimensions is not supported with explicit batch //TensorRT: axis must be 0 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: full reduce without keepDimensions is not supported with explicit batch //TensorRT: axis must be 0 +} + +TEST(ReductionOpTest, test_bool_ReduceMax_0) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {-1, 1}); + test.AddOutput("reduced", {2}, {true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_1) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {-1, 1}); + test.AddOutput("reduced", {2}, {false, false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_2) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {-1, 1}); + test.AddOutput("reduced", {2, 1, 1}, {true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + } + + ); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_3) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {-1, 1}); + test.AddOutput("reduced", {2, 1, 1}, {false, false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_4) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {2, 1}); + test.AddOutput("reduced", {2}, {true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_5) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {2, 1}); + test.AddOutput("reduced", {2}, {false, false}); + test.Run(); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_6) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {2, 1}); + test.AddOutput("reduced", {2, 1, 1}, {true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_7) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {2}, {2, 1}); + test.AddOutput("reduced", {2, 1, 1}, {false, false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_8) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {0}); + test.AddOutput("reduced", {3, 2}, {false, true, true, true, false, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_9) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {0}); + test.AddOutput("reduced", {3, 2}, {false, false, false, true, false, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_10) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {0}); + test.AddOutput("reduced", {1, 3, 2}, {false, true, true, true, false, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_11) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {0}); + test.AddOutput("reduced", {1, 3, 2}, {false, false, false, true, false, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_12) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {2}); + test.AddOutput("reduced", {2, 3}, {false, true, true, true, true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_13) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {2}); + test.AddOutput("reduced", {2, 3}, {false, true, false, false, false, false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_14) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {2}); + test.AddOutput("reduced", {2, 3, 1}, {false, true, true, true, true, true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_15) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddInput("axes", {1}, {2}); + test.AddOutput("reduced", {2, 3, 1}, {false, true, false, false, false, false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_16) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddOutput("reduced", {}, {true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_17) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(0)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddOutput("reduced", {}, {false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMax_18) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddOutput("reduced", {1, 1, 1}, {true}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); +} + +TEST(ReductionOpTest, test_bool_ReduceMin_19) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", static_cast(1)); + test.AddInput("data", {2, 3, 2}, {false, false, true, true, false, true, false, true, false, true, false, true}); + test.AddOutput("reduced", {1, 1, 1}, {false}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kOpenVINOExecutionProvider, + }); } TEST(ReductionOpTest, ReduceMax_do_not_keepdims) { @@ -3254,7 +3527,7 @@ TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1b) { // test that PrepareForReduce handles this case. Called by all reduction ops so any op can be used in the test TEST(ReductionOpTest, ReduceDimWithZero1) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr || DefaultRocmExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{1,0,1}] did not match run output shape [{1,1,1}] for reduced"; } @@ -3264,8 +3537,12 @@ TEST(ReductionOpTest, ReduceDimWithZero1) { tester.Run(expect, error_msg, // exclude EPs that don't handle this + // TODO: fix reduce kernel for zero set cases. see: https://github.com/microsoft/onnxruntime/issues/18588 { kCoreMLExecutionProvider, + kCudaExecutionProvider, + kDnnlExecutionProvider, + kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider, kTensorrtExecutionProvider, @@ -3275,9 +3552,8 @@ TEST(ReductionOpTest, ReduceDimWithZero1) { // reduce on all axes keeping dims. should allow the 0 to be the reduced value OpTester test("ReduceSum", 10); test.AddAttribute("keepdims", int64_t(1)); - test.AddShapeToTensorData(true, 1); // make second dim symbolic so that we don't break during shape inferencing test.AddInput("data", {3, 0, 2}, {}); - test.AddOutput("reduced", {1, 0, 1}, {}); + test.AddOutput("reduced", {1, 1, 1}, {0.0f}); run(test); } @@ -3301,8 +3577,8 @@ TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero2) { TEST(ReductionOpTest, ReduceDimWithZero2) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{3,0,2}"; + if (DefaultDmlExecutionProvider().get() != nullptr || DefaultRocmExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{?,0,?}"; } auto run = [](OpTester& tester, const std::string& error_msg = "") { @@ -3311,23 +3587,25 @@ TEST(ReductionOpTest, ReduceDimWithZero2) { tester.Run(expect, error_msg, // exclude EPs that don't handle this + // TODO: fix reduce kernel for zero set cases. see: https://github.com/microsoft/onnxruntime/issues/18588 { + kCoreMLExecutionProvider, + kCudaExecutionProvider, + kDnnlExecutionProvider, + kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider, kTensorrtExecutionProvider, - kCoreMLExecutionProvider, }); }; - // reduction without keeping dims on all axes. can't reduce on an axis with value of 0 + // reducing on all axes including one or more with 0 dimension, with keepdims=0, results a scalar of 0. OpTester test2("ReduceSum", 10); test2.AddAttribute("keepdims", int64_t(0)); test2.AddShapeToTensorData(true, 1); test2.AddInput("data", {3, 0, 2}, {}); - test2.AddOutput("reduced", {}, {0.f}); - run(test2, - "Can't reduce on dim with value of 0 if 'keepdims' is false. " - "Invalid output shape would be produced. input_shape:{3,0,2}"); + test2.AddOutput("reduced", {}, {0.0f}); + run(test2); } TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero3) { @@ -5478,5 +5756,101 @@ TEST(ReductionOpTest, ReduceSum_RKRK_keepdims) { test.Run(); } +void test_empty_set(const std::string& op, int opset, bool axes_as_input, float empty_value) { + OpTester test(op, opset); + std::vector input_shape = {2, 0, 4}; + int64_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), static_cast(1), std::multiplies()); + std::vector data(input_size); + test.AddInput("data", input_shape, data); + std::vector axes = {1}; + if (axes_as_input) { + test.AddInput("axes", {(int64_t)(axes.size())}, axes); + } else { + test.AddAttribute("axes", axes); + } + + std::vector output_shape = {2, 1, 4}; + int64_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), static_cast(1), std::multiplies()); + std::vector reduced(output_size, empty_value); + test.AddOutput("reduced", output_shape, reduced); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + { + kCoreMLExecutionProvider, + kCudaExecutionProvider, + kDmlExecutionProvider, + kDnnlExecutionProvider, + kMIGraphXExecutionProvider, + kOpenVINOExecutionProvider, + kQnnExecutionProvider, + kRocmExecutionProvider, + kTensorrtExecutionProvider, + }); +} + +TEST(ReductionOpTest, empty_set_ReduceL1) { + test_empty_set("ReduceL1", 20, true, 0); +} + +TEST(ReductionOpTest, empty_set_ReduceL1_13) { + test_empty_set("ReduceL1", 13, false, 0); +} + +TEST(ReductionOpTest, empty_set_ReduceL2) { + test_empty_set("ReduceL2", 20, true, 0); +} + +TEST(ReductionOpTest, empty_set_ReduceL2_13) { + test_empty_set("ReduceL2", 13, false, 0); +} + +TEST(ReductionOpTest, empty_set_ReduceLogSum) { + test_empty_set("ReduceLogSum", 20, true, -std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceLogSum_13) { + test_empty_set("ReduceLogSum", 13, false, -std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceLogSumExp) { + test_empty_set("ReduceLogSumExp", 20, true, -std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceLogSumExp_13) { + test_empty_set("ReduceLogSumExp", 13, false, -std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceMin) { + test_empty_set("ReduceMin", 20, true, std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceMin_13) { + test_empty_set("ReduceMin", 13, false, std::numeric_limits::infinity()); +} + +TEST(ReductionOpTest, empty_set_ReduceProd) { + test_empty_set("ReduceProd", 20, true, 1.0f); +} + +TEST(ReductionOpTest, empty_set_ReduceProd_13) { + test_empty_set("ReduceProd", 13, false, 1.0f); +} + +TEST(ReductionOpTest, empty_set_ReduceSum) { + test_empty_set("ReduceSum", 20, true, 0.0f); +} + +TEST(ReductionOpTest, empty_set_ReduceSum_13) { + test_empty_set("ReduceSum", 11, false, 0.0f); +} + +TEST(ReductionOpTest, empty_set_ReduceSumSquare) { + test_empty_set("ReduceSumSquare", 20, true, 0.0f); +} + +TEST(ReductionOpTest, empty_set_ReduceSumSquare_13) { + test_empty_set("ReduceSumSquare", 13, false, 0.0f); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 49d8d7150a117..3a13e39702904 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -262,22 +262,18 @@ "^test_string_split_empty_tensor", "^test_string_split_maxsplit", "^test_string_split_no_delimiter", - "^test_reduce_max_bool_inputs", - "^test_reduce_min_bool_inputs", - "^test_reduce_min_empty_set", - "^test_reduce_l1_empty_set", - "^test_reduce_l1_empty_set_expanded", - "^test_reduce_l2_empty_set", - "^test_reduce_l2_empty_set_expanded", - "^test_reduce_log_sum_empty_set", - "^test_reduce_log_sum_empty_set_expanded", - "^test_reduce_log_sum_exp_empty_set", - "^test_reduce_log_sum_exp_empty_set_expanded", - "^test_reduce_prod_empty_set", - "^test_reduce_sum_empty_set", - "^test_reduce_sum_empty_set_non_reduced_axis_zero", - "^test_reduce_sum_square_empty_set", - "^test_reduce_sum_square_empty_set_expanded" + "^test_reduce_l1_empty_set_cuda", + "^test_reduce_l1_empty_set_expanded_cuda", + "^test_reduce_l2_empty_set_cuda", + "^test_reduce_l2_empty_set_expanded_cuda", + "^test_reduce_log_sum_empty_set_cuda", + "^test_reduce_log_sum_empty_set_expanded_cuda", + "^test_reduce_log_sum_exp_empty_set_cuda", + "^test_reduce_log_sum_exp_empty_set_expanded_cuda", + "^test_reduce_prod_empty_set_cuda", + "^test_reduce_sum_empty_set_cuda", + "^test_reduce_sum_square_empty_set_cuda", + "^test_reduce_sum_square_empty_set_expanded_cuda" ], "current_failing_tests_x86": [ "^test_vgg19", @@ -377,7 +373,23 @@ "^test_constantofshape_int_zeros", // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1141563&view=logs&j=a018b46d-e41a-509d-6581-c95fdaa42fcd&t=d61c1d37-f101-5d28-982f-e5931b720302 "^test_gelu_tanh_2_cpu", - "^test_gelu_tanh_2_expanded_cpu" + "^test_gelu_tanh_2_expanded_cpu", + "^test_reduce_max_bool_inputs", + "^test_reduce_min_bool_inputs", + "^test_reduce_min_empty_set", + "^test_reduce_l1_empty_set", + "^test_reduce_l1_empty_set_expanded", + "^test_reduce_l2_empty_set", + "^test_reduce_l2_empty_set_expanded", + "^test_reduce_log_sum_empty_set", + "^test_reduce_log_sum_empty_set_expanded", + "^test_reduce_log_sum_exp_empty_set", + "^test_reduce_log_sum_exp_empty_set_expanded", + "^test_reduce_prod_empty_set", + "^test_reduce_sum_empty_set", + "^test_reduce_sum_empty_set_non_reduced_axis_zero", + "^test_reduce_sum_square_empty_set", + "^test_reduce_sum_square_empty_set_expanded" ], "current_failing_tests_NNAPI": [ "^test_maxpool_2d_uint8", @@ -498,7 +510,8 @@ "test_range_int32_type_negative_delta_expanded_cpu", // Error but not a failure. "test_range_float_type_positive_delta_expanded_cpu", // Error but not a failure. "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance. - "test_scan9_sum_cpu" // Disabled due to output mismatch with tolerance. + "test_scan9_sum_cpu", // Disabled due to output mismatch with tolerance. + "test_reduce_max_bool_inputs_cpu" ], "current_failing_tests_OPENVINO_NPU_FP16": [ "^test_prelu_broadcast", @@ -656,8 +669,10 @@ "^test_affine_grid_3d_expanded", "^test_constantofshape_float_ones", "^test_constantofshape_int_shape_zero", - "^test_constantofshape_int_zeros" - + "^test_constantofshape_int_zeros", + "^test_reduce_log_sum_empty_set_cpu", + "^test_reduce_log_sum_exp_empty_set_cpu", + "^test_reduce_prod_empty_set_cpu" ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ From e155c66b4a3df8a173ce3b16b31707aefec7b052 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 4 Jan 2024 17:44:49 -0800 Subject: [PATCH 29/45] Change all macOS python packages to use universal2 (#19013) ### Description Change all macOS python packages to use universal2, to reduce the number of packages we have. ### Motivation and Context According to [wikipedia](https://en.wikipedia.org/wiki/MacOS_Big_Sur), macOS 11 is the first macOS version that supports universal 2. And it is the min macOS version we support. So we no longer need to maintain separate binaries for different CPU archs. --- .../templates/py-packaging-stage.yml | 95 +------------------ .../linux/docker/scripts/requirements.txt | 4 +- 2 files changed, 5 insertions(+), 94 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index f2b91bbaacb89..44904f9248b10 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -356,96 +356,11 @@ stages: inputs: versionSpec: $(PythonVersion) - - template: use-xcode-version.yml - - script: | set -e -x - pushd . - mkdir -p /tmp/scripts - mkdir -p $(Build.BinariesDirectory)/installed - cp $(Build.SourcesDirectory)/cmake/deps.txt /tmp/scripts - $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/installed - popd - export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH - export ONNX_ML=1 - export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-x86_64 + export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-universal2 python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir $(Build.BinariesDirectory) --use_coreml --skip_submodule_sync --parallel --config Release --skip_onnx_tests --build_wheel ${{ parameters.build_py_parameters }} - displayName: 'Command Line Script' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)/Release/dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - - ${{ if eq(parameters.enable_mac_silicon, true) }}: - - job: MacOS_silicon_py_Wheels - timeoutInMinutes: 120 - workspace: - clean: all - pool: - vmImage: 'macOS-13' - variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' - strategy: - # As of 3.9.1, Python now fully supports building and running on macOS 11.0 (Big Sur) and on Apple Silicon Macs (based on the ARM64 architecture). - # https://docs.python.org/3/whatsnew/3.9.html - matrix: - Python38: - PythonVersion: '3.8' - Python39: - PythonVersion: '3.9' - Python310: - PythonVersion: '3.10' - Python311: - PythonVersion: '3.11' - steps: - - checkout: self - clean: true - submodules: recursive - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: $(PythonVersion) - - - script: | - set -ex - uname -m - system_profiler SPSoftwareDataType SPHardwareDataType - displayName: 'Mac machine info' - - - template: use-xcode-version.yml - - # Don't remove _PYTHON_HOST_PLATFORM, it's used to generate correct package name - # Setting _PYTHON_HOST_PLATFORM overwrites the value return by get_platform() - # Ref: https://wiki.debian.org/Python/MultiArch - - script: | - set -e -x - pushd . - mkdir -p /tmp/scripts - mkdir -p $(Build.BinariesDirectory)/installed - cp $(Build.SourcesDirectory)/cmake/deps.txt /tmp/scripts - $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/installed - popd - export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH - export ONNX_ML=1 - export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" - export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-arm64 - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir $(Build.BinariesDirectory) --use_coreml --skip_submodule_sync --parallel --config Release --skip_tests --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 --build_wheel ${{ parameters.build_py_parameters }} + python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir $(Build.BinariesDirectory) --use_coreml --skip_submodule_sync --parallel --config Release --build_wheel ${{ parameters.build_py_parameters }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" --update --build displayName: 'Command Line Script' - script: | @@ -454,13 +369,9 @@ stages: cd '$(Build.BinariesDirectory)/Release/dist' ls for file in *.whl - do - [[ "$file" == *arm64* ]] || ( echo "Mac Silicon package name is NOT correct" && exit 1) - done - for file in *.whl do delocate-listdeps "$file" - delocate-wheel --require-archs=arm64 -w fixed_wheels -v "$file" + delocate-wheel --require-archs=x86_64,arm64 -w fixed_wheels -v "$file" done displayName: 'delocate wheel' diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index a6452721a2b7d..0fc80b30c1b3a 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,8 +3,8 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=68.2.2 -wheel>=0.35.1 +setuptools==69.0.3 +wheel==0.42.0 onnx==1.15.0 argparse sympy==1.12 From 7f0aac0d8a5cfc8fa08c3f51b9f8002ec6eed4ba Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 6 Jan 2024 00:15:50 +0800 Subject: [PATCH 30/45] Revert "[WebNN EP] Rename op logicalNot to not" (#18997) Reverts microsoft/onnxruntime#18936 WebNN spec is discussing using the `logicalNot` name at https://github.com/webmachinelearning/webnn/issues/496, and the Chromium implementation has suspended the renaming change. For consistent, we should keep using `logicalNot` in WebNN EP util it is finalized. --- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../core/providers/webnn/builders/impl/unary_op_builder.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 28857d3002ede..8b8b85339a87c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -179,7 +179,7 @@ static const InlinedHashMap op_map = { {"Min", {"min", true}}, {"Mul", {"mul", true}}, {"Neg", {"neg", true}}, - {"Not", {"not", false}}, + {"Not", {"logicalNot", false}}, {"Pad", {"pad", true}}, {"Pow", {"pow", false}}, {"PRelu", {"prelu", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 129532e91f5a0..e6c5cf24080cd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -48,7 +48,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input); } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("not", input); + output = model_builder.GetBuilder().call("logicalNot", input); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input); } else if (op_type == "Sin") { From 447a3a7c706495fdc0f8dae8b1a130ef73af18e1 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Sat, 6 Jan 2024 00:16:15 +0800 Subject: [PATCH 31/45] [js/webgpu] Fix Expand/Gather when input type is bool (#18999) ### Description Also update the op test suite. ### Motivation and Context Previously the *total* size in case `Expand - last dim is not divisible by 4` was a multiple of 4, even though the *last dimension* was not, so the bug has never been caught. --- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- js/web/test/data/ops/expand.jsonc | 29 +++++++++++++++++++---- js/web/test/data/ops/gather.jsonc | 22 +++++++++++++++++ 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 3dc4e957e0fee..035d89755c7d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -47,7 +47,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const outputShape: number[] = calculateOutputShape(inputShape, shape); const dataType = inputs[0].dataType; const components = dataType === DataType.bool ? 4 : 1; - const outputSize = ShapeUtil.size(outputShape) / components; + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const enableInputShapeUniform = enableShapesUniforms(inputShape.length); const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 53ca094abfd62..469249f92ff28 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -31,7 +31,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const axisDimLimit = inputShape[axis]; const components = inputs[0].dataType === DataType.bool ? 4 : 1; - const outputSize = ShapeUtil.size(outputShape) / components; + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 22bc04d558d98..613b4507b2b15 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -168,20 +168,39 @@ "name": "Expand - last dim is not divisible by 4", "inputs": [ { - "data": [true, false, false, true, true, true, false, false, false, true, true, true], - "dims": [2, 6], + "data": [true, false, false, true, true, true], + "dims": [1, 6], "type": "bool" }, { - "data": [2, 1], + "data": [3, 1], "dims": [2], "type": "int64" } ], "outputs": [ { - "data": [true, false, false, true, true, true, false, false, false, true, true, true], - "dims": [2, 6], + "data": [ + true, + false, + false, + true, + true, + true, + true, + false, + false, + true, + true, + true, + true, + false, + false, + true, + true, + true + ], + "dims": [3, 6], "type": "bool" } ] diff --git a/js/web/test/data/ops/gather.jsonc b/js/web/test/data/ops/gather.jsonc index 0be077d237b88..d218d120d356d 100644 --- a/js/web/test/data/ops/gather.jsonc +++ b/js/web/test/data/ops/gather.jsonc @@ -99,6 +99,28 @@ "operator": "Gather", "attributes": [], "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [false, true, false, false], + "dims": [4], + "type": "bool" + }, + { + "data": [1], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [true], + "dims": [], + "type": "bool" + } + ] + }, { "name": "data[2,4] indices[1]", "inputs": [ From efdcefcf8cea1b724dd7694a8acb62fa14268e83 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Sat, 6 Jan 2024 02:05:34 +0800 Subject: [PATCH 32/45] [ROCm] fix security warning (#19017) fix security warning --- .../github/linux/docker/migraphx-ci-pipeline-env.Dockerfile | 3 --- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 3 --- 2 files changed, 6 deletions(-) diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 85d738d2167e1..6c71631368822 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -65,9 +65,6 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 conda update --all && \ rm ~/miniconda.sh && conda clean -ya -# Conda base patch -RUN pip install cryptography==41.0.4 - # Create migraphx-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/migraphx-ci ENV CONDA_DEFAULT_ENV migraphx-ci diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 29048b79d4b81..4db9df80ed187 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -67,9 +67,6 @@ ENV CONDA_DEFAULT_ENV rocm-ci RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} -# Conda base patch -RUN pip install cryptography==41.0.4 - # Enable rocm-ci environment SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] From 4190c29d2260bb8f274049d7eab2a634fae95a21 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:51:07 -0800 Subject: [PATCH 33/45] Add MatMulNBits accuracy_level parameter to quantization utilities. (#19015) Allow MatMulNBits `accuracy_level` attribute (added in #17669) to be set to a particular value when the model is quantized. --- .../quantization/matmul_4bits_quantizer.py | 37 ++++++++++++++++--- .../models/llama/convert_to_onnx.py | 37 ++++++++++++++----- .../transformers/models/llama/llama_inputs.py | 4 +- .../transformers/models/llama/llama_parity.py | 7 ++-- 4 files changed, 64 insertions(+), 21 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 9f90196e301e5..6293bcbbf95bd 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -4,10 +4,11 @@ # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import argparse import logging import os -from typing import List, Tuple import numpy as np import numpy.typing as npt @@ -26,16 +27,24 @@ class MatMul4BitsQuantizer: """Perform 4b quantization of constant MatMul weights""" - def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None): + def __init__( + self, + model: ModelProto, + block_size: int, + is_symmetric: bool, + accuracy_level: int | None = None, + nodes_to_exclude: list[str] | None = None, + ): if nodes_to_exclude is None: nodes_to_exclude = [] self.model = ONNXModel(model) self.block_size = block_size self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) @staticmethod - def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: for gid in range(len(graph_path) - 1, -1, -1): graph = graph_path[gid] for tensor in graph.initializer: @@ -66,7 +75,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": @@ -113,6 +122,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) kwargs["N"] = cols kwargs["bits"] = 4 kwargs["block_size"] = self.block_size + if self.accuracy_level is not None: + kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( "MatMulNBits", @@ -127,7 +138,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) return matmul_q4_node - def _process_subgraph(self, graph_stack: List[GraphProto]): + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -201,6 +212,14 @@ def parse_args(): type=bool, help="Indicate whether to quantize the model symmetrically", ) + parser.add_argument( + "--accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) parser.add_argument( @@ -228,6 +247,12 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) - quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + nodes_to_exclude=args.nodes_to_exclude, + ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index e694b5050cc8c..bc09b52574a27 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import argparse import logging import os import shutil from itertools import chain -from typing import List import onnx import torch @@ -21,11 +22,12 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +torch_export_onnx_opset_version = 14 logger = logging.getLogger("") init_dist() -def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_model_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in input_names: @@ -42,7 +44,7 @@ def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): return dynamic_axes -def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_model_with_past_kv_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in {"input_ids", "position_ids"}: @@ -65,7 +67,7 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li return dynamic_axes -def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): +def get_merged_model_dynamic_axes(input_names: list[str], output_names: list[str]): dynamic_axes = {} for name in input_names + output_names: if name in {"input_ids", "position_ids"}: @@ -229,7 +231,7 @@ def run_torchscript_separate_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -288,7 +290,7 @@ def run_torchscript_separate_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -368,7 +370,7 @@ def run_torchscript_merged_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, ) @@ -412,7 +414,7 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov def convert_to_float16( - args: argparse.Namespace, config: AutoConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 + args: argparse.Namespace, config: AutoConfig, old_paths: list[str], rank: int = 0, world_size: int = 1 ): decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( @@ -635,7 +637,7 @@ def get_args(): help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", ) - blockwise_group = parser.add_argument_group("4-bit quantization") + blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)") blockwise_group.add_argument( "--block_size", @@ -645,6 +647,15 @@ def get_args(): help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", ) + blockwise_group.add_argument( + "--int4_accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) + smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)") smooth_quant_group.add_argument( @@ -937,7 +948,13 @@ def main(): for fp_path, int4_path in zip(old_paths, new_paths): if os.path.exists(fp_path): model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=True, + accuracy_level=args.int4_accuracy_level, + nodes_to_exclude=[], + ) quant.process() quant.model.save_model_to_file(int4_path, use_external_data_format=True) del model diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index bae1ae82e8f7e..a329b73259dda 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from __future__ import annotations import numpy as np import torch @@ -235,7 +235,7 @@ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, u # Convert list of past_key_values to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): +def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]): past_kv = {} for i, (past_k, past_v) in enumerate(past_key_values): past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 418a65325c8f0..25d7519769604 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import argparse import logging import os import time -from typing import List import numpy as np import torch @@ -139,7 +140,7 @@ def verify_parity( return kv_cache_ortvalues -def get_args(argv: List[str]): +def get_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( @@ -232,7 +233,7 @@ def get_args(argv: List[str]): return args -def main(argv: List[str] = []): # noqa: B006 +def main(argv: list[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") From db3c07608130e4853bb6d9db66fcf57f95a864e9 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 8 Jan 2024 03:06:45 -0800 Subject: [PATCH 34/45] [ROCm] do not use failed miopen fusion compile (#19012) The FusedConv operator for the ROCm EP could fail to compile the fused operation, in which case it should not attempt to use the failed fusion plan. In addition, the hash for the miopenConvolutionDescriptor_t for newer ROCm versions was failing to use all components of the descriptor. --- onnxruntime/contrib_ops/rocm/fused_conv.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc index d597e0d57fbcb..63804f79a32fb 100644 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -76,7 +76,12 @@ struct FNVHash { void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { int spatial_dim = 1; #if ROCM_VERSION >= 50500 - miopenGetConvolutionSpatialDim(cdesc, &spatial_dim); + MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); + std::vector pads{spatial_dim}; + std::vector strides{spatial_dim}; + std::vector dilations{spatial_dim}; + miopenConvolutionMode_t mode; + MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); #else // Previous versions of MIOpen doesn't provide API to probe the dimension of a // miopenConvolutionDescriptor_t, so we have to guess. @@ -100,11 +105,12 @@ struct FNVHash { pads.resize(spatial_dim); strides.resize(spatial_dim); dilations.resize(spatial_dim); +#endif (*this) << spatial_dim; (*this) << pads; (*this) << strides; (*this) << dilations; -#endif + (*this) << mode; } private: @@ -313,6 +319,8 @@ class FusedConv : public onnxruntime::rocm::Conv { auto ret = miopenCompileFusionPlan(handle, fusion->plan); if (miopenStatusSuccess == ret) { fusion->compiled_on.insert(handle); + } else { + return ret; } return miopenStatusSuccess; } From e8ac97c8d864eb3088cf87732b5fc0a7d7df495f Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 8 Jan 2024 17:19:58 +0000 Subject: [PATCH 35/45] Move Windows GPU training job to A10 (#19041) ### Description 1. Update sm to 86 ### Motivation and Context We have more A10 quota then T4 and Nvidia AXX could be partitioned --- .../ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index fdb9238071c9e..eee38ac04b355 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -59,15 +59,14 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_cuda.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --skip_onnx_tests --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + additionalBuildFlags: --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --skip_onnx_tests --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} ORT_EP_NAME: CUDA WITH_CACHE: true - # Some unit tests crash on A10 GPUs. So this job still needs to use T4. - MachinePool: onnxruntime-Win2022-GPU-T4 + MachinePool: onnxruntime-Win2022-GPU-A10 isTraining: true - stage: dml From 52e560144978d453c73198c93043c3f1b8a30d04 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 8 Jan 2024 12:44:12 -0800 Subject: [PATCH 36/45] [QNN Nuget Pipeline] Build with ML ops and detect ORT version (#19024) ### Description - Removes `--disable_ml_ops` build flag - Automatically detects ORT version from VERSION file via `templates/set-version-number-variables-step.yml`. We will no longer need to create a commit to update ORT versions. ### Motivation and Context - A new unit test caused failures in the QNN Nuget pipeline because it did not enable ml ops. - Automate ORT version specification --- .../qnn-ep-nuget-packaging-pipeline.yml | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index d9aff36c4ad34..f6fcbd08ff03a 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -9,11 +9,6 @@ parameters: type: string default: qnn-v2.17.0.231124_win -- name: ort_package_version - displayName: OnnxRuntime Nuget package version - type: string - default: 1.15.0 - - name: build_config displayName: Build Configuration type: string @@ -47,7 +42,7 @@ jobs: buildArch: x64 setVcvars: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' - commonBuildArgs: '--compile_no_warning_as_error --disable_ml_ops --build_dir $(Build.BinariesDirectory)\Windows --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home ${{parameters.qnn_sdk_path_win}}' + commonBuildArgs: '--compile_no_warning_as_error --build_dir $(Build.BinariesDirectory)\Windows --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home ${{parameters.qnn_sdk_path_win}}' steps: - template: templates/set-version-number-variables-step.yml @@ -90,7 +85,7 @@ jobs: displayName: 'Generating nuspec for the native Nuget package x64' inputs: script: | - python "$(Build.SourcesDirectory)\tools\nuget\generate_nuspec_for_native_nuget.py" --package_version ${{ parameters.ort_package_version }} --package_name Microsoft.ML.OnnxRuntime.QNN --target_architecture x64 --build_config ${{ parameters.build_config }} --native_build_path=$(Build.BinariesDirectory)\Windows\${{ parameters.build_config }}\${{ parameters.build_config }} --packages_path $(Build.BinariesDirectory)\Windows\packages --ort_build_path $(Build.BinariesDirectory)\Windows --sources_path $(Build.SourcesDirectory) --commit_id $(OnnxRuntimeGitCommitHash) --is_release_build ${{ parameters.IsReleaseBuild }} --sdk_info ${{ parameters.qnn_sdk_info }} + python "$(Build.SourcesDirectory)\tools\nuget\generate_nuspec_for_native_nuget.py" --package_version $(OnnxRuntimeVersion) --package_name Microsoft.ML.OnnxRuntime.QNN --target_architecture x64 --build_config ${{ parameters.build_config }} --native_build_path=$(Build.BinariesDirectory)\Windows\${{ parameters.build_config }}\${{ parameters.build_config }} --packages_path $(Build.BinariesDirectory)\Windows\packages --ort_build_path $(Build.BinariesDirectory)\Windows --sources_path $(Build.SourcesDirectory) --commit_id $(OnnxRuntimeGitCommitHash) --is_release_build ${{ parameters.IsReleaseBuild }} --sdk_info ${{ parameters.qnn_sdk_info }} cd $(Build.BinariesDirectory)\Windows\${{ parameters.build_config }}\${{ parameters.build_config }} nuget pack NativeNuget.nuspec mkdir $(Build.ArtifactStagingDirectory)\x64 @@ -130,7 +125,7 @@ jobs: displayName: 'Generate CMake Configuration for arm64' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--update --arm64 --disable_ml_ops --build_dir $(Build.BinariesDirectory)\Win_arm64 --skip_submodule_sync --skip_tests --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home ${{parameters.qnn_sdk_path_win}}' + arguments: '--update --arm64 --build_dir $(Build.BinariesDirectory)\Win_arm64 --skip_submodule_sync --skip_tests --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home ${{parameters.qnn_sdk_path_win}}' - task: VSBuild@1 displayName: 'Build onnxruntime arm64' @@ -178,7 +173,7 @@ jobs: displayName: 'Generating nuspec for the native Nuget package arm64' inputs: script: | - python "$(Build.SourcesDirectory)\tools\nuget\generate_nuspec_for_native_nuget.py" --package_version ${{ parameters.ort_package_version }} --package_name Microsoft.ML.OnnxRuntime.QNN --target_architecture arm64 --build_config ${{ parameters.build_config }} --native_build_path=$(Build.BinariesDirectory)\Win_arm64\${{ parameters.build_config }}\${{ parameters.build_config }} --packages_path $(Build.BinariesDirectory)\Win_arm64\packages --ort_build_path $(Build.BinariesDirectory)\Win_arm64 --sources_path $(Build.SourcesDirectory) --commit_id $(OnnxRuntimeGitCommitHash) --is_release_build ${{ parameters.IsReleaseBuild }} --sdk_info ${{ parameters.qnn_sdk_info }} + python "$(Build.SourcesDirectory)\tools\nuget\generate_nuspec_for_native_nuget.py" --package_version $(OnnxRuntimeVersion) --package_name Microsoft.ML.OnnxRuntime.QNN --target_architecture arm64 --build_config ${{ parameters.build_config }} --native_build_path=$(Build.BinariesDirectory)\Win_arm64\${{ parameters.build_config }}\${{ parameters.build_config }} --packages_path $(Build.BinariesDirectory)\Win_arm64\packages --ort_build_path $(Build.BinariesDirectory)\Win_arm64 --sources_path $(Build.SourcesDirectory) --commit_id $(OnnxRuntimeGitCommitHash) --is_release_build ${{ parameters.IsReleaseBuild }} --sdk_info ${{ parameters.qnn_sdk_info }} cd $(Build.BinariesDirectory)\Win_arm64\${{ parameters.build_config }}\${{ parameters.build_config }} nuget pack NativeNuget.nuspec mkdir $(Build.ArtifactStagingDirectory)\arm64 From 99a8400e903ab330e3067629d9aa4e23ce82cf12 Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 9 Jan 2024 09:16:52 +0800 Subject: [PATCH 37/45] [WebNN EP] Fall back resize nearest mode for WebNN CPU backend (#19039) WebNN CPU backend only supports linear mode. Fall back for this case. --- .../webnn/builders/impl/resize_op_builder.cc | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index ea9fc379ee23f..186d1e7c1035a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -30,7 +30,7 @@ class ResizeOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. @@ -161,7 +161,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -181,9 +181,18 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const auto mode = helper.Get("mode", "nearest"); bool is_linear_resize = mode == "linear"; bool is_nearest_resize = mode == "nearest"; - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode; - return false; + // WebNN CPU backend only supports "linear" mode. + // WebNN GPU backend only supports "linear" and "nearest" modes. + if (device_type == WebnnDeviceType::CPU) { + if (!is_linear_resize) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode << " for CPU backend."; + return false; + } + } else { + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode << " for GPU backend."; + return false; + } } const auto exclude_outside = helper.Get("exclude_outside", 0); From 975a315cd70aac54a6c4ff8b7d4e0a76d25666de Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:49:19 -0800 Subject: [PATCH 38/45] Fix x86 build error in GraphDescBuilder.cpp affecting packaging pipeline (#19045) ### Description This addresses a 32 bit build error affecting the packaging pipeline ### Motivation and Context --- .../providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index adb4fd131119f..c6a15e76f4736 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -360,7 +360,7 @@ namespace Dml::GraphDescBuilder // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); - size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); auto data = static_cast(constantInput->GetData()); std::vector tensorData(data, data + minimumConstantSize); From a8bb1df331e56e3a65106578b6475d89d17b27c5 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 8 Jan 2024 17:58:38 -0800 Subject: [PATCH 39/45] [js/webgpu] fix heap access > 2GB (#19010) --- onnxruntime/core/providers/js/js_kernel.h | 1 + .../core/providers/js/operators/conv.h | 12 ++--- .../providers/js/operators/conv_transpose.h | 20 ++++---- onnxruntime/core/providers/js/operators/pad.h | 2 +- .../core/providers/js/operators/reduce.h | 46 +++++++++---------- .../core/providers/js/operators/resize.h | 2 +- .../core/providers/js/operators/slice.h | 6 +-- .../core/providers/js/operators/split.h | 2 +- .../core/providers/js/operators/transpose.h | 2 +- 9 files changed, 47 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 5c2d1f0b881ba..b850bea4bc275 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -67,6 +67,7 @@ namespace js { float value; \ ORT_ENFORCE(info.GetAttr(#attr_name, &value));, \ , ({#attr_name : $1}), static_cast(value)) +#define JSEP_HEAP_PTR(ptr) reinterpret_cast(ptr) // TODO: // class JsMultiProgramKernel : public OpKernel { /* TBD */ }; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 5c0fbf93a4004..98a530c6b77f6 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -54,13 +54,13 @@ class ConvBase : public JsKernel { static_cast(conv_attrs_.group), static_cast(kernel_shape_0), static_cast(local_pads.size()), - reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_), + JSEP_HEAP_PTR(&w_is_const_), conv_attrs_.activation.c_str(), activation_params.size(), - reinterpret_cast(activation_params_ptr) >> 2); + JSEP_HEAP_PTR(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", @@ -81,14 +81,14 @@ class ConvBase : public JsKernel { static_cast(kernel_shape_0), static_cast(kernel_shape_1), static_cast(local_pads.size()), - reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_), + JSEP_HEAP_PTR(&w_is_const_), conv_attrs_.activation.c_str(), activation_params.size(), - reinterpret_cast(activation_params_ptr) >> 2); + JSEP_HEAP_PTR(activation_params_ptr) >> 2); } } diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 5d30dc851e00f..353a946e95c21 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -64,11 +64,11 @@ class ConvTranspose : public JsKernel { static_cast(pads_1), static_cast(strides), static_cast(channels_last), - reinterpret_cast(&w_is_const_), + JSEP_HEAP_PTR(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding_ptr) >> 2, + JSEP_HEAP_PTR(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape_ptr) >> 2, + JSEP_HEAP_PTR(local_output_shape_ptr) >> 2, conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; @@ -114,17 +114,17 @@ class ConvTranspose : public JsKernel { "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), - reinterpret_cast(local_dilations.data()) >> 2, + JSEP_HEAP_PTR(local_dilations.data()) >> 2, static_cast(conv_transpose_attrs_.group), - reinterpret_cast(local_kernel_shape.data()) >> 2, - reinterpret_cast(local_pads.data()) >> 2, - reinterpret_cast(local_strides.data()) >> 2, + JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2, + JSEP_HEAP_PTR(local_pads.data()) >> 2, + JSEP_HEAP_PTR(local_strides.data()) >> 2, static_cast(channels_last), - reinterpret_cast(&w_is_const_), + JSEP_HEAP_PTR(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding_ptr) >> 2, + JSEP_HEAP_PTR(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape_ptr) >> 2, + JSEP_HEAP_PTR(local_output_shape_ptr) >> 2, conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index 19168f40b4722..bf808be949cf8 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -26,7 +26,7 @@ class Pad : public JsKernel, public PadBase { static_cast(mode_), static_cast(value_), gsl::narrow_cast(pads.size()), - reinterpret_cast((pads.size() > 0) ? pads.data() : nullptr) >> 2); + JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index a5a4aa834c2ca..95c4f2bec230d 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - gsl::narrow_cast(axes.size()), \ - reinterpret_cast((axes.size() > 0) ? axes.data() : nullptr) >> 2); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + gsl::narrow_cast(axes.size()), \ + JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 65854222ba988..4b1c288ae3015 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -34,7 +34,7 @@ class Resize : public JsKernel, public UpsampleBase { }), static_cast(antialias_), gsl::narrow_cast(axes.size()), - reinterpret_cast((axes.size() > 0) ? axes.data() : nullptr) >> 2, + JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2, resize_coordinate_transformation_mode.c_str(), static_cast(cubic_coeff_a_), static_cast(exclude_outside_), diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index 6792997025d65..989adabf029a5 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -24,11 +24,11 @@ class Slice : public JsKernel, public SliceBase { "ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [], "axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}), gsl::narrow_cast(starts.size()), - reinterpret_cast((starts.size() > 0) ? starts.data() : nullptr) >> 2, + JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2, gsl::narrow_cast(ends.size()), - reinterpret_cast((ends.size() > 0) ? ends.data() : nullptr) >> 2, + JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2, gsl::narrow_cast(axes.size()), - reinterpret_cast((axes.size() > 0) ? axes.data() : nullptr) >> 2); + JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index cfacc1aa6a363..1c1874e5aa98e 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -53,7 +53,7 @@ class Split : public JsKernel, public SplitBase { static_cast(axis_), static_cast(num_outputs_), gsl::narrow_cast(split_sizes.size()), - reinterpret_cast((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); + JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index 311badbde0d11..dae442b9f5a13 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -27,7 +27,7 @@ class Transpose final : public JsKernel, public TransposeBase { gsl::narrow_cast(perm_specified_ ? perm_.size() : 0), // $2: index to HEAP32 of the first int32 element. calculated from right shift memory // address by 2 - reinterpret_cast(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); + JSEP_HEAP_PTR(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); } }; From 8f024b739439c521cd19fe5a9830d4015de99bd7 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 9 Jan 2024 10:16:25 +0800 Subject: [PATCH 40/45] [js/webgpu] Support uniforms for layer-norm (#18755) --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 83 ++++++++++--------- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 8e1ec782079be..06c3c6c196501 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -17,7 +17,7 @@ import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; -import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; +import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; @@ -83,7 +83,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], - ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], + ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 8a9eeecf2c68d..bc446079faf8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -4,12 +4,11 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common'; -export interface LayerNormAttributes extends AttributeWithCacheKey { +interface LayerNormAttributes { axis: number; epsilon: number; } @@ -39,7 +38,7 @@ const createLayerNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const meanInvStdDevDim = []; + const meanInvStdDevDim: number[] = []; for (let i = 0; i < xShape.length; ++i) { if (i < axis) { meanInvStdDevDim.push(xShape[i]); @@ -47,50 +46,57 @@ const createLayerNormProgramInfo = meanInvStdDevDim.push(1); } } - const components = getMaxComponents(normSize); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, + {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} ]; if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + inputDependencies.push('type'); } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; - if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: f32 = ${normSize}; - const normSizeVectorized: u32 = ${normSize / components}; - const epsilon: f32 = ${attributes.epsilon}; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanDataOutput) { + variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } - ${shaderHelper.declareVariables(...variables)} + const uniforms: UniformsArrayType = [ + {name: 'norm_count', type: 'u32'}, {name: 'norm_size', type: 'f32'}, + {name: 'norm_size_vectorized', type: 'u32'}, {name: 'epsilon', type: 'f32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} - let offset = global_idx * normSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} + let offset = global_idx * uniforms.norm_size_vectorized; var meanVector = ${fillVector('f32', components)}; var meanSquareVector = ${fillVector('f32', components)}; - for (var h: u32 = 0u; h < normSizeVectorized; h++) { + for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; meanVector += value; meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSize; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} - / normSize - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / uniforms.norm_size - mean * mean + uniforms.epsilon); - for (var j: u32 = 0; j < normSizeVectorized; j++) { + for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale @@ -98,9 +104,10 @@ const createLayerNormProgramInfo = ); } - ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; + ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = 1 / meanSquare' : ''}; }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (hasMeanDataOutput) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -111,15 +118,13 @@ const createLayerNormProgramInfo = return { name: 'LayerNormalization', - shaderCache: {hint: `${attributes.cacheKey}|${outputCount}|${inputs.length}`}, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}}), + shaderCache: {hint: `${components};${outputCount}`, inputDependencies}, + getRunData: () => + ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}), getShaderSource, }; }; -export const parseLayerNormAttributes = (attributes: LayerNormAttributes): LayerNormAttributes => - createAttributeWithCacheKey({axis: attributes.axis, epsilon: attributes.epsilon}); - export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); context.compute(createLayerNormProgramInfo(context.inputs, attributes, context.outputCount)); From 68c29ece23821b1d2b73ac55c2a4266c72865219 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 8 Jan 2024 19:46:33 -0800 Subject: [PATCH 41/45] In a Linux or Android build check if the compiler support bfloat16 and float16 (#18813) ### Description Restrict clang version because we have an upcoming change that requires clang version >=16 , which will mainly affect Android build. --- cmake/CMakeLists.txt | 27 +++++++++++++++++-------- cmake/adjust_global_compile_flags.cmake | 25 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 34355fb0fd936..0f57258dca706 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -354,13 +354,7 @@ if (onnxruntime_USE_ROCM) endif() endif() -if (APPLE) - if (NOT CMAKE_OSX_ARCHITECTURES) - message("Building ONNX Runtime for ${CMAKE_HOST_SYSTEM_PROCESSOR}") - endif() -elseif (NOT WIN32 AND NOT APPLE) - message("Building ONNX Runtime for ${CMAKE_SYSTEM_PROCESSOR}") -endif() + # Single output director for all binaries set(RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.") @@ -493,6 +487,14 @@ endif() include(adjust_global_compile_flags.cmake) +if (APPLE) + if (NOT CMAKE_OSX_ARCHITECTURES) + message("Building ONNX Runtime for ${CMAKE_HOST_SYSTEM_PROCESSOR} CPU ARCH") + endif() +elseif (NOT WIN32 AND NOT APPLE) + message("Building ONNX Runtime for ${onnxruntime_target_platform} CPU ARCH") +endif() + # We need to link with libatomic on systems that do not have built-in atomics, or # don't have built-in support for 8 byte atomics # Derived from https://github.com/protocolbuffers/protobuf/blob/master/cmake/CMakeLists.txt @@ -639,7 +641,16 @@ else() check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) check_function_exists(reallocarray HAS_REALLOCARRAY) - + if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") + check_cxx_compiler_flag(-march=armv8.2-a+bf16 HAS_ARM64_BFLOAT16) + if(NOT HAS_ARM64_BFLOAT16) + message(FATAL_ERROR "The compiler doesn't support BFLOAT16!!!") + endif() + check_cxx_compiler_flag(-march=armv8.2-a+fp16 HAS_ARM64_FLOAT16) + if(NOT HAS_ARM64_FLOAT16) + message(FATAL_ERROR "The compiler doesn't support FLOAT16!!!") + endif() + endif() if (HAS_TAUTOLOGICAL_POINTER_COMPARE) #we may have extra null pointer checkings in debug build, it's not an issue list(APPEND ORT_WARNING_FLAGS -Wno-tautological-pointer-compare) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index e825bfeaea952..9f00c873715f4 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -300,6 +300,31 @@ if (MSVC) endif() else() if (NOT APPLE) + #XXX: Sometimes the value of CMAKE_SYSTEM_PROCESSOR is set but it's wrong. For example, if you run an armv7 docker + #image on an aarch64 machine with an aarch64 Ubuntu host OS, in the docker instance cmake may still report + # CMAKE_SYSTEM_PROCESSOR as aarch64 by default. Given compiling this code may need more than 2GB memory, we do not + # support compiling for ARM32 natively(only support cross-compiling), we will ignore this issue for now. + if(NOT CMAKE_SYSTEM_PROCESSOR) + message(WARNING "CMAKE_SYSTEM_PROCESSOR is not set. Please set it in your toolchain cmake file.") + # Try to detect it + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") + execute_process( + COMMAND "${CMAKE_C_COMPILER}" -dumpmachine + OUTPUT_VARIABLE GCC_DUMP_MACHINE_OUT OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_VARIABLE _err + RESULT_VARIABLE _res + ) + if(NOT _res EQUAL 0) + message(SEND_ERROR "Failed to run 'gcc -dumpmachine':\n ${_res}") + endif() + string(REPLACE "-" ";" GCC_DUMP_MACHINE_OUT_LIST "${GCC_DUMP_MACHINE_OUT}") + list(LENGTH GCC_DUMP_MACHINE_OUT_LIST GCC_TRIPLET_LEN) + if(GCC_TRIPLET_LEN EQUAL 4) + list(GET GCC_DUMP_MACHINE_OUT_LIST 0 CMAKE_SYSTEM_PROCESSOR) + message("Setting CMAKE_SYSTEM_PROCESSOR to ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() + endif() set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) endif() if (onnxruntime_BUILD_FOR_NATIVE_MACHINE) From eb35896ede6e77bf4a453b9e7314e728e40f96ba Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 9 Jan 2024 14:02:44 +0800 Subject: [PATCH 42/45] [WebNN EP] Update WebNN normalization ops (#18817) Use batchNormalization, layerNormalization and instanceNormalization instead of meanVarianceNormalization to implement normalization Ops. The spec of meanVarianceNormalization has been deleted. Remove groupNormalization. --- .../core/providers/webnn/builders/helper.h | 7 +- .../builders/impl/normalization_op_builder.cc | 141 +++++++----------- .../webnn/builders/op_builder_factory.cc | 1 - 3 files changed, 57 insertions(+), 92 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 8b8b85339a87c..5aec81af15761 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -139,7 +139,7 @@ static const InlinedHashMap op_map = { {"ArgMax", {"argMax", false}}, {"ArgMin", {"argMin", false}}, {"AveragePool", {"averagePool2d", true}}, - {"BatchNormalization", {"meanVarianceNormalization", false}}, + {"BatchNormalization", {"batchNormalization", false}}, {"Cast", {"cast", false}}, {"Ceil", {"ceil", true}}, {"Clip", {"clamp", true}}, @@ -162,12 +162,11 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", {"l2Pool2d", false}}, {"Greater", {"greater", false}}, {"GreaterOrEqual", {"greaterOrEqual", false}}, - {"GroupNormalization", {"meanVarianceNormalization", false}}, {"HardSigmoid", {"hardSigmoid", false}}, {"HardSwish", {"hardSwish", true}}, {"Identity", {"identity", false}}, - {"InstanceNormalization", {"meanVarianceNormalization", false}}, - {"LayerNormalization", {"meanVarianceNormalization", false}}, + {"InstanceNormalization", {"instanceNormalization", false}}, + {"LayerNormalization", {"layerNormalization", false}}, {"LeakyRelu", {"leakyRelu", true}}, {"Less", {"lesser", false}}, {"LessOrEqual", {"lesserOrEqual", false}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 756a838cc0c3e..4d2470dfe7deb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -27,8 +27,6 @@ class NormalizationOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; -// All normalization are based on layout NCHW. -// TODO: add support for NHWC. Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { @@ -61,49 +59,13 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); } - std::vector new_scale_shape; - if (scale_size < rank) { - if (op_type == "BatchNormalization") { - scale_shape.insert(scale_shape.begin(), 1); - scale_shape.insert(scale_shape.end(), rank - 2, 1); - } else if (op_type == "LayerNormalization") { - // Align right with leading ones. - scale_shape.insert(scale_shape.begin(), rank - scale_size, 1); - } else if (op_type == "InstanceNormalization") { - // Insert ones before and after the channel dimension. - scale_shape.insert(scale_shape.begin(), 1); - ORT_RETURN_IF(scale_size != 1 || rank < 2, - "The scale size should be 1 and rank should be at least 2 for InstanceNorm."); - scale_shape.insert(scale_shape.end(), rank - scale_size - 1, 1); - } else if (op_type == "GroupNormalization") { - // The input will be reshaped to 3D later. So just insert ones before the channel and after. - scale_shape.insert(scale_shape.begin(), 1); - scale_shape.insert(scale_shape.end(), 1); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); - } + emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); + options.set("scale", scale); - std::transform(scale_shape.cbegin(), scale_shape.cend(), - std::back_inserter(new_scale_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val reshape_output_scale = - model_builder.GetBuilder().call("reshape", reshape_scale, emscripten::val::array(new_scale_shape)); - options.set("scale", reshape_output_scale); - - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name()); - emscripten::val reshape_output_bias = - model_builder.GetBuilder().call("reshape", reshape_bias, emscripten::val::array(new_scale_shape)); - options.set("bias", reshape_output_bias); - } - } else { - options.set("scale", model_builder.GetOperand(input_defs[1]->Name())); - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); - } + if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { + // Bias input exists, and bias's shape is the same as scale's shape. + emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name()); + options.set("bias", bias); } NodeAttrHelper helper(node); @@ -114,56 +76,62 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name()); - // Enlarge 1-D mean and variance to new scale shape. - emscripten::val reshape_mean = - model_builder.GetBuilder().call("reshape", mean, emscripten::val::array(new_scale_shape)); - emscripten::val reshape_variance = - model_builder.GetBuilder().call("reshape", variance, emscripten::val::array(new_scale_shape)); - - std::vector axes = {0}; - for (uint32_t i = 2; i < rank; i++) { - axes.push_back(i); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("axis", rank - 1); } - - options.set("axes", emscripten::val::array(axes)); - options.set("mean", reshape_mean); - options.set("variance", reshape_variance); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); + output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); - std::iota(axes.begin(), axes.end(), axis); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) { + std::iota(axes.begin(), axes.end(), axis - 1); + } else { + std::iota(axes.begin(), axes.end(), axis); + } options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); + output = model_builder.GetBuilder().call("layerNormalization", input, options); } else if (op_type == "InstanceNormalization") { - std::vector axes; - for (uint32_t i = 2; i < rank; i++) { - axes.emplace_back(i); + // WebNN spec only supports 4D input for instanceNormalization. + // Supports 3D input by prepending 1 size dimension. + // For models with dimensions greater than 4, they will be reshaped into 4D. + constexpr size_t webnn_shape_rank = 4; + if (input_shape.size() != webnn_shape_rank) { + std::vector new_shape; + new_shape.reserve(std::max(input_shape.size(), webnn_shape_rank)); + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + + size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3; + ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank; + auto insertion_point = new_shape.begin() + insertion_offset; + if (input_shape.size() < webnn_shape_rank) { + // Pad the shape with extra 1's to satisfy WebNN v1's rank requirements. + new_shape.insert(insertion_point, -excess_rank, 1); + } else { + // Fold the extra range to fit within WebNN v1's rank requirements. + uint32_t sum = std::accumulate( + insertion_point, insertion_point + excess_rank + 1, 1, std::multiplies()); + new_shape.erase(insertion_point, insertion_point + excess_rank); + *insertion_point = sum; + } + input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + } + + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("layout", emscripten::val("nhwc")); + } + output = model_builder.GetBuilder().call("instanceNormalization", input, options); + // Reshape back to the original output shape for 3D input. + if (input_shape.size() != 4) { + std::vector output_shape; + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(output_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + output = model_builder.GetBuilder().call( + "reshape", output, emscripten::val::array(output_shape)); } - options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); - } else if (op_type == "GroupNormalization") { - ORT_RETURN_IF_NOT(helper.HasAttr("num_groups"), "GroupNormalization num_group must be provided."); - int32_t group_count = helper.Get("num_groups", -1); - std::vector orig_shape, new_shape; - std::transform(input_shape.cbegin(), input_shape.cend(), - std::back_inserter(orig_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - // Add N and Group. - ORT_RETURN_IF_NOT(rank >= 2, "Input for GroupNormalization cannot be a scalar or 1D"); - new_shape.emplace_back(SafeInt(input_shape[0])); - new_shape.emplace_back(SafeInt(group_count)); - - ORT_RETURN_IF_NOT(group_count > 0 && input_shape[1] % group_count == 0, - "GroupNormalization num_group must be divisible by group."); - new_shape.emplace_back(SafeInt(std::reduce(input_shape.begin() + 2, input_shape.end(), - input_shape[1] / group_count, std::multiplies()))); - // Input will be reshaped to (N, group count, channels per group x D1 x D2 ... Dn) and recovered after normalization. - options.set("axes", emscripten::val::array(std::vector{2})); - output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", output, options); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(orig_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); } @@ -214,7 +182,6 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat constexpr static std::string_view op_types[] = { "BatchNormalization", - "GroupNormalization", "InstanceNormalization", "LayerNormalization", }; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 463317a4dafda..613771eda71fe 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -111,7 +111,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { { // Normalization CreateNormalizationOpBuilder("BatchNormalization", op_registrations); - CreateNormalizationOpBuilder("GroupNormalization", op_registrations); CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); } From 7cb8b20db2d329cf67e170293b2d2c81213e6100 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 9 Jan 2024 20:05:34 +0800 Subject: [PATCH 43/45] Remove mem consuming test case to unblock running ci on lower-end gpu (#19059) ### Description ### Motivation and Context --- orttraining/orttraining/test/gradient/gradient_ops_test.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 6fb42dd59b6a0..feca94ae27c13 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -2218,12 +2218,6 @@ TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) { {3072, 768}, {4096, 768}, {8192, 768}, - {16384, 768}, - {32768, 768}, - {65536, 768}, - {131072, 768}, - {250002, 768}, - {500004, 768}, }; for (const auto& test_dim : test_dims) { From ab897a4a4064e8259ed952cbd33c35d3ddd5f370 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 9 Jan 2024 07:45:03 -0800 Subject: [PATCH 44/45] Remove Windows ARM32 from nuget packaging pipelines (#19049) ### Description 1. Remove Windows ARM32 from nuget packaging pipelines 2. Add missing component-governance-component-detection-steps.yml to some build jobs. ### Motivation and Context Stop supporting Windows ARM32 to align with [Windows's support policy](https://learn.microsoft.com/en-us/windows/arm/arm32-to-arm64). Users who need this feature still can build the DLLs from source. However, later on we will remove that support too. --- ...anch.Nuget-WindowsAI-Pipeline.Official.yml | 66 +------------- .../c-api-noopenmp-packaging-pipelines.yml | 85 +++++++++---------- .../azure-pipelines/templates/c-api-cpu.yml | 26 ------ .../azure-pipelines/templates/py-linux.yml | 4 + 4 files changed, 49 insertions(+), 132 deletions(-) diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml index b9de1b79e1d51..67f9d8b0ce392 100644 --- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml +++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml @@ -53,10 +53,6 @@ extends: BuildArch: x86 PythonPackageName: pythonx86 - - template: .pipelines/windowsai-steps.yml@self - parameters: - BuildArch: arm - - template: .pipelines/windowsai-steps.yml@self parameters: BuildArch: arm64 @@ -72,11 +68,6 @@ extends: PythonPackageName: pythonx86 Runtime: static - - template: .pipelines/windowsai-steps.yml@self - parameters: - BuildArch: arm - Runtime: static - - template: .pipelines/windowsai-steps.yml@self parameters: BuildArch: arm64 @@ -94,11 +85,9 @@ extends: dependsOn: - Windows_Packaging_x64_dynamic - Windows_Packaging_x86_dynamic - - Windows_Packaging_arm_dynamic - Windows_Packaging_arm64_dynamic - Windows_Packaging_x64_static - Windows_Packaging_x86_static - - Windows_Packaging_arm_static - Windows_Packaging_arm64_static condition: succeeded() steps: @@ -120,12 +109,6 @@ extends: artifactName: 'drop_Windows_Build_Windows_Packaging_arm64_dynamic' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet DirectML arm' - inputs: - artifactName: 'drop_Windows_Build_Windows_Packaging_arm_dynamic' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet DirectML x64 StaticRuntime' inputs: @@ -144,12 +127,6 @@ extends: artifactName: 'drop_Windows_Build_Windows_Packaging_arm64_static' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64-static-runtime' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet DirectML arm StaticRuntime' - inputs: - artifactName: 'drop_Windows_Build_Windows_Packaging_arm_static' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm-static-runtime' - - task: PowerShell@2 displayName: 'Bundle NuGet and other binaries' inputs: @@ -194,17 +171,7 @@ extends: $arm64_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm64_static_runtime_nuget_package)) [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_static_runtime_nuget_package, $arm64_static_runtime_nupkg_unzipped_directory) - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) - $arm_nuget_package = $nupkgs[0].FullName - $arm_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_nuget_package, $arm_nupkg_unzipped_directory) - - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm-static-runtime -Filter Microsoft.AI.MachineLearning*.nupkg -Recurse) - $arm_static_runtime_nuget_package = $nupkgs[0].FullName - $arm_static_runtime_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_static_runtime_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_static_runtime_nupkg_unzipped_directory_root, 'binaries', [System.IO.Path]::GetFileNameWithoutExtension($arm_static_runtime_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_static_runtime_nuget_package, $arm_static_runtime_nupkg_unzipped_directory) + $x64_static_runtime_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native') $x64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x64', '_native', 'static') @@ -216,10 +183,7 @@ extends: $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_static_runtime_path_old = [System.IO.Path]::Combine($arm64_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native', 'static') - $arm_runtime_path_old = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_static_runtime_path_old = [System.IO.Path]::Combine($arm_static_runtime_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_static_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native', 'static') + $uap_build_path_old = [System.IO.Path]::Combine($x64_static_runtime_nupkg_unzipped_directory, 'build', 'native') $uap_build_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'build', 'uap10.0') @@ -228,8 +192,6 @@ extends: New-Item -Path $x86_static_runtime_path_new -ItemType Directory New-Item -Path $arm64_runtime_path_new -ItemType Directory New-Item -Path $arm64_static_runtime_path_new -ItemType Directory - New-Item -Path $arm_runtime_path_new -ItemType Directory - New-Item -Path $arm_static_runtime_path_new -ItemType Directory Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.dll')) $x86_runtime_path_new Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.lib')) $x86_runtime_path_new @@ -241,11 +203,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $arm64_runtime_path_new Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $arm64_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.dll')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.lib')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.dll')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.lib')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.dll')) Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'onnxruntime.lib')) Copy-Item ([System.IO.Path]::Combine($x64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($x64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) @@ -261,11 +218,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) Copy-Item ([System.IO.Path]::Combine($arm64_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($arm64_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'onnxruntime.dll')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'onnxruntime.dll')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'onnxruntime.lib')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'onnxruntime.lib')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'microsoft.ai.machinelearning.dll')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'microsoft.ai.machinelearning.dll')) - Copy-Item ([System.IO.Path]::Combine($arm_static_runtime_path_old, 'microsoft.ai.machinelearning.lib')) ([System.IO.Path]::Combine($arm_static_runtime_path_new, 'microsoft.ai.machinelearning.lib')) - Copy-Item -Recurse $uap_build_path_old $uap_build_path_new $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') @@ -304,22 +256,13 @@ extends: $arm64_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($arm64_nuget_package)) [System.IO.Compression.ZipFile]::ExtractToDirectory($arm64_nuget_package, $arm64_nupkg_unzipped_directory) - $nupkgs = (Get-ChildItem ..\nuget-artifact-arm -Filter Microsoft.AI.MachineLearning*.snupkg -Recurse) - $arm_nuget_package = $nupkgs[0].FullName - $arm_nupkg_unzipped_directory_root = $nupkgs[0].Directory.FullName - $arm_nupkg_unzipped_directory = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory_root, 'symbols', [System.IO.Path]::GetFileNameWithoutExtension($arm_nuget_package)) - [System.IO.Compression.ZipFile]::ExtractToDirectory($arm_nuget_package, $arm_nupkg_unzipped_directory) - $x86_runtime_path_old = [System.IO.Path]::Combine($x86_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') $x86_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-x86', '_native') $arm64_runtime_path_old = [System.IO.Path]::Combine($arm64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') $arm64_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm64', '_native') - $arm_runtime_path_old = [System.IO.Path]::Combine($arm_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - $arm_runtime_path_new = [System.IO.Path]::Combine($x64_nupkg_unzipped_directory, 'runtimes', 'win-arm', '_native') - + New-Item -Path $x86_runtime_path_new -ItemType Directory New-Item -Path $arm64_runtime_path_new -ItemType Directory - New-Item -Path $arm_runtime_path_new -ItemType Directory Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'onnxruntime.pdb')) $x86_runtime_path_new Copy-Item ([System.IO.Path]::Combine($x86_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $x86_runtime_path_new @@ -327,9 +270,6 @@ extends: Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'onnxruntime.pdb')) $arm64_runtime_path_new Copy-Item ([System.IO.Path]::Combine($arm64_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $arm64_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'onnxruntime.pdb')) $arm_runtime_path_new - Copy-Item ([System.IO.Path]::Combine($arm_runtime_path_old, 'microsoft.ai.machinelearning.pdb')) $arm_runtime_path_new - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'merged') if (!(Test-Path $merged_nuget_path)) { New-Item -Path $merged_nuget_path -ItemType Directory diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index badee79fd78b3..172a0dc1866ab 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -92,6 +92,9 @@ stages: vmImage: ubuntu-latest steps: - checkout: none + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - bash: | # Do not output ##vso[] commands with `set -x` or they may be parsed again and include a trailing quote. set +x @@ -105,6 +108,10 @@ stages: echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" fi name: Set_Release_Version_Suffix + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + - stage: Debug dependsOn: Setup @@ -116,7 +123,14 @@ stages: MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] steps: - checkout: none + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - bash: echo $(MyVar) + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + - stage: Download_Java_Tools dependsOn: [] @@ -126,6 +140,9 @@ stages: vmImage: ubuntu-latest steps: - checkout: none + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - task: CmdLine@2 displayName: Download Java Tools inputs: @@ -141,6 +158,9 @@ stages: inputs: targetPath: '$(Agent.TempDirectory)/java-tools' artifact: 'onnxruntime-java-tools' + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' - template: templates/c-api-cpu.yml parameters: @@ -525,6 +545,9 @@ stages: submodules: false - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux submodules: false + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - script: | set -e -x @@ -603,6 +626,10 @@ stages: inputs: targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' artifactName: 'onnxruntime-linux-x64-gpu' + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + - stage: Windows_Packaging_combined_GPU dependsOn: @@ -619,6 +646,10 @@ stages: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples submodules: false + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - script: dir $(Build.SourcesDirectory) - task: BatchScript@1 displayName: 'setup env' @@ -688,7 +719,9 @@ stages: inputs: artifactName: 'onnxruntime-win-x64-gpu' targetPath: '$(Build.ArtifactStagingDirectory)' - + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' - stage: NuGet_Packaging_GPU dependsOn: @@ -1246,45 +1279,21 @@ stages: mkdir $(Build.ArtifactStagingDirectory)\testdata copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\custom_op_library.* $(Build.ArtifactStagingDirectory)\testdata -- template: nuget/templates/dml-vs-2022.yml - parameters: - AgentPool : 'onnxruntime-Win-CPU-2022' - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - ArtifactName: 'drop-win-dml-arm-zip' - StageName: 'Windows_CI_GPU_DML_Dev_arm' - BuildCommand: --build_dir $(Build.BinariesDirectory) --arm --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" - BuildArch: 'x64' - EnvSetupScript: 'setup_env.bat' - sln_platform: 'arm' - DoDebugBuild: 'false' - DoNugetPack : 'true' - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - RunTests: 'false' - BuildNodejs: 'false' - NuPackScript: | - msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /p:TargetArchitecture=arm /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - cd $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\ - ren Microsoft.ML.OnnxRuntime.DirectML.* win-dml-arm.zip - copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\win-dml-arm.zip $(Build.ArtifactStagingDirectory) - mkdir $(Build.ArtifactStagingDirectory)\testdata - copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\custom_op_library.* $(Build.ArtifactStagingDirectory)\testdata - - stage: NuGet_Packaging_DML dependsOn: - Windows_CI_GPU_DML_Dev - Windows_CI_GPU_DML_Dev_x86 - Windows_CI_GPU_DML_Dev_arm64 - - Windows_CI_GPU_DML_Dev_arm condition: succeeded() jobs: - job: workspace: clean: all - pool: 'onnxruntime-Win2022-GPU-T4' - + pool: 'onnxruntime-Win2022-GPU-dml-A10' steps: - + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet DirectML' inputs: @@ -1303,12 +1312,6 @@ stages: artifactName: 'drop-win-dml-arm64-zip' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-dml' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet DirectML arm' - inputs: - artifactName: 'drop-win-dml-arm-zip' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-dml' - - script: | pushd $(Build.BinariesDirectory)\nuget-artifact-dml dir @@ -1339,13 +1342,6 @@ stages: move win-arm64\runtimes\win-arm64\native\onnxruntime.lib %%~ni\runtimes\win-arm64\native\onnxruntime.lib move win-arm64\runtimes\win-arm64\native\onnxruntime.pdb %%~ni\runtimes\win-arm64\native\onnxruntime.pdb - unzip win-dml-arm.zip -d win-arm - mkdir %%~ni\runtimes\win-arm - mkdir %%~ni\runtimes\win-arm\native - - move win-arm\runtimes\win-arm\native\onnxruntime.dll %%~ni\runtimes\win-arm\native\onnxruntime.dll - move win-arm\runtimes\win-arm\native\onnxruntime.lib %%~ni\runtimes\win-arm\native\onnxruntime.lib - move win-arm\runtimes\win-arm\native\onnxruntime.pdb %%~ni\runtimes\win-arm\native\onnxruntime.pdb pushd %%~ni zip -r ..\%%~ni.zip . @@ -1368,7 +1364,7 @@ stages: PackageType: 'nuget' PackagePath: '$(Build.ArtifactStagingDirectory)' PackageName: 'Microsoft.ML.OnnxRuntime.DirectML*nupkg' - PlatformsSupported: 'win-x64,win-x86,win-arm64,win-arm' + PlatformsSupported: 'win-x64,win-x86,win-arm64' VerifyNugetSigning: ${{ parameters.DoEsrp }} - task: PublishPipelineArtifact@0 @@ -1376,3 +1372,6 @@ stages: inputs: artifactName: 'drop-signed-nuget-dml' targetPath: '$(Build.ArtifactStagingDirectory)' + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 37b4bdc43afcd..e6025ae1b56bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -161,20 +161,6 @@ stages: buildJava: false buildNodejs: false -- template: win-ci.yml - parameters: - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: CPU_arm_${{ parameters.BuildVariant }} - buildArch: x64 - msbuildPlatform: arm - packageName: arm - buildparameter: --arm ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe - runTests: false - buildJava: false - buildNodejs: false - ort_build_pool_name: onnxruntime-Win-CPU-2022 - - template: win-ci.yml parameters: DoCompliance: ${{ parameters.DoCompliance }} @@ -205,10 +191,7 @@ stages: dependsOn: - Linux_C_API_Packaging_CPU - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x86_${{ parameters.BuildVariant }} - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} - Download_Java_Tools condition: succeeded() jobs: @@ -297,7 +280,6 @@ stages: - MacOS_C_API_Package_Publish - Windows_Packaging_CPU_x86_${{ parameters.BuildVariant }} - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm_${{ parameters.BuildVariant }} - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} - Android_Java_API_AAR_Packaging_Full - iOS_Full_xcframework @@ -340,14 +322,6 @@ stages: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download win-arm Pipeline Artifact' - ArtifactName: 'onnxruntime-win-arm' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - template: flex-downloadPipelineArtifact.yml parameters: StepName: 'Download osx-x64 Pipeline Artifact' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index db3782c69cf62..2adcbb13dbeb8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -106,3 +106,7 @@ jobs: inputs: artifactName: 'drop-linux-cpu-${{ parameters.arch }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' + + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' \ No newline at end of file From dee6a5b3715c5bdf7a6d29c2b9516902ebd0e0b1 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 9 Jan 2024 23:46:30 +0800 Subject: [PATCH 45/45] [js/webgpu] Support uniforms for attention and multihead attention (#18903) --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 330 +++++++++--------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 4 +- .../jsep/webgpu/ops/multi-head-attentiion.ts | 38 +- 4 files changed, 190 insertions(+), 186 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 06c3c6c196501..c182d3c4eaf6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; -import {attention, parseAttentionAttributes} from './ops/attention'; +import {attention} from './ops/attention'; import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; @@ -50,7 +50,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], - ['Attention', [attention, parseAttentionAttributes]], + ['Attention', [attention]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BatchNormalization', [batchNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index e1f2a47301bfb..ef8038dff487e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {tensorDataTypeEnumToString} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; export const enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -231,20 +231,8 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({...attributes}); - export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { const components = getMaxComponents(d); - const inputHelper = outputVariable('x', input.dataType, input.dims, components); - - let threadMaxValue = 'threadMaxVector'; - if (components === 2) { - threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; - } else if (components === 4) { - threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; - } - const dataType = tensorTypeToWsglStorageType(input.dataType); let WG = 64; const dComp = d / components; if (dComp < WG) { @@ -253,25 +241,41 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); + const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; + const programUniforms: ProgramUniform[] = + [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const dataType = tensorTypeToWsglStorageType(input.dataType, components); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + let threadMaxValue = 'thread_max_vector'; + if (components === 2) { + threadMaxValue = 'max(thread_max_vector.x, thread_max_vector.y)'; + } else if (components === 4) { + threadMaxValue = + 'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))'; + } + const elemValueType = tensorTypeToWsglValueType(input.dataType); + const uniforms: UniformsArrayType = [ + {name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'}, + {name: 'elements_per_wg', type: 'u32'} + ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dInv: ${dataType} = 1 / ${d}; - const dComp = ${d / components}; + return ` var wgMax: array; var wgSum: array; - - ${shaderHelper.declareVariables(inputHelper)} - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_index) local_index : u32) { - let localOffset = local_index * ${elementsPerWG}; - let offset: u32 = workgroup_id.x * dComp + localOffset; - - var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.mainStart([ + WG, 1, 1 + ])} + let localOffset = local_idx * uniforms.elements_per_wg; + let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset; + + var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector); } - wgMax[local_index] = ${threadMaxValue}; + wgMax[local_idx] = ${threadMaxValue}; workgroupBarrier(); var maxValue = -3.402823e+38f; @@ -280,10 +284,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } var sumVector = ${fillVector('f32', components, '0')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue); } - wgSum[local_index] = ${sumVector('sumVector', components)}; + wgSum[local_idx] = ${sumVector('sumVector', components)}; workgroupBarrier(); var sum: f32 = 0; @@ -292,26 +296,24 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } if (sum == 0) { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; } } else { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + let f32input = ${castToF32(elemValueType, components, 'x[offset + i]')}; x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); } } }`; + }; context.compute( { name: 'AttentionProbsSoftmax', - shaderCache: {hint: `${d}`}, + shaderCache: {hint: `${WG};${dataType};${components}`}, getShaderSource, - getRunData: () => ({ - outputs: [], - dispatchGroup: {x: n}, - }), + getRunData: () => ({outputs: [], dispatchGroup: {x: n}, programUniforms}), }, {inputs: [input], outputs: []}); }; @@ -326,47 +328,43 @@ const computeAttentionProbs = // TODO: handle mask const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - - const dataType = tensorTypeToWsglStorageType(q.dataType); - const components = getMaxComponents(parameters.headSize); - const qInput = inputVariable('q', q.dataType, q.dims, components); - const kInput = inputVariable('key', key.dataType, key.dims, components); - const output = outputVariable('output', q.dataType, probsShape); - const vectorizedHeadSize = parameters.headSize / components; - const M = parameters.sequenceLength; - const N = parameters.totalSequenceLength; - const K = vectorizedHeadSize; - const TILE_SIZE = 12; - const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, + {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, + {type: tensorDataType, data: alpha} + ]; const inputs = [q, key]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K}u; - const alpha: ${dataType} = ${alpha}; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} + ]; + return ` const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(qInput, kInput, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} // x holds the N and y holds the M let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE; @@ -374,40 +372,42 @@ const computeAttentionProbs = let lm = m + local_id.y; let ln = n + local_id.x; - let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; - let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; + let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K; var value = ${fillVector(dataType, components)}; - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m + local_id.y < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; } - if (n + local_id.y < N && w + local_id.x < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -423,78 +423,76 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; - - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); - const output = outputVariable('output', probs.dataType, outputShape); - - const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), z: params.batchSize * params.numHeads }; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, + {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, + {type: 'uint32', data: params.vHiddenSize} + ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${params.sequenceLength}u; - const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; - const numHeads: u32 = ${params.numHeads}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - - var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(probsHelper, vHelper, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + ${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; - let offsetA = headIdx * (M * K) + m * K; - let offsetB = headIdx * (N * K) + n; + let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; + let offsetB = headIdx * (uniforms.N * uniforms.K) + n; - var value = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + var value = ${probsHelper.type.storage}(0); + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -517,71 +515,71 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { parameters.sequenceLength, parameters.headSize, ]; - - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); - const M = parameters.sequenceLength; const K = parameters.inputHiddenSize; const N = parameters.headSize; - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(parameters.headSize / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, + {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, + {type: 'uint32', data: parameters.hiddenSize}, + {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + ]; - const getShaderSource = () => ` - const M: u32 = ${M}u; - const K: u32 = ${K}u; - const N: u32 = ${N}u; - const numHeads: u32 = ${parameters.numHeads}; - const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape); + const outputK = outputVariable('output_k', inputs[0].dataType, outputShape); + const outputV = outputVariable('output_v', inputs[0].dataType, outputShape); + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims); + const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const dataType = input.type.storage; + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'num_heads', type: 'u32'}, + {name: 'head_size', type: 'u32'}, {name: 'hidden_size', type: 'u32'}, {name: 'ldb', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; - - @group(0) @binding(0) var input: array<${dataType}>; - @group(0) @binding(1) var weight: array<${dataType}>; - @group(0) @binding(2) var bias: array<${dataType}>; - @group(0) @binding(3) var outputQ: array<${dataType}>; - @group(0) @binding(4) var outputK: array<${dataType}>; - @group(0) @binding(5) var outputV: array<${dataType}>; - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - - let batchIndex = workgroup_id.z / ${parameters.numHeads}; - let headNumber = workgroup_id.z % ${parameters.numHeads}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} + let batchIndex = workgroup_id.z / uniforms.num_heads; + let headNumber = workgroup_id.z % uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; - let inputOffset = batchIndex * (M * K) + m * K; - let biasOffsetQ = headNumber * ${parameters.headSize}; - let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; - let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K; + let biasOffsetQ = headNumber * uniforms.head_size; + let biasOffsetK = uniforms.hidden_size + biasOffsetQ; + let biasOffsetV = uniforms.hidden_size + biasOffsetK; var valueQ = ${dataType}(0); var valueK = ${dataType}(0); var valueV = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - let offset = n + (w + local_id.y) * ldb; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + let offset = n + (w + local_id.y) * uniforms.ldb; tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; } workgroupBarrier(); - for (var k: u32 = 0u; k { workgroupBarrier(); } - let headOffset = (m * N + n) % ${parameters.headSize}; + let headOffset = (m * uniforms.N + n) % uniforms.head_size; valueQ += bias[headOffset + biasOffsetQ]; valueK += bias[headOffset + biasOffsetK]; valueV += bias[headOffset + biasOffsetV]; - let offset = workgroup_id.z * M * N; - if (m < M && n < N) { - let outputIdx = offset + m * N + n; - outputQ[outputIdx] = valueQ; - outputK[outputIdx] = valueK; - outputV[outputIdx] = valueV; + let offset = workgroup_id.z * uniforms.M * uniforms.N; + if (m < uniforms.M && n < uniforms.N) { + let outputIdx = offset + m * uniforms.N + n; + output_q[outputIdx] = valueQ; + output_k[outputIdx] = valueK; + output_v[outputIdx] = valueV; } }`; - - const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + }; return context.compute( { name: 'AttentionPrepare', - shaderCache: {hint: JSON.stringify(parameters)}, + shaderCache: {inputDependencies: ['type', 'type', 'type']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, @@ -619,6 +616,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, ], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 3ce114c5d3884..bc3265be955f0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -780,8 +780,10 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + @builtin(workgroup_id) workgroup_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_idx : u32, + `@builtin(local_invocation_id) local_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index b7726a36bcaad..6d22e3780efd9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -4,10 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { @@ -228,7 +228,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr }; }; - export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => createAttributeWithCacheKey({...attributes}); @@ -239,30 +238,35 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - - const dataType = tensorTypeToWsglStorageType(qkv.dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const biasOffset = ${biasOffset}u; - const hiddenSize = ${hiddenSize}u; - - @group(0) @binding(0) var qkv: array<${dataType}>; - @group(0) @binding(1) var bias: array<${dataType}>; - @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; - + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); + const qkvInput = inputVariable('qkv', qkv.dataType, outputShape); + const biasInput = inputVariable('bias', bias.dataType, outputShape); + + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'bias_offset', type: 'u32'}, {name: 'hidden_size', type: 'u32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset; - qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx]; }`; + }; return context.compute( { name: 'MultiHeadAttentionAddBias', - shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + shaderCache: {inputDependencies: ['type', 'type']}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, },