From 43ea8459b2487a8e40900c728207453392cc66ee Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Thu, 21 Sep 2023 19:21:07 -0700 Subject: [PATCH] Enable QLinearAveragePooling DML EP (#17384) DML EP Implementation for [QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) ``` Note: Google Test filter = *QLinear*Pool* [==========] Running 72 tests from 2 test suites. [----------] Global test environment set-up. [----------] 36 tests from QLinearGlobalAveragePool [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms) [----------] 36 tests from QLinearGlobalAveragePool (5968 ms total) [----------] 36 tests from QLinearPoolTest [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage [ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global [ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms) [----------] 36 tests from QLinearPoolTest (12914 ms total) [----------] Global test environment tear-down [==========] 72 tests from 2 test suites ran. (18885 ms total) [ PASSED ] 72 tests. memleakdbg: ----- No memory leaks detected ----- ``` --- .../src/External/DirectMLHelpers/ApiTraits.h | 37 ++- .../External/DirectMLHelpers/DirectMLSchema.h | 25 ++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 22 ++ .../src/Operators/DmlOperatorAttention.cpp | 46 ++-- .../src/Operators/DmlOperatorQAttention.cpp | 252 ++++++++++++++++-- .../DmlOperatorQLinearAveragePooling.cpp | 150 +++++++++++ .../src/Operators/OperatorRegistration.cpp | 8 + .../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++ .../dml/DmlExecutionProvider/src/TensorDesc.h | 3 + .../dml/OperatorAuthorHelper/Attributes.h | 1 + .../OperatorAuthorHelper/OperatorHelper.cpp | 95 ++++++- .../dml/OperatorAuthorHelper/OperatorHelper.h | 47 +++- .../OperatorAuthorHelper/OperatorVersions.h | 2 + .../qlinear_global_average_pool_test.cc | 3 + 14 files changed, 679 insertions(+), 48 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index f6d71ce629a8d..570a0f82b62ff 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -14,7 +14,26 @@ struct DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC _Maybenull_ const DML_TENSOR_DESC* BiasTensor; const DML_TENSOR_DESC* OutputTensor; }; -const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011; +const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011; + +struct DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC +{ + const DML_TENSOR_DESC* InputTensor; + const DML_TENSOR_DESC* InputScaleTensor; + _Maybenull_ const DML_TENSOR_DESC* InputZeroPointTensor; + const DML_TENSOR_DESC* OutputScaleTensor; + _Maybenull_ const DML_TENSOR_DESC* OutputZeroPointTensor; + const DML_TENSOR_DESC* OutputTensor; + UINT DimensionCount; + _Field_size_(DimensionCount) const UINT* Strides; + _Field_size_(DimensionCount) const UINT* WindowSize; + _Field_size_(DimensionCount) const UINT* StartPadding; + _Field_size_(DimensionCount) const UINT* EndPadding; + _Field_size_(DimensionCount) const UINT* Dilations; + BOOL IncludePadding; +}; +const int DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING = 0x8000000B; + namespace ApiTraits { @@ -38,7 +57,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 161; + static constexpr auto ValueCount = 162; static constexpr size_t ActivationFunctionCount = 24; }; @@ -497,6 +516,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + template <> struct OperatorDescTraits { @@ -1492,6 +1517,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2524,6 +2555,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args #pragma warning(disable: 4063) case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); #pragma warning(pop) default: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index f3a3aec50e4b4..2e9217cf3f4f7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -829,6 +829,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 90915c7e757de..1b82295ea4f9e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -473,6 +473,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -2492,6 +2510,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); #pragma warning(pop) default: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index bbebb4a333baf..e9559956704d5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -6,9 +6,11 @@ /* Abbreviations: B is batch_size, S is sequence_length, W is hidden_size N is number of attention heads, H is head size, and W=N*H - M is mask_index tensor + M is mask_index tensor, P is optional past tensor - M A B C // M, A, B, and C are Inputs + // M, A, B, C and P are Inputs + + M A B C | \ | / | Gemm | / | \ @@ -17,13 +19,30 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size | Slice Slice Slice | | | | | | | | - | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while - | | | | // keeping the GEMM strides as NCHW to better target metacommands + | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while + | | | | // keeping the GEMM strides as NCHW to better target metacommands | | | | - ----------------- MHA ----- - | - | - Output // Final output + | | | | P + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + Output1 presentKey presentValue + \ / + \ / + \ / + Concat + | + Output2 (present) This kernel creates a DML_GRAPH, as mentioned above. For reference, refer to this Doc: @@ -540,11 +559,6 @@ class DmlOperatorAttention : public DmlOperator void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // `past` input tensor is not supported yet - if (context->IsInputValid(4)) - { - return; - } // `past_sequence_length` input tensor is not supported yet if (context->IsInputValid(6)) @@ -552,12 +566,6 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o return; } - // `present` output tensor is not supported yet - if (context->IsOutputValid(1)) - { - return; - } - // `unidirectional == 1` is not supported yet MLOperatorAttributes attributes(context); if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 6e0785c91a43b..9cec8dc804f0e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -8,7 +8,9 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size N is number of attention heads, H is head size, and W=N*H M is mask_index tensor - M A B C // M, A, B, and C are Inputs +// M, A, B, C and P are Inputs + + M A B C | | | / | MatMulIntToFloat | / | \ @@ -20,10 +22,27 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while | | | | // keeping the GEMM strides as NCHW to better target metacommands | | | | - ----------------- MHA ----- - | - | - Output // Final output + | | | | P + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + Output1 presentKey presentValue + \ / + \ / + \ / + Concat + | + Output2 (present) This kernel creates a DML_GRAPH, as mentioned above. For reference, refer to this Doc: @@ -69,18 +88,28 @@ class DmlOperatorQAttention : public DmlOperator inputCount, }; + enum MHAOutputIndex : uint32_t + { + mhaOutputIndex, + mhaPresentKeyIndex, + mhaPresentValueIndex, + mhaOutputCount, + }; + enum OutputIndex : uint32_t { outputIndex, + presentIndex, outputCount, }; - ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; + const bool hasPast = kernelCreationContext.IsInputValid(pastIndex); DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); @@ -117,12 +146,15 @@ class DmlOperatorQAttention : public DmlOperator ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); } + const bool unidirectional = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Unidirectional)); + const uint32_t hiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[0]; const uint32_t vHiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[2]; const uint32_t headSize = hiddenSize / numHeads; const uint32_t vHeadSize = vHiddenSize / numHeads; const uint32_t batchSize = inputTensorShape[0]; const uint32_t sequenceLength = inputTensorShape[1]; + const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0; uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], hiddenSize + hiddenSize + vHiddenSize}; MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType; @@ -189,6 +221,16 @@ class DmlOperatorQAttention : public DmlOperator } } + MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined; + MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined; + if (hasPast) + { + pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType; + presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType; + auto pastTensorShape = m_inputTensorDescs[pastIndex].GetSizes(); + m_inputTensorDescs[pastIndex] = TensorDesc::ConstructDefaultTensorDesc(kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType, pastTensorShape); + } + TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc(); @@ -319,12 +361,84 @@ class DmlOperatorQAttention : public DmlOperator } const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc}; + // We need to slice Past to get PastValue and PastKey tensor for MHA + std::array pastKeyOutputShape {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastKeyStrides = {0, 1, 1, 1, 1}; + std::array pastKeyOffsets = {0, 0, 0, 0, 0}; + TensorDesc pastKeyOutputTensorDesc; + DML_TENSOR_DESC namedPastKeyOutputTensorDesc; + + std::array pastValueOutputShape {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastValueStrides = {0, 1, 1, 1, 1}; + //std::array pastValueOffsets {0, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastValueOffsets {1, 0, 0, 0, 0}; + TensorDesc pastValueOutputTensorDesc; + DML_TENSOR_DESC namedPastValueOutputTensorDesc; + + DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {}; + DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {}; + + // Check if needed DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; + if (hasPast) + { + pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape); + namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc(); + pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc; + pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(5); + pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data(); + pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data(); + pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data(); + + pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape); + namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc(); + pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc; + pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(5); + pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data(); + pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data(); + pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data(); + } + + const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc}; + const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; + + std::array unidirectionalMaskOutputShape {1, batchSize}; + TensorDesc unidirectionalMaskTensorDesc; + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC unidirectionalMaskOperatorDesc = {}; + DML_TENSOR_DESC namedUnidirectionalMaskTensorDesc; + + if (unidirectional && !hasMask) + { + unidirectionalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, unidirectionalMaskOutputShape); + namedUnidirectionalMaskTensorDesc = unidirectionalMaskTensorDesc.GetDmlDesc(); + unidirectionalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; + unidirectionalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength; + unidirectionalMaskOperatorDesc.ValueDelta.Int32 = 1; + unidirectionalMaskOperatorDesc.OutputTensor = &namedUnidirectionalMaskTensorDesc; + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; + } + DML_OPERATOR_DESC unidirectionalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &unidirectionalMaskOperatorDesc }; + + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; + std::array preseKeyOutputShape {batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + std::array preseValueOutputShape {batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + TensorDesc presetKeyTensorDesc; + TensorDesc presetValueTensorDesc; + DML_TENSOR_DESC namedPresentKeyOutputTensorDesc; + DML_TENSOR_DESC namedPresentValueOutputTensorDesc; + mhaOperatorDesc.ValueTensor = hasSlicedValue ? &namedValueSlicedInputTensorDesc : nullptr; mhaOperatorDesc.StackedQueryKeyTensor = hasSlicedValue ? &namedQueryKeyTransposedOutputTensorDesc : nullptr; mhaOperatorDesc.StackedQueryKeyValueTensor = hasSlicedValue ? nullptr : &namedQueryKeyValueTransposedOutputTensorDesc; - if (hasMaxSequenceMask) + if (unidirectional && !hasMask) + { + mhaOperatorDesc.MaskTensor = &namedUnidirectionalMaskTensorDesc; + } + else if (hasMaxSequenceMask) { mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc; } @@ -339,8 +453,35 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; + if (hasPast) + { + presetKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, preseKeyOutputShape); + namedPresentKeyOutputTensorDesc = presetKeyTensorDesc.GetDmlDesc(); + presetValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, preseValueOutputShape); + namedPresentValueOutputTensorDesc = presetValueTensorDesc.GetDmlDesc(); + mhaOperatorDesc.PastKeyTensor = hasPast ? &namedPastKeyOutputTensorDesc : nullptr; + mhaOperatorDesc.PastValueTensor = hasPast ? &namedPastValueOutputTensorDesc : nullptr; + mhaOperatorDesc.OutputPresentKeyTensor = hasPast ? &namedPresentKeyOutputTensorDesc : nullptr; + mhaOperatorDesc.OutputPresentValueTensor = hasPast ? &namedPresentValueOutputTensorDesc : nullptr; + } + const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc }; + DML_JOIN_OPERATOR_DESC presetKeyValueJoinOperatorDesc = {}; + + if (hasPast) + { + std::vector joinInputDescs; + joinInputDescs.push_back(namedPresentKeyOutputTensorDesc); + joinInputDescs.push_back(namedPresentValueOutputTensorDesc); + presetKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast(joinInputDescs.size()); + presetKeyValueJoinOperatorDesc.InputTensors = joinInputDescs.data(); + presetKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex]; + presetKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast(4);//m_outputTensorDescs[presentIndex].GetDimensionCount(); + } + + DML_OPERATOR_DESC presetKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presetKeyValueJoinOperatorDesc }; + // Construct the graph std::vector inputEdges; std::vector intermediateEdges; @@ -383,6 +524,26 @@ class DmlOperatorQAttention : public DmlOperator maskSliceNodeIndex = currentNodeIndex++; } + uint32_t pastKeySliceNodeIndex = 0; + uint32_t pastValueSliceNodeIndex = 0; + uint32_t concatNodeIndex = 0; + if (hasPast) + { + opDescs.push_back(&pastKeySlicedDesc); + pastKeySliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&pastValueSlicedDesc); + pastValueSliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&presetKeyValueJoinDesc); + concatNodeIndex = currentNodeIndex++; + } + + uint32_t unidirectionalMaskNodeIndex = 0; + if (unidirectional && !hasMask) + { + opDescs.push_back(&unidirectionalMaskDesc); + unidirectionalMaskNodeIndex = currentNodeIndex++; + } + DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {}; inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex; inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; @@ -463,6 +624,55 @@ class DmlOperatorQAttention : public DmlOperator } } + if (unidirectional && !hasMask) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC unidirectionalMaskToMhaEdge = {}; + unidirectionalMaskToMhaEdge.FromNodeIndex = unidirectionalMaskNodeIndex; + unidirectionalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex; + unidirectionalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex ; + intermediateEdges.push_back(unidirectionalMaskToMhaEdge); + + } + + if (hasPast) + { + DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {}; + pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex; + pastToPastKeySliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastKeySliceEdge); + + DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {}; + pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex; + pastToPastValueSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastValueSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {}; + pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex; + pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex; + intermediateEdges.push_back(pastKeyToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {}; + pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex; + pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex; + intermediateEdges.push_back(pastValueToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {}; + presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex; + presentKeyToConcatEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(presentKeyToConcatEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {}; + presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentValueToConcatEdge.ToNodeIndex = concatNodeIndex; + presentValueToConcatEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(presentValueToConcatEdge); + } + if (hasSlicedValue) { // We need to slice QK and V, and transpose QK @@ -521,10 +731,19 @@ class DmlOperatorQAttention : public DmlOperator DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; mhaToOutputEdge.FromNodeIndex = mhaNodeIndex; - mhaToOutputEdge.FromNodeOutputIndex = 0; - mhaToOutputEdge.GraphOutputIndex = 0; + mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex; + mhaToOutputEdge.GraphOutputIndex = outputIndex; outputEdges.push_back(mhaToOutputEdge); + if (hasPast) + { + DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {}; + concatToOutputEdge.FromNodeIndex = concatNodeIndex; + concatToOutputEdge.FromNodeOutputIndex = 0; + concatToOutputEdge.GraphOutputIndex = presentIndex; + outputEdges.push_back(concatToOutputEdge); + } + MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); operatorGraphDesc.inputEdges = inputEdges.data(); @@ -542,21 +761,10 @@ class DmlOperatorQAttention : public DmlOperator void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // `past` input tensor is not supported yet - if (context->IsInputValid(8)) - { - return; - } - - // `present` output tensor is not supported yet - if (context->IsOutputValid(1)) - { - return; - } - // `unidirectional == 1` is not supported yet + // `unidirectional == 1` with Mask Tensor is not supported yet MLOperatorAttributes attributes(context); - if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0) + if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5)) { return; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp new file mode 100644 index 0000000000000..0fccedfe311c1 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase +{ + // For QLinear Avg Pool ORT and DML have same indexing order + enum OrtInputTensors : uint32_t + { + ortInput, + ortInputScale, + ortInputZeroPoint, + ortOutputScale, + ortOutputZeroPoint, + ortInputCount + }; + +public: + using Self = DmlOperatorQLinearAveragePooling; + + DmlOperatorQLinearAveragePooling( + const MLOperatorKernelCreationContext& kernelInfo, + bool useGlobalPooling + ) + : DmlOperator(kernelInfo), + PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling) + { + DmlOperator::Initialize(kernelInfo); + + bool isNhwc = m_kernel.channelsLast; + std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount(); + ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2); + + // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling + uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; + if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) + { + size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount; + + for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--) + { + m_kernel.windowSize[i + shift] = m_kernel.windowSize[i]; + m_kernel.windowSize[i] = 1; + + m_kernel.strides[i + shift] = m_kernel.strides[i]; + m_kernel.strides[i] = 1; + + m_kernel.startPadding[i + shift] = m_kernel.startPadding[i]; + m_kernel.startPadding[i] = 0; + + m_kernel.endPadding[i + shift] = m_kernel.endPadding[i]; + m_kernel.endPadding[i] = 0; + + m_kernel.dilations[i + shift] = m_kernel.dilations[i]; + m_kernel.dilations[i] = 1; + } + + m_kernel.spatialDimensionCount = expectedSpatialDimCount; + } + + // Initialize dimensionMapping for NCHW or NHWC layout + std::vector dimensionMapping = {0u, dmlDimSize - 1u}; + dimensionMapping.resize(dmlDimSize); + if (isNhwc) + { + // Form a remapping for dimensions so C is moved before the spatial dimensions. + // e.g. NWC -> {0,2,1} -> NCW + // NHWC -> {0,3,1,2} -> NCHW + // NDHWC -> {0,4,1,2,3} -> NCDHW + std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u); + } + else + { + // Use NCHW {0,1,2,3} format with increasing order of indexs + std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u); + } + m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Reshape the Output Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {}; + + qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; + qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; + qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; + qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; + qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; + qLinearAvgPooldesc.Strides = m_kernel.strides; + qLinearAvgPooldesc.StartPadding = m_kernel.startPadding; + qLinearAvgPooldesc.EndPadding = m_kernel.endPadding; + qLinearAvgPooldesc.Dilations = m_kernel.dilations; + qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +template +class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling +{ +public: + DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling) + { + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 07ff4f3145459..daa8d70b6dac2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -257,6 +257,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool); DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); @@ -587,6 +589,10 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; +constexpr static std::array supportedTypeListQLinearAveragePool = { + SupportedTensorDataTypes::Ints8Bit +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -992,6 +998,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 067a320dd8000..a2183aab52eed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm } m_bufferTensorDesc.DimensionCount = newDimensionCount; } + +// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout +void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment) +{ + EnsureStridesExist(); + SetDimensionCount(static_cast(dimensionMapping.size()), alignment); + + // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping + std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount}; + std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount}; + + for (size_t i = 0; i < dimensionMapping.size(); i++) + { + m_sizes[i] = tempSizes[dimensionMapping[i]]; + m_strides[i] = tempStrides[dimensionMapping[i]]; + } + + m_bufferTensorDesc.Sizes = m_sizes; + m_bufferTensorDesc.Strides = m_strides; +} + +void TensorDesc::EnsureStridesExist() +{ + if (m_bufferTensorDesc.Strides != nullptr) + { + // Strides are populated + return; + } + + uint32_t stride = 1; + for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;) + { + m_strides[i] = stride; + stride *= m_sizes[i]; + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index ff70dec5b8871..909e2084d0163 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -44,6 +44,7 @@ namespace Dml gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } gsl::span GetStrides() const; void SetStrides(gsl::span strides); + void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment); inline uint64_t GetBufferSizeInBytes() const { @@ -90,6 +91,8 @@ namespace Dml uint32_t m_sizes[MaximumDimensionCount] = {}; uint32_t m_strides[MaximumDimensionCount] = {}; DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; + + void EnsureStridesExist(); }; class TensorDescBuilder diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 5be84a931f4f1..543e30fcd9722 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -24,6 +24,7 @@ namespace AttrName static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; static constexpr const char* CeilMode = "ceil_mode"; + static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* Clip = "clip"; static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode"; static constexpr const char* CountIncludePad = "count_include_pad"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 4d59964dcc664..7158ec864e47a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -365,13 +365,20 @@ namespace OperatorHelper } // Creates a kernel that spans the entire spatial dimensions of the input. - KernelArgs InitializeGlobalKernel(gsl::span inputDimensions) + KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions) { ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor) uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount; ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor). KernelArgs args(spatialDimensionCount); + args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); + // For Global Pooling, kernel size equal to the spatial dimension of input tensor + // NHWC layout need to offset by one dim to acount for channel placed at the end + int dimOffset = args.channelsLast ? 1 : 0; for (size_t dim = 0; dim < spatialDimensionCount; ++dim) { @@ -379,7 +386,7 @@ namespace OperatorHelper args.dilations[dim] = 1; args.startPadding[dim] = 0; args.endPadding[dim] = 0; - args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]); + args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]); } return args; @@ -495,6 +502,7 @@ namespace OperatorHelper } args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); return args; } @@ -2012,7 +2020,37 @@ namespace OperatorHelper } return outputShapes; } + + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + + std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); @@ -2706,6 +2744,59 @@ namespace OperatorHelper m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); } + std::vector QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5); + + auto queryShape = shapeInfo.GetInputTensorShape(0); + ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3); + + auto weightShape = shapeInfo.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2); + + if (m_qkvHiddenSizes.empty()) + { + ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0); + } + else + { + ML_CHECK_VALID_ARGUMENT(m_qkvHiddenSizes.size() == 3); + } + + const uint32_t batchSize = queryShape[0]; + const uint32_t sequenceLength = queryShape[1]; + const uint32_t vHiddenSize = m_qkvHiddenSizes.empty() ? weightShape[1] / 3 : m_qkvHiddenSizes[2]; + uint32_t hiddenSize = m_qkvHiddenSizes.empty() ? weightShape[1] / 3 : m_qkvHiddenSizes[0]; + uint32_t headSize = hiddenSize / m_numHeads; + + + std::vector outputShapes(2); + + outputShapes[0] = EdgeShapes({batchSize, sequenceLength, vHiddenSize}); + + uint32_t totalSequenceLength = sequenceLength; + if (shapeInfo.IsInputValid(8)) + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5); + const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3]; + totalSequenceLength += pastSequenceLength; + } + + if (shapeInfo.IsOutputValid(1)) + { + outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize}); + } + + return outputShapes; + } + + void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation) + { + m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); + m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads)); + + } + std::vector SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 5add951dccb78..4a81d5cff6d14 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -160,6 +160,7 @@ struct KernelArgs bool autoPad = false; bool autoPadSameUpper = false; bool useCeilingOutputShape = false; + bool channelsLast = false; uint32_t spatialDimensionCount = 0; KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) @@ -188,6 +189,7 @@ struct KernelArgs KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), autoPadSameUpper(kernelArgs.autoPadSameUpper), + channelsLast(kernelArgs.channelsLast), spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); @@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose( gsl::span inputDimensions, const KernelArgs& args); -KernelArgs InitializeGlobalKernel(gsl::span inputDimensions); +KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions); KernelArgs InitializeKernel( const MLOperatorAttributes& kernelInfo, @@ -1066,7 +1070,7 @@ class PoolingHelperBase bool useGlobalPooling ) : m_kernel(useGlobalPooling - ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) + ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0)) : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { if (!useGlobalPooling) @@ -1168,6 +1172,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; +class QLinearAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + +class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + class SqueezeHelper { public: @@ -1468,6 +1490,23 @@ class AttentionHelper std::vector m_qkvHiddenSizes; }; +class QAttentionHelper +{ +public: + template + QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo) + { + Initialize(KernelInformationAdapter(info)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +private: + void Initialize(const IKernelInformationAdapter& kernelInformation); + std::vector m_qkvHiddenSizes; + uint32_t m_numHeads; +}; + class SkipLayerNormHelper { public: @@ -1497,6 +1536,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper; using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; +using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper; +using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; @@ -1606,7 +1647,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; -using ShapeInferenceHelper_QAttention = AttentionHelper; +using ShapeInferenceHelper_QAttention = QAttentionHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index d785f77e24344..078f4a7aef6b0 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -441,6 +441,8 @@ namespace OperatorHelper static const int sc_sinceVer_GroupNorm = 1; static const int sc_sinceVer_DynamicQuantizeMatMul = 1; static const int sc_sinceVer_QLinearConcat = 1; + static const int sc_sinceVer_QLinearAveragePool = 1; + static const int sc_sinceVer_QLinearGlobalAveragePool = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc index 8fb245819fd26..71b6f27b5391f 100644 --- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc @@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool( test.AddInput("y_scale", {}, {y_scale}); test.AddInput("y_zero_point", {}, {y_zero_point}); test.AddOutput("Y", y_dims, y_data); + if (channels_last) { + test.AddAttribute("channels_last", (int64_t)1LL); + } auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) { const OrtValue& ort_value = fetches[0];