Skip to content

Commit

Permalink
Enable QLinearAveragePooling DML EP (#17384)
Browse files Browse the repository at this point in the history
DML EP Implementation for
[QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool)
```
Note: Google Test filter = *QLinear*Pool*
[==========] Running 72 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 36 tests from QLinearGlobalAveragePool
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms)
[ RUN      ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8
[       OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms)
[ RUN      ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8
[       OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms)
[----------] 36 tests from QLinearGlobalAveragePool (5968 ms total)

[----------] 36 tests from QLinearPoolTest
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage
[       OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global
[       OK ] QLinearPoolTest.AveragePool2D_Global (481 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms)
[ RUN      ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms)
[ RUN      ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8
[       OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms)
[----------] 36 tests from QLinearPoolTest (12914 ms total)

[----------] Global test environment tear-down
[==========] 72 tests from 2 test suites ran. (18885 ms total)
[  PASSED  ] 72 tests.
memleakdbg:
----- No memory leaks detected -----
```
  • Loading branch information
raoanag committed Oct 13, 2023
1 parent adb2dd3 commit 43ea845
Show file tree
Hide file tree
Showing 14 changed files with 679 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,26 @@ struct DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
const DML_TENSOR_DESC* OutputTensor;
};
const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011;
const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011;

struct DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC
{
const DML_TENSOR_DESC* InputTensor;
const DML_TENSOR_DESC* InputScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* InputZeroPointTensor;
const DML_TENSOR_DESC* OutputScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputZeroPointTensor;
const DML_TENSOR_DESC* OutputTensor;
UINT DimensionCount;
_Field_size_(DimensionCount) const UINT* Strides;
_Field_size_(DimensionCount) const UINT* WindowSize;
_Field_size_(DimensionCount) const UINT* StartPadding;
_Field_size_(DimensionCount) const UINT* EndPadding;
_Field_size_(DimensionCount) const UINT* Dilations;
BOOL IncludePadding;
};
const int DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING = 0x8000000B;


namespace ApiTraits
{
Expand All @@ -38,7 +57,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 161;
static constexpr auto ValueCount = 162;
static constexpr size_t ActivationFunctionCount = 24;
};

Expand Down Expand Up @@ -497,6 +516,12 @@ struct OperatorDescTraits<DML_ROI_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
};

template <>
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};

template <>
struct OperatorDescTraits<DML_SLICE_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1492,6 +1517,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
{
using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
{
Expand Down Expand Up @@ -2524,6 +2555,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
#pragma warning(disable: 4063)
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
#pragma warning(pop)

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
};


constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
};

constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
13,
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,24 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SIZE_2D>(desc.PooledSize))),
};
}
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_SLICE_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -2492,6 +2510,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
#pragma warning(pop)

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
/*
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
N is number of attention heads, H is head size, and W=N*H
M is mask_index tensor
M is mask_index tensor, P is optional past tensor
M A B C // M, A, B, and C are Inputs
// M, A, B, C and P are Inputs
M A B C
| \ | /
| Gemm
| / | \
Expand All @@ -17,13 +19,30 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
| Slice Slice Slice
| | | |
| | | |
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| | | |
----------------- MHA -----
|
|
Output // Final output
| | | | P
| | | | / \
| | | | / \
| | | | Slice Slice
| | | | | |
| | | | | |
| | | | | |
--------------------------MHA -----------
/ | \
/ | \
/ | \
/ | \
/ | \
/ | \
Output1 presentKey presentValue
\ /
\ /
\ /
Concat
|
Output2 (present)
This kernel creates a DML_GRAPH, as mentioned above.
For reference, refer to this Doc:
Expand Down Expand Up @@ -540,24 +559,13 @@ class DmlOperatorAttention : public DmlOperator
void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
{
*isSupported = false;
// `past` input tensor is not supported yet
if (context->IsInputValid(4))
{
return;
}

// `past_sequence_length` input tensor is not supported yet
if (context->IsInputValid(6))
{
return;
}

// `present` output tensor is not supported yet
if (context->IsOutputValid(1))
{
return;
}

// `unidirectional == 1` is not supported yet
MLOperatorAttributes attributes(context);
if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0)
Expand Down
Loading

0 comments on commit 43ea845

Please sign in to comment.