Skip to content

Commit

Permalink
Register LPpool18 and AvgPool 19 (#16880)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 authored and jeffbloo committed Jan 4, 2024
1 parent c3d96a7 commit 9bbe425
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,12 +459,24 @@ struct OperatorDescTraits<DML_AVERAGE_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING;
};

template <>
struct OperatorDescTraits<DML_AVERAGE_POOLING1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1;
};

template <>
struct OperatorDescTraits<DML_LP_POOLING_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING;
};

template <>
struct OperatorDescTraits<DML_LP_POOLING1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1;
};

template <>
struct OperatorDescTraits<DML_MAX_POOLING_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING>
using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1>
{
using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING>
{
using DescType = DML_LP_POOLING_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1>
{
using DescType = DML_LP_POOLING1_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING>
{
Expand Down Expand Up @@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_AVERAGE_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_AVERAGE_POOLING1:
return std::invoke(std::forward<Visitor>(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LP_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LP_POOLING1:
return std::invoke(std::forward<Visitor>(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MAX_POOLING:
return std::invoke(std::forward<Visitor>(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MAX_POOLING1:
Expand Down Expand Up @@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1";
case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1";
case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1";
case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA {
DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
};

constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA {
"DML_OPERATOR_AVERAGE_POOLING1",
DML_OPERATOR_AVERAGE_POOLING1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
Expand All @@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA {
DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false },
};

constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA {
"DML_OPERATOR_LP_POOLING1",
DML_OPERATOR_LP_POOLING1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,21 @@ inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_OPERATOR_D
OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}

inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc)
{
return {
Expand All @@ -438,6 +453,20 @@ inline std::vector<OperatorField> GetFields(const DML_LP_POOLING_OPERATOR_DESC&
OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.P))),
};
}
inline std::vector<OperatorField> GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<UINT>(desc.P))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA;
case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA;
Expand Down Expand Up @@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_AVERAGE_POOLING1:
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_AVERAGE_POOLING1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_LP_POOLING:
return AbstractOperatorDesc(
&DML_LP_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_LP_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_LP_POOLING1:
return AbstractOperatorDesc(
&DML_LP_POOLING1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_LP_POOLING1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MAX_POOLING:
return AbstractOperatorDesc(
&DML_MAX_POOLING_OPERATOR_SCHEMA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
kernelOutputIndices.emplace_back(1);
}
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
Expand Down Expand Up @@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
case DML_OPERATOR_AVERAGE_POOLING1:
{
if (hasDilations) {
DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {};
desc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
desc.Dilations = m_kernel.dilations;
SetOpDesc(desc);
}
else {
DML_AVERAGE_POOLING_OPERATOR_DESC desc = {};
desc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
SetOpDesc(desc);
}
break;
}
case DML_OPERATOR_LP_POOLING:
{
DML_LP_POOLING_OPERATOR_DESC desc = {};
Expand All @@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
case DML_OPERATOR_LP_POOLING1:
{
if (hasDilations) {
DML_LP_POOLING1_OPERATOR_DESC desc = {};
desc.P = kernelInfo.GetOptionalAttribute<int>(AttrName::P, 2);
ML_CHECK_VALID_ARGUMENT(desc.P > 0);
desc.Dilations = m_kernel.dilations;
SetOpDesc(desc);
}
else {
DML_LP_POOLING_OPERATOR_DESC desc = {};
desc.P = kernelInfo.GetOptionalAttribute<int>(AttrName::P, 2);
ML_CHECK_VALID_ARGUMENT(desc.P > 0);
SetOpDesc(desc);
}
break;
}
case DML_OPERATOR_MAX_POOLING:
case DML_OPERATOR_MAX_POOLING1:
case DML_OPERATOR_MAX_POOLING2:
Expand Down Expand Up @@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling
void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;

MLOperatorAttributes attributes(context);

int storageOrder = attributes.GetOptionalAttribute<int>(AttrName::StorageOrder, 0);
Expand All @@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool*
*isSupported = true;
}

DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate<DML_OPERATOR_AVERAGE_POOLING, false>);
DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate<DML_OPERATOR_AVERAGE_POOLING1, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate<DML_OPERATOR_AVERAGE_POOLING, true>);
DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate<DML_OPERATOR_MAX_POOLING2, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate<DML_OPERATOR_MAX_POOLING, true>);
DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate<DML_OPERATOR_LP_POOLING, false>);
DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate<DML_OPERATOR_LP_POOLING1, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate<DML_OPERATOR_LP_POOLING, true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
Expand All @@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,12 @@ namespace OperatorHelper
static const int sc_sinceVer_BitwiseNot = 18;
static const int sc_sinceVer_Pad = 18;
static const int sc_sinceVer_Split = 18;
static const int sc_sinceVer_LpPool = 18;
}

namespace OnnxOperatorSet19
{
static const int sc_sinceVer_AveragePool = 19;
}

namespace MsftOperatorSet1
Expand Down
Loading

0 comments on commit 9bbe425

Please sign in to comment.