Skip to content

Commit

Permalink
Add cos sin test
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 20, 2024
1 parent 1fdc19b commit eb916b8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
3 changes: 2 additions & 1 deletion onnxscript/rewriter/onnxruntime/xformers/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def _save(model, modelpath):

def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug"
with tempfile.TemporaryDirectory() as temp_dir2:

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable temp\_dir2 is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
# Run model
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def rewrite(self, op, inv_freq, position_ids, **_):
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))
cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
Expand All @@ -62,6 +64,7 @@ def rewrite(self, op, inv_freq, position_ids, **_):
cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> None:
def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
return count
29 changes: 29 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'onnx' is not used.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

onnx imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@

import unittest

import onnx

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


def model_repr(self):
return f"Model({self.graph.name})"


onnx.ModelProto.__repr__ = model_repr


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
Expand Down

0 comments on commit eb916b8

Please sign in to comment.