Skip to content

Commit

Permalink
refine unet attention fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jan 22, 2024
1 parent 8577e25 commit f8a9cf5
Showing 1 changed file with 50 additions and 48 deletions.
98 changes: 50 additions & 48 deletions onnxruntime/python/tools/transformers/fusion_attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,60 +1206,62 @@ def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_no
return False

match_qkv = self.match_qkv_a1111(root_input, skip_add)
if match_qkv is not None:
(
is_torch2,
reshape_qkv,
transpose_qkv,
reshape_q,
matmul_q,
matmul_k,
matmul_v,
) = match_qkv

cast_q = self.model.match_parent(matmul_q, "Cast", 0)
cast_k = self.model.match_parent(matmul_k, "Cast", 0)
cast_v = self.model.match_parent(matmul_v, "Cast", 0)
if not (
cast_q is not None
and cast_k is not None
and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
and cast_k == cast_v
):
return False
if cast_q.input[0] != normalize_node.output[0]:
return False
if match_qkv is None:
return False

attention_last_node = reshape_qkv
(
reshape_qkv,
transpose_qkv,
reshape_q,
matmul_q,
matmul_k,
matmul_v,
) = match_qkv

cast_q = self.model.match_parent(matmul_q, "Cast", 0)
cast_k = self.model.match_parent(matmul_k, "Cast", 0)
cast_v = self.model.match_parent(matmul_v, "Cast", 0)
if not (
cast_q is not None
and cast_k is not None
and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
and cast_k == cast_v
):
return False

q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return False
if cast_q.input[0] != normalize_node.output[0]:
return False

q_hidden_size = self.get_hidden_size(normalize_node)
attention_last_node = reshape_qkv

# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
matmul_k,
matmul_v,
q_num_heads,
q_hidden_size,
input=matmul_q.input[0],
output=attention_last_node.output[0],
)
if new_node is None:
return False
q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return False

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
q_hidden_size = self.get_hidden_size(normalize_node)

# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
matmul_k,
matmul_v,
q_num_heads,
q_hidden_size,
input=matmul_q.input[0],
output=attention_last_node.output[0],
)
if new_node is None:
return False

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True
return True
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True
return True

def match_qkv_a1111(self, root_input, skip_add):
"""Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
Expand Down Expand Up @@ -1303,4 +1305,4 @@ def match_qkv_a1111(self, root_input, skip_add):

(_, _, _, matmul_k) = k_nodes

return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v

0 comments on commit f8a9cf5

Please sign in to comment.