From 1431215dcfe95911c42938a09e18caaf92ff7697 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Sep 2024 21:32:16 -0700 Subject: [PATCH] Add fusion script for segment anything v2 (#22167) ### Description * Add MultiHeadAttention fusion for SAM2. * Add LayerNormalization fusion for NCHW format by inserting Transpose from NCHW to NHWC before layer normalization, and add another Transpose after layer norm to convert NHWC back to NCHW. Hopefully, those extra Transpose nodes will be removed when prefer_nhwc is enabled later. * Add a condition that the input shall be 3D when fuse SkipLayerNorm. * Update convert_to_onnx.py to add `--optimize` and `--use_gpu` options to output optimized onnx model for CPU/CUDA eps. * Add an option `--dtype fp16|fp32` in convert_to_onnx.py to support converting optimized model to float16. * Update the demo to use the optimized onnx models. ### Motivation and Context To support optimization of SAM2 for CPU/CUDA eps that is exported in https://github.com/microsoft/onnxruntime/pull/22119 --- .../transformers/fusion_attention_sam2.py | 534 ++++++++++++++++++ .../tools/transformers/fusion_layernorm.py | 160 +++++- .../transformers/fusion_skiplayernorm.py | 7 + .../tools/transformers/models/sam2/README.md | 27 + .../models/sam2/convert_to_onnx.py | 136 +++-- .../transformers/models/sam2/sam2_demo.py | 79 ++- .../models/sam2/sam2_image_onnx_predictor.py | 53 +- .../transformers/models/sam2/sam2_utils.py | 49 +- .../tools/transformers/onnx_model_sam2.py | 138 +++++ .../python/tools/transformers/optimizer.py | 8 +- 10 files changed, 1085 insertions(+), 106 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_attention_sam2.py create mode 100644 onnxruntime/python/tools/transformers/onnx_model_sam2.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention_sam2.py b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py new file mode 100644 index 0000000000000..ce7ddd3c1050e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py @@ -0,0 +1,534 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple, Union + +import numpy as np +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import NodeProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionMultiHeadAttentionSam2(Fusion): + """ + Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2). + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + super().__init__(model, "MultiHeadAttention", ["LayerNormalization"]) + self.hidden_size = hidden_size + self.num_heads = num_heads + + # Flags to show warning only once + self.num_heads_warning = True + self.hidden_size_warning = True + + def get_decoder_num_heads(self, reshape_q: NodeProto) -> int: + """Detect num_heads from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + Returns: + int: num_heads, or 0 if not found + """ + num_heads = 0 + + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + shape_value = self.model.get_constant_value(reshape_q.input[1]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]: + num_heads = int(shape_value[2]) + + if isinstance(num_heads, int) and num_heads > 0: + return num_heads + + return 0 + + def get_encoder_num_heads(self, reshape_in: NodeProto) -> int: + """Detect num_heads from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + Returns: + int: num_heads, or 0 if not found + """ + num_heads = 0 + + shape_value = self.model.get_constant_value(reshape_in.input[1]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]: + num_heads = int(shape_value[3]) + else: + concat_shape = self.model.match_parent(reshape_in, "Concat", 1) + if concat_shape is not None and len(concat_shape.input) == 5: + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + shape_value = self.model.get_constant_value(concat_shape.input[3]) + if shape_value is not None: + if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]: + num_heads = int(shape_value[0]) + + if isinstance(num_heads, int) and num_heads > 0: + return num_heads + + return 0 + + def get_hidden_size(self, layernorm_node): + """Detect hidden_size from LayerNormalization node. + Args: + layernorm_node (NodeProto): LayerNormalization node before Q, K and V + Returns: + int: hidden_size, or 0 if not found + """ + layernorm_bias = self.model.get_initializer(layernorm_node.input[2]) + if layernorm_bias: + return NumpyHelper.to_array(layernorm_bias).shape[0] + + return 0 + + def get_num_heads_and_hidden_size( + self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False + ) -> Tuple[int, int]: + """Detect num_heads and hidden_size. + + Args: + reshape_q (NodeProto): reshape node for Q + layernorm_node (NodeProto): LayerNormalization node before Q, K, V + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + if is_encoder: + num_heads = self.get_encoder_num_heads(reshape_q) + else: + num_heads = self.get_decoder_num_heads(reshape_q) + if num_heads <= 0: + num_heads = self.num_heads # Fall back to user specified value + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + hidden_size = self.get_hidden_size(layernorm_node) + if hidden_size <= 0: + hidden_size = self.hidden_size # Fall back to user specified value + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def create_attention_node( + self, + q_matmul: NodeProto, + q_add: NodeProto, + k_matmul: NodeProto, + k_add: NodeProto, + v_matmul: NodeProto, + v_add: NodeProto, + num_heads: int, + hidden_size: int, + output: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. + + Args: + q_matmul (NodeProto): MatMul node in fully connection for Q + q_add (NodeProto): Add bias node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + k_add (NodeProto): Add bias node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + v_add (NodeProto): Add bias node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + output (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + if not (q_weight and k_weight and v_weight): + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}") + + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + attention_inputs = [ + q_add.output[0], + k_add.output[0], + v_add.output[0], + ] + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + counter_name = "MultiHeadAttention ({})".format("cross attention") + self.increase_counter(counter_name) + return attention_node + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node): + return + + match_qkv = self.match_attention_subgraph(normalize_node) + if match_qkv is None: + if normalize_node.input[0] not in output_name_to_node: + return + + skip_add = output_name_to_node[normalize_node.input[0]] + if skip_add.op_type != "Add": + return + + match_qkv = self.match_attention_subgraph(skip_add) + + if match_qkv is None: + return + + reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv + + attention_last_node = reshape_qkv + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + add_q, + matmul_k, + add_k, + matmul_v, + add_v, + q_num_heads, + q_hidden_size, + output=attention_last_node.output[0], + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + + def match_attention_subgraph(self, node_after_output_projection): + """Match Q, K and V paths exported by PyTorch 2.*""" + qkv_nodes = self.model.match_parent_path( + node_after_output_projection, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [None, None, None, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, add_v, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path( + matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None] + ) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_mul_k, _, _, add_k, matmul_k) = k_nodes + + # The scalar for Q and K is sqrt(1.0/sqrt(head_size)). + mul_q_nodes = self.model.match_parent_path( + mul_q, + ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"], + [None, 0, 1, 0, 0, 0, 0, 0], + ) + if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q: + logger.debug("fuse_attention: failed to match mul_q path") + return None + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v + + # -------------------------------------------------------- + # The following are for SAM encoder + # -------------------------------------------------------- + def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool: + # SAM encoder attention layer pattern: + # Add -----------+ + # | | + # LayerNorm | + # | | + # Reshape | + # | | + # Transpose | + # | | + # MatMul | + # | | + # Add | + # | | + # Reshape | + # | | + # Split | + # | | + # Self Attention subgraph | + # | | + # Reshape | + # | | + # Transpose | + # | | + # Reshape | + # | | + # Add ----------+ + # | + # LayerNorm (starts from here) + + nodes = self.model.match_parent_path( + normalize_node, + ["Add", "Reshape", "Transpose", "Reshape"], + [0, None, 0, 0], + ) + if nodes is None: + nodes = self.model.match_parent_path( + normalize_node, + ["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"], + [0, None, 0, 0, 0, 0], + ) + if nodes is None: + nodes = self.model.match_parent_path( + normalize_node, + ["Add"], + [0], + ) + if nodes is None: + return False + + node_after_output_projection = nodes[-1] + matched_sdpa = self.match_sam_encoder_attention_subgraph( + node_after_output_projection, input_index=1 if len(nodes) == 1 else None + ) + if matched_sdpa is None: + return False + + reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa + + # B, S, N, H => B, N, S, H + permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm") + if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]: + return False + + # B, S, N, H => B, N, H, S + permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm") + if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]: + return False + + # B, S, N, H => B, N, S, H + permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm") + if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]: + return False + + input_projection_nodes = self.model.match_parent_path( + split_qkv, + ["Reshape", "Add", "MatMul"], + [0, 0, None], + ) + if input_projection_nodes is None: + return False + reshape_in, add_in, matmul_in = input_projection_nodes + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator. + new_dims_name = "bsnh_to_bsd_reshape_dims" + new_dims = self.model.get_initializer(new_dims_name) + if new_dims is None: + new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name) + self.model.add_initializer(new_dims, self.this_graph_name) + reshape_q_name = self.model.create_node_name("Reshape") + reshape_q = helper.make_node( + "Reshape", + inputs=[transpose_q.input[0], new_dims_name], + outputs=[transpose_q.input[0] + "_BSD"], + name=reshape_q_name, + ) + self.nodes_to_add.append(reshape_q) + self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name + + # Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node. + transpose_k_bnsh = transpose_q + transpose_k_bnsh.input[0] = transpose_k.input[0] + transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH" + + logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}") + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_mha_node( + reshape_q, + transpose_k_bnsh, + transpose_v, + q_num_heads, + ) + if new_node is None: + return False + + # Update the input of the next node that consumes the output of the MHA. + assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1 + reshape_out.input[0] = new_node.output[0] + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.nodes_to_remove.extend([transpose_out]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None): + """Match SDPA pattern in SAM2 enconder.*""" + + # nodes of output projection and the second MatMul in SDPA. + out_nodes = self.model.match_parent_path( + node_after_output_projection, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [input_index, None, None, 0, 0], + ) + + if out_nodes is None: + return None + + (_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes + + # Split and Reshape is for packed QKV + v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("failed to match v path") + return None + (transpose_v, _, split_qkv, reshape_qkv) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0]) + if q_nodes is None: + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"], + [0, None, 0, 0, 0, 0, 0, 0, 0], + ) + if q_nodes is None: + logger.debug("failed to match q path") + return None + + if q_nodes[-1] != split_qkv: + return None + transpose_q = q_nodes[1] + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0]) + if k_nodes is None: + logger.debug("failed to match k path") + return None + + if k_nodes[-1] != split_qkv: + return None + (mul_k, transpose_k, _squeeze_k, _) = k_nodes + + return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v + + def create_mha_node( + self, + reshape_q: NodeProto, + transpose_k: NodeProto, + transpose_v: NodeProto, + num_heads: int, + ) -> NodeProto: + """Create a MultiHeadAttention node for SAM2 encoder. + + Args: + reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format + transpose_k (NodeProto): Transpose node for K, output is BNSH format + transpose_v (NodeProto): Transpose node for V, output is BNSH format + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + + Returns: + NodeProto: the MultiHeadAttention node created. + """ + + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + inputs = [ + reshape_q.output[0], + transpose_k.output[0], + transpose_v.output[0], + ] + + # Create a new output name since the shape is 3D, which is different from the original output shape (4D). + output = attention_node_name + "_out" + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + counter_name = "MultiHeadAttention ({})".format("self attention") + self.increase_counter(counter_name) + return attention_node diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 678d8c42bad67..aac05a7f01325 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict +from typing import Dict, List from fusion_base import Fusion -from onnx import helper +from onnx import TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -143,6 +143,162 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name +class FusionLayerNormalizationNCHW(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "LayerNormalization", "ReduceMean") + + def get_weight_or_bias(self, output_name, description): + value = self.model.get_constant_value(output_name) + if value is None: + logger.debug(f"{description} {output_name} is not initializer.") + return None + + if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1: + logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}") + return None + + return value.reshape([value.shape[0]]) + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + + return transpose_node + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse Layer Normalization subgraph into one node LayerNormalization: + +----------------------+ + | NxCxHxW | + | v (Cx1x1) (Cx1x1) + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add --> + (axes=1) | (Y=2) (axes=1) (E-6) ^ + | | + +-----------------------------------------------+ + + Fused subgraph: + (0,2,3,1) (0,3,1,2) + [Root] --> Transpose --> LayerNormalization --> Transpose --> + """ + axes = OnnxModel.get_node_attribute(node, "axes") + if (not isinstance(axes, list)) or axes != [1]: + return + + subgraph_nodes = [] + children = self.model.get_children(node, input_name_to_nodes) + if len(children) != 1: + return + + root_input = node.input[0] + + if children[0].op_type != "Sub" or children[0].input[0] != root_input: + return + sub = children[0] + + div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False) + if div_node is None: + return + + parent_nodes = self.model.match_parent_path( + div_node, + ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], + [1, 0, 0, 0, 0], + output_name_to_node, + ) + if parent_nodes is None: + return + + _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes + if sub != sub_node: + return + + i, add_weight = self.model.get_constant_input(second_add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") + return + + axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes") + assert isinstance(axes, list) + if axes != [1]: + return + + if self.model.find_constant_input(pow_node, 2.0) != 1: + return + + temp_node = input_name_to_nodes[div_node.output[0]][0] + mul_node = temp_node + if mul_node.op_type != "Mul": + return + + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + return + + subgraph_nodes.append(node) + subgraph_nodes.extend(parent_nodes) + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + logger.debug("It is not safe to fuse LayerNormalization node. Skip") + return + + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] + weight = self.get_weight_or_bias(weight_input, "layernorm weight") + if weight is None: + return + + bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] + bias = self.get_weight_or_bias(bias_input, "layernorm bias") + if bias is None: + return + + weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) + + bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) + self.model.add_initializer(weight_nhwc, self.this_graph_name) + self.model.add_initializer(bias_nhwc, self.this_graph_name) + + self.nodes_to_remove.extend(subgraph_nodes) + + transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1]) + + layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm") + + transpose_output = self.create_transpose_node( + layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0] + ) + + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"], + outputs=[layernorm_node_name + "_out_nhwc"], + name=layernorm_node_name, + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + + self.nodes_to_add.append(transpose_input) + self.nodes_to_add.append(normalize_node) + self.nodes_to_add.append(transpose_output) + self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name + + counter_name = "LayerNormalization(NHWC)" + self.increase_counter(counter_name) + + class FusionLayerNormalizationTF(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "Add", "TF") diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index a10b61fdc3f08..4728caaaf3289 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -59,6 +59,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if hasattr(self, "shape_infer_helper"): if self.shape_infer_helper is not None: + if ( + self.shape_infer_helper.get_edge_shape(add.input[0]) + and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3 + ): + logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0]) + return + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index 6ae2b35ba248c..83c0c51f09929 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -54,12 +54,39 @@ To see all parameters, run the following command: python3 convert_to_onnx.py -h ``` +## Optimize ONNX + +To optimize the onnx models for CPU with float32 data type: +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp32 +``` + +To optimize the onnx models for GPU with float16 data type: +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu +``` + +Another option is to use optimizer.py like the following: +``` +cd ../.. +python optimizer.py --input models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder.onnx \ + --output models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder_fp16_gpu.onnx \ + --use_gpu --model_type sam2 --float16 +``` +The optimizer.py could be helpful when you have SAM2 onnx models that is exported by other tools. + ## Run Demo + The exported ONNX models can run on a CPU. The demo will output sam2_demo.png. ```bash curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo ``` +It is able to run demo on optimized model as well. For example, +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu --demo +``` + ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index 9b629f5c40802..8ad69dee0a763 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -13,7 +13,7 @@ from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx from sam2_demo import run_demo, show_all_images -from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger +from sam2_utils import load_sam2_model, sam2_onnx_path, setup_logger def parse_arguments(): @@ -93,6 +93,26 @@ def parse_arguments(): help="Run demo with the exported ONNX models.", ) + parser.add_argument( + "--optimize", + required=False, + default=False, + action="store_true", + help="Optimize onnx models", + ) + + parser.add_argument( + "--dtype", required=False, choices=["fp32", "fp16"], default="fp32", help="Data type for inference." + ) + + parser.add_argument( + "--use_gpu", + required=False, + default=False, + action="store_true", + help="Optimize onnx models for GPU", + ) + parser.add_argument( "--verbose", required=False, @@ -105,41 +125,36 @@ def parse_arguments(): return args -def main(): - args = parse_arguments() +def optimize_sam2_model(onnx_model_path, optimized_model_path, use_gpu: bool, float16: bool): + print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...") - checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints") - sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs") - if not os.path.exists(args.sam2_dir): - raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.") + # Import from source directory. + transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + if transformers_dir not in sys.path: + sys.path.insert(0, transformers_dir) + from optimizer import optimize_model - if not os.path.exists(checkpoints_dir): - raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") + optimized_model = optimize_model(onnx_model_path, model_type="sam2", opt_level=1, use_gpu=use_gpu) + if float16: + optimized_model.convert_float_to_float16(keep_io_types=False) + optimized_model.save_model_to_file(optimized_model_path) - if not os.path.exists(sam2_config_dir): - raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") - if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")): - raise FileNotFoundError( - f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory." - ) +def main(): + args = parse_arguments() - if args.sam2_dir not in sys.path: - sys.path.append(args.sam2_dir) + sam2_model = load_sam2_model(args.sam2_dir, args.model_type, device="cpu") pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) - sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu") - for component in args.components: + onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) if component == "image_encoder": - onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) if args.overwrite or not os.path.exists(onnx_model_path): export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) elif component == "mask_decoder": - onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx") if args.overwrite or not os.path.exists(onnx_model_path): export_mask_decoder_onnx( sam2_model, @@ -155,38 +170,87 @@ def main(): not args.disable_dynamic_multimask_via_stability, ) elif component == "prompt_encoder": - onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx") if args.overwrite or not os.path.exists(onnx_model_path): export_prompt_encoder_onnx(sam2_model, onnx_model_path) test_prompt_encoder_onnx(sam2_model, onnx_model_path) - elif component == "image_decoder": - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output) + else: + assert component == "image_decoder" if args.overwrite or not os.path.exists(onnx_model_path): export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) + suffix = "" + convert_to_fp16 = args.dtype == "fp16" + if args.optimize: + suffix = f"_{args.dtype}_" + ("gpu" if args.use_gpu else "cpu") + for component in args.components: + onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) + optimized_model_path = sam2_onnx_path( + args.output_dir, args.model_type, component, args.multimask_output, suffix + ) + optimize_sam2_model(onnx_model_path, optimized_model_path, convert_to_fp16, args.use_gpu) + if args.demo: # Export required ONNX models for demo if not already exported. - onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) - if not os.path.exists(onnx_model_path): - export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) + image_encoder_onnx_path = sam2_onnx_path( + args.output_dir, args.model_type, "image_encoder", args.multimask_output + ) + if not os.path.exists(image_encoder_onnx_path): + export_image_encoder_onnx(sam2_model, image_encoder_onnx_path, args.dynamic_batch_axes, args.verbose) + + image_decoder_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", False) + if not os.path.exists(image_decoder_onnx_path): + export_decoder_onnx(sam2_model, image_decoder_onnx_path, False) + + image_decoder_multi_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", True) + if not os.path.exists(image_decoder_multi_onnx_path): + export_decoder_onnx(sam2_model, image_decoder_multi_onnx_path, True) + + dtype = torch.float32 if args.dtype == "fp32" else torch.float16 + if suffix: + optimized_image_encoder_onnx_path = image_encoder_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_encoder_onnx_path): + optimize_sam2_model( + image_encoder_onnx_path, optimized_image_encoder_onnx_path, convert_to_fp16, args.use_gpu + ) - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True) - if not os.path.exists(onnx_model_path): - export_decoder_onnx(sam2_model, onnx_model_path, True) + optimized_image_decoder_onnx_path = image_decoder_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_decoder_onnx_path): + optimize_sam2_model( + image_decoder_onnx_path, optimized_image_decoder_onnx_path, convert_to_fp16, args.use_gpu + ) - onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False) - if not os.path.exists(onnx_model_path): - export_decoder_onnx(sam2_model, onnx_model_path, False) + optimized_image_decoder_multi_onnx_path = image_decoder_multi_onnx_path.replace(".onnx", f"{suffix}.onnx") + if not os.path.exists(optimized_image_decoder_multi_onnx_path): + optimize_sam2_model( + image_decoder_multi_onnx_path, + optimized_image_decoder_multi_onnx_path, + convert_to_fp16, + args.use_gpu, + ) - ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir) + # Use optimized models to run demo. + image_encoder_onnx_path = optimized_image_encoder_onnx_path + image_decoder_onnx_path = optimized_image_decoder_onnx_path + image_decoder_multi_onnx_path = optimized_image_decoder_multi_onnx_path + + ort_image_files = run_demo( + args.sam2_dir, + args.model_type, + engine="ort", + dtype=dtype, + image_encoder_onnx_path=image_encoder_onnx_path, + image_decoder_onnx_path=image_decoder_onnx_path, + image_decoder_multi_onnx_path=image_decoder_multi_onnx_path, + use_gpu=args.use_gpu, + ) print("demo output files for ONNX Runtime:", ort_image_files) # Get results from torch engine to compare. - torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir) + torch_image_files = run_demo(args.sam2_dir, args.model_type, engine="torch", dtype=dtype, use_gpu=args.use_gpu) print("demo output files for PyTorch:", torch_image_files) - show_all_images(ort_image_files, torch_image_files) + show_all_images(ort_image_files, torch_image_files, suffix) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py index e2cd93ae2157d..9533e2652f8a5 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import os +from typing import Union import matplotlib.image as mpimg import matplotlib.pyplot as plt @@ -12,7 +13,9 @@ from PIL import Image from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor -from sam2_utils import build_sam2_model +from sam2_utils import load_sam2_model + +import onnxruntime def show_mask(mask, ax, random_color=False, borders=True): @@ -88,46 +91,72 @@ def show_masks( def get_predictor( - checkpoint_dir: str, - device: torch.device, + sam2_dir: str, + device: Union[str, torch.device], + dtype: torch.dtype, model_type="sam2_hiera_large", engine="torch", - onnx_directory="sam2_onnx_models", + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + provider: str = "CUDAExecutionProvider", ): - sam2_model = build_sam2_model(checkpoint_dir, model_type, device=device) + sam2_model = load_sam2_model(sam2_dir, model_type, device=device) if engine == "torch": predictor = SAM2ImagePredictor(sam2_model) else: - predictor = SAM2ImageOnnxPredictor(sam2_model, onnx_directory=onnx_directory, model_type=model_type) + predictor = SAM2ImageOnnxPredictor( + sam2_model, + image_encoder_onnx_path=image_encoder_onnx_path, + image_decoder_onnx_path=image_decoder_onnx_path, + image_decoder_multi_onnx_path=image_decoder_multi_onnx_path, + provider=provider, + device=device, + onnx_dtype=dtype, + ) return predictor def run_demo( - checkpoint_dir: str, - model_type="sam2_hiera_large", - engine="torch", - onnx_directory="sam2_onnx_models", - enable_batch=False, + sam2_dir: str, + model_type: str = "sam2_hiera_large", + engine: str = "torch", + dtype: torch.dtype = torch.float32, + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + use_gpu: bool = True, + enable_batch: bool = False, ): - use_gpu = torch.cuda.is_available() - device = torch.device("cuda" if use_gpu else "cpu") - if use_gpu: - if engine == "torch": - # Turn on tfloat32 for Ampere GPUs. - if torch.cuda.get_device_properties(0).major >= 8: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - elif engine == "ort": - import onnxruntime + assert torch.cuda.is_available() + assert "CUDAExecutionProvider" in onnxruntime.get_available_providers() + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" - assert use_gpu == ("CUDAExecutionProvider" in onnxruntime.get_available_providers()) + device = torch.device("cuda" if use_gpu else "cpu") + + if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8: + # Turn on tfloat32 for Ampere GPUs. + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True np.random.seed(3) image = Image.open("truck.jpg") image = np.array(image.convert("RGB")) - predictor = get_predictor(checkpoint_dir, device, model_type, engine, onnx_directory=onnx_directory) + predictor = get_predictor( + sam2_dir, + device, + dtype, + model_type, + engine, + image_encoder_onnx_path, + image_decoder_onnx_path, + image_decoder_multi_onnx_path, + provider=provider, + ) predictor.set_image(image) prefix = f"sam2_demo_{engine}_" @@ -271,7 +300,7 @@ def run_demo( return image_files -def show_all_images(left_images, right_images): +def show_all_images(left_images, right_images, suffix=""): # Show images in two rows since display screen is horizontal in most cases. fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80)) for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)): @@ -289,5 +318,5 @@ def show_all_images(left_images, right_images): axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0]) plt.tight_layout() - plt.savefig("sam2_demo.png", format="png", bbox_inches="tight", dpi=1000) + plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000) plt.show() diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py index 36b87f0ffbd90..363b5daf461a4 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py @@ -11,7 +11,7 @@ from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base from sam2.sam2_image_predictor import SAM2ImagePredictor -from sam2_utils import decoder_shape_dict, encoder_shape_dict, get_decoder_onnx_path, get_image_encoder_onnx_path +from sam2_utils import decoder_shape_dict, encoder_shape_dict from onnxruntime import InferenceSession from onnxruntime.transformers.io_binding_helper import CudaSession @@ -33,12 +33,16 @@ def create_ort_session( providers = [(provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - print(f"Using providers: {providers}") + logger.info("Using providers: %s", providers) return InferenceSession(onnx_path, session_options, providers=providers) def create_session( - onnx_path: str, session_options=None, provider="CUDAExecutionProvider", device="cuda", enable_cuda_graph=False + onnx_path: str, + session_options=None, + provider="CUDAExecutionProvider", + device: Union[str, torch.device] = "cuda", + enable_cuda_graph=False, ) -> CudaSession: ort_session = create_ort_session( onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True @@ -51,8 +55,11 @@ class SAM2ImageOnnxPredictor(SAM2ImagePredictor): def __init__( self, sam_model: SAM2Base, - onnx_directory: str = "sam2_onnx_models", - model_type: str = "sam2_hiera_large", + image_encoder_onnx_path: str = "", + image_decoder_onnx_path: str = "", + image_decoder_multi_onnx_path: str = "", + provider: str = "CUDAExecutionProvider", + device: Union[str, torch.device] = "cuda", onnx_dtype: torch.dtype = torch.float32, mask_threshold=0.0, max_hole_area=0.0, @@ -76,19 +83,11 @@ def __init__( sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area ) - print(self.device) - if torch.cuda.is_available(): - provider = "CUDAExecutionProvider" - device = "cuda" - else: - provider = "CPUExecutionProvider" - device = "cpu" + logger.debug("self.device=%s, device=%s", self.device, device) # This model is exported by image_encoder.py. - onnx_path = get_image_encoder_onnx_path(onnx_directory, model_type) - self.encoder_session = create_session( - onnx_path, + image_encoder_onnx_path, session_options=None, provider=provider, device=device, @@ -97,9 +96,8 @@ def __init__( self.onnx_dtype = onnx_dtype # This model is exported by image_decoder.py. It outputs only one mask. - onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=False) self.decoder_session = create_session( - onnx_path, + image_decoder_onnx_path, session_options=None, provider=provider, device=device, @@ -107,9 +105,8 @@ def __init__( ) # This model is exported by image_decoder.py. It outputs multiple (3) masks. - onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=True) self.decoder_session_multi_out = create_session( - onnx_path, + image_decoder_multi_onnx_path, session_options=None, provider=provider, device=device, @@ -253,20 +250,20 @@ def _predict( image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0) if mask_input is None: - input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float, device=self.device) - has_input_masks = torch.zeros(num_labels, dtype=torch.float, device=self.device) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device) + has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device) else: input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1) - has_input_masks = torch.ones(num_labels, dtype=torch.float, device=self.device) + has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device) feed_dict = { - "image_embeddings": image_embeddings.contiguous().to(dtype=torch.float32).to(self.device), - "image_features_0": image_features_0.contiguous().to(dtype=torch.float32).to(self.device), - "image_features_1": image_features_1.contiguous().to(dtype=torch.float32).to(self.device), - "point_coords": concat_points[0].to(dtype=torch.float32).to(self.device), + "image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device), + "point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device), "point_labels": concat_points[1].to(dtype=torch.int32).to(self.device), - "input_masks": input_masks.to(dtype=torch.float32).to(self.device), - "has_input_masks": has_input_masks.to(dtype=torch.float32).to(self.device), + "input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device), + "has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device), "original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device), } diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py index cf88eb42213f2..4ec4ccc274291 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py @@ -4,6 +4,8 @@ # -------------------------------------------------------------------------- import logging import os +import sys +from typing import List, Mapping, Union import torch from sam2.build_sam import build_sam2 @@ -12,7 +14,7 @@ logger = logging.getLogger(__name__) -def get_model_cfg(model_type) -> str: +def _get_model_cfg(model_type) -> str: assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"] if model_type == "sam2_hiera_tiny": model_cfg = "sam2_hiera_t.yaml" @@ -25,22 +27,45 @@ def get_model_cfg(model_type) -> str: return model_cfg -def build_sam2_model(checkpoint_dir: str, model_type: str, device="cpu") -> SAM2Base: - sam2_checkpoint = os.path.join(checkpoint_dir, f"{model_type}.pt") - model_cfg = get_model_cfg(model_type) - sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) - return sam2_model +def load_sam2_model(sam2_dir, model_type, device: Union[str, torch.device] = "cpu") -> SAM2Base: + checkpoints_dir = os.path.join(sam2_dir, "checkpoints") + sam2_config_dir = os.path.join(sam2_dir, "sam2_configs") + if not os.path.exists(sam2_dir): + raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(checkpoints_dir): + raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(sam2_config_dir): + raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") + checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.") -def get_decoder_onnx_path(dir: str, model_type, multimask_output) -> str: - return os.path.join(dir, f"{model_type}_decoder" + ("_multi" if multimask_output else "") + ".onnx") + if sam2_dir not in sys.path: + sys.path.append(sam2_dir) + + model_cfg = _get_model_cfg(model_type) + sam2_model = build_sam2(model_cfg, checkpoint_path, device=device) + return sam2_model -def get_image_encoder_onnx_path(dir: str, model_type) -> str: - return os.path.join(dir, f"{model_type}_image_encoder.onnx") +def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""): + if component == "image_encoder": + return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx") + elif component == "mask_decoder": + return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx") + elif component == "prompt_encoder": + return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx") + else: + assert component == "image_decoder" + return os.path.join( + output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx" + ) -def encoder_shape_dict(batch_size: int, height: int, width: int): +def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, List[int]]: assert height == 1024 and width == 1024, "Only 1024x1024 images are supported." return { "image": [batch_size, 3, height, width], @@ -109,7 +134,7 @@ def compare_tensors_with_tolerance( def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor: - image = torch.randn(batch_size, 3, image_height, image_width).cpu() + image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu() return image diff --git a/onnxruntime/python/tools/transformers/onnx_model_sam2.py b/onnxruntime/python/tools/transformers/onnx_model_sam2.py new file mode 100644 index 0000000000000..ac608fb509a81 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_sam2.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +from typing import Optional + +from fusion_attention_sam2 import FusionMultiHeadAttentionSam2 +from fusion_layernorm import FusionLayerNormalizationNCHW +from fusion_options import FusionOptions +from import_utils import is_installed +from onnx import ModelProto +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class Sam2OnnxModel(BertOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + """Initialize SAM2 ONNX Model. + + Args: + model (ModelProto): the ONNX model + num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). + hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). + """ + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def postprocess(self): + self.prune_graph() + self.remove_unused_constant() + + def fuse_layer_norm(self): + super().fuse_layer_norm() + + fusion = FusionLayerNormalizationNCHW(self) + fusion.apply() + + def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + mha_fusion = FusionMultiHeadAttentionSam2(self, self.hidden_size, self.num_heads) + mha_fusion.apply() + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 12 + progress_bar = tqdm.tqdm(range(steps), initial=0, desc="sam2 fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + if (options is not None) and not options.enable_shape_inference: + self.disable_shape_inference() + + self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) + + # Remove cast nodes that having same data type of input and output based on symbolic shape inference. + self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_layer_norm: + self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_gelu: + self.fuse_gelu() + if progress_bar: + progress_bar.update(1) + + self.fuse_reshape() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_attention: + self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_skip_layer_norm: + self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + self.fuse_shape() + if progress_bar: + progress_bar.update(1) + + # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. + self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) + + if (options is None) or options.enable_bias_skip_layer_norm: + # Fuse SkipLayerNormalization and Add Bias before it. + self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) + + if options is not None and options.enable_gelu_approximation: + self.gelu_approximation() + if progress_bar: + progress_bar.update(1) + + self.postprocess() + if progress_bar: + progress_bar.update(1) + + logger.info(f"opset version: {self.get_opset_version()}") + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "MultiHeadAttention", + "LayerNormalization", + "SkipLayerNormalization", + ] + + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 06264b426d1e5..933bd785dc00d 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -36,6 +36,7 @@ from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_phi import PhiOnnxModel +from onnx_model_sam2 import Sam2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -53,17 +54,18 @@ "bert_tf": (BertOnnxModelTF, "tf2onnx", 0), "bert_keras": (BertOnnxModelKeras, "keras2onnx", 0), "clip": (ClipOnnxModel, "pytorch", 1), # Clip in Stable Diffusion + "conformer": (ConformerOnnxModel, "pytorch", 1), "gpt2": (Gpt2OnnxModel, "pytorch", 1), "gpt2_tf": (Gpt2OnnxModel, "tf2onnx", 0), # might add a class for GPT2OnnxModel for TF later. "gpt_neox": (BertOnnxModel, "pytorch", 0), # GPT-NeoX + "phi": (PhiOnnxModel, "pytorch", 0), + "sam2": (Sam2OnnxModel, "pytorch", 1), "swin": (BertOnnxModel, "pytorch", 1), "tnlr": (TnlrOnnxModel, "pytorch", 1), "t5": (T5OnnxModel, "pytorch", 2), "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), - "conformer": (ConformerOnnxModel, "pytorch", 1), - "phi": (PhiOnnxModel, "pytorch", 0), } @@ -235,7 +237,7 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "swin", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "swin", "unet", "vae", "clip", "sam2"] and (num_heads == 0 or hidden_size == 0): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") if model_type not in MODEL_TYPES: