Skip to content

Commit

Permalink
First series of P0 patterns to optimize llama (#1490)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xadupre and justinchuby authored May 29, 2024
1 parent 34e410a commit 1b2ecf5
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 2 deletions.
88 changes: 88 additions & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import onnxscript.ir as ir
import onnxscript.rewriter.no_op as no_op
import onnxscript.rewriter.pattern as orp

op = orp.onnxop


def transpose_identity_pattern(op, x, perm):
return op.Transpose(x, perm=perm)


def transpose_identity_check(context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
if isinstance(perm, ir.RefAttr):
return False
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False


def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)


def transpose_transpose_pattern(op, x, perm1, perm2):
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)


def transpose_transpose_check(
context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False
return True


def _apply_transpose(perm: tuple[int, ...], on: list[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
return res


def _apply_transposes(perms: list[tuple[int, ...]], on: list[int] | None = None) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = _apply_transpose(p, on)
return on


def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.value)))
last = _apply_transposes([perm1.value, perm2.value])
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)


transpose_identity_rule = orp.RewriteRule(
transpose_identity_pattern, transpose_identity_rewrite, transpose_identity_check
)
transpose_transpose_rule = orp.RewriteRule(
transpose_transpose_pattern, transpose_transpose_rewrite, transpose_transpose_check
)


def llama_p0_rule_set() -> orp.RewriteRuleSet:
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.
Returns:
RewriteRuleSet
"""
return orp.RewriteRuleSet(
[
no_op.mul_by_1_rule,
no_op.add_0_rule,
no_op.add_0_rule,
no_op.div_by_1_rule,
transpose_identity_rule,
transpose_transpose_rule,
]
)
130 changes: 130 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import unittest
from typing import Any

import numpy as np
import onnx
import onnx.reference

import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
from onnxscript import ir

FLOAT = onnx.TensorProto.FLOAT


class LlamaRuleSetsTest(unittest.TestCase):
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}
for i in model.graph.input:
shape = tuple(d + 2 for d in range(len(i.type.tensor_type.shape.dim)))
if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
feeds[i.name] = np.random.randn(*shape).astype(np.float32)
else:
raise AssertionError(f"Not implemented for input {i}")
return feeds

def _check_model(
self,
model: onnx.ModelProto,
optimized_model: onnx.ModelProto,
feeds: dict[str, Any] | None = None,
atol: float = 0.0,
rtol: float = 1e-7,
):
if not feeds:
feeds = self._get_random_inputs(model)
ref = onnx.reference.ReferenceEvaluator(model)
opt = onnx.reference.ReferenceEvaluator(optimized_model)
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
self.assertEqual(len(expected), len(got))
for a, b in zip(expected, got):
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)

@classmethod
def _identity_models(cls):
models = [
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
[
onnx.numpy_helper.from_array(
np.array([1], dtype=np.float32), name="one"
)
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

def test_llama_p0_rule_set_identity(self):
for model_proto in self._identity_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Identity"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)

@classmethod
def _transpose_transpose_models(cls):
models = [
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

def test_llama_p0_rule_set_transpose_transpose(self):
for model_proto in self._transpose_transpose_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Transpose"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)


if __name__ == "__main__":
unittest.main(verbosity=2)
3 changes: 2 additions & 1 deletion tests/common/onnx_script_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def run_converter_test(
onnx_case_model: Optional[onnx.ModelProto] = None,
*,
ir_version: int = 9,
rtol: Optional[float] = None,
):
# FIXME(justinchuby): Defaulting to ir_version 9 because ONNX Runtime supports
# up to IR version 9 as of 4/2/2024. We should have a better mechanism to
Expand Down Expand Up @@ -252,7 +253,7 @@ def run_converter_test(
raise AssertionError(f"Unable to load model\n{model}") from e
# input['input_2'] = None
actual = session.run(None, input)
np.testing.assert_allclose(actual, param.output, rtol=self.rtol)
np.testing.assert_allclose(actual, param.output, rtol=rtol or self.rtol)

def run_eager_test(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/functions/gemmgelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_gemmgelu(self):
onnx_script_test_case.FunctionTestParams(gemmgelu.gemmgelu, [a, w, b], [expected])
]
for case in cases:
self.run_converter_test(case)
self.run_converter_test(case, rtol=1e-6)
self.run_eager_test(case)


Expand Down

0 comments on commit 1b2ecf5

Please sign in to comment.