diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 973ef8d304e2e..50c8e4b5e0398 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,7 +51,7 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; +constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; @@ -206,7 +206,7 @@ void LaunchSkipLayerNormKernel( #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ - if constexpr (next_size_value >= 8 * 256) { \ + if constexpr (next_size_value >= 320) { \ if (can_unroll_vec8) { \ constexpr int block_size = next_size_value / 8; \ LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ @@ -239,6 +239,9 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); CASE_NEXT_SIZE(kSizes[7]); + CASE_NEXT_SIZE(kSizes[8]); + CASE_NEXT_SIZE(kSizes[9]); + CASE_NEXT_SIZE(kSizes[10]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL(); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 81ae3eb9e7e5d..a91ff91010e4b 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -154,6 +154,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "MatMulInteger16": self._infer_MatMulInteger, "MaxPool": self._infer_Pool, "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, "Min": self._infer_symbolic_compute_ops, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index f2cdec0fa30f1..f680a15fc2c1b 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -145,7 +145,8 @@ def make_value_info_from_tensor(tensor): # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices -ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]} +# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. +ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} class InitializerTracker: diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index cd7dc7017cf16..c718d2c27e015 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -82,23 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) - if instance_norm_scale is None: - return - instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) - if instance_norm_bias is None: - return - - # Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32]. - if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32): - logger.debug( - "Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape) - ) + if instance_norm_scale is None or len(instance_norm_scale.shape) != 1: return - if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32): - logger.debug( - "Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape) - ) + instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) + if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape: return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -108,10 +96,6 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") - if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements) - return - self.add_initializer( name=group_norm_name + "_gamma", data_type=TensorProto.FLOAT, diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 8c80fcad0ab49..b9b92d2fe8a00 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -61,6 +61,7 @@ def __init__(self, model_type): if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True self.enable_group_norm = True + self.enable_skip_group_norm = True self.enable_bias_splitgelu = True self.enable_packed_qkv = True self.enable_packed_kv = True @@ -116,6 +117,8 @@ def parse(args): options.enable_nhwc_conv = False if args.disable_group_norm: options.enable_group_norm = False + if args.disable_skip_group_norm: + options.enable_skip_group_norm = False if args.disable_bias_splitgelu: options.enable_bias_splitgelu = False if args.disable_packed_qkv: @@ -250,6 +253,14 @@ def add_arguments(parser: ArgumentParser): ) parser.set_defaults(disable_group_norm=False) + parser.add_argument( + "--disable_skip_group_norm", + required=False, + action="store_true", + help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_skip_group_norm=False) + parser.add_argument( "--disable_packed_kv", required=False, diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py new file mode 100644 index 0000000000000..df80acbd97807 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import List + +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionSkipGroupNorm(Fusion): + """ + Fuse Add + GroupNorm into one node: SkipGroupNorm. + """ + + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipGroupNorm", "GroupNorm") + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape(update=True) + + if self.shape_infer_helper is None: + logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.") + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + return transpose_node + + def get_skip_index(self, add, is_channel_last: bool): + """Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast).""" + skip = -1 + broadcast = False + + assert self.shape_infer_helper is not None + shape_a = self.shape_infer_helper.get_edge_shape(add.input[0]) + shape_b = self.shape_infer_helper.get_edge_shape(add.input[1]) + assert shape_a is not None and shape_b is not None + + if len(shape_a) == 4 and len(shape_b) == 4: + if shape_a == shape_b: + skip = 1 + else: + c = 3 if is_channel_last else 1 + h = 1 if is_channel_last else 2 + w = 2 if is_channel_last else 3 + if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]: + if shape_b[h] == 1 and shape_b[w] == 1: + skip = 1 + broadcast = True + elif shape_a[h] == 1 and shape_a[w] == 1: + skip = 0 + broadcast = True + + if skip < 0: + logger.debug( + "skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected", + add.input[0], + add.input[1], + ) + return skip, broadcast + + def has_multiple_consumers(self, output_name, input_name_to_nodes): + """Whether an output has multiple consumers (like graph output or more than one children nodes)""" + return self.model.find_graph_output(output_name) is not None or ( + output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1 + ) + + def remove_if_safe(self, node, input_name_to_nodes): + """Remove a node if it is safe (only one children, and not graph output)""" + if not self.has_multiple_consumers(node.output[0], input_name_to_nodes): + self.nodes_to_remove.extend([node]) + + def is_bias_1d(self, bias_name: str): + """Whether bias is an initializer of one dimension""" + initializer = self.model.get_initializer(bias_name) + if initializer is None: + return False + + bias_weight = NumpyHelper.to_array(initializer) + if bias_weight is None: + logger.debug("Bias weight not found") + return False + + if len(bias_weight.shape) != 1: + logger.debug("Bias weight is not 1D") + return False + return True + + def match_bias_path(self, node, input_name_to_nodes, output_name_to_node): + """ + Match the bias graph pattern from an Transpose node after Reshape node like in below example. + It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape. + """ + # Before Fusion: + # MatMul (bias) + # \ / (shape) + # Add / + # \ / + # (a) Reshape + # \ | + # Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes. + # \ / + # Add + # / \ + # (c) Transpose([0,2,3,1]) + # | + # GroupNorm + # | + # (d) + # + # After Fusion (the nodes below Reshape is handled in the fuse function): + # MatMul (shape) + # \ / + # (a) Reshape + # \ / + # SkipGroupNorm + # / \ + # (d) Transpose([0, 3, 1, 2]) + # \ + # (c) + + add_input_index = [] + bias_nodes = self.model.match_parent_path( + node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index + ) + if bias_nodes is None: + return None + + (reshape, add_bias, matmul) = bias_nodes + bias = bias_nodes[1].input[1 - add_input_index[0]] + if not self.is_bias_1d(bias): + return None + + reshape.input[0] = matmul.output[0] + self.remove_if_safe(add_bias, input_name_to_nodes) + + return bias + + def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node): + """Match whether an output is from a Transpose(perm=[0,3,1,2]) node.""" + parent = output_name_to_node[output_name] if output_name in output_name_to_node else None + if parent is not None and parent.op_type == "Transpose": + permutation = OnnxModel.get_node_attribute(parent, "perm") + if permutation == [0, 3, 1, 2]: + self.remove_if_safe(parent, input_name_to_nodes) + return parent + return None + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # This fusion requires shape information, so skip it if shape is not available. + if self.shape_infer_helper is None: + return + + # Before Fusion: + # (a) (b) + # \ / + # Add + # /\ + # (c) Transpose([0,2,3,1]) + # \ + # GroupNorm + # | + # (d) + # + # After Fusion: + # (a) (b) + # \ / + # Transpose([0,2,3,1]) Transpose([0,2,3,1]) + # \ / + # SkipGroupNorm + # / \ + # / Transpose([0, 3, 1, 2]) + # / \ + # (d) (c) + nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node) + if nodes is None: + return + + (transpose, add) = nodes + if transpose in self.nodes_to_remove or add in self.nodes_to_remove: + return + + if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes): + return + + permutation = OnnxModel.get_node_attribute(transpose, "perm") + if permutation != [0, 2, 3, 1]: + return + + inputs = [] + bias = None + for i in range(2): + matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node) + if matched_transpose: + # When there is an Transpose node before Add (see examples in match_bias_path), we do not need to + # insert another Transpose node. The existing Transpose node will be removed in prune_graph if it + # has only one consumer. + inputs.append(matched_transpose.input[0]) + # See whether it match bias pattern. + if bias is None: + bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node) + else: + # Otherwise, insert a Transpose node before Add. + new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1]) + self.model.add_node(new_transpose, self.this_graph_name) + inputs.append(new_transpose.output[0]) + + skip, broadcast = self.get_skip_index(add, is_channel_last=False) + if skip < 0: + return + + inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]] + if bias: + inputs = [*inputs, bias] + + outputs = node.output + + new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm") + if self.has_multiple_consumers(add.output[0], input_name_to_nodes): + add_out_name = new_node_name + "_add_out" + outputs.append(add_out_name) + + # Insert a Transpose node after add output. + add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0]) + self.model.add_node(add_out_transpose, self.this_graph_name) + + skip_group_norm = helper.make_node( + self.fused_op_type, + inputs=inputs, + outputs=outputs, + name=new_node_name, + ) + skip_group_norm.domain = "com.microsoft" + + self.increase_counter( + f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})" + ) + + # Pass attributes from GroupNorm node to SkipGroupNorm + for att in node.attribute: + skip_group_norm.attribute.extend([att]) + + self.nodes_to_remove.extend([add, transpose, node]) + self.nodes_to_add.append(skip_group_norm) + self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index d6de5c45a5210..fb051ac1ed3b4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -61,6 +61,9 @@ _, shared_device_memory = cudart.cudaMalloc(max_device_memory) pipeline.backend.activate_engines(shared_device_memory) + if engine_type == EngineType.ORT_CUDA and args.enable_vae_slicing: + pipeline.backend.enable_vae_slicing() + pipeline.load_resources(image_height, image_width, batch_size) def run_inference(warmup=False): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index efc87a207d130..16e776a08282c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -70,6 +70,14 @@ def run_demo(): base.backend.activate_engines(shared_device_memory) refiner.backend.activate_engines(shared_device_memory) + if engine_type == EngineType.ORT_CUDA: + enable_vae_slicing = args.enable_vae_slicing + if batch_size > 4 and not enable_vae_slicing: + print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") + enable_vae_slicing = True + if enable_vae_slicing: + refiner.backend.enable_vae_slicing() + base.load_resources(image_height, image_width, batch_size) refiner.load_resources(image_height, image_width, batch_size) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 3996c8c325be3..e65efd2c53839 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -68,7 +68,7 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "EulerA", "UniPC"], + choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], help="Scheduler for diffusion process", ) @@ -145,6 +145,9 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + group = parser.add_argument_group("Options for ORT_CUDA engine only") + group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") + # TensorRT only options group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index dc777e26938e4..4a2e9eb3443da 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -303,7 +303,15 @@ def fp32_input_output_names(self) -> List[str]: """ return [] - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): optimizer = self.get_ort_optimizer() optimizer.optimize( input_onnx_path, @@ -312,6 +320,7 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ keep_io_types=self.fp32_input_output_names(), fp32_op_list=fp32_op_list, optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): @@ -471,7 +480,15 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, onnx_model.add_node(cast_node) onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_op_list=None, optimize_by_ort=True): + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): optimizer = self.get_ort_optimizer() if not self.output_hidden_state: @@ -483,8 +500,9 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ fp32_op_list=fp32_op_list, keep_outputs=["text_embeddings"], optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) - else: + elif optimize_by_fusion: with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to add hidden_states to graph output ...") @@ -500,7 +518,19 @@ def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True, fp32_ fp32_op_list=fp32_op_list, keep_outputs=["text_embeddings", "hidden_states"], optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, ) + else: # input is optimized model, there is no need to add hidden states. + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings", "hidden_states"], + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index 13c450a517eba..ec3041e134e75 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -695,6 +695,7 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, + idx, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 029125c639c09..dfdfa007d74eb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -73,6 +73,10 @@ def __init__( self.models = {} self.engines = {} self.torch_models = {} + self.use_vae_slicing = False + + def enable_vae_slicing(self): + self.use_vae_slicing = True def teardown(self): for engine in self.engines.values(): @@ -84,9 +88,9 @@ def get_cached_model_name(self, model_name): model_name += "_inpaint" return model_name - def get_onnx_path(self, model_name, onnx_dir, opt=True): + def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""): engine_name = self.engine_type.name.lower() - directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + suffix onnx_model_dir = os.path.join(onnx_dir, directory_name) os.makedirs(onnx_model_dir, exist_ok=True) return os.path.join(onnx_model_dir, "model.onnx") @@ -160,11 +164,12 @@ def load_resources(self, image_height, image_width, batch_size): for model_name, obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: continue + slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size self.engines[model_name].allocate_buffers( - shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device ) - def vae_decode(self, latents): + def _vae_decode(self, latents): if self.vae_torch_fallback: if not self.custom_fp16_vae: latents = latents.to(dtype=torch.float32) @@ -175,6 +180,14 @@ def vae_decode(self, latents): return images + def vae_decode(self, latents): + if self.use_vae_slicing: + # The output tensor points to same buffer. Need clone it to avoid overwritten. + decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)] + return torch.cat(decoded_slices) + + return self._vae_decode(latents) + def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): root_dir = work_dir or "." diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 11a39b0decad6..07c675b2ed990 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -144,7 +144,7 @@ def configure_xl(self, onnx_opset_version: int): self._configure( "unetxl", onnx_opset_version=onnx_opset_version, - use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph + use_cuda_graph=self.use_cuda_graph, ) self._configure( @@ -164,6 +164,7 @@ def build_engines( opt_batch_size: int = 1, force_engine_rebuild: bool = False, device_id: int = 0, + save_fp32_intermediate_model=False, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) @@ -195,7 +196,9 @@ def build_engines( continue onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): print("----") @@ -225,17 +228,41 @@ def build_engines( else: logger.info("Found cached model: %s", onnx_path) - # Run graph optimization and convert to mixed precision (computation in FP16) + # Generate fp32 optimized model. + # If final target is fp16 model, we save fp32 optimized model so that it is easy to tune + # fp16 conversion. That could save a lot of time in developing. + use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16 + if use_fp32_intermediate: + if not os.path.exists(onnx_fp32_path): + print("------") + logger.info("Generating optimized model: %s", onnx_fp32_path) + + # There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue. + model_obj.optimize_ort( + onnx_path, + onnx_fp32_path, + to_fp16=False, + fp32_op_list=self.model_config[model_name].force_fp32_ops, + optimize_by_ort=self.model_config[model_name].optimize_by_ort, + ) + else: + logger.info("Found cached optimized model: %s", onnx_fp32_path) + + # Generate the final optimized model. if not os.path.exists(onnx_opt_path): print("------") logger.info("Generating optimized model: %s", onnx_opt_path) + # When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16. + optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort + model_obj.optimize_ort( - onnx_path, + onnx_fp32_path if use_fp32_intermediate else onnx_path, onnx_opt_path, to_fp16=self.model_config[model_name].fp16, fp32_op_list=self.model_config[model_name].force_fp32_ops, - optimize_by_ort=self.model_config[model_name].optimize_by_ort, + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=not use_fp32_intermediate, ) else: logger.info("Found cached optimized model: %s", onnx_opt_path) @@ -245,7 +272,9 @@ def build_engines( if model_name == "vae" and self.vae_torch_fallback: continue - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path use_cuda_graph = self.model_config[model_name].use_cuda_graph diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 2078f8d1a497c..4b48396b6c783 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -60,34 +60,37 @@ def optimize( fp32_op_list=None, keep_outputs=None, optimize_by_ort=True, + optimize_by_fusion=True, + final_target_float16=True, ): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") - fusion_options = FusionOptions(self.model_type) - if self.model_type in ["unet"] and not float16: - fusion_options.enable_packed_kv = False - fusion_options.enable_packed_qkv = False - - m = optimize_model( - input_fp32_onnx_path, - model_type=self.model_type, - num_heads=0, # will be deduced from graph - hidden_size=0, # will be deduced from graph - opt_level=0, - optimization_options=fusion_options, - use_gpu=True, - ) + + if optimize_by_fusion: + fusion_options = FusionOptions(self.model_type) + + # It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step. + # For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel. + if self.model_type in ["unet"] and not final_target_float16: + fusion_options.enable_packed_kv = False + fusion_options.enable_packed_qkv = False + + m = optimize_model( + input_fp32_onnx_path, + model_type=self.model_type, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + else: + model = onnx.load_model(input_fp32_onnx_path, load_external_data=True) + m = self.model_type_class_mapping[self.model_type](model) if keep_outputs: m.prune_graph(outputs=keep_outputs) - if float16: - logger.info("Convert to float16 ...") - m.convert_float_to_float16( - keep_io_types=keep_io_types, - op_block_list=fp32_op_list, - ) - use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF # Note that ORT < 1.16 could not save model larger than 2GB. @@ -100,6 +103,13 @@ def optimize( if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format): m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) + if float16: + logger.info("Convert to float16 ...") + m.convert_float_to_float16( + keep_io_types=keep_io_types, + op_block_list=fp32_op_list, + ) + m.get_operator_statistics() m.get_fused_operator_statistics() m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index 444b6d9a8ca14..b9759b44e7635 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -79,7 +79,7 @@ def _infer( latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) # VAE decode latent - images = self.decode_latent(latents) + images = self.decode_latent(latents / self.vae_scaling_factor) torch.cuda.synchronize() e2e_toc = time.perf_counter() diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 66ec0de88b44c..7bdbc08cf733a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1126,7 +1126,9 @@ def get_operator_statistics(self, include_domain=False): op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type op_count[op] = 1 if op not in op_count else (op_count[op] + 1) - logger.info(f"Operators:{op_count}") + # Sorted by count in the descending order, then by key in alphabetical order. + logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}") + return op_count @staticmethod diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 7a69922e67072..882100a0d019e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -488,16 +488,22 @@ def get_fused_operator_statistics(self): logger.info(f"Optimized operators: {op_count}") return op_count - def is_fully_optimized(self): + def is_fully_optimized(self, fused_op_count=None): """ Returns True when the model is fully optimized. """ - op_count = self.get_fused_operator_statistics() - embed = op_count["EmbedLayerNormalization"] - attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] - gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] - layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] - simple_layer_norm = op_count["SimplifiedLayerNormalization"] + op_count["SkipSimplifiedLayerNormalization"] + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + embed = op_count("EmbedLayerNormalization") + attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention") + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization") + is_perfect = ( (embed > 0) and (attention > 0) @@ -512,13 +518,13 @@ def is_fully_optimized(self): logger.debug("Simple Layer Normalization not fused") if gelu == 0: - logger.debug("Gelu/FastGelu not fused") + logger.debug("Gelu (or FastGelu) not fused") if embed == 0: - logger.debug("Embed Layer not fused") + logger.debug("EmbedLayerNormalization not fused") if attention == 0: - logger.warning("Attention not fused") + logger.warning("Attention (or MultiHeadAttention) not fused") return is_perfect diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 294641dd1e067..4d15b9288e7b6 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -12,6 +12,7 @@ from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions +from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose from onnx import ModelProto from onnx_model import OnnxModel @@ -57,8 +58,8 @@ def remove_useless_div(self): logger.info("Removed %d Div nodes", len(nodes_to_remove)) def convert_conv_to_nhwc(self): - # Do not update weight here since save external data has a bug - conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False) + # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes. + conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True) conv_to_nhwc_conv.apply() def merge_adjacent_transpose(self): @@ -150,6 +151,10 @@ def optimize(self, options: Optional[FusionOptions] = None): # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if (options is None) or options.enable_skip_group_norm: + skip_group_norm_fusion = FusionSkipGroupNorm(self) + skip_group_norm_fusion.apply() + if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() @@ -181,6 +186,7 @@ def get_fused_operator_statistics(self): "SkipLayerNormalization", "BiasSplitGelu", "GroupNorm", + "SkipGroupNorm", "NhwcConv", "BiasAdd", ] diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py index 9e79014e71027..de8b59074a871 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_vae.py +++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py @@ -32,6 +32,7 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "GroupNorm", + "SkipGroupNorm", "NhwcConv", ] for op in ops: diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 00b26c019d4b5..94a757320e598 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -546,11 +546,14 @@ def main(): if args.input_int32: optimizer.change_graph_inputs_to_int32() - if args.model_type in set(MODEL_TYPES.keys()): - if optimizer.is_fully_optimized(): - logger.info("The model has been fully optimized.") - else: - logger.info("The model has been optimized.") + # Print the operator statistics might help end user. + optimizer.get_operator_statistics() + + fused_op_count = optimizer.get_fused_operator_statistics() + if "bert" in args.model_type and optimizer.is_fully_optimized(fused_op_count): + logger.info("The model has been fully optimized.") + else: + logger.info("The model has been optimized.") if args.convert_to_packing_mode: if args.model_type == "bert":