diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index d11cb91d98b0c..f48cabd25fc5c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -129,6 +129,9 @@ def __init__( self.num_heads_warning = True self.hidden_size_warning = True + self.shape_infer = None + self.shape_infer_done = True + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - shape_infer = self.model.infer_runtime_shape(update=True) - if shape_infer is None: + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is None: return None - input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) - input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: logger.debug(f"one of the inputs of {add_qk} is None") diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 250ec5f3eb159..9a353e7e2d675 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -28,10 +28,19 @@ def __init__( enable_packed_qkv: bool, enable_packed_kv: bool, ): - super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + super().__init__( + model, + "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention", + ["LayerNormalization"], + ) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + + # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA. + # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization, + # and CUDA operator pre-packs those tensors to preferred format based on available kernels. + # In this way, we can support LoRA and get optimal performance at same time. self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv @@ -170,9 +179,7 @@ def create_attention_node( return None # Sometimes weights are stored in fp16 - if q_weight.data_type == 10: - logger.debug("weights are in fp16. Please run fp16 conversion after optimization") - return None + float_type = q_weight.data_type qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -212,7 +219,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight, ) @@ -235,8 +242,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[attention_node_name + "_input"], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_qkv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -251,7 +261,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) @@ -280,7 +290,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[kv_weight.shape[0], kv_weight.shape[1]], vals=kv_weight, ) @@ -303,8 +313,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[k_matmul.output[0]], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_kv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -317,7 +330,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_bias_dim], vals=qkv_bias, ) @@ -330,7 +343,7 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [attention_node_name + "_input"] + attention_inputs = [attention_node_name + "_qkv_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -342,7 +355,7 @@ def create_attention_node( else: attention_inputs = [ q_matmul.output[0], - k_matmul.output[0], + attention_node_name + "_kv_input", ] attention_node = helper.make_node( @@ -839,6 +852,9 @@ def create_attention_node_lora( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node): + return + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) # In SD 1.5, for self attention, LayerNorm has parent Reshape @@ -1168,3 +1184,125 @@ def match_lora_path( return (lora_mul_node, lora_matmul_1_node) return None + + def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node): + """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension""" + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0]) + if entry_path is None: + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0]) + if entry_path is None: + return False + _cast, node_before_layernorm = entry_path + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet + skip_add = node + break + if skip_add is None: + return False + + match_qkv = self.match_qkv_a1111(root_input, skip_add) + if match_qkv is None: + return False + + ( + reshape_qkv, + transpose_qkv, + reshape_q, + matmul_q, + matmul_k, + matmul_v, + ) = match_qkv + + cast_q = self.model.match_parent(matmul_q, "Cast", 0) + cast_k = self.model.match_parent(matmul_k, "Cast", 0) + cast_v = self.model.match_parent(matmul_v, "Cast", 0) + if not ( + cast_q is not None + and cast_k is not None + and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k) + and cast_k == cast_v + ): + return False + + if cast_q.input[0] != normalize_node.output[0]: + return False + + attention_last_node = reshape_qkv + + q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + q_hidden_size = self.get_hidden_size(normalize_node) + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=matmul_q.input[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return False + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_qkv_a1111(self, root_input, skip_add): + """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path( + einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None] + ) + if qk_nodes is not None: + (_, _, _softmax_qk, _, einsum_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, matmul_k) = k_nodes + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index bc38399e3cce5..42156d9123383 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): description, ) self.utils = FusionUtils(model) - self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = None + self.shape_infer_done = False + # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] - if self.shape_infer_helper is not None: - input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) - position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + input_ids_shape = self.shape_infer.get_edge_shape(input_ids) + position_ids_shape = self.shape_infer.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape if not ( len(input_ids_shape) == 2 @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit ) return False - if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): + if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( input_ids_shape, - self.shape_infer_helper.get_edge_shape(segment_ids), + self.shape_infer.get_edge_shape(segment_ids), ) ) return False diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index f1d803a3cc082..4d9913f427b37 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 141ebb1f95a11..5233fdf272fbd 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -7,7 +7,8 @@ from typing import List from fusion_base import Fusion -from onnx import TensorProto, helper, numpy_helper +from fusion_utils import FusionUtils +from onnx import helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion): def __init__(self, model: OnnxModel, update_weight=False): super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") self.update_weight = update_weight + self.fusion_utils = FusionUtils(model) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" @@ -49,6 +51,15 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): if len(weight.shape) != 4: return + dtype = self.model.get_dtype(nhwc_conv_input) + if not (dtype is not None and weight_tensor.data_type == dtype): + cast_node = self.fusion_utils.add_cast_node( + input_name=nhwc_conv_input, + to_type=weight_tensor.data_type, + output_name_to_node=output_name_to_node, + ) + nhwc_conv_input = cast_node.output[0] + if self.update_weight: # Transpose weights from NCHW to NHWC weight = weight.transpose(0, 2, 3, 1) @@ -56,7 +67,7 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight_name = node_name + "_weight_NHWC" self.add_initializer( name=weight_name, - data_type=TensorProto.FLOAT, + data_type=weight_tensor.data_type, dims=list(weight.shape), vals=weight, ) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index bc32d78eda66c..dfa77fc7d0221 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i return None def get_dimensions(self, input_name: str) -> Union[int, None]: - graph_input = self.model.find_graph_input(input_name) - if graph_input: - return self.get_dimensions_from_tensor_proto(graph_input) + shape = self.model.get_shape(input_name) + if shape is not None: + return len(shape) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index afc968fab46c1..726c587ff7043 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple +from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -29,17 +29,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input(self, input_name: str, target_type="int32"): - cast_output = input_name + "_" + target_type - - # Avoid consequent Cast nodes. - inputs = [input_name] - output_name_to_node = self.model.output_name_to_node() - if input_name in output_name_to_node: - parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == "Cast": - inputs = [parent_node.input[0]] - - cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) + output_name = input_name + "_" + target_type if target_type == "int32": to_type = int(TensorProto.INT32) @@ -50,10 +40,36 @@ def cast_input(self, input_name: str, target_type="int32"): else: raise ValueError("Invalid target_type: {target_type}") + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) - self.model.add_node(cast_node) + self.model.add_node(cast_node, graph_name=graph_name) - return cast_output, cast_node + return cast_node def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") @@ -224,9 +240,10 @@ def check_node_input_value(self, node, input_index: int, expected_value): def remove_identity_nodes(self): """Remove Identity nodes, except those right before graph output.""" nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() for node in self.model.nodes(): if node.op_type == "Identity": - if node.output[0] not in self.model.get_graphs_output_names(): + if node.output[0] not in graph_output_names: self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/import_utils.py b/onnxruntime/python/tools/transformers/import_utils.py new file mode 100644 index 0000000000000..9755a26b7b004 --- /dev/null +++ b/onnxruntime/python/tools/transformers/import_utils.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib.metadata +import importlib.util + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index b10c10c87ee57..8607485bc265b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -51,7 +51,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root python3 -m pip install --upgrade pip -python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-*.whl --force-reinstall ``` If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity (like 89 for RTX 4090, or 86 for RTX 3090). diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 37b39c91b5c15..9d1066b6e372b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -40,6 +40,12 @@ def initialize(self, model): self.enable_shape_infer: bool = True self.all_graphs: Optional[List[GraphProto]] = None + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + def disable_shape_inference(self): self.enable_shape_infer = False @@ -519,20 +525,60 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, input_or_output: str): - """Try get data type given a name (could be initializer, graph input or output).""" - tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type - if input_or_output in tensor_type_map: - return tensor_type_map[input_or_output].tensor_type.elem_type + if name in self._dtype_dict: + return self._dtype_dict[name] - graph_input = self.find_graph_input(input_or_output) - if graph_input: - return graph_input.type.tensor_type.elem_type + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None - graph_output = self.find_graph_output(input_or_output) - if graph_output: - return graph_output.type.tensor_type.elem_type + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type return None @@ -566,23 +612,14 @@ def remove_cascaded_cast_nodes(self): def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) - if shape_infer is None: - logger.info("Skip removing useless cast nodes since shape inference failed.") - return - - def get_data_type(input_or_output_name): - dtype = self.get_dtype(input_or_output_name) - if dtype: - return dtype - if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"): - return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type - return None + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Cast": - input_dtype = get_data_type(node.input[0]) - output_dtype = get_data_type(node.output[0]) + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) if input_dtype and input_dtype == output_dtype: nodes_to_remove.append(node) @@ -601,7 +638,10 @@ def get_data_type(input_or_output_name): self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) - logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove)) + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( @@ -1214,7 +1254,10 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, ): same[j] = i @@ -1223,7 +1266,8 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): if same[i] >= 0: count += 1 self.replace_input_of_all_nodes( - self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, ) if count > 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 51deb67ce5bf3..431e64509e3cc 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -126,7 +126,8 @@ def fuse_rotary_embeddings(self): # Remove non-MS domain functions rot_emb_nodes = list( filter( - lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", + self.model.graph.node, ) ) non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) @@ -350,7 +351,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): self.attention_fusion = FusionAttention( - self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + self, + self.hidden_size, + self.num_heads, + self.attention_mask, + options.use_multi_head_attention, ) if (options is None) or options.enable_attention: @@ -415,7 +420,12 @@ def get_fused_operator_statistics(self): "SkipSimplifiedLayerNormalization", "RotaryEmbedding", ] - q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] + q_ops = [ + "QOrderedAttention", + "QOrderedGelu", + "QOrderedLayerNormalization", + "QOrderedMatMul", + ] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4d15b9288e7b6..01298b3576eb1 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from logging import getLogger +import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet @@ -14,11 +14,12 @@ from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose +from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -logger = getLogger(__name__) +logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): @@ -94,14 +95,24 @@ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=enable_packed_qkv, + enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() @@ -110,23 +121,48 @@ def fuse_bias_add(self): fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() + if progress_bar: + progress_bar.update(1) self.preprocess() + if progress_bar: + progress_bar.update(1) self.fuse_reshape() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last @@ -135,42 +171,66 @@ def optimize(self, options: Optional[FusionOptions] = None): insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_attention: + # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) self.fuse_shape() + if progress_bar: + progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + if progress_bar: + progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() - self.merge_adjacent_transpose() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() + if progress_bar: + progress_bar.update(1) self.postprocess() + if progress_bar: + progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") @@ -190,6 +250,7 @@ def get_fused_operator_statistics(self): "NhwcConv", "BiasAdd", ] + for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes)