Skip to content

Commit

Permalink
Rewrite stable diffusion unet with com.microsoft attenion operators (#…
Browse files Browse the repository at this point in the history
…1454)

This PR introduces attention rewriting rules for stable diffusion unet.
In summary, scripted function _scaled_dot_product_attention is replaced
by either cim.microsoft.Attention, or com.microsoft.MultiHeadAttention.
  • Loading branch information
titaiwangms authored Apr 27, 2024
1 parent 19d2498 commit b1d9a81
Show file tree
Hide file tree
Showing 37 changed files with 264 additions and 35 deletions.
1 change: 1 addition & 0 deletions onnxscript/rewriter/onnxruntime/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
multihead_attention.GQALlama2RewriteRule,
multihead_attention.GQALlamaSdpa2RewriteRule,
multihead_attention.AttnPhi15RewriteRule,
multihead_attention.MHAStableDiffusionUnetRewriteRule,
layernorm.LNRewriteRule,
fastgelu.GeluRewriteRule,
biassplitgelu.GegluRewriteRule,
Expand Down
182 changes: 147 additions & 35 deletions onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,39 +55,77 @@

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import function_rule
from onnxscript.rewriter import _ir_utils, function_rule

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class AttnSizeConfig:
num_attention_heads: int
num_key_value_heads: int
num_key_value_heads: int | None
head_size: int
hidden_size: int


class AttentionRewriteRule(function_rule.FunctionRewriteRule, abc.ABC):
def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
if len(function.outputs) != 3:
if len(function.outputs) == 3:
# Usually the Attention related modules have 3 outputs:
# present_value, present_key, attn_output
present_value, _, attn_output = function.outputs
if present_value.shape is None:
raise function_rule.FunctionRewriteError(
"Failed to find shape for present_value."
)
if attn_output.shape is None:
raise function_rule.FunctionRewriteError(
"Failed to find shape for attn_output."
)
head_size = present_value.shape[3]
num_key_value_heads = present_value.shape[1]
hidden_size = attn_output.shape[2]
num_attention_heads = hidden_size // head_size
return AttnSizeConfig(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_size=head_size,
hidden_size=hidden_size,
)
elif any("scaled_dot_product_attention" in node.op_type for node in function):
# If the Attention related modules use scaled_dot_product_attention,
# present_value and present_key are not present in the output.
hidden_size = function.outputs[0].shape[2]
# Get head size and number of heads from the Reshape node.
# Reference:
# https://github.com/huggingface/diffusers/blob/ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4/src/diffusers/models/attention_processor.py#L1269
reshape_nodes = [node for node in function if node.op_type == "Reshape"]
assert (
len(reshape_nodes) == 4
), "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention."
for reshape_node in reshape_nodes:
constant_node = reshape_node.inputs[1].producer()
assert (
constant_node.op_type == "Constant"
), "Expected the second input to Reshape to be a Constant node."
value = _ir_utils.propagate_const_value(reshape_node.inputs[1])
constant_numpy_value = _ir_utils.get_numpy_from_ir_value(value)
if constant_numpy_value.shape[0] == 4:
num_attention_heads = constant_numpy_value[2]
head_size = constant_numpy_value[3]
return AttnSizeConfig(
num_attention_heads=num_attention_heads,
num_key_value_heads=None,
head_size=head_size,
hidden_size=hidden_size,
)
raise function_rule.FunctionRewriteError(
f"Unexpected number of outputs. Expected 3, got {len(function.outputs)}."
"Failed to infer head size and number of heads from QKV Reshape nodes. \
Expected 4D shape in the constant node (batch_size, seq_length, num_attention_heads, head_size)."
)
present_value, _, attn_output = function.outputs
if present_value.shape is None:
raise function_rule.FunctionRewriteError("Failed to find shape for present_value.")
if attn_output.shape is None:
raise function_rule.FunctionRewriteError("Failed to find shape for attn_output.")
head_size = present_value.shape[3]
num_key_value_heads = present_value.shape[1]
hidden_size = attn_output.shape[2]
num_attention_heads = hidden_size // head_size
return AttnSizeConfig(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_size=head_size,
hidden_size=hidden_size,
raise function_rule.FunctionRewriteError(
f"Attenion modules should have 3 outputs or scaled_dot_product_attention node, "
f"got output: {len(function.outputs)} and no scaled_dot_product_attention."
)


Expand All @@ -96,14 +134,11 @@ class MHALlama2RewriteRule(AttentionRewriteRule):
PACKAGE_NAME = "transformers"
_version_controller = function_rule.VersionController()

def __init__(self) -> None:
super().__init__()

@_version_controller.register_version(min_version="4.33", max_version="4.36")
def _fusion_with_4d_cache(self, function: ir.Function) -> ir.Function:
if len(function.input) != 9:
if len(function.inputs) != 9:
raise function_rule.FunctionRewriteError(
f"Unexpected number of inputs. Expected 9, got {len(function.input)}."
f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}."
)

# Infer size configurations from the function.
Expand Down Expand Up @@ -172,9 +207,9 @@ def _fusion_with_2d_cache(self, function: ir.Function) -> ir.Function:
# Infer size configurations from the function.
attn_size_config = self.infer_attn_size_config(function)

if len(function.input) != 9:
if len(function.inputs) != 9:
raise function_rule.FunctionRewriteError(
f"Unexpected number of inputs. Expected 9, got {len(function.input)}."
f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}."
)

# Code new pattern with onnxscript.
Expand Down Expand Up @@ -234,9 +269,6 @@ class GQALlama2RewriteRule(AttentionRewriteRule):
PACKAGE_NAME = "transformers"
_version_controller = function_rule.VersionController()

def __init__(self) -> None:
super().__init__()

@_version_controller.register_version(min_version="4.33", max_version="4.36")
def _fusion_with_4d_cache(self, function: ir.Function) -> ir.Function:
if len(function.inputs) != 9:
Expand Down Expand Up @@ -384,9 +416,6 @@ class GQALlamaSdpa2RewriteRule(AttentionRewriteRule):
PACKAGE_NAME = "transformers"
_version_controller = function_rule.VersionController()

def __init__(self) -> None:
super().__init__()

@_version_controller.register_version(min_version="4.36", max_version="4.38")
def _fusion(self, function: ir.Function) -> ir.Function:
# Infer size configurations from the function.
Expand Down Expand Up @@ -451,7 +480,6 @@ def gqa(
def _fusion_without_cos_sin_cache(self, function: ir.Function) -> ir.Function:
# Infer size configurations from the function.
attn_size_config = self.infer_attn_size_config(function)

# Code new pattern with onnxscript.
op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)
Expand Down Expand Up @@ -528,9 +556,6 @@ class AttnPhi15RewriteRule(AttentionRewriteRule):
PACKAGE_NAME = "transformers_modules"
_version_controller = function_rule.VersionController()

def __init__(self) -> None:
super().__init__()

@_version_controller.register_version()
def _fusion(self, function: ir.Function) -> ir.Function:
# Infer size configurations from the function.
Expand Down Expand Up @@ -592,3 +617,90 @@ def phi_attention(
phi_attention
).to_function_proto()
return ir.serde.deserialize_function(function_proto)


class MHAStableDiffusionUnetRewriteRule(AttentionRewriteRule):
"""Rewrite rule for Attention in diffusers."""

FUNCTION_KEYWORD = "Attention"
PACKAGE_NAME = "diffusers"
_version_controller = function_rule.VersionController()

@_version_controller.register_version()
def _fusion(self, function: ir.Function) -> ir.Function:
# Attention inputs could be 6 or 7:
# hidden_states, encoder_hidden_states(optional), q_weight, k_weight, v_weight, o_weight, o_bias
if len(function.inputs) != 6 and len(function.inputs) != 7:
raise function_rule.FunctionRewriteError(
f"Unexpected number of inputs. Expected 6 or 7, got {len(function.inputs)}."
)

# Infer size configurations from the function.
attn_size_config = self.infer_attn_size_config(function)

# Code new pattern with onnxscript.
op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)

def attention(
hidden_states,
q_weight,
k_weight,
v_weight,
o_weight,
o_bias,
):
qkv_weight = op.Transpose(
op.Concat(q_weight, k_weight, v_weight, axis=0),
perm=[1, 0],
)

# NOTE: MHA does not work when Q, K, and V has the same root inputs.
attn_output, _ = msft_op.Attention(
hidden_states,
qkv_weight,
None,
None,
num_heads=attn_size_config.num_attention_heads,
)

# linear projection
output = op.Add(op.MatMul(attn_output, op.Transpose(o_weight, [1, 0])), o_bias)
return output

def mha(
hidden_states,
encoder_hidden_states,
q_weight,
k_weight,
v_weight,
o_weight,
o_bias,
):
q = op.MatMul(hidden_states, op.Transpose(q_weight, [1, 0]))
k = op.MatMul(encoder_hidden_states, op.Transpose(k_weight, [1, 0]))
v = op.MatMul(encoder_hidden_states, op.Transpose(v_weight, [1, 0]))

# NOTE: Q and K needs to have the sequence length (dim 1) to use
# GQA.
mha_output, _, _ = msft_op.MultiHeadAttention(
q,
k,
v,
None,
None,
num_heads=attn_size_config.num_attention_heads,
)
attn_output = op.Add(op.MatMul(mha_output, op.Transpose(o_weight, [1, 0])), o_bias)
return attn_output

if len(function.inputs) == 6:
function_proto = onnxscript.script(default_opset=onnxscript.opset18)(
attention
).to_function_proto()
return ir.serde.deserialize_function(function_proto)

function_proto = onnxscript.script(default_opset=onnxscript.opset18)(
mha
).to_function_proto()
return ir.serde.deserialize_function(function_proto)
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_sdpa_yi_4_38(self):
"sdpa_yi_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")}
)

@testutils.skip_if_no_cuda("CPU has parity issue.")
def test_attn_stable_diffusion_unet(self):
testutils.test_onnxruntime_rewrite(
"attn_stable_diffusion_unet", 2, {("com.microsoft", "MultiHeadAttention", "")}
)


class AttnParityTest(unittest.TestCase):
def setUp(self):
Expand All @@ -66,6 +72,14 @@ def test_attn_phi_1_5(self):
"attn_phi_1_5", 4, {("com.microsoft", "Attention", "")}
)

@testutils.skip_if_no_cuda("CPU has parity issue.")
def test_attn_stable_diffusion_unet_without_encoder_hidden_states(self):
testutils.test_onnxruntime_rewrite(
"attn_stable_diffusion_unet_without_encoder_hidden_states",
2,
{("com.microsoft", "Attention", "")},
)


if __name__ == "__main__":
unittest.main()
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit b1d9a81

Please sign in to comment.