Skip to content

Commit

Permalink
skip_infer for SkipGroupNorm in SymbolicShapeInference (#18630)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
#18273 added
`SkipGroupNorm` contrib op but it did not skip onnx shape inference for
this op in `SymbolicShapeInference`.

This leads to failed shape inference of the transformers optimized model
with `enable_skip_group_norm=True`. Also results in an invalid float16
model for the SD CUDA example.

This PR adds `SkipGroupNorm` to `skip_infer` so that it skips onnx shape
inference for this op and instead uses the relevant dispatcher.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fix shape inference failure for models with `SkipGroupNorm` nodes.
  • Loading branch information
jambayk authored Nov 30, 2023
1 parent 227dcb3 commit c20488c
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def _onnx_infer_single_node(self, node):
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
"SkipGroupNorm",
"BiasSplitGelu",
"BiasAdd",
"NhwcConv",
Expand Down

0 comments on commit c20488c

Please sign in to comment.