diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index c75b662af788d..94f2220fcc168 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -459,12 +459,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1; +}; + template <> struct OperatorDescTraits { static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1; +}; + template <> struct OperatorDescTraits { @@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1> +{ + using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING> { using DescType = DML_LP_POOLING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1> +{ + using DescType = DML_LP_POOLING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING> { @@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_AVERAGE_POOLING: return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_AVERAGE_POOLING1: + return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LP_POOLING: return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LP_POOLING1: + return std::invoke(std::forward(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING: return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MAX_POOLING1: @@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; + case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; + case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 1ebd52d4ed427..9eae1c1fe8158 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA { DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_AVERAGE_POOLING1", + DML_OPERATOR_AVERAGE_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA { DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_LP_POOLING1", + DML_OPERATOR_LP_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 833871de0bbd9..ad4cceb85cfd2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,6 +425,21 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } + +inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc) { return { @@ -438,6 +453,20 @@ inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))), }; } +inline std::vector GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.P))), + }; +} inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc) { return { @@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA; case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA; case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA; case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA; @@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_AVERAGE_POOLING1: + return AbstractOperatorDesc( + &DML_AVERAGE_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_LP_POOLING: return AbstractOperatorDesc( &DML_LP_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LP_POOLING1: + return AbstractOperatorDesc( + &DML_LP_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_MAX_POOLING: return AbstractOperatorDesc( &DML_MAX_POOLING_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index e8d5b2746aa13..10ff1d8be8a29 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase kernelOutputIndices.emplace_back(1); } DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices); - + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1."); @@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_AVERAGE_POOLING1: + { + if (hasDilations) { + DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_AVERAGE_POOLING_OPERATOR_DESC desc = {}; + desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_LP_POOLING: { DML_LP_POOLING_OPERATOR_DESC desc = {}; @@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase SetOpDesc(desc); break; } + case DML_OPERATOR_LP_POOLING1: + { + if (hasDilations) { + DML_LP_POOLING1_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + desc.Dilations = m_kernel.dilations; + SetOpDesc(desc); + } + else { + DML_LP_POOLING_OPERATOR_DESC desc = {}; + desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2); + ML_CHECK_VALID_ARGUMENT(desc.P > 0); + SetOpDesc(desc); + } + break; + } case DML_OPERATOR_MAX_POOLING: case DML_OPERATOR_MAX_POOLING1: case DML_OPERATOR_MAX_POOLING2: @@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { *isSupported = false; - + MLOperatorAttributes attributes(context); int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0); @@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* *isSupported = true; } -DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 28360f09bcba3..dbe9f5da4f569 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, @@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e18ba31def48a..3eb35faeba82f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -406,6 +406,12 @@ namespace OperatorHelper static const int sc_sinceVer_BitwiseNot = 18; static const int sc_sinceVer_Pad = 18; static const int sc_sinceVer_Split = 18; + static const int sc_sinceVer_LpPool = 18; + } + + namespace OnnxOperatorSet19 + { + static const int sc_sinceVer_AveragePool = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 10476ada2fa69..4b194ec18b31b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -777,11 +777,6 @@ TEST(PoolTest, GlobalMaxPool3D) { } TEST(PoolTest, AveragePool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -863,11 +858,6 @@ TEST(PoolTest, AveragePool) { } TEST(PoolTest, AveragePool_IncludePadPixel) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool"); test.AddAttribute("auto_pad", ""); @@ -911,11 +901,6 @@ TEST(PoolTest, AveragePool_DefaultStrides) { } TEST(PoolTest, AveragePool_10_ceil1_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 10); test.AddAttribute("auto_pad", ""); @@ -939,11 +924,6 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { } TEST(PoolTest, AveragePool_19_dilation_2d) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("AveragePool", 19); test.AddAttribute("auto_pad", ""); @@ -1070,11 +1050,6 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { } TEST(PoolTest, LpPool) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect."; - } - OpTester test("LpPool"); test.AddAttribute("auto_pad", "");