Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 21, 2024
1 parent d874dbc commit a745039
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
25 changes: 21 additions & 4 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import remove_unused_nodes
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.
Expand All @@ -25,7 +26,7 @@ def __init__(self, name: str, max_pos_id: int):
self._max_pos_id = max_pos_id
self.remove_nodes = False

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute remove_nodes, which was previously defined in superclass
RewriteRuleClassBase
.

def pattern(self, op, x, inv_freq, position_ids):
def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads):
position_ids_expanded = op.Unsqueeze(position_ids, 1)
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_expanded)
Expand All @@ -35,7 +36,14 @@ def pattern(self, op, x, inv_freq, position_ids):
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1) # convert
sin_4d = op.Unsqueeze(sin, 1)
return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion")
return op.RotaryEmbedding(
x,
cos_4d,
sin_4d,
interleaved=interleaved,
num_heads=num_heads,
_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, **_):
if not _ir_utils.has_rank(position_ids, 2):
Expand All @@ -47,7 +55,7 @@ def check(self, context, inv_freq, position_ids, **_):
return False

Check warning on line 55 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L55

Added line #L55 was not covered by tests
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

def rewrite(self, op, x, inv_freq, position_ids, **_):
def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
Expand All @@ -59,7 +67,15 @@ def rewrite(self, op, x, inv_freq, position_ids, **_):
# cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
# sin = op.Gather(sin_2d, position_ids, axis=0)
return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion")
return op.RotaryEmbedding(
x,
position_ids,
cos_2d,
sin_2d,
interleaved=interleaved,
num_heads=num_heads,
_domain="com.microsoft",
)


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)
Expand All @@ -70,4 +86,5 @@ def rewrite(self, op, x, inv_freq, position_ids, **_):
def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
remove_unused_nodes(model)
return count
12 changes: 9 additions & 3 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size)
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
if x is None or x.shape is None or len(x.shape) != 4:
return False

Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33

Added line #L33 was not covered by tests
if not isinstance(x.shape[1], int):
return False

Check warning on line 35 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L35

Added line #L35 was not covered by tests
head_size = x.shape[3]
if not isinstance(head_size, int):
return False

Check warning on line 38 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L38

Added line #L38 was not covered by tests
Expand All @@ -45,14 +47,18 @@ def check(self, op, x, start1, end1, start2, end2, **_):
)

def rewrite(self, op, x, cos, sin, **_):
return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion")
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
)


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> None:
def fuse_rotary_embedding(model: ir.Model) -> int:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
return count

0 comments on commit a745039

Please sign in to comment.