Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unet fusion for stable diffusion webui #19227

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
self.num_heads_warning = True
self.hidden_size_warning = True

self.shape_infer = None
self.shape_infer_done = True

def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
"""
Detect num_heads and hidden_size from Concat node in the following subgraph:
Expand Down Expand Up @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]
return num_heads, hidden_size

def get_add_qk_str(self, add_qk: NodeProto):
shape_infer = self.model.infer_runtime_shape(update=True)
if shape_infer is None:
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is None:
return None

input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])

if input_0_shape is None or input_1_shape is None:
logger.debug(f"one of the inputs of {add_qk} is None")
Expand Down
166 changes: 152 additions & 14 deletions onnxruntime/python/tools/transformers/fusion_attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@ def __init__(
enable_packed_qkv: bool,
enable_packed_kv: bool,
):
super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
super().__init__(
model,
"Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
["LayerNormalization"],
)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.is_cross_attention = is_cross_attention

# Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
# To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
# and CUDA operator pre-packs those tensors to preferred format based on available kernels.
# In this way, we can support LoRA and get optimal performance at same time.
self.enable_packed_qkv = enable_packed_qkv
self.enable_packed_kv = enable_packed_kv

Expand Down Expand Up @@ -170,9 +179,7 @@ def create_attention_node(
return None

# Sometimes weights are stored in fp16
if q_weight.data_type == 10:
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
return None
float_type = q_weight.data_type

qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
Expand Down Expand Up @@ -212,7 +219,7 @@ def create_attention_node(
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
self.add_initializer(
name=matmul_node_name + "_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
vals=qkv_weight,
)
Expand All @@ -235,8 +242,11 @@ def create_attention_node(

reshape_node = helper.make_node(
"Reshape",
inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
outputs=[attention_node_name + "_input"],
inputs=[
matmul_node_name + "_out",
matmul_node_name + "_reshape_shape",
],
outputs=[attention_node_name + "_qkv_input"],
name=matmul_node_name + "_reshape",
)
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
Expand All @@ -251,7 +261,7 @@ def create_attention_node(

self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight,
)
Expand Down Expand Up @@ -280,7 +290,7 @@ def create_attention_node(
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
self.add_initializer(
name=matmul_node_name + "_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[kv_weight.shape[0], kv_weight.shape[1]],
vals=kv_weight,
)
Expand All @@ -303,8 +313,11 @@ def create_attention_node(

reshape_node = helper.make_node(
"Reshape",
inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
outputs=[k_matmul.output[0]],
inputs=[
matmul_node_name + "_out",
matmul_node_name + "_reshape_shape",
],
outputs=[attention_node_name + "_kv_input"],
name=matmul_node_name + "_reshape",
)
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
Expand All @@ -317,7 +330,7 @@ def create_attention_node(

self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qkv_bias_dim],
vals=qkv_bias,
)
Expand All @@ -330,7 +343,7 @@ def create_attention_node(
attention_node_name + "_qkv_bias",
]
else:
attention_inputs = [attention_node_name + "_input"]
attention_inputs = [attention_node_name + "_qkv_input"]
else:
if not self.enable_packed_kv:
attention_inputs = [
Expand All @@ -342,7 +355,7 @@ def create_attention_node(
else:
attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
attention_node_name + "_kv_input",
]

attention_node = helper.make_node(
Expand Down Expand Up @@ -839,6 +852,9 @@ def create_attention_node_lora(
return attention_node

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
return

node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)

# In SD 1.5, for self attention, LayerNorm has parent Reshape
Expand Down Expand Up @@ -1168,3 +1184,125 @@ def match_lora_path(
return (lora_mul_node, lora_matmul_1_node)

return None

def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
Fixed Show fixed Hide fixed
"""Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
if entry_path is None:
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
if entry_path is None:
return False
_cast, node_before_layernorm = entry_path

root_input = node_before_layernorm.output[0]

children_nodes = input_name_to_nodes[root_input]
skip_add = None
for node in children_nodes:
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
skip_add = node
break
if skip_add is None:
return False

match_qkv = self.match_qkv_a1111(root_input, skip_add)
if match_qkv is None:
return False

(
reshape_qkv,
transpose_qkv,
reshape_q,
matmul_q,
matmul_k,
matmul_v,
) = match_qkv

cast_q = self.model.match_parent(matmul_q, "Cast", 0)
cast_k = self.model.match_parent(matmul_k, "Cast", 0)
cast_v = self.model.match_parent(matmul_v, "Cast", 0)
if not (
cast_q is not None
and cast_k is not None
and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
and cast_k == cast_v
):
return False

if cast_q.input[0] != normalize_node.output[0]:
return False

attention_last_node = reshape_qkv

q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
return False

q_hidden_size = self.get_hidden_size(normalize_node)

# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
matmul_k,
matmul_v,
q_num_heads,
q_hidden_size,
input=matmul_q.input[0],
output=attention_last_node.output[0],
)
if new_node is None:
return False

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True
return True

def match_qkv_a1111(self, root_input, skip_add):
"""Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
[another_input, None, None, 0, 0, 0],
)

if qkv_nodes is None:
return None

(_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes

v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return None
(_, _, _, matmul_v) = v_nodes

qk_nodes = self.model.match_parent_path(
einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
)
if qk_nodes is not None:
(_, _, _softmax_qk, _, einsum_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return None

q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return None
(_, _transpose_q, reshape_q, matmul_q) = q_nodes

k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return None

(_, _, _, matmul_k) = k_nodes

return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
18 changes: 12 additions & 6 deletions onnxruntime/python/tools/transformers/fusion_embedlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"):
description,
)
self.utils = FusionUtils(model)
self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = None
self.shape_infer_done = False

# The following will be reset in each fuse call of FusionEmbedLayerNormalization
self.attention = None
self.embed_node = None
Expand Down Expand Up @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit
segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
position_ids = position_embedding_gather.input[1]

if self.shape_infer_helper is not None:
input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids)
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
assert input_ids_shape and position_ids_shape
if not (
len(input_ids_shape) == 2
Expand All @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit
)
return False

if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids):
if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
logger.info(
"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format(
input_ids_shape,
self.shape_infer_helper.get_edge_shape(segment_ids),
self.shape_infer.get_edge_shape(segment_ids),
)
)
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]:
return self.get_dimensions_from_tensor_proto(graph_input)

if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import List

from fusion_base import Fusion
from onnx import TensorProto, helper, numpy_helper
from fusion_utils import FusionUtils
from onnx import helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)
Expand All @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion):
def __init__(self, model: OnnxModel, update_weight=False):
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
self.update_weight = update_weight
self.fusion_utils = FusionUtils(model)

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
Expand Down Expand Up @@ -49,14 +51,23 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node):
if len(weight.shape) != 4:
return

dtype = self.model.get_dtype(nhwc_conv_input)
if not (dtype is not None and weight_tensor.data_type == dtype):
cast_node = self.fusion_utils.add_cast_node(
input_name=nhwc_conv_input,
to_type=weight_tensor.data_type,
output_name_to_node=output_name_to_node,
)
nhwc_conv_input = cast_node.output[0]

if self.update_weight:
# Transpose weights from NCHW to NHWC
weight = weight.transpose(0, 2, 3, 1)

weight_name = node_name + "_weight_NHWC"
self.add_initializer(
name=weight_name,
data_type=TensorProto.FLOAT,
data_type=weight_tensor.data_type,
dims=list(weight.shape),
vals=weight,
)
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i
return None

def get_dimensions(self, input_name: str) -> Union[int, None]:
graph_input = self.model.find_graph_input(input_name)
if graph_input:
return self.get_dimensions_from_tensor_proto(graph_input)
shape = self.model.get_shape(input_name)
if shape is not None:
return len(shape)

if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
Expand Down
Loading
Loading