Skip to content

Commit

Permalink
Enable QLinearAveragePooling DML EP (microsoft#17384) (microsoft#18240)
Browse files Browse the repository at this point in the history
[Cherry Pick Reviewed]
DML EP Implementation for

[QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool)
```
Note: Google Test filter = *QLinear*Pool*
[==========] Running 72 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 36 tests from QLinearGlobalAveragePool
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms)
[----------] 36 tests from QLinearGlobalAveragePool (5968 ms total)

[----------] 36 tests from QLinearPoolTest
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage
[       OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global
[       OK ] QLinearPoolTest.AveragePool2D_Global (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms)
[----------] 36 tests from QLinearPoolTest (12914 ms total)

[----------] Global test environment tear-down
[==========] 72 tests from 2 test suites ran. (18885 ms total)
[  PASSED  ] 72 tests.
memleakdbg:
----- No memory leaks detected -----
```

### Description
<!-- Describe your changes. -->

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
raoanag authored and jslap-ubi committed Apr 5, 2024
1 parent ef08f12 commit 28b0e8b
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 160;
static constexpr auto ValueCount = 161;
static constexpr size_t ActivationFunctionCount = 24;
};

Expand Down Expand Up @@ -495,6 +495,12 @@ struct OperatorDescTraits<DML_ROI_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
};

template <>
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};

template <>
struct OperatorDescTraits<DML_SLICE_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
{
using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
{
Expand Down Expand Up @@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
case DML_OPERATOR_ACTIVATION_GELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);

#pragma warning(push)
#pragma warning(disable: 4063)
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
#pragma warning(pop)

default:
ORT_THROW_HR(E_INVALIDARG);
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
};


constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
};

constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
13,
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,24 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SIZE_2D>(desc.PooledSize))),
};
}
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_SLICE_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_GELU_OPERATOR_DESC*>(opDesc.Desc)));
#pragma warning(push)
#pragma warning(disable: 4063)
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
#pragma warning(pop)

default:
ORT_THROW_HR(E_INVALIDARG);
return AbstractOperatorDesc(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "precomp.h"

namespace Dml
{

class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase
{
// For QLinear Avg Pool ORT and DML have same indexing order
enum OrtInputTensors : uint32_t
{
ortInput,
ortInputScale,
ortInputZeroPoint,
ortOutputScale,
ortOutputZeroPoint,
ortInputCount
};

public:
using Self = DmlOperatorQLinearAveragePooling;

DmlOperatorQLinearAveragePooling(
const MLOperatorKernelCreationContext& kernelInfo,
bool useGlobalPooling
)
: DmlOperator(kernelInfo),
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling)
{
DmlOperator::Initialize(kernelInfo);

bool isNhwc = m_kernel.channelsLast;
std::vector<DimensionType> inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);

uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount();
ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2);

// DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
{
size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;

for (int i = gsl::narrow_cast<int>(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
{
m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
m_kernel.windowSize[i] = 1;

m_kernel.strides[i + shift] = m_kernel.strides[i];
m_kernel.strides[i] = 1;

m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
m_kernel.startPadding[i] = 0;

m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
m_kernel.endPadding[i] = 0;

m_kernel.dilations[i + shift] = m_kernel.dilations[i];
m_kernel.dilations[i] = 1;
}

m_kernel.spatialDimensionCount = expectedSpatialDimCount;
}

// Initialize dimensionMapping for NCHW or NHWC layout
std::vector<uint32_t> dimensionMapping = {0u, dmlDimSize - 1u};
dimensionMapping.resize(dmlDimSize);
if (isNhwc)
{
// Form a remapping for dimensions so C is moved before the spatial dimensions.
// e.g. NWC -> {0,2,1} -> NCW
// NHWC -> {0,3,1,2} -> NCHW
// NDHWC -> {0,4,1,2,3} -> NCDHW
std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u);
}
else
{
// Use NCHW {0,1,2,3} format with increasing order of indexs
std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u);
}
m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);

// Reshape the Input Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);

// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint))
{
m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
}

// Reshape the Output Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);

// Reshape the Input ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint))
{
m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
}

// Initialize the output description while overriding the shape
m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);

assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {};

qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
qLinearAvgPooldesc.Strides = m_kernel.strides;
qLinearAvgPooldesc.StartPadding = m_kernel.startPadding;
qLinearAvgPooldesc.EndPadding = m_kernel.endPadding;
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);

DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};

template <bool UseGlobalPooling>
class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling
{
public:
DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling)
{
}
};

DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate<false>);
DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate<true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
Expand Down Expand Up @@ -634,6 +636,10 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinea
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};

constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearAveragePool = {
SupportedTensorDataTypes::Ints8Bit
};

template<typename... Args>
constexpr auto requiredConstantCpuInputs(Args... args)
{
Expand Down Expand Up @@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.

{REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
}
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}

// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
void TensorDesc::PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment)
{
EnsureStridesExist();
SetDimensionCount(static_cast<uint32_t>(dimensionMapping.size()), alignment);

// Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
std::vector<uint32_t> tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
std::vector<uint32_t> tempStrides{m_strides, m_strides + MaximumDimensionCount};

for (size_t i = 0; i < dimensionMapping.size(); i++)
{
m_sizes[i] = tempSizes[dimensionMapping[i]];
m_strides[i] = tempStrides[dimensionMapping[i]];
}

m_bufferTensorDesc.Sizes = m_sizes;
m_bufferTensorDesc.Strides = m_strides;
}

void TensorDesc::EnsureStridesExist()
{
if (m_bufferTensorDesc.Strides != nullptr)
{
// Strides are populated
return;
}

uint32_t stride = 1;
for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
{
m_strides[i] = stride;
stride *= m_sizes[i];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace Dml
gsl::span<const uint32_t> GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span<const uint32_t> GetStrides() const;
void SetStrides(gsl::span<const uint32_t> strides);
void PermuteDimensions(gsl::span<const uint32_t> dimensionMapping, const TensorAxis alignment);

inline uint64_t GetBufferSizeInBytes() const
{
Expand Down Expand Up @@ -90,6 +91,8 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};

void EnsureStridesExist();
};

class TensorDescBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace AttrName
static constexpr const char* BlockSize = "blocksize";
static constexpr const char* Border = "border";
static constexpr const char* Broadcast = "broadcast";
static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* CeilMode = "ceil_mode";
static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* Clip = "clip";
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
static constexpr const char* CountIncludePad = "count_include_pad";
Expand Down
Loading

0 comments on commit 28b0e8b

Please sign in to comment.