diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 5125a359f..440f4a111 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -5,6 +5,7 @@ import numpy as np import onnxscript.ir as ir +from onnxscript.optimizer import remove_unused_nodes from onnxscript.rewriter import _ir_utils, pattern # Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. @@ -25,7 +26,7 @@ def __init__(self, name: str, max_pos_id: int): self._max_pos_id = max_pos_id self.remove_nodes = False - def pattern(self, op, x, inv_freq, position_ids): + def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): position_ids_expanded = op.Unsqueeze(position_ids, 1) position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) freqs = op.MatMul(inv_freq, position_ids_expanded) @@ -35,7 +36,14 @@ def pattern(self, op, x, inv_freq, position_ids): sin = op.Sin(emb) cos_4d = op.Unsqueeze(cos, 1) # convert sin_4d = op.Unsqueeze(sin, 1) - return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion") + return op.RotaryEmbedding( + x, + cos_4d, + sin_4d, + interleaved=interleaved, + num_heads=num_heads, + _domain="ai.onnxruntime.fusion", + ) def check(self, context, inv_freq, position_ids, **_): if not _ir_utils.has_rank(position_ids, 2): @@ -47,7 +55,7 @@ def check(self, context, inv_freq, position_ids, **_): return False return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 - def rewrite(self, op, x, inv_freq, position_ids, **_): + def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) @@ -59,7 +67,15 @@ def rewrite(self, op, x, inv_freq, position_ids, **_): # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) # sin = op.Gather(sin_2d, position_ids, axis=0) - return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion") + return op.RotaryEmbedding( + x, + position_ids, + cos_2d, + sin_2d, + interleaved=interleaved, + num_heads=num_heads, + _domain="com.microsoft", + ) _rule = CosSinCacheFusion.rule("CosSinCache", 2048) @@ -70,4 +86,5 @@ def rewrite(self, op, x, inv_freq, position_ids, **_): def fuse_cos_sin_cache(model: ir.Model) -> int: count = cos_sin_cache_rules.apply_to_model(model) print(f"CosSinCache count: {count}") + remove_unused_nodes(model) return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 83749cb5d..22e6bfeee 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -28,9 +28,11 @@ def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin def check(self, op, x, start1, end1, start2, end2, **_): - # x needs to be a 4D tensor with known last dimension size (== head_size) + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: return False + if not isinstance(x.shape[1], int): + return False head_size = x.shape[3] if not isinstance(head_size, int): return False @@ -45,7 +47,10 @@ def check(self, op, x, start1, end1, start2, end2, **_): ) def rewrite(self, op, x, cos, sin, **_): - return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion") + num_heads = x.shape[1] + return op.RotaryEmbedding( + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" + ) _rule = RotaryEmbeddingFusion.rule() @@ -53,6 +58,7 @@ def rewrite(self, op, x, cos, sin, **_): rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) -def fuse_rotary_embedding(model: ir.Model) -> None: +def fuse_rotary_embedding(model: ir.Model) -> int: count = rotary_embedding_rules.apply_to_model(model) print(f"Rotary Embedding count: {count}") + return count