Skip to content

Commit

Permalink
Merge branch 'main' into titaiwang/diffuser_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Apr 26, 2024
2 parents 8d6e937 + 997beb2 commit ca03ca8
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
- name: Pull Test Data
run: git lfs pull
- name: Run tests
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
env:
CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
Expand Down
13 changes: 4 additions & 9 deletions onnxscript/rewriter/cast_constant_of_shape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import onnx.checker
import onnx.parser
import onnx.printer

from onnxscript import ir
from onnxscript.rewriter import cast_constant_of_shape
Expand All @@ -16,20 +15,17 @@ def test_cast_after_constant_of_shape_is_fused(self):
agraph (int64[2] input_x) => (float16[1, 4] output)
{
constant = ConstantOfShape <value: tensor = float[1] {1.}>(input_x)
temp = Cast <to = 10> (constant)
output = Identity (temp)
output = Cast <to = 10> (constant)
}
"""
)
onnx.checker.check_model(input_model_proto, True)
model = ir.serde.deserialize_model(input_model_proto)
count = cast_constant_of_shape.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 2)
self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10)
output_model_proto = ir.serde.serialize_model(model)
# TODO: Eliminating `temp` in above example causes a failure.
# Rewriter changes graph output name, doesn't introduce type for it
onnx.checker.check_model(output_model_proto, True)

def test_cast_after_constant_of_shape_without_value_is_fused(self):
Expand All @@ -39,15 +35,14 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self):
agraph (int64[2] input_x) => (float16[1, 4] output)
{
constant = ConstantOfShape (input_x)
temp = Cast <to = 10> (constant)
output = Identity (temp)
output = Cast <to = 10> (constant)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = cast_constant_of_shape.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 2)
self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10)
output_model_proto = ir.serde.serialize_model(model)
onnx.checker.check_model(output_model_proto, True)
Expand Down
16 changes: 12 additions & 4 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,13 +910,21 @@ def _apply_deltas(
else:
deleted_nodes, inserted_nodes = delta
# Replace deleted nodes with inserted nodes.
# However, we merge the last deleted node and last inserted node
# to avoid replacing the values produced by the last deleted node
# in all places where they are used. So, we reuse the output
# values from the last deleted node and replace the node itself
# TODO: simplify this
last_deleted = deleted_nodes[-1]
last_inserted = inserted_nodes[-1]

for old_value, new_value in zip(last_deleted.outputs, last_inserted.outputs):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps we should merge old and new types. As of now, new
# values don't have type information. Note that this could be a problem
# for semantics-altering rewrite-rules: we should allow users to override
# this for such rules.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted node to use the new outputs
_convenience.replace_all_uses_with(last_deleted.outputs, last_inserted.outputs)
# Update graph/function outputs if the node generates output
Expand Down
4 changes: 3 additions & 1 deletion onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import numpy as np
import onnx.checker
import onnx.parser

from onnxscript import ir
Expand Down Expand Up @@ -337,7 +338,7 @@ def double(op, x):
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
agraph (float[N] x) => (float[M] z)
{
z = afunction (x)
z = pkg.custom.afunction (x)
}
<domain: "pkg.custom", opset_import: [ "" : 17]>
afunction (x) => (z)
Expand All @@ -355,6 +356,7 @@ def double(op, x):
self.assertEqual(
model.functions[("pkg.custom", "afunction", "")].opset_imports["custom.domain"], 10
)
onnx.checker.check_model(ir.serde.serialize_model(model))


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions testdata/dort_models/llama_forward.onnx
Git LFS file not shown
39 changes: 38 additions & 1 deletion tests/optimizer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import numpy as np
import onnx
import onnx.inliner
import onnxruntime
import parameterized

from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter
from onnxscript.utils import evaluation_utils

_SKIP_TABLE = {}
Expand Down Expand Up @@ -64,6 +66,41 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name):
for output, expected_output in zip(outputs, expected_outputs):
np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3)

def test_optimizer_after_inlining(self):
model_dir = pathlib.Path(model_folder_path) / ".." / "dort_models"
filename = model_dir / "llama_forward.onnx"

onnx_model = onnx.load(filename)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# first time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# inline
onnx_model = onnx.inliner.inline_local_functions(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# second time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit ca03ca8

Please sign in to comment.