Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First series of P0 patterns to optimize llama #1490

Merged
merged 17 commits into from
May 29, 2024
93 changes: 93 additions & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations
Fixed Show fixed Hide fixed
xadupre marked this conversation as resolved.
Show resolved Hide resolved

import onnxscript.ir as ir
import onnxscript.rewriter.no_op as no_op
import onnxscript.rewriter.pattern as orp
from onnxscript.rewriter import pattern
xadupre marked this conversation as resolved.
Show resolved Hide resolved

_op = pattern.onnxop


def transpose_identity(x, perm):
return _op.Transpose(x, perm=perm)


def transpose_identity_check(x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(perm, ir.RefAttr):
return False

Check warning on line 17 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L17

Added line #L17 was not covered by tests
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False

Check warning on line 21 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L20-L21

Added lines #L20 - L21 were not covered by tests


def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)

Check warning on line 25 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L25

Added line #L25 was not covered by tests


def transpose_transpose(x, perm1, perm2):
return _op.Transpose(_op.Transpose(x, perm=perm1), perm=perm2)


def transpose_transpose_check(
x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False
return True

Check warning on line 37 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L36-L37

Added lines #L36 - L37 were not covered by tests


def _apply_transpose(perm: tuple[int, ...], on: list[int | str]) -> list[int | str]:
assert len(perm) == len(on), "length mismatch"

Check warning on line 41 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L41

Added line #L41 was not covered by tests
res = [None for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
return res

Check warning on line 45 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L44-L45

Added lines #L44 - L45 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def _apply_transposes(
perms: list[tuple[int, ...]], on: list[int | str] | None = None
) -> list[int | str]:
if on is None:
on = list(range(len(perms[0])))

Check warning on line 52 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L52

Added line #L52 was not covered by tests
for p in perms:
on = _apply_transpose(p, on)
return on

Check warning on line 55 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L54-L55

Added lines #L54 - L55 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
xadupre marked this conversation as resolved.
Show resolved Hide resolved
first = list(range(len(perm1.value)))
last = _apply_transposes([perm1.value, perm2.value])

Check warning on line 60 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L59-L60

Added lines #L59 - L60 were not covered by tests
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)

Check warning on line 63 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L62-L63

Added lines #L62 - L63 were not covered by tests


transpose_identity_rule = pattern.RewriteRule(
transpose_identity, transpose_identity_rewrite, transpose_identity_check
)
transpose_transpose_rule = pattern.RewriteRule(
transpose_transpose, transpose_transpose_rewrite, transpose_transpose_check
)


def llama_p0_rule_set(verbose: int = 0) -> orp.RewriteRuleSet:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.

Args:
verbose: verbosity
Returns:
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
RewriteRuleSet
"""
return orp.RewriteRuleSet(

Check warning on line 84 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L84

Added line #L84 was not covered by tests
[
no_op.mul_by_1_rule,
no_op.add_0_rule,
no_op.add_0_rule,
no_op.div_by_1_rule,
transpose_identity_rule,
transpose_transpose_rule,
]
)
131 changes: 131 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_tests.py
xadupre marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations
Fixed Show fixed Hide fixed

import unittest
Fixed Show fixed Hide fixed
from typing import Any

import numpy as np
import onnx
import onnx.reference

import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
from onnxscript import ir

FLOAT = onnx.TensorProto.FLOAT


class LlamaRuleSetsTest(unittest.TestCase):

def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}

Check warning on line 19 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L19

Added line #L19 was not covered by tests
for i in model.graph.input:
shape = tuple(d + 2 for d in range(len(i.type.tensor_type.shape.dim)))
if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
feeds[i.name] = np.random.randn(*shape).astype(np.float32)

Check warning on line 23 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L23

Added line #L23 was not covered by tests
else:
raise AssertionError(f"Not implemented for input {i}")
return feeds

Check warning on line 26 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L25-L26

Added lines #L25 - L26 were not covered by tests

def _check_model(
self,
model: onnx.ModelProto,
optimized_model: onnx.ModelProto,
feeds: dict[str, Any] | None = None,
atol: float = 0.0,
rtol: float = 1e-7,
):
if not feeds:
feeds = self._get_random_inputs(model)
ref = onnx.reference.ReferenceEvaluator(model)
opt = onnx.reference.ReferenceEvaluator(optimized_model)
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
self.assertEqual(len(expected), len(got))

Check warning on line 42 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L37-L42

Added lines #L37 - L42 were not covered by tests
for a, b in zip(expected, got):
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)

Check warning on line 44 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L44

Added line #L44 was not covered by tests

def _identity_models(self):
models = [

Check warning on line 47 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L47

Added line #L47 was not covered by tests
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
[
onnx.numpy_helper.from_array(
np.array([1], dtype=np.float32), name="one"
)
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

Check warning on line 88 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L88

Added line #L88 was not covered by tests

def test_llama_p0_rule_set_identity(self):
for model in self._identity_models():

ir_model = ir.serde.deserialize_model(model)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Check warning on line 96 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L93-L96

Added lines #L93 - L96 were not covered by tests

self.assertEqual(["Identity"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model, rewritten_model)

Check warning on line 99 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L99

Added line #L99 was not covered by tests

def _transpose_transpose_models(self):
xadupre marked this conversation as resolved.
Show resolved Hide resolved
models = [

Check warning on line 102 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L102

Added line #L102 was not covered by tests
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

Check warning on line 116 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L116

Added line #L116 was not covered by tests

def test_llama_p0_rule_set_transpose_transpose(self):
for model in self._transpose_transpose_models():
xadupre marked this conversation as resolved.
Show resolved Hide resolved

ir_model = ir.serde.deserialize_model(model)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Check warning on line 124 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L121-L124

Added lines #L121 - L124 were not covered by tests

self.assertEqual(["Transpose"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model, rewritten_model)

Check warning on line 127 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L127

Added line #L127 was not covered by tests


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 131 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L131

Added line #L131 was not covered by tests
Loading