Skip to content

Commit

Permalink
Register groupnorm for opset 21 (#22830)
Browse files Browse the repository at this point in the history
### Description
This PR registers GroupNormalization for opset 21



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
dtang317 authored Nov 14, 2024
1 parent 5659d05 commit 12dfe28
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ Do not modify directly.*
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|GroupNorm||21+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Hardmax|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO( 21, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, MatMulNBits, typeNameListTwo, supportedTypeListMatMulNBits, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMatMulNBits)},

// Operators that need to alias an input with an output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,7 @@ using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Transpose = 21;
static const int sc_sinceVer_Identity = 21;
static const int sc_sinceVer_QLinearMatMul = 21;
static const int sc_sinceVer_GroupNorm = 21;
}

namespace MsftOperatorSet1
Expand Down

0 comments on commit 12dfe28

Please sign in to comment.