From eb916b8809fac3ad85e783faf17cfbc3f87ac503 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Dec 2024 11:16:28 -0800 Subject: [PATCH] Add cos sin test --- .../onnxruntime/xformers/_test_utils.py | 3 +- .../onnxruntime/xformers/cos_sin_cache.py | 5 +++- .../xformers/cos_sin_cache_test.py | 29 +++++++++++++++++++ .../xformers/rms_normalization_test.py | 9 ------ 4 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py index 0b4e2c55f..37618522a 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -23,7 +23,8 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] - with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug" + with tempfile.TemporaryDirectory() as temp_dir2: model_path = os.path.join(temp_dir, f"{model_name}.onnx") io.save(model, model_path) # Run model diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index a0e73730b..7ddae004c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -49,7 +49,9 @@ def rewrite(self, op, inv_freq, position_ids, **_): pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) + cos_value = np.concatenate([cos_value, cos_value], axis=-1) sin_value = np.sin(angles) + sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) @@ -62,6 +64,7 @@ def rewrite(self, op, inv_freq, position_ids, **_): cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) -def fuse_cos_sin_cache(model: ir.Model) -> None: +def fuse_cos_sin_cache(model: ir.Model) -> int: count = cos_sin_cache_rules.apply_to_model(model) print(f"CosSinCache count: {count}") + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py new file mode 100644 index 000000000..9a84f45f1 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache + + +class TestCosSinCacheTransform(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py index 79a966838..30080474c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py @@ -4,21 +4,12 @@ import unittest -import onnx - import onnxscript.optimizer from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization -def model_repr(self): - return f"Model({self.graph.name})" - - -onnx.ModelProto.__repr__ = model_repr - - class TestRmsNormalization(unittest.TestCase): def test_smollm(self): smollm_test = _SmollmTestData()