-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First series of P0 patterns to optimize llama (#1490)
Signed-off-by: Xavier Dupre <[email protected]> Co-authored-by: Justin Chu <[email protected]>
- Loading branch information
1 parent
34e410a
commit 1b2ecf5
Showing
4 changed files
with
221 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
import onnxscript.rewriter.no_op as no_op | ||
import onnxscript.rewriter.pattern as orp | ||
|
||
op = orp.onnxop | ||
|
||
|
||
def transpose_identity_pattern(op, x, perm): | ||
return op.Transpose(x, perm=perm) | ||
|
||
|
||
def transpose_identity_check(context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool: | ||
if isinstance(perm, ir.RefAttr): | ||
return False | ||
if perm.type == ir.AttributeType.INTS: | ||
if perm.value == list(range(len(perm.value))): | ||
return True | ||
return False | ||
|
||
|
||
def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None): | ||
return op.Identity(x) | ||
|
||
|
||
def transpose_transpose_pattern(op, x, perm1, perm2): | ||
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) | ||
|
||
|
||
def transpose_transpose_check( | ||
context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr | ||
) -> bool: | ||
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): | ||
return False | ||
return True | ||
|
||
|
||
def _apply_transpose(perm: tuple[int, ...], on: list[int]) -> list[int]: | ||
assert len(perm) == len(on), "length mismatch" | ||
res = [-1 for i in on] | ||
for i, p in enumerate(perm): | ||
res[i] = on[p] | ||
return res | ||
|
||
|
||
def _apply_transposes(perms: list[tuple[int, ...]], on: list[int] | None = None) -> list[int]: | ||
if on is None: | ||
on = list(range(len(perms[0]))) | ||
for p in perms: | ||
on = _apply_transpose(p, on) | ||
return on | ||
|
||
|
||
def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr): | ||
first = list(range(len(perm1.value))) | ||
last = _apply_transposes([perm1.value, perm2.value]) | ||
if first == last: | ||
return op.Identity(x) | ||
return op.Transpose(x, perm=last) | ||
|
||
|
||
transpose_identity_rule = orp.RewriteRule( | ||
transpose_identity_pattern, transpose_identity_rewrite, transpose_identity_check | ||
) | ||
transpose_transpose_rule = orp.RewriteRule( | ||
transpose_transpose_pattern, transpose_transpose_rewrite, transpose_transpose_check | ||
) | ||
|
||
|
||
def llama_p0_rule_set() -> orp.RewriteRuleSet: | ||
"""Returns a set of rules which should be applied | ||
before any other one as they usually remove unnecessary computation | ||
such as the multiplication by 1 or two consecutive transpose. | ||
Returns: | ||
RewriteRuleSet | ||
""" | ||
return orp.RewriteRuleSet( | ||
[ | ||
no_op.mul_by_1_rule, | ||
no_op.add_0_rule, | ||
no_op.add_0_rule, | ||
no_op.div_by_1_rule, | ||
transpose_identity_rule, | ||
transpose_transpose_rule, | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from __future__ import annotations | ||
|
||
import unittest | ||
from typing import Any | ||
|
||
import numpy as np | ||
import onnx | ||
import onnx.reference | ||
|
||
import onnxscript.rewriter.llama_rule_sets as llama_rule_sets | ||
from onnxscript import ir | ||
|
||
FLOAT = onnx.TensorProto.FLOAT | ||
|
||
|
||
class LlamaRuleSetsTest(unittest.TestCase): | ||
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: | ||
feeds: dict[str, Any] = {} | ||
for i in model.graph.input: | ||
shape = tuple(d + 2 for d in range(len(i.type.tensor_type.shape.dim))) | ||
if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: | ||
feeds[i.name] = np.random.randn(*shape).astype(np.float32) | ||
else: | ||
raise AssertionError(f"Not implemented for input {i}") | ||
return feeds | ||
|
||
def _check_model( | ||
self, | ||
model: onnx.ModelProto, | ||
optimized_model: onnx.ModelProto, | ||
feeds: dict[str, Any] | None = None, | ||
atol: float = 0.0, | ||
rtol: float = 1e-7, | ||
): | ||
if not feeds: | ||
feeds = self._get_random_inputs(model) | ||
ref = onnx.reference.ReferenceEvaluator(model) | ||
opt = onnx.reference.ReferenceEvaluator(optimized_model) | ||
expected = ref.run(None, feeds) | ||
got = opt.run(None, feeds) | ||
self.assertEqual(len(expected), len(got)) | ||
for a, b in zip(expected, got): | ||
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) | ||
|
||
@classmethod | ||
def _identity_models(cls): | ||
models = [ | ||
onnx.helper.make_model( | ||
onnx.helper.make_graph( | ||
[ | ||
onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]), | ||
], | ||
"name", | ||
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], | ||
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], | ||
), | ||
opset_imports=[onnx.helper.make_opsetid("", 18)], | ||
), | ||
onnx.helper.make_model( | ||
onnx.helper.make_graph( | ||
[ | ||
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]), | ||
], | ||
"name", | ||
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])], | ||
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])], | ||
[ | ||
onnx.numpy_helper.from_array( | ||
np.array([1], dtype=np.float32), name="one" | ||
) | ||
], | ||
), | ||
opset_imports=[onnx.helper.make_opsetid("", 18)], | ||
), | ||
onnx.helper.make_model( | ||
onnx.helper.make_graph( | ||
[ | ||
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]), | ||
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]), | ||
], | ||
"name", | ||
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])], | ||
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])], | ||
), | ||
opset_imports=[onnx.helper.make_opsetid("", 18)], | ||
), | ||
] | ||
return models | ||
|
||
def test_llama_p0_rule_set_identity(self): | ||
for model_proto in self._identity_models(): | ||
ir_model = ir.serde.deserialize_model(model_proto) | ||
rule_set = llama_rule_sets.llama_p0_rule_set() | ||
rule_set.apply_to_model(ir_model) | ||
rewritten_model = ir.serde.serialize_model(ir_model) | ||
|
||
self.assertEqual(["Identity"], [n.op_type for n in rewritten_model.graph.node]) | ||
self._check_model(model_proto, rewritten_model) | ||
|
||
@classmethod | ||
def _transpose_transpose_models(cls): | ||
models = [ | ||
onnx.helper.make_model( | ||
onnx.helper.make_graph( | ||
[ | ||
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]), | ||
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]), | ||
], | ||
"name", | ||
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], | ||
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], | ||
), | ||
opset_imports=[onnx.helper.make_opsetid("", 18)], | ||
), | ||
] | ||
return models | ||
|
||
def test_llama_p0_rule_set_transpose_transpose(self): | ||
for model_proto in self._transpose_transpose_models(): | ||
ir_model = ir.serde.deserialize_model(model_proto) | ||
rule_set = llama_rule_sets.llama_p0_rule_set() | ||
rule_set.apply_to_model(ir_model) | ||
rewritten_model = ir.serde.serialize_model(ir_model) | ||
|
||
self.assertEqual(["Transpose"], [n.op_type for n in rewritten_model.graph.node]) | ||
self._check_model(model_proto, rewritten_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters