-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1fdc19b
commit eb916b8
Showing
4 changed files
with
35 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,8 @@ def _save(model, modelpath): | |
|
||
def ort_run(model_name: str, model, inputs): | ||
providers = ["CPUExecutionProvider"] | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug" | ||
with tempfile.TemporaryDirectory() as temp_dir2: | ||
Check warning Code scanning / lintrunner RUFF/F841 Warning
Local variable temp\_dir2 is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
model_path = os.path.join(temp_dir, f"{model_name}.onnx") | ||
io.save(model, model_path) | ||
# Run model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnx | ||
Check notice Code scanning / CodeQL Unused import Note
Import of 'onnx' is not used.
Check warning Code scanning / lintrunner RUFF/F401 Warning
onnx imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
|
||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run | ||
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
|
||
|
||
class TestCosSinCacheTransform(unittest.TestCase): | ||
def test_smollm(self): | ||
smollm_test = _SmollmTestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
inputs = smollm_test.get_ort_inputs() | ||
original_outputs = ort_run("original", model, inputs) | ||
count = fuse_cos_sin_cache(model) | ||
self.assertGreater(count, 0) | ||
new_outputs = ort_run("optimized", model, inputs) | ||
assert_allclose(new_outputs, original_outputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters