Skip to content

Commit

Permalink
Fix GroupNorm fusion: skip if num of channels not supported (#17869)
Browse files Browse the repository at this point in the history
Right now, GroupNorm only support limited number of channels (320, 640,
960, 1280, 1920, 2560, 128, 256, 512). Skip the fusion if number of
channels are not supported.

### Motivation and Context
SD XL refiner model uses number of channels 384, 768, 1152, 2304 and
3072 in GroupNorm.
  • Loading branch information
tianleiwu authored Oct 12, 2023
1 parent 25bbd8d commit e2cd674
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions onnxruntime/python/tools/transformers/fusion_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if instance_norm_bias is None:
return

if not (
len(instance_norm_scale.shape) == 1
and len(instance_norm_bias.shape) == 1
and instance_norm_scale.shape == instance_norm_bias.shape
and instance_norm_scale.shape[0] == 32
):
logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0])
# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
)
return

if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
)
return

if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
Expand All @@ -105,7 +109,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")

if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
logger.info("GroupNorm channels=%d", weight_elements)
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
return

self.add_initializer(
name=group_norm_name + "_gamma",
Expand Down

0 comments on commit e2cd674

Please sign in to comment.