From eb576a6783a9d3294d96dc6b05b682645211f0ff Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:09:11 -0800 Subject: [PATCH] Enable QLinearAveragePooling DML EP (#17384) (#18240) [Cherry Pick Reviewed] DML EP Implementation for [QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool) ``` Note: Google Test filter = *QLinear*Pool* [==========] Running 72 tests from 2 test suites. [----------] Global test environment set-up. [----------] 36 tests from QLinearGlobalAveragePool [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms) [ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 [ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms) [ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 [ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms) [----------] 36 tests from QLinearGlobalAveragePool (5968 ms total) [----------] 36 tests from QLinearPoolTest [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage [ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global [ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms) [ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms) [ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms) [ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms) [ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms) [ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms) [ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms) [ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 [ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms) [----------] 36 tests from QLinearPoolTest (12914 ms total) [----------] Global test environment tear-down [==========] 72 tests from 2 test suites ran. (18885 ms total) [ PASSED ] 72 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Description ### Motivation and Context --- .../src/External/DirectMLHelpers/ApiTraits.h | 20 ++- .../External/DirectMLHelpers/DirectMLSchema.h | 25 +++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 26 +++ .../DmlOperatorQLinearAveragePooling.cpp | 150 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 8 + .../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++++ .../dml/DmlExecutionProvider/src/TensorDesc.h | 3 + .../dml/OperatorAuthorHelper/Attributes.h | 2 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++++- .../dml/OperatorAuthorHelper/OperatorHelper.h | 28 +++- .../OperatorAuthorHelper/OperatorVersions.h | 2 + .../qlinear_global_average_pool_test.cc | 3 + 12 files changed, 339 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 94f2220fcc168..a5415ba85f3d3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,7 +24,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 160; + static constexpr auto ValueCount = 161; static constexpr size_t ActivationFunctionCount = 24; }; @@ -495,6 +495,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + template <> struct OperatorDescTraits { @@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 9eae1c1fe8158..2a82c12872a72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index ad4cceb85cfd2..99218c135f058 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -502,6 +502,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); +#pragma warning(push) +#pragma warning(disable: 4063) + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); +#pragma warning(pop) + default: ORT_THROW_HR(E_INVALIDARG); return AbstractOperatorDesc( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp new file mode 100644 index 0000000000000..0fccedfe311c1 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase +{ + // For QLinear Avg Pool ORT and DML have same indexing order + enum OrtInputTensors : uint32_t + { + ortInput, + ortInputScale, + ortInputZeroPoint, + ortOutputScale, + ortOutputZeroPoint, + ortInputCount + }; + +public: + using Self = DmlOperatorQLinearAveragePooling; + + DmlOperatorQLinearAveragePooling( + const MLOperatorKernelCreationContext& kernelInfo, + bool useGlobalPooling + ) + : DmlOperator(kernelInfo), + PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling) + { + DmlOperator::Initialize(kernelInfo); + + bool isNhwc = m_kernel.channelsLast; + std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount(); + ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2); + + // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling + uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; + if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) + { + size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount; + + for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--) + { + m_kernel.windowSize[i + shift] = m_kernel.windowSize[i]; + m_kernel.windowSize[i] = 1; + + m_kernel.strides[i + shift] = m_kernel.strides[i]; + m_kernel.strides[i] = 1; + + m_kernel.startPadding[i + shift] = m_kernel.startPadding[i]; + m_kernel.startPadding[i] = 0; + + m_kernel.endPadding[i + shift] = m_kernel.endPadding[i]; + m_kernel.endPadding[i] = 0; + + m_kernel.dilations[i + shift] = m_kernel.dilations[i]; + m_kernel.dilations[i] = 1; + } + + m_kernel.spatialDimensionCount = expectedSpatialDimCount; + } + + // Initialize dimensionMapping for NCHW or NHWC layout + std::vector dimensionMapping = {0u, dmlDimSize - 1u}; + dimensionMapping.resize(dmlDimSize); + if (isNhwc) + { + // Form a remapping for dimensions so C is moved before the spatial dimensions. + // e.g. NWC -> {0,2,1} -> NCW + // NHWC -> {0,3,1,2} -> NCHW + // NDHWC -> {0,4,1,2,3} -> NCDHW + std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u); + } + else + { + // Use NCHW {0,1,2,3} format with increasing order of indexs + std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u); + } + m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Reshape the Output Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + // Reshape the Input ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint)) + { + m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + } + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned); + + assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {}; + + qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput]; + qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale]; + qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint]; + qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];; + qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];; + qLinearAvgPooldesc.OutputTensor = &outputDescs[0]; + qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount; + qLinearAvgPooldesc.WindowSize = m_kernel.windowSize; + qLinearAvgPooldesc.Strides = m_kernel.strides; + qLinearAvgPooldesc.StartPadding = m_kernel.startPadding; + qLinearAvgPooldesc.EndPadding = m_kernel.endPadding; + qLinearAvgPooldesc.Dilations = m_kernel.dilations; + qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +template +class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling +{ +public: + DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling) + { + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 6119e3ed6a472..fa8ebb78b8d6c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool); DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); @@ -631,6 +633,10 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; +constexpr static std::array supportedTypeListQLinearAveragePool = { + SupportedTensorDataTypes::Ints8Bit +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -1036,6 +1042,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 067a320dd8000..a2183aab52eed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm } m_bufferTensorDesc.DimensionCount = newDimensionCount; } + +// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout +void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment) +{ + EnsureStridesExist(); + SetDimensionCount(static_cast(dimensionMapping.size()), alignment); + + // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping + std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount}; + std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount}; + + for (size_t i = 0; i < dimensionMapping.size(); i++) + { + m_sizes[i] = tempSizes[dimensionMapping[i]]; + m_strides[i] = tempStrides[dimensionMapping[i]]; + } + + m_bufferTensorDesc.Sizes = m_sizes; + m_bufferTensorDesc.Strides = m_strides; +} + +void TensorDesc::EnsureStridesExist() +{ + if (m_bufferTensorDesc.Strides != nullptr) + { + // Strides are populated + return; + } + + uint32_t stride = 1; + for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;) + { + m_strides[i] = stride; + stride *= m_sizes[i]; + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index ff70dec5b8871..909e2084d0163 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -44,6 +44,7 @@ namespace Dml gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } gsl::span GetStrides() const; void SetStrides(gsl::span strides); + void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment); inline uint64_t GetBufferSizeInBytes() const { @@ -90,6 +91,8 @@ namespace Dml uint32_t m_sizes[MaximumDimensionCount] = {}; uint32_t m_strides[MaximumDimensionCount] = {}; DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; + + void EnsureStridesExist(); }; class TensorDescBuilder diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index dac128f92ae0c..543e30fcd9722 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -23,8 +23,8 @@ namespace AttrName static constexpr const char* BlockSize = "blocksize"; static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; - static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* CeilMode = "ceil_mode"; + static constexpr const char* ChannelsLast = "channels_last"; static constexpr const char* Clip = "clip"; static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode"; static constexpr const char* CountIncludePad = "count_include_pad"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 4d59964dcc664..1fcd3b04300f4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -365,13 +365,20 @@ namespace OperatorHelper } // Creates a kernel that spans the entire spatial dimensions of the input. - KernelArgs InitializeGlobalKernel(gsl::span inputDimensions) + KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions) { ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor) uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount; ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor). KernelArgs args(spatialDimensionCount); + args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); + // For Global Pooling, kernel size equal to the spatial dimension of input tensor + // NHWC layout need to offset by one dim to acount for channel placed at the end + int dimOffset = args.channelsLast ? 1 : 0; for (size_t dim = 0; dim < spatialDimensionCount; ++dim) { @@ -379,7 +386,7 @@ namespace OperatorHelper args.dilations[dim] = 1; args.startPadding[dim] = 0; args.endPadding[dim] = 0; - args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]); + args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]); } return args; @@ -495,6 +502,7 @@ namespace OperatorHelper } args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0); return args; } @@ -2012,7 +2020,37 @@ namespace OperatorHelper } return outputShapes; } + + std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + + std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + auto inputShape = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast); + const uint32_t outputCount = shapeInfo.GetOutputCount(); + + std::vector outputShapes; + for (uint32_t i = 0; i < outputCount; ++i) + { + outputShapes.push_back(outputDimensions); + } + return outputShapes; + } + std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 24a022778adf4..7cadcb2feeab6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -160,6 +160,7 @@ struct KernelArgs bool autoPad = false; bool autoPadSameUpper = false; bool useCeilingOutputShape = false; + bool channelsLast = false; uint32_t spatialDimensionCount = 0; KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) @@ -188,6 +189,7 @@ struct KernelArgs KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), autoPadSameUpper(kernelArgs.autoPadSameUpper), + channelsLast(kernelArgs.channelsLast), spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); @@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose( gsl::span inputDimensions, const KernelArgs& args); -KernelArgs InitializeGlobalKernel(gsl::span inputDimensions); +KernelArgs InitializeGlobalKernel( + const MLOperatorAttributes& kernelInfo, + gsl::span inputDimensions); KernelArgs InitializeKernel( const MLOperatorAttributes& kernelInfo, @@ -1059,7 +1063,7 @@ class PoolingHelperBase bool useGlobalPooling ) : m_kernel(useGlobalPooling - ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) + ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0)) : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { if (!useGlobalPooling) @@ -1161,6 +1165,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; +class QLinearAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + +class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase +{ +public: + template + QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +}; + class SqueezeHelper { public: @@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper; using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; +using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper; +using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 94d3b127ce168..442cf2073eb02 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -445,6 +445,8 @@ namespace OperatorHelper static const int sc_sinceVer_GroupNorm = 1; static const int sc_sinceVer_DynamicQuantizeMatMul = 1; static const int sc_sinceVer_QLinearConcat = 1; + static const int sc_sinceVer_QLinearAveragePool = 1; + static const int sc_sinceVer_QLinearGlobalAveragePool = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc index 8fb245819fd26..71b6f27b5391f 100644 --- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc @@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool( test.AddInput("y_scale", {}, {y_scale}); test.AddInput("y_zero_point", {}, {y_zero_point}); test.AddOutput("Y", y_dims, y_data); + if (channels_last) { + test.AddAttribute("channels_last", (int64_t)1LL); + } auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) { const OrtValue& ort_value = fetches[0];