Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Apr 26, 2024
1 parent ca03ca8 commit d3912de
Showing 1 changed file with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class AttnSizeConfig:
class AttentionRewriteRule(function_rule.FunctionRewriteRule, abc.ABC):
def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
if len(function.outputs) == 3:
# Usually the Attention related modules have 3 outputs:
# present_value, present_key, attn_output
present_value, _, attn_output = function.outputs

Check warning on line 76 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L76

Added line #L76 was not covered by tests
if present_value.shape is None:
raise function_rule.FunctionRewriteError(

Check warning on line 78 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L78

Added line #L78 was not covered by tests
Expand All @@ -91,6 +93,8 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
hidden_size=hidden_size,
)
elif any("scaled_dot_product_attention" in node.op_type for node in function):
# If the Attention related modules use scaled_dot_product_attention,
# present_value and present_key are not present in the output.
hidden_size = function.outputs[0].shape[2]

Check warning on line 98 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L98

Added line #L98 was not covered by tests
# Get head size and number of heads from the Reshape node.
# Reference:
Expand All @@ -116,10 +120,12 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
hidden_size=hidden_size,
)
raise function_rule.FunctionRewriteError(
"Failed to infer head size and number of heads from Reshape nodes."
"Failed to infer head size and number of heads from QKV Reshape nodes. \
Expected 4D shape in the constant node (batch_size, seq_length, num_attention_heads, head_size)."
)
raise function_rule.FunctionRewriteError(

Check warning on line 126 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L126

Added line #L126 was not covered by tests
f"Unexpected function structure, got output: {len(function.outputs)}."
f"Attenion modules should have 3 outputs or scaled_dot_product_attention node, \
got output: {len(function.outputs)} and no scaled_dot_product_attention."
)


Expand Down Expand Up @@ -626,6 +632,8 @@ def phi_attention(


class MHAStableDiffusionUnetRewriteRule(AttentionRewriteRule):
"""Rewrite rule for Attention in diffusers."""

FUNCTION_KEYWORD = "Attention"
PACKAGE_NAME = "diffusers"
_version_controller = function_rule.VersionController()
Expand All @@ -635,6 +643,8 @@ def __init__(self) -> None:

@_version_controller.register_version()
def _fusion(self, function: ir.Function) -> ir.Function:
# Attention inputs could be 6 or 7:
# hidden_states, encoder_hidden_states(optional), q_weight, k_weight, v_weight, o_weight, o_bias
if len(function.inputs) != 6 and len(function.inputs) != 7:
raise function_rule.FunctionRewriteError(

Check warning on line 649 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L649

Added line #L649 was not covered by tests
f"Unexpected number of inputs. Expected 6 or 7, got {len(function.inputs)}."
Expand All @@ -661,7 +671,7 @@ def attention(
)

# NOTE: MHA does not work when Q, K, and V has the same root inputs.
mha_output, _ = msft_op.Attention(
attn_output, _ = msft_op.Attention(

Check warning on line 674 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L674

Added line #L674 was not covered by tests
hidden_states,
qkv_weight,
None,
Expand All @@ -670,7 +680,7 @@ def attention(
)

# linear projection
output = op.Add(op.MatMul(mha_output, op.Transpose(o_weight, [1, 0])), o_bias)
output = op.Add(op.MatMul(attn_output, op.Transpose(o_weight, [1, 0])), o_bias)
return output

Check warning on line 684 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L683-L684

Added lines #L683 - L684 were not covered by tests

def mha(

Check warning on line 686 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L686

Added line #L686 was not covered by tests
Expand All @@ -688,15 +698,15 @@ def mha(

# NOTE: Q and K needs to have the sequence length (dim 1) to use
# GQA.
gqa_output, _, _ = msft_op.MultiHeadAttention(
mha_output, _, _ = msft_op.MultiHeadAttention(

Check warning on line 701 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L701

Added line #L701 was not covered by tests
q,
k,
v,
None,
None,
num_heads=attn_size_config.num_attention_heads,
)
attn_output = op.Add(op.MatMul(gqa_output, op.Transpose(o_weight, [1, 0])), o_bias)
attn_output = op.Add(op.MatMul(mha_output, op.Transpose(o_weight, [1, 0])), o_bias)
return attn_output

Check warning on line 710 in onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py#L709-L710

Added lines #L709 - L710 were not covered by tests

if len(function.inputs) == 6:
Expand Down

0 comments on commit d3912de

Please sign in to comment.