From fcd9aac562226c15ed472574e7f230512637bb41 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 6 Nov 2023 22:02:33 -0800 Subject: [PATCH] SkipGroupNorm fusion and SDXL Pipeline Update (#18273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update a few optimizations for Stable Diffusion XL: (1) Add SkipGroupNorm fusion (2) Remvoe GroupNorm fusion limits. Previously, we only fuse GroupNorm when channels is one of `320, 640, 960, 1280, 1920, 2560, 128, 256, 512` so some GroupNorm in refiner was not fused. (3) Tune SkipLayerNormalization to use vectorized kernel for hidden size 320, 640 and 1280. Pipeline Improvements: (4) Enable cuda graph for unetxl. (5) Change optimization to generate optimized fp32 model with ORT, then convert to fp16. Otherwise, fp16 model might be invalid. (6) Add option to enable-vae-slicing. Bug fixes: (a) Fix vae decode in SD demo. (b) Fix UnipPC add_noise missing a parameter. (c) EulerA exception in SDXL demo. Disable it for now. (d) Batch size > 4 has error in VAE without slicing. Force to enable vae slicing when batch size > 4. #### Performance Test on A100-SXM4-80GB Description about the experiment in results: *Baseline*: removed GroupNorm fusion limits; CUDA graph is enabled in Clip and VAE, but not in Clip2 and UNet. *UNetCG*: Enable Cuda Graph on UNet *SLN*: Tune SkipLayerNormalization *SGN*: Add SkipGroupNorm fusion The latency (ms) of generating an image of size 1024x1024 with 30 steps base model and 9 steps of refiner model:   | Baseline | UNetCG| UNetCG+SLN | UNetCG+SLN+SGN -- | -- | -- | -- | -- Base Clip | 3.74 | 3.70 | 3.88 | 3.81 Base Unet x30 | 2567.73 | 2510.69 | 2505.09 | 2499.99 Refiner Clip | 7.59 | 7.42 | 7.41 | 7.58 Refiner Unet x 9 | 814.43 | 803.03 | 802.20 | 799.06 Refiner VAE Decoder | 84.62 | 85.18 | 85.24 | 87.43 E2E | 3480.56 | 3412.05 | 3405.77 | 3400.23 We can see that enable cuda graph brought major gain (around 68ms). SLN Tuning has about 7ms gain. SkipGroupNorm fusion has 5ms gain. SkipGroupNorm fusion won't reduce latency much, while it also has benefit of reducing memory usage, so it is recommended to enable it. ### Motivation and Context Additional optimizations upon previous work in https://github.com/microsoft/onnxruntime/pull/17536. --- .../cuda/bert/skip_layer_norm_impl.cu | 7 +- .../python/tools/symbolic_shape_infer.py | 2 + .../python/tools/transformers/float16.py | 3 +- .../tools/transformers/fusion_group_norm.py | 22 +- .../tools/transformers/fusion_options.py | 11 + .../transformers/fusion_skip_group_norm.py | 255 ++++++++++++++++++ .../models/stable_diffusion/demo_txt2img.py | 3 + .../stable_diffusion/demo_txt2img_xl.py | 8 + .../models/stable_diffusion/demo_utils.py | 5 +- .../stable_diffusion/diffusion_models.py | 36 ++- .../stable_diffusion/diffusion_schedulers.py | 1 + .../models/stable_diffusion/engine_builder.py | 21 +- .../engine_builder_ort_cuda.py | 41 ++- .../models/stable_diffusion/ort_optimizer.py | 52 ++-- .../stable_diffusion/pipeline_txt2img.py | 2 +- .../python/tools/transformers/onnx_model.py | 4 +- .../tools/transformers/onnx_model_bert.py | 26 +- .../tools/transformers/onnx_model_unet.py | 10 +- .../tools/transformers/onnx_model_vae.py | 1 + .../python/tools/transformers/optimizer.py | 13 +- 20 files changed, 447 insertions(+), 76 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_skip_group_norm.py diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 973ef8d304e2e..50c8e4b5e0398 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,7 +51,7 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; +constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; @@ -206,7 +206,7 @@ void LaunchSkipLayerNormKernel( #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ - if constexpr (next_size_value >= 8 * 256) { \ + if constexpr (next_size_value >= 320) { \ if (can_unroll_vec8) { \ constexpr int block_size = next_size_value / 8; \ LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ @@ -239,6 +239,9 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); CASE_NEXT_SIZE(kSizes[7]); + CASE_NEXT_SIZE(kSizes[8]); + CASE_NEXT_SIZE(kSizes[9]); + CASE_NEXT_SIZE(kSizes[10]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL(); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 5899c4fcfc0a0..69f8530dff39a 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -154,6 +154,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "MatMulInteger16": self._infer_MatMulInteger, "MaxPool": self._infer_Pool, "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, "Min": self._infer_symbolic_compute_ops, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 222f5f5e27d98..95e7437493bc8 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -145,7 +145,8 @@ def make_value_info_from_tensor(tensor): # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices -ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]} +# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. +ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} class InitializerTracker: diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index cd7dc7017cf16..c718d2c27e015 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -82,23 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) - if instance_norm_scale is None: - return - instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) - if instance_norm_bias is None: - return - - # Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32]. - if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32): - logger.debug( - "Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape) - ) + if instance_norm_scale is None or len(instance_norm_scale.shape) != 1: return - if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32): - logger.debug( - "Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape) - ) + instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) + if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape: return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -108,10 +96,6 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") - if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements) - return - self.add_initializer( name=group_norm_name + "_gamma", data_type=TensorProto.FLOAT, diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 8c80fcad0ab49..b9b92d2fe8a00 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -61,6 +61,7 @@ def __init__(self, model_type): if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True self.enable_group_norm = True + self.enable_skip_group_norm = True self.enable_bias_splitgelu = True self.enable_packed_qkv = True self.enable_packed_kv = True @@ -116,6 +117,8 @@ def parse(args): options.enable_nhwc_conv = False if args.disable_group_norm: options.enable_group_norm = False + if args.disable_skip_group_norm: + options.enable_skip_group_norm = False if args.disable_bias_splitgelu: options.enable_bias_splitgelu = False if args.disable_packed_qkv: @@ -250,6 +253,14 @@ def add_arguments(parser: ArgumentParser): ) parser.set_defaults(disable_group_norm=False) + parser.add_argument( + "--disable_skip_group_norm", + required=False, + action="store_true", + help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_skip_group_norm=False) + parser.add_argument( "--disable_packed_kv", required=False, diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py new file mode 100644 index 0000000000000..df80acbd97807 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import List + +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionSkipGroupNorm(Fusion): + """ + Fuse Add + GroupNorm into one node: SkipGroupNorm. + """ + + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipGroupNorm", "GroupNorm") + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape(update=True) + + if self.shape_infer_helper is None: + logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.") + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + return transpose_node + + def get_skip_index(self, add, is_channel_last: bool): + """Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast).""" + skip = -1 + broadcast = False + + assert self.shape_infer_helper is not None + shape_a = self.shape_infer_helper.get_edge_shape(add.input[0]) + shape_b = self.shape_infer_helper.get_edge_shape(add.input[1]) + assert shape_a is not None and shape_b is not None + + if len(shape_a) == 4 and len(shape_b) == 4: + if shape_a == shape_b: + skip = 1 + else: + c = 3 if is_channel_last else 1 + h = 1 if is_channel_last else 2 + w = 2 if is_channel_last else 3 + if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]: + if shape_b[h] == 1 and shape_b[w] == 1: + skip = 1 + broadcast = True + elif shape_a[h] == 1 and shape_a[w] == 1: + skip = 0 + broadcast = True + + if skip < 0: + logger.debug( + "skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected", + add.input[0], + add.input[1], + ) + return skip, broadcast + + def has_multiple_consumers(self, output_name, input_name_to_nodes): + """Whether an output has multiple consumers (like graph output or more than one children nodes)""" + return self.model.find_graph_output(output_name) is not None or ( + output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1 + ) + + def remove_if_safe(self, node, input_name_to_nodes): + """Remove a node if it is safe (only one children, and not graph output)""" + if not self.has_multiple_consumers(node.output[0], input_name_to_nodes): + self.nodes_to_remove.extend([node]) + + def is_bias_1d(self, bias_name: str): + """Whether bias is an initializer of one dimension""" + initializer = self.model.get_initializer(bias_name) + if initializer is None: + return False + + bias_weight = NumpyHelper.to_array(initializer) + if bias_weight is None: + logger.debug("Bias weight not found") + return False + + if len(bias_weight.shape) != 1: + logger.debug("Bias weight is not 1D") + return False + return True + + def match_bias_path(self, node, input_name_to_nodes, output_name_to_node): + """ + Match the bias graph pattern from an Transpose node after Reshape node like in below example. + It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape. + """ + # Before Fusion: + # MatMul (bias) + # \ / (shape) + # Add / + # \ / + # (a) Reshape + # \ | + # Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes. + # \ / + # Add + # / \ + # (c) Transpose([0,2,3,1]) + # | + # GroupNorm + # | + # (d) + # + # After Fusion (the nodes below Reshape is handled in the fuse function): + # MatMul (shape) + # \ / + # (a) Reshape + # \ / + # SkipGroupNorm + # / \ + # (d) Transpose([0, 3, 1, 2]) + # \ + # (c) + + add_input_index = [] + bias_nodes = self.model.match_parent_path( + node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index + ) + if bias_nodes is None: + return None + + (reshape, add_bias, matmul) = bias_nodes + bias = bias_nodes[1].input[1 - add_input_index[0]] + if not self.is_bias_1d(bias): + return None + + reshape.input[0] = matmul.output[0] + self.remove_if_safe(add_bias, input_name_to_nodes) + + return bias + + def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node): + """Match whether an output is from a Transpose(perm=[0,3,1,2]) node.""" + parent = output_name_to_node[output_name] if output_name in output_name_to_node else None + if parent is not None and parent.op_type == "Transpose": + permutation = OnnxModel.get_node_attribute(parent, "perm") + if permutation == [0, 3, 1, 2]: + self.remove_if_safe(parent, input_name_to_nodes) + return parent + return None + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # This fusion requires shape information, so skip it if shape is not available. + if self.shape_infer_helper is None: + return + + # Before Fusion: + # (a) (b) + # \ / + # Add + # /\ + # (c) Transpose([0,2,3,1]) + # \ + # GroupNorm + # | + # (d) + # + # After Fusion: + # (a) (b) + # \ / + # Transpose([0,2,3,1]) Transpose([0,2,3,1]) + # \ / + # SkipGroupNorm + # / \ + # / Transpose([0, 3, 1, 2]) + # / \ + # (d) (c) + nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node) + if nodes is None: + return + + (transpose, add) = nodes + if transpose in self.nodes_to_remove or add in self.nodes_to_remove: + return + + if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes): + return + + permutation = OnnxModel.get_node_attribute(transpose, "perm") + if permutation != [0, 2, 3, 1]: + return + + inputs = [] + bias = None + for i in range(2): + matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node) + if matched_transpose: + # When there is an Transpose node before Add (see examples in match_bias_path), we do not need to + # insert another Transpose node. The existing Transpose node will be removed in prune_graph if it + # has only one consumer. + inputs.append(matched_transpose.input[0]) + # See whether it match bias pattern. + if bias is None: + bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node) + else: + # Otherwise, insert a Transpose node before Add. + new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1]) + self.model.add_node(new_transpose, self.this_graph_name) + inputs.append(new_transpose.output[0]) + + skip, broadcast = self.get_skip_index(add, is_channel_last=False) + if skip < 0: + return + + inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]] + if bias: + inputs = [*inputs, bias] + + outputs = node.output + + new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm") + if self.has_multiple_consumers(add.output[0], input_name_to_nodes): + add_out_name = new_node_name + "_add_out" + outputs.append(add_out_name) + + # Insert a Transpose node after add output. + add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0]) + self.model.add_node(add_out_transpose, self.this_graph_name) + + skip_group_norm = helper.make_node( + self.fused_op_type, + inputs=inputs, + outputs=outputs, + name=new_node_name, + ) + skip_group_norm.domain = "com.microsoft" + + self.increase_counter( + f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})" + ) + + # Pass attributes from GroupNorm node to SkipGroupNorm + for att in node.attribute: + skip_group_norm.attribute.extend([att]) + + self.nodes_to_remove.extend([add, transpose, node]) + self.nodes_to_add.append(skip_group_norm) + self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index d6de5c45a5210..fb051ac1ed3b4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -61,6 +61,9 @@ _, shared_device_memory = cudart.cudaMalloc(max_device_memory) pipeline.backend.activate_engines(shared_device_memory) + if engine_type == EngineType.ORT_CUDA and args.enable_vae_slicing: + pipeline.backend.enable_vae_slicing() + pipeline.load_resources(image_height, image_width, batch_size) def run_inference(warmup=False): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index efc87a207d130..16e776a08282c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -70,6 +70,14 @@ def run_demo(): base.backend.activate_engines(shared_device_memory) refiner.backend.activate_engines(shared_device_memory) + if engine_type == EngineType.ORT_CUDA: + enable_vae_slicing = args.enable_vae_slicing + if batch_size > 4 and not enable_vae_slicing: + print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") + enable_vae_slicing = True + if enable_vae_slicing: + refiner.backend.enable_vae_slicing() + base.load_resources(image_height, image_width, batch_size) refiner.load_resources(image_height, image_width, batch_size) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 3996c8c325be3..e65efd2c53839 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -68,7 +68,7 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "EulerA", "UniPC"], + choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], help="Scheduler for diffusion process", ) @@ -145,6 +145,9 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + group = parser.add_argument_group("Options for ORT_CUDA engine only") + group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") + # TensorRT only options group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index dc777e26938e4..4a2e9eb3443da 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -303,7 +303,15 @@ def fp32_input_output_names(self) -> List[str]: """ return [] - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): optimizer = self.get_ort_optimizer() optimizer.optimize( input_onnx_path, @@ -312,6 +320,7 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ keep_io_types=self.fp32_input_output_names(), fp32_op_list=fp32_op_list, optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): @@ -471,7 +480,15 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, onnx_model.add_node(cast_node) onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): optimizer = self.get_ort_optimizer() if not self.output_hidden_state: @@ -483,8 +500,9 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ fp32_op_list=fp32_op_list, keep_outputs=["text_embeddings"], optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) - else: + elif optimize_by_fusion: with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to add hidden_states to graph output ...") @@ -500,7 +518,19 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ fp32_op_list=fp32_op_list, keep_outputs=["text_embeddings", "hidden_states"], optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) + else: # input is optimized model, there is no need to add hidden states. + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings", "hidden_states"], + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index 13c450a517eba..ec3041e134e75 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -695,6 +695,7 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, + idx, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 029125c639c09..dfdfa007d74eb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -73,6 +73,10 @@ def __init__( self.models = {} self.engines = {} self.torch_models = {} + self.use_vae_slicing = False + + def enable_vae_slicing(self): + self.use_vae_slicing = True def teardown(self): for engine in self.engines.values(): @@ -84,9 +88,9 @@ def get_cached_model_name(self, model_name): model_name += "_inpaint" return model_name - def get_onnx_path(self, model_name, onnx_dir, opt=True): + def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""): engine_name = self.engine_type.name.lower() - directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + suffix onnx_model_dir = os.path.join(onnx_dir, directory_name) os.makedirs(onnx_model_dir, exist_ok=True) return os.path.join(onnx_model_dir, "model.onnx") @@ -160,11 +164,12 @@ def load_resources(self, image_height, image_width, batch_size): for model_name, obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue + slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size self.engines[model_name].allocate_buffers( - shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device ) - def vae_decode(self, latents): + def _vae_decode(self, latents): if self.vae_torch_fallback: if not self.custom_fp16_vae: latents = latents.to(dtype=torch.float32) @@ -175,6 +180,14 @@ def vae_decode(self, latents): return images + def vae_decode(self, latents): + if self.use_vae_slicing: + # The output tensor points to same buffer. Need clone it to avoid overwritten. + decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)] + return torch.cat(decoded_slices) + + return self._vae_decode(latents) + def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): root_dir = work_dir or "." diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 11a39b0decad6..07c675b2ed990 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -144,7 +144,7 @@ def configure_xl(self, onnx_opset_version: int): self._configure( "unetxl", onnx_opset_version=onnx_opset_version, - use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph + use_cuda_graph=self.use_cuda_graph, ) self._configure( @@ -164,6 +164,7 @@ def build_engines( opt_batch_size: int = 1, force_engine_rebuild: bool = False, device_id: int = 0, + save_fp32_intermediate_model=False, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) @@ -195,7 +196,9 @@ def build_engines( continue onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): print("----") @@ -225,17 +228,41 @@ def build_engines( else: logger.info("Found cached model: %s", onnx_path) - # Run graph optimization and convert to mixed precision (computation in FP16) + # Generate fp32 optimized model. + # If final target is fp16 model, we save fp32 optimized model so that it is easy to tune + # fp16 conversion. That could save a lot of time in developing. + use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16 + if use_fp32_intermediate: + if not os.path.exists(onnx_fp32_path): + print("------") + logger.info("Generating optimized model: %s", onnx_fp32_path) + + # There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue. + model_obj.optimize_ort( + onnx_path, + onnx_fp32_path, + to_fp16=False, + fp32_op_list=self.model_config[model_name].force_fp32_ops, + optimize_by_ort=self.model_config[model_name].optimize_by_ort, + ) + else: + logger.info("Found cached optimized model: %s", onnx_fp32_path) + + # Generate the final optimized model. if not os.path.exists(onnx_opt_path): print("------") logger.info("Generating optimized model: %s", onnx_opt_path) + # When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16. + optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort + model_obj.optimize_ort( - onnx_path, + onnx_fp32_path if use_fp32_intermediate else onnx_path, onnx_opt_path, to_fp16=self.model_config[model_name].fp16, fp32_op_list=self.model_config[model_name].force_fp32_ops, - optimize_by_ort=self.model_config[model_name].optimize_by_ort, + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=not use_fp32_intermediate, ) else: logger.info("Found cached optimized model: %s", onnx_opt_path) @@ -245,7 +272,9 @@ def build_engines( if model_name == "vae" and self.vae_torch_fallback: continue - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path use_cuda_graph = self.model_config[model_name].use_cuda_graph diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 2078f8d1a497c..4b48396b6c783 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -60,34 +60,37 @@ def optimize( fp32_op_list=None, keep_outputs=None, optimize_by_ort=True, + optimize_by_fusion=True, + final_target_float16=True, ): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") - fusion_options = FusionOptions(self.model_type) - if self.model_type in ["unet"] and not float16: - fusion_options.enable_packed_kv = False - fusion_options.enable_packed_qkv = False - - m = optimize_model( - input_fp32_onnx_path, - model_type=self.model_type, - num_heads=0, # will be deduced from graph - hidden_size=0, # will be deduced from graph - opt_level=0, - optimization_options=fusion_options, - use_gpu=True, - ) + + if optimize_by_fusion: + fusion_options = FusionOptions(self.model_type) + + # It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step. + # For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel. + if self.model_type in ["unet"] and not final_target_float16: + fusion_options.enable_packed_kv = False + fusion_options.enable_packed_qkv = False + + m = optimize_model( + input_fp32_onnx_path, + model_type=self.model_type, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + else: + model = onnx.load_model(input_fp32_onnx_path, load_external_data=True) + m = self.model_type_class_mapping[self.model_type](model) if keep_outputs: m.prune_graph(outputs=keep_outputs) - if float16: - logger.info("Convert to float16 ...") - m.convert_float_to_float16( - keep_io_types=keep_io_types, - op_block_list=fp32_op_list, - ) - use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF # Note that ORT < 1.16 could not save model larger than 2GB. @@ -100,6 +103,13 @@ def optimize( if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format): m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) + if float16: + logger.info("Convert to float16 ...") + m.convert_float_to_float16( + keep_io_types=keep_io_types, + op_block_list=fp32_op_list, + ) + m.get_operator_statistics() m.get_fused_operator_statistics() m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index 444b6d9a8ca14..b9759b44e7635 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -79,7 +79,7 @@ def _infer( latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) # VAE decode latent - images = self.decode_latent(latents) + images = self.decode_latent(latents / self.vae_scaling_factor) torch.cuda.synchronize() e2e_toc = time.perf_counter() diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 392f2f948968e..5fda3e6d84c1b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1126,7 +1126,9 @@ def get_operator_statistics(self, include_domain=False): op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type op_count[op] = 1 if op not in op_count else (op_count[op] + 1) - logger.info(f"Operators:{op_count}") + # Sorted by count in the descending order, then by key in alphabetical order. + logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}") + return op_count @staticmethod diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 7a69922e67072..882100a0d019e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -488,16 +488,22 @@ def get_fused_operator_statistics(self): logger.info(f"Optimized operators: {op_count}") return op_count - def is_fully_optimized(self): + def is_fully_optimized(self, fused_op_count=None): """ Returns True when the model is fully optimized. """ - op_count = self.get_fused_operator_statistics() - embed = op_count["EmbedLayerNormalization"] - attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] - gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] - layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] - simple_layer_norm = op_count["SimplifiedLayerNormalization"] + op_count["SkipSimplifiedLayerNormalization"] + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + embed = op_count("EmbedLayerNormalization") + attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention") + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization") + is_perfect = ( (embed > 0) and (attention > 0) @@ -512,13 +518,13 @@ def is_fully_optimized(self): logger.debug("Simple Layer Normalization not fused") if gelu == 0: - logger.debug("Gelu/FastGelu not fused") + logger.debug("Gelu (or FastGelu) not fused") if embed == 0: - logger.debug("Embed Layer not fused") + logger.debug("EmbedLayerNormalization not fused") if attention == 0: - logger.warning("Attention not fused") + logger.warning("Attention (or MultiHeadAttention) not fused") return is_perfect diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 294641dd1e067..4d15b9288e7b6 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -12,6 +12,7 @@ from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions +from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose from onnx import ModelProto from onnx_model import OnnxModel @@ -57,8 +58,8 @@ def remove_useless_div(self): logger.info("Removed %d Div nodes", len(nodes_to_remove)) def convert_conv_to_nhwc(self): - # Do not update weight here since save external data has a bug - conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False) + # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes. + conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True) conv_to_nhwc_conv.apply() def merge_adjacent_transpose(self): @@ -150,6 +151,10 @@ def optimize(self, options: Optional[FusionOptions] = None): # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if (options is None) or options.enable_skip_group_norm: + skip_group_norm_fusion = FusionSkipGroupNorm(self) + skip_group_norm_fusion.apply() + if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() @@ -181,6 +186,7 @@ def get_fused_operator_statistics(self): "SkipLayerNormalization", "BiasSplitGelu", "GroupNorm", + "SkipGroupNorm", "NhwcConv", "BiasAdd", ] diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py index 9e79014e71027..de8b59074a871 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_vae.py +++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py @@ -32,6 +32,7 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "GroupNorm", + "SkipGroupNorm", "NhwcConv", ] for op in ops: diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index f47bbaffaac51..b2d6423a45d21 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -510,11 +510,14 @@ def main(): if args.input_int32: optimizer.change_graph_inputs_to_int32() - if args.model_type in set(MODEL_TYPES.keys()): - if optimizer.is_fully_optimized(): - logger.info("The model has been fully optimized.") - else: - logger.info("The model has been optimized.") + # Print the operator statistics might help end user. + optimizer.get_operator_statistics() + + fused_op_count = optimizer.get_fused_operator_statistics() + if "bert" in args.model_type and optimizer.is_fully_optimized(fused_op_count): + logger.info("The model has been fully optimized.") + else: + logger.info("The model has been optimized.") if args.convert_to_packing_mode: if args.model_type == "bert":