diff --git a/onnxruntime/python/tools/transformers/fusion_attention_sam2.py b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py new file mode 100644 index 0000000000000..ce7ddd3c1050e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py @@ -0,0 +1,534 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple, Union + +import numpy as np +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import NodeProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionMultiHeadAttentionSam2(Fusion): + """ + Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2). + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + super().__init__(model, "MultiHeadAttention", ["LayerNormalization"]) + self.hidden_size = hidden_size + self.num_heads = num_heads + + # Flags to show warning only once + self.num_heads_warning = True + self.hidden_size_warning = True + + def get_decoder_num_heads(self, reshape_q: NodeProto) -> int: + """Detect num_heads from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + Returns: + int: num_heads, or 0 if not found + """ + num_heads = 0 + + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + shape_value = self.model.get_constant_value(reshape_q.input[1]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]: + num_heads = int(shape_value[2]) + + if isinstance(num_heads, int) and num_heads > 0: + return num_heads + + return 0 + + def get_encoder_num_heads(self, reshape_in: NodeProto) -> int: + """Detect num_heads from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + Returns: + int: num_heads, or 0 if not found + """ + num_heads = 0 + + shape_value = self.model.get_constant_value(reshape_in.input[1]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]: + num_heads = int(shape_value[3]) + else: + concat_shape = self.model.match_parent(reshape_in, "Concat", 1) + if concat_shape is not None and len(concat_shape.input) == 5: + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + shape_value = self.model.get_constant_value(concat_shape.input[3]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]: + num_heads = int(shape_value[0]) + + if isinstance(num_heads, int) and num_heads > 0: + return num_heads + + return 0 + + def get_hidden_size(self, layernorm_node): + """Detect hidden_size from LayerNormalization node. + Args: + layernorm_node (NodeProto): LayerNormalization node before Q, K and V + Returns: + int: hidden_size, or 0 if not found + """ + layernorm_bias = self.model.get_initializer(layernorm_node.input[2]) + if layernorm_bias: + return NumpyHelper.to_array(layernorm_bias).shape[0] + + return 0 + + def get_num_heads_and_hidden_size( + self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False + ) -> Tuple[int, int]: + """Detect num_heads and hidden_size. + + Args: + reshape_q (NodeProto): reshape node for Q + layernorm_node (NodeProto): LayerNormalization node before Q, K, V + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + if is_encoder: + num_heads = self.get_encoder_num_heads(reshape_q) + else: + num_heads = self.get_decoder_num_heads(reshape_q) + if num_heads <= 0: + num_heads = self.num_heads # Fall back to user specified value + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + hidden_size = self.get_hidden_size(layernorm_node) + if hidden_size <= 0: + hidden_size = self.hidden_size # Fall back to user specified value + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def create_attention_node( + self, + q_matmul: NodeProto, + q_add: NodeProto, + k_matmul: NodeProto, + k_add: NodeProto, + v_matmul: NodeProto, + v_add: NodeProto, + num_heads: int, + hidden_size: int, + output: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. + + Args: + q_matmul (NodeProto): MatMul node in fully connection for Q + q_add (NodeProto): Add bias node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + k_add (NodeProto): Add bias node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + v_add (NodeProto): Add bias node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + output (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + if not (q_weight and k_weight and v_weight): + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}") + + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + attention_inputs = [ + q_add.output[0], + k_add.output[0], + v_add.output[0], + ] + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + counter_name = "MultiHeadAttention ({})".format("cross attention") + self.increase_counter(counter_name) + return attention_node + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node): + return + + match_qkv = self.match_attention_subgraph(normalize_node) + if match_qkv is None: + if normalize_node.input[0] not in output_name_to_node: + return + + skip_add = output_name_to_node[normalize_node.input[0]] + if skip_add.op_type != "Add": + return + + match_qkv = self.match_attention_subgraph(skip_add) + + if match_qkv is None: + return + + reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv + + attention_last_node = reshape_qkv + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + add_q, + matmul_k, + add_k, + matmul_v, + add_v, + q_num_heads, + q_hidden_size, + output=attention_last_node.output[0], + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + + def match_attention_subgraph(self, node_after_output_projection): + """Match Q, K and V paths exported by PyTorch 2.*""" + qkv_nodes = self.model.match_parent_path( + node_after_output_projection, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [None, None, None, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, add_v, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path( + matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None] + ) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_mul_k, _, _, add_k, matmul_k) = k_nodes + + # The scalar for Q and K is sqrt(1.0/sqrt(head_size)). + mul_q_nodes = self.model.match_parent_path( + mul_q, + ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"], + [None, 0, 1, 0, 0, 0, 0, 0], + ) + if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q: + logger.debug("fuse_attention: failed to match mul_q path") + return None + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v + + # -------------------------------------------------------- + # The following are for SAM encoder + # -------------------------------------------------------- + def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool: + # SAM encoder attention layer pattern: + # Add -----------+ + # | | + # LayerNorm | + # | | + # Reshape | + # | | + # Transpose | + # | | + # MatMul | + # | | + # Add | + # | | + # Reshape | + # | | + # Split | + # | | + # Self Attention subgraph | + # | | + # Reshape | + # | | + # Transpose | + # | | + # Reshape | + # | | + # Add ----------+ + # | + # LayerNorm (starts from here) + + nodes = self.model.match_parent_path( + normalize_node, + ["Add", "Reshape", "Transpose", "Reshape"], + [0, None, 0, 0], + ) + if nodes is None: + nodes = self.model.match_parent_path( + normalize_node, + ["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"], + [0, None, 0, 0, 0, 0], + ) + if nodes is None: + nodes = self.model.match_parent_path( + normalize_node, + ["Add"], + [0], + ) + if nodes is None: + return False + + node_after_output_projection = nodes[-1] + matched_sdpa = self.match_sam_encoder_attention_subgraph( + node_after_output_projection, input_index=1 if len(nodes) == 1 else None + ) + if matched_sdpa is None: + return False + + reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa + + # B, S, N, H => B, N, S, H + permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm") + if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]: + return False + + # B, S, N, H => B, N, H, S + permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm") + if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]: + return False + + # B, S, N, H => B, N, S, H + permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm") + if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]: + return False + + input_projection_nodes = self.model.match_parent_path( + split_qkv, + ["Reshape", "Add", "MatMul"], + [0, 0, None], + ) + if input_projection_nodes is None: + return False + reshape_in, add_in, matmul_in = input_projection_nodes + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. + new_dims_name = "bsnh_to_bsd_reshape_dims" + new_dims = self.model.get_initializer(new_dims_name) + if new_dims is None: + new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) + self.model.add_initializer(new_dims, self.this_graph_name) + reshape_q_name = self.model.create_node_name("Reshape") + reshape_q = helper.make_node( + "Reshape", + inputs=[transpose_q.input[0], new_dims_name], + outputs=[transpose_q.input[0] + "_BSD"], + name=reshape_q_name, + ) + self.nodes_to_add.append(reshape_q) + self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name + + # Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node. + transpose_k_bnsh = transpose_q + transpose_k_bnsh.input[0] = transpose_k.input[0] + transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH" + + logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}") + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_mha_node( + reshape_q, + transpose_k_bnsh, + transpose_v, + q_num_heads, + ) + if new_node is None: + return False + + # Update the input of the next node that consumes the output of the MHA. + assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1 + reshape_out.input[0] = new_node.output[0] + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.nodes_to_remove.extend([transpose_out]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None): + """Match SDPA pattern in SAM2 enconder.*""" + + # nodes of output projection and the second MatMul in SDPA. + out_nodes = self.model.match_parent_path( + node_after_output_projection, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [input_index, None, None, 0, 0], + ) + + if out_nodes is None: + return None + + (_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes + + # Split and Reshape is for packed QKV + v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("failed to match v path") + return None + (transpose_v, _, split_qkv, reshape_qkv) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0]) + if q_nodes is None: + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"], + [0, None, 0, 0, 0, 0, 0, 0, 0], + ) + if q_nodes is None: + logger.debug("failed to match q path") + return None + + if q_nodes[-1] != split_qkv: + return None + transpose_q = q_nodes[1] + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0]) + if k_nodes is None: + logger.debug("failed to match k path") + return None + + if k_nodes[-1] != split_qkv: + return None + (mul_k, transpose_k, _squeeze_k, _) = k_nodes + + return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v + + def create_mha_node( + self, + reshape_q: NodeProto, + transpose_k: NodeProto, + transpose_v: NodeProto, + num_heads: int, + ) -> NodeProto: + """Create a MultiHeadAttention node for SAM2 encoder. + + Args: + reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format + transpose_k (NodeProto): Transpose node for K, output is BNSH format + transpose_v (NodeProto): Transpose node for V, output is BNSH format + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + + Returns: + NodeProto: the MultiHeadAttention node created. + """ + + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + inputs = [ + reshape_q.output[0], + transpose_k.output[0], + transpose_v.output[0], + ] + + # Create a new output name since the shape is 3D, which is different from the original output shape (4D). + output = attention_node_name + "_out" + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + counter_name = "MultiHeadAttention ({})".format("self attention") + self.increase_counter(counter_name) + return attention_node diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 678d8c42bad67..aac05a7f01325 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict +from typing import Dict, List from fusion_base import Fusion -from onnx import helper +from onnx import TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -143,6 +143,162 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name +class FusionLayerNormalizationNCHW(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "LayerNormalization", "ReduceMean") + + def get_weight_or_bias(self, output_name, description): + value = self.model.get_constant_value(output_name) + if value is None: + logger.debug(f"{description} {output_name} is not initializer.") + return None + + if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1: + logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}") + return None + + return value.reshape([value.shape[0]]) + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + + return transpose_node + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse Layer Normalization subgraph into one node LayerNormalization: + +----------------------+ + | NxCxHxW | + | v (Cx1x1) (Cx1x1) + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add --> + (axes=1) | (Y=2) (axes=1) (E-6) ^ + | | + +-----------------------------------------------+ + + Fused subgraph: + (0,2,3,1) (0,3,1,2) + [Root] --> Transpose --> LayerNormalization --> Transpose --> + """ + axes = OnnxModel.get_node_attribute(node, "axes") + if (not isinstance(axes, list)) or axes != [1]: + return + + subgraph_nodes = [] + children = self.model.get_children(node, input_name_to_nodes) + if len(children) != 1: + return + + root_input = node.input[0] + + if children[0].op_type != "Sub" or children[0].input[0] != root_input: + return + sub = children[0] + + div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False) + if div_node is None: + return + + parent_nodes = self.model.match_parent_path( + div_node, + ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], + [1, 0, 0, 0, 0], + output_name_to_node, + ) + if parent_nodes is None: + return + + _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes + if sub != sub_node: + return + + i, add_weight = self.model.get_constant_input(second_add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + return + + axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes") + assert isinstance(axes, list) + if axes != [1]: + return + + if self.model.find_constant_input(pow_node, 2.0) != 1: + return + + temp_node = input_name_to_nodes[div_node.output[0]][0] + mul_node = temp_node + if mul_node.op_type != "Mul": + return + + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + return + + subgraph_nodes.append(node) + subgraph_nodes.extend(parent_nodes) + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + logger.debug("It is not safe to fuse LayerNormalization node. Skip") + return + + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] + weight = self.get_weight_or_bias(weight_input, "layernorm weight") + if weight is None: + return + + bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] + bias = self.get_weight_or_bias(bias_input, "layernorm bias") + if bias is None: + return + + weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) + + bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) + self.model.add_initializer(weight_nhwc, self.this_graph_name) + self.model.add_initializer(bias_nhwc, self.this_graph_name) + + self.nodes_to_remove.extend(subgraph_nodes) + + transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1]) + + layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm") + + transpose_output = self.create_transpose_node( + layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0] + ) + + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"], + outputs=[layernorm_node_name + "_out_nhwc"], + name=layernorm_node_name, + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + + self.nodes_to_add.append(transpose_input) + self.nodes_to_add.append(normalize_node) + self.nodes_to_add.append(transpose_output) + self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name + + counter_name = "LayerNormalization(NHWC)" + self.increase_counter(counter_name) + + class FusionLayerNormalizationTF(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "Add", "TF") diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index a10b61fdc3f08..4728caaaf3289 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -59,6 +59,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if hasattr(self, "shape_infer_helper"): if self.shape_infer_helper is not None: + if ( + self.shape_infer_helper.get_edge_shape(add.input[0]) + and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3 + ): + logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0]) + return + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index 6ae2b35ba248c..83c0c51f09929 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -54,12 +54,39 @@ To see all parameters, run the following command: python3 convert_to_onnx.py -h ``` +## Optimize ONNX + +To optimize the onnx models for CPU with float32 data type: +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp32 +``` + +To optimize the onnx models for GPU with float16 data type: +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu +``` + +Another option is to use optimizer.py like the following: +``` +cd ../.. +python optimizer.py --input models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder.onnx \ + --output models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder_fp16_gpu.onnx \ + --use_gpu --model_type sam2 --float16 +``` +The optimizer.py could be helpful when you have SAM2 onnx models that is exported by other tools. + ## Run Demo + The exported ONNX models can run on a CPU. The demo will output sam2_demo.png. ```bash curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo ``` +It is able to run demo on optimized model as well. For example, +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu --demo +``` + ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index 9b629f5c40802..8ad69dee0a763 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -13,7 +13,7 @@ from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx from sam2_demo import run_demo, show_all_images -from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger +from sam2_utils import load_sam2_model, sam2_onnx_path, setup_logger def parse_arguments(): @@ -93,6 +93,26 @@ def parse_arguments(): help="Run demo with the exported ONNX models.", ) + parser.add_argument( + "--optimize", + required=False, + default=False, + action="store_true", + help="Optimize onnx models", + ) + + parser.add_argument( + "--dtype", required=False, choices=["fp32", "fp16"], default="fp32", help="Data type for inference." + ) + + parser.add_argument( + "--use_gpu", + required=False, + default=False, + action="store_true", + help="Optimize onnx models for GPU", + ) + parser.add_argument( "--verbose", required=False, @@ -105,41 +125,36 @@ def parse_arguments(): return args -def main(): - args = parse_arguments() +def optimize_sam2_model(onnx_model_path, optimized_model_path, use_gpu: bool, float16: bool): + print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...") - checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints") - sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs") - if not os.path.exists(args.sam2_dir): - raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.") + # Import from source directory. + transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + if transformers_dir not in sys.path: + sys.path.insert(0, transformers_dir) + from optimizer import optimize_model - if not os.path.exists(checkpoints_dir): - raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") + optimized_model = optimize_model(onnx_model_path, model_type="sam2", opt_level=1, use_gpu=use_gpu) + if float16: + optimized_model.convert_float_to_float16(keep_io_types=False) + optimized_model.save_model_to_file(optimized_model_path) - if not os.path.exists(sam2_config_dir): - raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") - if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")): - raise FileNotFoundError( - f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory." - ) +def main(): + args = parse_arguments() - if args.sam2_dir not in sys.path: - sys.path.append(args.sam2_dir) + sam2_model = load_sam2_model(args.sam2_dir, args.model_type, device="cpu") pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) - sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu") - for component in args.components: + onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) if component == "image_encoder": - onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) if args.overwrite or not os.path.exists(onnx_model_path): export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) elif component == "mask_decoder": - onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx") if args.overwrite or not os.path.exists(onnx_model_path): export_mask_decoder_onnx( sam2_model, @@ -155,38 +170,87 @@ def main(): not args.disable_dynamic_multimask_via_stability, ) elif component == "prompt_encoder": - onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx") if args.overwrite or not os.path.exists(onnx_model_path): export_prompt_encoder_onnx(sam2_model, onnx_model_path) test_prompt_encoder_onnx(sam2_model, onnx_model_path) - elif component == "image_decoder": - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output) + else: + assert component == "image_decoder" if args.overwrite or not os.path.exists(onnx_model_path): export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) + suffix = "" + convert_to_fp16 = args.dtype == "fp16" + if args.optimize: + suffix = f"_{args.dtype}_" + ("gpu" if args.use_gpu else "cpu") + for component in args.components: + onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) + optimized_model_path = sam2_onnx_path( + args.output_dir, args.model_type, component, args.multimask_output, suffix + ) + optimize_sam2_model(onnx_model_path, optimized_model_path, convert_to_fp16, args.use_gpu) + if args.demo: # Export required ONNX models for demo if not already exported. - onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) - if not os.path.exists(onnx_model_path): - export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) + image_encoder_onnx_path = sam2_onnx_path( + args.output_dir, args.model_type, "image_encoder", args.multimask_output + ) + if not os.path.exists(image_encoder_onnx_path): + export_image_encoder_onnx(sam2_model, image_encoder_onnx_path, args.dynamic_batch_axes, args.verbose) + + image_decoder_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", False) + if not os.path.exists(image_decoder_onnx_path): + export_decoder_onnx(sam2_model, image_decoder_onnx_path, False) + + image_decoder_multi_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", True) + if not os.path.exists(image_decoder_multi_onnx_path): + export_decoder_onnx(sam2_model, image_decoder_multi_onnx_path, True) + + dtype = torch.float32 if args.dtype == "fp32" else torch.float16 + if suffix: + optimized_image_encoder_onnx_path = image_encoder_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_encoder_onnx_path): + optimize_sam2_model( + image_encoder_onnx_path, optimized_image_encoder_onnx_path, convert_to_fp16, args.use_gpu + ) - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True) - if not os.path.exists(onnx_model_path): - export_decoder_onnx(sam2_model, onnx_model_path, True) + optimized_image_decoder_onnx_path = image_decoder_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_decoder_onnx_path): + optimize_sam2_model( + image_decoder_onnx_path, optimized_image_decoder_onnx_path, convert_to_fp16, args.use_gpu + ) - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False) - if not os.path.exists(onnx_model_path): - export_decoder_onnx(sam2_model, onnx_model_path, False) + optimized_image_decoder_multi_onnx_path = image_decoder_multi_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_decoder_multi_onnx_path): + optimize_sam2_model( + image_decoder_multi_onnx_path, + optimized_image_decoder_multi_onnx_path, + convert_to_fp16, + args.use_gpu, + ) - ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir) + # Use optimized models to run demo. + image_encoder_onnx_path = optimized_image_encoder_onnx_path + image_decoder_onnx_path = optimized_image_decoder_onnx_path + image_decoder_multi_onnx_path = optimized_image_decoder_multi_onnx_path + + ort_image_files = run_demo( + args.sam2_dir, + args.model_type, + engine="ort", + dtype=dtype, + image_encoder_onnx_path=image_encoder_onnx_path, + image_decoder_onnx_path=image_decoder_onnx_path, + image_decoder_multi_onnx_path=image_decoder_multi_onnx_path, + use_gpu=args.use_gpu, + ) print("demo output files for ONNX Runtime:", ort_image_files) # Get results from torch engine to compare. - torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir) + torch_image_files = run_demo(args.sam2_dir, args.model_type, engine="torch", dtype=dtype, use_gpu=args.use_gpu) print("demo output files for PyTorch:", torch_image_files) - show_all_images(ort_image_files, torch_image_files) + show_all_images(ort_image_files, torch_image_files, suffix) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py index e2cd93ae2157d..9533e2652f8a5 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import os +from typing import Union import matplotlib.image as mpimg import matplotlib.pyplot as plt @@ -12,7 +13,9 @@ from PIL import Image from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor -from sam2_utils import build_sam2_model +from sam2_utils import load_sam2_model + +import onnxruntime def show_mask(mask, ax, random_color=False, borders=True): @@ -88,46 +91,72 @@ def show_masks( def get_predictor( - checkpoint_dir: str, - device: torch.device, + sam2_dir: str, + device: Union[str, torch.device], + dtype: torch.dtype, model_type="sam2_hiera_large", engine="torch", - onnx_directory="sam2_onnx_models", + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + provider: str = "CUDAExecutionProvider", ): - sam2_model = build_sam2_model(checkpoint_dir, model_type, device=device) + sam2_model = load_sam2_model(sam2_dir, model_type, device=device) if engine == "torch": predictor = SAM2ImagePredictor(sam2_model) else: - predictor = SAM2ImageOnnxPredictor(sam2_model, onnx_directory=onnx_directory, model_type=model_type) + predictor = SAM2ImageOnnxPredictor( + sam2_model, + image_encoder_onnx_path=image_encoder_onnx_path, + image_decoder_onnx_path=image_decoder_onnx_path, + image_decoder_multi_onnx_path=image_decoder_multi_onnx_path, + provider=provider, + device=device, + onnx_dtype=dtype, + ) return predictor def run_demo( - checkpoint_dir: str, - model_type="sam2_hiera_large", - engine="torch", - onnx_directory="sam2_onnx_models", - enable_batch=False, + sam2_dir: str, + model_type: str = "sam2_hiera_large", + engine: str = "torch", + dtype: torch.dtype = torch.float32, + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + use_gpu: bool = True, + enable_batch: bool = False, ): - use_gpu = torch.cuda.is_available() - device = torch.device("cuda" if use_gpu else "cpu") - if use_gpu: - if engine == "torch": - # Turn on tfloat32 for Ampere GPUs. - if torch.cuda.get_device_properties(0).major >= 8: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - elif engine == "ort": - import onnxruntime + assert torch.cuda.is_available() + assert "CUDAExecutionProvider" in onnxruntime.get_available_providers() + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" - assert use_gpu == ("CUDAExecutionProvider" in onnxruntime.get_available_providers()) + device = torch.device("cuda" if use_gpu else "cpu") + + if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8: + # Turn on tfloat32 for Ampere GPUs. + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True np.random.seed(3) image = Image.open("truck.jpg") image = np.array(image.convert("RGB")) - predictor = get_predictor(checkpoint_dir, device, model_type, engine, onnx_directory=onnx_directory) + predictor = get_predictor( + sam2_dir, + device, + dtype, + model_type, + engine, + image_encoder_onnx_path, + image_decoder_onnx_path, + image_decoder_multi_onnx_path, + provider=provider, + ) predictor.set_image(image) prefix = f"sam2_demo_{engine}_" @@ -271,7 +300,7 @@ def run_demo( return image_files -def show_all_images(left_images, right_images): +def show_all_images(left_images, right_images, suffix=""): # Show images in two rows since display screen is horizontal in most cases. fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80)) for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)): @@ -289,5 +318,5 @@ def show_all_images(left_images, right_images): axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0]) plt.tight_layout() - plt.savefig("sam2_demo.png", format="png", bbox_inches="tight", dpi=1000) + plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000) plt.show() diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py index 36b87f0ffbd90..363b5daf461a4 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py @@ -11,7 +11,7 @@ from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base from sam2.sam2_image_predictor import SAM2ImagePredictor -from sam2_utils import decoder_shape_dict, encoder_shape_dict, get_decoder_onnx_path, get_image_encoder_onnx_path +from sam2_utils import decoder_shape_dict, encoder_shape_dict from onnxruntime import InferenceSession from onnxruntime.transformers.io_binding_helper import CudaSession @@ -33,12 +33,16 @@ def create_ort_session( providers = [(provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - print(f"Using providers: {providers}") + logger.info("Using providers: %s", providers) return InferenceSession(onnx_path, session_options, providers=providers) def create_session( - onnx_path: str, session_options=None, provider="CUDAExecutionProvider", device="cuda", enable_cuda_graph=False + onnx_path: str, + session_options=None, + provider="CUDAExecutionProvider", + device: Union[str, torch.device] = "cuda", + enable_cuda_graph=False, ) -> CudaSession: ort_session = create_ort_session( onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True @@ -51,8 +55,11 @@ class SAM2ImageOnnxPredictor(SAM2ImagePredictor): def __init__( self, sam_model: SAM2Base, - onnx_directory: str = "sam2_onnx_models", - model_type: str = "sam2_hiera_large", + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + provider: str = "CUDAExecutionProvider", + device: Union[str, torch.device] = "cuda", onnx_dtype: torch.dtype = torch.float32, mask_threshold=0.0, max_hole_area=0.0, @@ -76,19 +83,11 @@ def __init__( sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area ) - print(self.device) - if torch.cuda.is_available(): - provider = "CUDAExecutionProvider" - device = "cuda" - else: - provider = "CPUExecutionProvider" - device = "cpu" + logger.debug("self.device=%s, device=%s", self.device, device) # This model is exported by image_encoder.py. - onnx_path = get_image_encoder_onnx_path(onnx_directory, model_type) - self.encoder_session = create_session( - onnx_path, + image_encoder_onnx_path, session_options=None, provider=provider, device=device, @@ -97,9 +96,8 @@ def __init__( self.onnx_dtype = onnx_dtype # This model is exported by image_decoder.py. It outputs only one mask. - onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=False) self.decoder_session = create_session( - onnx_path, + image_decoder_onnx_path, session_options=None, provider=provider, device=device, @@ -107,9 +105,8 @@ def __init__( ) # This model is exported by image_decoder.py. It outputs multiple (3) masks. - onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=True) self.decoder_session_multi_out = create_session( - onnx_path, + image_decoder_multi_onnx_path, session_options=None, provider=provider, device=device, @@ -253,20 +250,20 @@ def _predict( image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0) if mask_input is None: - input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float, device=self.device) - has_input_masks = torch.zeros(num_labels, dtype=torch.float, device=self.device) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device) + has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device) else: input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1) - has_input_masks = torch.ones(num_labels, dtype=torch.float, device=self.device) + has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device) feed_dict = { - "image_embeddings": image_embeddings.contiguous().to(dtype=torch.float32).to(self.device), - "image_features_0": image_features_0.contiguous().to(dtype=torch.float32).to(self.device), - "image_features_1": image_features_1.contiguous().to(dtype=torch.float32).to(self.device), - "point_coords": concat_points[0].to(dtype=torch.float32).to(self.device), + "image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device), "point_labels": concat_points[1].to(dtype=torch.int32).to(self.device), - "input_masks": input_masks.to(dtype=torch.float32).to(self.device), - "has_input_masks": has_input_masks.to(dtype=torch.float32).to(self.device), + "input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device), + "has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device), "original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device), } diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py index cf88eb42213f2..4ec4ccc274291 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py @@ -4,6 +4,8 @@ # -------------------------------------------------------------------------- import logging import os +import sys +from typing import List, Mapping, Union import torch from sam2.build_sam import build_sam2 @@ -12,7 +14,7 @@ logger = logging.getLogger(__name__) -def get_model_cfg(model_type) -> str: +def _get_model_cfg(model_type) -> str: assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"] if model_type == "sam2_hiera_tiny": model_cfg = "sam2_hiera_t.yaml" @@ -25,22 +27,45 @@ def get_model_cfg(model_type) -> str: return model_cfg -def build_sam2_model(checkpoint_dir: str, model_type: str, device="cpu") -> SAM2Base: - sam2_checkpoint = os.path.join(checkpoint_dir, f"{model_type}.pt") - model_cfg = get_model_cfg(model_type) - sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) - return sam2_model +def load_sam2_model(sam2_dir, model_type, device: Union[str, torch.device] = "cpu") -> SAM2Base: + checkpoints_dir = os.path.join(sam2_dir, "checkpoints") + sam2_config_dir = os.path.join(sam2_dir, "sam2_configs") + if not os.path.exists(sam2_dir): + raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(checkpoints_dir): + raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(sam2_config_dir): + raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") + checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.") -def get_decoder_onnx_path(dir: str, model_type, multimask_output) -> str: - return os.path.join(dir, f"{model_type}_decoder" + ("_multi" if multimask_output else "") + ".onnx") + if sam2_dir not in sys.path: + sys.path.append(sam2_dir) + + model_cfg = _get_model_cfg(model_type) + sam2_model = build_sam2(model_cfg, checkpoint_path, device=device) + return sam2_model -def get_image_encoder_onnx_path(dir: str, model_type) -> str: - return os.path.join(dir, f"{model_type}_image_encoder.onnx") +def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""): + if component == "image_encoder": + return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx") + elif component == "mask_decoder": + return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx") + elif component == "prompt_encoder": + return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx") + else: + assert component == "image_decoder" + return os.path.join( + output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx" + ) -def encoder_shape_dict(batch_size: int, height: int, width: int): +def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, List[int]]: assert height == 1024 and width == 1024, "Only 1024x1024 images are supported." return { "image": [batch_size, 3, height, width], @@ -109,7 +134,7 @@ def compare_tensors_with_tolerance( def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor: - image = torch.randn(batch_size, 3, image_height, image_width).cpu() + image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu() return image diff --git a/onnxruntime/python/tools/transformers/onnx_model_sam2.py b/onnxruntime/python/tools/transformers/onnx_model_sam2.py new file mode 100644 index 0000000000000..ac608fb509a81 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_sam2.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +from typing import Optional + +from fusion_attention_sam2 import FusionMultiHeadAttentionSam2 +from fusion_layernorm import FusionLayerNormalizationNCHW +from fusion_options import FusionOptions +from import_utils import is_installed +from onnx import ModelProto +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class Sam2OnnxModel(BertOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + """Initialize SAM2 ONNX Model. + + Args: + model (ModelProto): the ONNX model + num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). + hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). + """ + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def postprocess(self): + self.prune_graph() + self.remove_unused_constant() + + def fuse_layer_norm(self): + super().fuse_layer_norm() + + fusion = FusionLayerNormalizationNCHW(self) + fusion.apply() + + def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + mha_fusion = FusionMultiHeadAttentionSam2(self, self.hidden_size, self.num_heads) + mha_fusion.apply() + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 12 + progress_bar = tqdm.tqdm(range(steps), initial=0, desc="sam2 fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + if (options is not None) and not options.enable_shape_inference: + self.disable_shape_inference() + + self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) + + # Remove cast nodes that having same data type of input and output based on symbolic shape inference. + self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_layer_norm: + self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_gelu: + self.fuse_gelu() + if progress_bar: + progress_bar.update(1) + + self.fuse_reshape() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_attention: + self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_skip_layer_norm: + self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + self.fuse_shape() + if progress_bar: + progress_bar.update(1) + + # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. + self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_bias_skip_layer_norm: + # Fuse SkipLayerNormalization and Add Bias before it. + self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + if options is not None and options.enable_gelu_approximation: + self.gelu_approximation() + if progress_bar: + progress_bar.update(1) + + self.postprocess() + if progress_bar: + progress_bar.update(1) + + logger.info(f"opset version: {self.get_opset_version()}") + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "MultiHeadAttention", + "LayerNormalization", + "SkipLayerNormalization", + ] + + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 06264b426d1e5..933bd785dc00d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -36,6 +36,7 @@ from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_phi import PhiOnnxModel +from onnx_model_sam2 import Sam2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -53,17 +54,18 @@ "bert_tf": (BertOnnxModelTF, "tf2onnx", 0), "bert_keras": (BertOnnxModelKeras, "keras2onnx", 0), "clip": (ClipOnnxModel, "pytorch", 1), # Clip in Stable Diffusion + "conformer": (ConformerOnnxModel, "pytorch", 1), "gpt2": (Gpt2OnnxModel, "pytorch", 1), "gpt2_tf": (Gpt2OnnxModel, "tf2onnx", 0), # might add a class for GPT2OnnxModel for TF later. "gpt_neox": (BertOnnxModel, "pytorch", 0), # GPT-NeoX + "phi": (PhiOnnxModel, "pytorch", 0), + "sam2": (Sam2OnnxModel, "pytorch", 1), "swin": (BertOnnxModel, "pytorch", 1), "tnlr": (TnlrOnnxModel, "pytorch", 1), "t5": (T5OnnxModel, "pytorch", 2), "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), - "conformer": (ConformerOnnxModel, "pytorch", 1), - "phi": (PhiOnnxModel, "pytorch", 0), } @@ -235,7 +237,7 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2"] and (num_heads == 0 or hidden_size == 0): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") if model_type not in MODEL_TYPES: