Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SkipGroupNorm fusion and optimizations for SDXL #18285

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ half maybe2half(float x) {

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192};
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
Expand Down Expand Up @@ -206,7 +206,7 @@ void LaunchSkipLayerNormKernel(
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \
if constexpr (next_size_value >= 8 * 256) { \
if constexpr (next_size_value >= 320) { \
if (can_unroll_vec8) { \
constexpr int block_size = next_size_value / 8; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \
Expand Down Expand Up @@ -239,6 +239,9 @@ void LaunchSkipLayerNormKernel(
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
CASE_NEXT_SIZE(kSizes[7]);
CASE_NEXT_SIZE(kSizes[8]);
CASE_NEXT_SIZE(kSizes[9]);
CASE_NEXT_SIZE(kSizes[10]);
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"MemcpyFromHost": self._pass_on_shape_and_type,
"MemcpyToHost": self._pass_on_shape_and_type,
"Min": self._infer_symbolic_compute_ops,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def make_value_info_from_tensor(tensor):


# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]}
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}


class InitializerTracker:
Expand Down
22 changes: 3 additions & 19 deletions onnxruntime/python/tools/transformers/fusion_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
return

instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
if instance_norm_scale is None:
return
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None:
return

# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
)
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
return

if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
)
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
return

if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
Expand All @@ -108,10 +96,6 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):

group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")

if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
return

self.add_initializer(
name=group_norm_name + "_gamma",
data_type=TensorProto.FLOAT,
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, model_type):
if model_type in ["unet", "vae", "clip"]:
self.enable_nhwc_conv = True
self.enable_group_norm = True
self.enable_skip_group_norm = True
self.enable_bias_splitgelu = True
self.enable_packed_qkv = True
self.enable_packed_kv = True
Expand Down Expand Up @@ -116,6 +117,8 @@ def parse(args):
options.enable_nhwc_conv = False
if args.disable_group_norm:
options.enable_group_norm = False
if args.disable_skip_group_norm:
options.enable_skip_group_norm = False
if args.disable_bias_splitgelu:
options.enable_bias_splitgelu = False
if args.disable_packed_qkv:
Expand Down Expand Up @@ -250,6 +253,14 @@ def add_arguments(parser: ArgumentParser):
)
parser.set_defaults(disable_group_norm=False)

parser.add_argument(
"--disable_skip_group_norm",
required=False,
action="store_true",
help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
)
parser.set_defaults(disable_skip_group_norm=False)

parser.add_argument(
"--disable_packed_kv",
required=False,
Expand Down
255 changes: 255 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_skip_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List

from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionSkipGroupNorm(Fusion):
"""
Fuse Add + GroupNorm into one node: SkipGroupNorm.
"""

def __init__(self, model: OnnxModel):
super().__init__(model, "SkipGroupNorm", "GroupNorm")
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape(update=True)

if self.shape_infer_helper is None:
logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")
if output_name is None:
output_name = node_name + "_out" + "-" + input_name
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
return transpose_node

def get_skip_index(self, add, is_channel_last: bool):
"""Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
skip = -1
broadcast = False

assert self.shape_infer_helper is not None
shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
assert shape_a is not None and shape_b is not None

if len(shape_a) == 4 and len(shape_b) == 4:
if shape_a == shape_b:
skip = 1
else:
c = 3 if is_channel_last else 1
h = 1 if is_channel_last else 2
w = 2 if is_channel_last else 3
if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
if shape_b[h] == 1 and shape_b[w] == 1:
skip = 1
broadcast = True
elif shape_a[h] == 1 and shape_a[w] == 1:
skip = 0
broadcast = True

if skip < 0:
logger.debug(
"skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
add.input[0],
add.input[1],
)
return skip, broadcast

def has_multiple_consumers(self, output_name, input_name_to_nodes):
"""Whether an output has multiple consumers (like graph output or more than one children nodes)"""
return self.model.find_graph_output(output_name) is not None or (
output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
)

def remove_if_safe(self, node, input_name_to_nodes):
"""Remove a node if it is safe (only one children, and not graph output)"""
if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
self.nodes_to_remove.extend([node])

def is_bias_1d(self, bias_name: str):
"""Whether bias is an initializer of one dimension"""
initializer = self.model.get_initializer(bias_name)
if initializer is None:
return False

bias_weight = NumpyHelper.to_array(initializer)
if bias_weight is None:
logger.debug("Bias weight not found")
return False

if len(bias_weight.shape) != 1:
logger.debug("Bias weight is not 1D")
return False
return True

def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
"""
Match the bias graph pattern from an Transpose node after Reshape node like in below example.
It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
"""
# Before Fusion:
# MatMul (bias)
# \ / (shape)
# Add /
# \ /
# (a) Reshape
# \ |
# Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
# \ /
# Add
# / \
# (c) Transpose([0,2,3,1])
# |
# GroupNorm
# |
# (d)
#
# After Fusion (the nodes below Reshape is handled in the fuse function):
# MatMul (shape)
# \ /
# (a) Reshape
# \ /
# SkipGroupNorm
# / \
# (d) Transpose([0, 3, 1, 2])
# \
# (c)

add_input_index = []
bias_nodes = self.model.match_parent_path(
node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
)
if bias_nodes is None:
return None

(reshape, add_bias, matmul) = bias_nodes
bias = bias_nodes[1].input[1 - add_input_index[0]]
if not self.is_bias_1d(bias):
return None

reshape.input[0] = matmul.output[0]
self.remove_if_safe(add_bias, input_name_to_nodes)

return bias

def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
"""Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
parent = output_name_to_node[output_name] if output_name in output_name_to_node else None
if parent is not None and parent.op_type == "Transpose":
permutation = OnnxModel.get_node_attribute(parent, "perm")
if permutation == [0, 3, 1, 2]:
self.remove_if_safe(parent, input_name_to_nodes)
return parent
return None

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# This fusion requires shape information, so skip it if shape is not available.
if self.shape_infer_helper is None:
return

# Before Fusion:
# (a) (b)
# \ /
# Add
# /\
# (c) Transpose([0,2,3,1])
# \
# GroupNorm
# |
# (d)
#
# After Fusion:
# (a) (b)
# \ /
# Transpose([0,2,3,1]) Transpose([0,2,3,1])
# \ /
# SkipGroupNorm
# / \
# / Transpose([0, 3, 1, 2])
# / \
# (d) (c)
nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
if nodes is None:
return

(transpose, add) = nodes
if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
return

if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
return

permutation = OnnxModel.get_node_attribute(transpose, "perm")
if permutation != [0, 2, 3, 1]:
return

inputs = []
bias = None
for i in range(2):
matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
if matched_transpose:
# When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
# insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
# has only one consumer.
inputs.append(matched_transpose.input[0])
# See whether it match bias pattern.
if bias is None:
bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
else:
# Otherwise, insert a Transpose node before Add.
new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
self.model.add_node(new_transpose, self.this_graph_name)
inputs.append(new_transpose.output[0])

skip, broadcast = self.get_skip_index(add, is_channel_last=False)
if skip < 0:
return

inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
if bias:
inputs = [*inputs, bias]

outputs = node.output

new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
add_out_name = new_node_name + "_add_out"
outputs.append(add_out_name)

# Insert a Transpose node after add output.
add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
self.model.add_node(add_out_transpose, self.this_graph_name)

skip_group_norm = helper.make_node(
self.fused_op_type,
inputs=inputs,
outputs=outputs,
name=new_node_name,
)
skip_group_norm.domain = "com.microsoft"

self.increase_counter(
f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
)

# Pass attributes from GroupNorm node to SkipGroupNorm
for att in node.attribute:
skip_group_norm.attribute.extend([att])

self.nodes_to_remove.extend([add, transpose, node])
self.nodes_to_add.append(skip_group_norm)
self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
self.prune_graph = True
Loading
Loading