Skip to content

Commit

Permalink
SkipGroupNorm fusion and SDXL Pipeline Update (#18273)
Browse files Browse the repository at this point in the history
Update a few optimizations for Stable Diffusion XL:
(1) Add SkipGroupNorm fusion
(2) Remvoe GroupNorm fusion limits. Previously, we only fuse GroupNorm
when channels is one of `320, 640, 960, 1280, 1920, 2560, 128, 256, 512`
so some GroupNorm in refiner was not fused.
(3) Tune SkipLayerNormalization to use vectorized kernel for hidden size
320, 640 and 1280.

Pipeline Improvements:
(4) Enable cuda graph for unetxl.
(5) Change optimization to generate optimized fp32 model with ORT, then
convert to fp16. Otherwise, fp16 model might be invalid.
(6) Add option to enable-vae-slicing.

Bug fixes:
(a) Fix vae decode in SD demo.
(b) Fix UnipPC add_noise missing a parameter.
(c) EulerA exception in SDXL demo. Disable it for now.
(d) Batch size > 4 has error in VAE without slicing. Force to enable vae
slicing when batch size > 4.

#### Performance Test on A100-SXM4-80GB

Description about the experiment in results:
*Baseline*: removed GroupNorm fusion limits; CUDA graph is enabled in
Clip and VAE, but not in Clip2 and UNet.
*UNetCG*: Enable Cuda Graph on UNet
*SLN*: Tune SkipLayerNormalization
*SGN*: Add SkipGroupNorm fusion

The latency (ms) of generating an image of size 1024x1024 with 30 steps
base model and 9 steps of refiner model:

  | Baseline | UNetCG| UNetCG+SLN | UNetCG+SLN+SGN
-- | -- | -- | -- | --
Base Clip | 3.74 | 3.70 | 3.88 | 3.81
Base Unet x30 | 2567.73 | 2510.69 | 2505.09 | 2499.99
Refiner Clip | 7.59 | 7.42 | 7.41 | 7.58
Refiner Unet x 9 | 814.43 | 803.03 | 802.20 | 799.06
Refiner VAE Decoder | 84.62 | 85.18 | 85.24 | 87.43
E2E | 3480.56 | 3412.05 | 3405.77 | 3400.23

We can see that enable cuda graph brought major gain (around 68ms). SLN
Tuning has about 7ms gain. SkipGroupNorm fusion has 5ms gain.

SkipGroupNorm fusion won't reduce latency much, while it also has
benefit of reducing memory usage, so it is recommended to enable it.

### Motivation and Context
Additional optimizations upon previous work in
#17536.
  • Loading branch information
tianleiwu committed Nov 7, 2023
1 parent 37d6219 commit fcd9aac
Show file tree
Hide file tree
Showing 20 changed files with 447 additions and 76 deletions.
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ half maybe2half(float x) {

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192};
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
Expand Down Expand Up @@ -206,7 +206,7 @@ void LaunchSkipLayerNormKernel(
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \
if constexpr (next_size_value >= 8 * 256) { \
if constexpr (next_size_value >= 320) { \
if (can_unroll_vec8) { \
constexpr int block_size = next_size_value / 8; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \
Expand Down Expand Up @@ -239,6 +239,9 @@ void LaunchSkipLayerNormKernel(
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
CASE_NEXT_SIZE(kSizes[7]);
CASE_NEXT_SIZE(kSizes[8]);
CASE_NEXT_SIZE(kSizes[9]);
CASE_NEXT_SIZE(kSizes[10]);
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"MemcpyFromHost": self._pass_on_shape_and_type,
"MemcpyToHost": self._pass_on_shape_and_type,
"Min": self._infer_symbolic_compute_ops,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def make_value_info_from_tensor(tensor):


# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]}
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}


class InitializerTracker:
Expand Down
22 changes: 3 additions & 19 deletions onnxruntime/python/tools/transformers/fusion_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
return

instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
if instance_norm_scale is None:
return
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None:
return

# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
)
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
return

if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
)
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
return

if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
Expand All @@ -108,10 +96,6 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):

group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")

if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
return

self.add_initializer(
name=group_norm_name + "_gamma",
data_type=TensorProto.FLOAT,
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, model_type):
if model_type in ["unet", "vae", "clip"]:
self.enable_nhwc_conv = True
self.enable_group_norm = True
self.enable_skip_group_norm = True
self.enable_bias_splitgelu = True
self.enable_packed_qkv = True
self.enable_packed_kv = True
Expand Down Expand Up @@ -116,6 +117,8 @@ def parse(args):
options.enable_nhwc_conv = False
if args.disable_group_norm:
options.enable_group_norm = False
if args.disable_skip_group_norm:
options.enable_skip_group_norm = False
if args.disable_bias_splitgelu:
options.enable_bias_splitgelu = False
if args.disable_packed_qkv:
Expand Down Expand Up @@ -250,6 +253,14 @@ def add_arguments(parser: ArgumentParser):
)
parser.set_defaults(disable_group_norm=False)

parser.add_argument(
"--disable_skip_group_norm",
required=False,
action="store_true",
help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
)
parser.set_defaults(disable_skip_group_norm=False)

parser.add_argument(
"--disable_packed_kv",
required=False,
Expand Down
255 changes: 255 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_skip_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List

from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionSkipGroupNorm(Fusion):
"""
Fuse Add + GroupNorm into one node: SkipGroupNorm.
"""

def __init__(self, model: OnnxModel):
super().__init__(model, "SkipGroupNorm", "GroupNorm")
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape(update=True)

if self.shape_infer_helper is None:
logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")
if output_name is None:
output_name = node_name + "_out" + "-" + input_name
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
return transpose_node

def get_skip_index(self, add, is_channel_last: bool):
"""Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
skip = -1
broadcast = False

assert self.shape_infer_helper is not None
shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
assert shape_a is not None and shape_b is not None

if len(shape_a) == 4 and len(shape_b) == 4:
if shape_a == shape_b:
skip = 1
else:
c = 3 if is_channel_last else 1
h = 1 if is_channel_last else 2
w = 2 if is_channel_last else 3
if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
if shape_b[h] == 1 and shape_b[w] == 1:
skip = 1
broadcast = True
elif shape_a[h] == 1 and shape_a[w] == 1:
skip = 0
broadcast = True

if skip < 0:
logger.debug(
"skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
add.input[0],
add.input[1],
)
return skip, broadcast

def has_multiple_consumers(self, output_name, input_name_to_nodes):
"""Whether an output has multiple consumers (like graph output or more than one children nodes)"""
return self.model.find_graph_output(output_name) is not None or (
output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
)

def remove_if_safe(self, node, input_name_to_nodes):
"""Remove a node if it is safe (only one children, and not graph output)"""
if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
self.nodes_to_remove.extend([node])

def is_bias_1d(self, bias_name: str):
"""Whether bias is an initializer of one dimension"""
initializer = self.model.get_initializer(bias_name)
if initializer is None:
return False

bias_weight = NumpyHelper.to_array(initializer)
if bias_weight is None:
logger.debug("Bias weight not found")
return False

if len(bias_weight.shape) != 1:
logger.debug("Bias weight is not 1D")
return False
return True

def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
"""
Match the bias graph pattern from an Transpose node after Reshape node like in below example.
It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
"""
# Before Fusion:
# MatMul (bias)
# \ / (shape)
# Add /
# \ /
# (a) Reshape
# \ |
# Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
# \ /
# Add
# / \
# (c) Transpose([0,2,3,1])
# |
# GroupNorm
# |
# (d)
#
# After Fusion (the nodes below Reshape is handled in the fuse function):
# MatMul (shape)
# \ /
# (a) Reshape
# \ /
# SkipGroupNorm
# / \
# (d) Transpose([0, 3, 1, 2])
# \
# (c)

add_input_index = []
bias_nodes = self.model.match_parent_path(
node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
)
if bias_nodes is None:
return None

(reshape, add_bias, matmul) = bias_nodes
bias = bias_nodes[1].input[1 - add_input_index[0]]
if not self.is_bias_1d(bias):
return None

reshape.input[0] = matmul.output[0]
self.remove_if_safe(add_bias, input_name_to_nodes)

return bias

def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
"""Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
parent = output_name_to_node[output_name] if output_name in output_name_to_node else None
if parent is not None and parent.op_type == "Transpose":
permutation = OnnxModel.get_node_attribute(parent, "perm")
if permutation == [0, 3, 1, 2]:
self.remove_if_safe(parent, input_name_to_nodes)
return parent
return None

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# This fusion requires shape information, so skip it if shape is not available.
if self.shape_infer_helper is None:
return

# Before Fusion:
# (a) (b)
# \ /
# Add
# /\
# (c) Transpose([0,2,3,1])
# \
# GroupNorm
# |
# (d)
#
# After Fusion:
# (a) (b)
# \ /
# Transpose([0,2,3,1]) Transpose([0,2,3,1])
# \ /
# SkipGroupNorm
# / \
# / Transpose([0, 3, 1, 2])
# / \
# (d) (c)
nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
if nodes is None:
return

(transpose, add) = nodes
if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
return

if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
return

permutation = OnnxModel.get_node_attribute(transpose, "perm")
if permutation != [0, 2, 3, 1]:
return

inputs = []
bias = None
for i in range(2):
matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
if matched_transpose:
# When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
# insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
# has only one consumer.
inputs.append(matched_transpose.input[0])
# See whether it match bias pattern.
if bias is None:
bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
else:
# Otherwise, insert a Transpose node before Add.
new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
self.model.add_node(new_transpose, self.this_graph_name)
inputs.append(new_transpose.output[0])

skip, broadcast = self.get_skip_index(add, is_channel_last=False)
if skip < 0:
return

inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
if bias:
inputs = [*inputs, bias]

outputs = node.output

new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
add_out_name = new_node_name + "_add_out"
outputs.append(add_out_name)

# Insert a Transpose node after add output.
add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
self.model.add_node(add_out_transpose, self.this_graph_name)

skip_group_norm = helper.make_node(
self.fused_op_type,
inputs=inputs,
outputs=outputs,
name=new_node_name,
)
skip_group_norm.domain = "com.microsoft"

self.increase_counter(
f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
)

# Pass attributes from GroupNorm node to SkipGroupNorm
for att in node.attribute:
skip_group_norm.attribute.extend([att])

self.nodes_to_remove.extend([add, transpose, node])
self.nodes_to_add.append(skip_group_norm)
self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
self.prune_graph = True
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
pipeline.backend.activate_engines(shared_device_memory)

if engine_type == EngineType.ORT_CUDA and args.enable_vae_slicing:
pipeline.backend.enable_vae_slicing()

pipeline.load_resources(image_height, image_width, batch_size)

def run_inference(warmup=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def run_demo():
base.backend.activate_engines(shared_device_memory)
refiner.backend.activate_engines(shared_device_memory)

if engine_type == EngineType.ORT_CUDA:
enable_vae_slicing = args.enable_vae_slicing
if batch_size > 4 and not enable_vae_slicing:
print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.")
enable_vae_slicing = True
if enable_vae_slicing:
refiner.backend.enable_vae_slicing()

base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)

Expand Down
Loading

0 comments on commit fcd9aac

Please sign in to comment.