Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First series of P0 patterns to optimize llama #1490

Merged
merged 17 commits into from
May 29, 2024
91 changes: 91 additions & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations
Fixed Show fixed Hide fixed
xadupre marked this conversation as resolved.
Show resolved Hide resolved

import onnxscript.ir as ir
import onnxscript.rewriter.no_op as no_op
import onnxscript.rewriter.pattern as orp
from onnxscript.rewriter import pattern
xadupre marked this conversation as resolved.
Show resolved Hide resolved

_op = pattern.onnxop


def transpose_identity(x, perm):
return _op.Transpose(x, perm=perm)


def transpose_identity_check(x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(perm, ir.RefAttr):
return False

Check warning on line 17 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L17

Added line #L17 was not covered by tests
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False

Check warning on line 21 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L20-L21

Added lines #L20 - L21 were not covered by tests


def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)

Check warning on line 25 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L25

Added line #L25 was not covered by tests


def transpose_transpose(x, perm1, perm2):
return _op.Transpose(_op.Transpose(x, perm=perm1), perm=perm2)


def transpose_transpose_check(
x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False
return True

Check warning on line 37 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L36-L37

Added lines #L36 - L37 were not covered by tests


def _apply_transpose(perm: tuple[int, ...], on: list[int]) -> list[int]:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
assert len(perm) == len(on), "length mismatch"

Check warning on line 41 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L41

Added line #L41 was not covered by tests
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
return res

Check warning on line 45 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L44-L45

Added lines #L44 - L45 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def _apply_transposes(perms: list[tuple[int, ...]], on: list[int] | None = None) -> list[int]:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
xadupre marked this conversation as resolved.
Show resolved Hide resolved
if on is None:
on = list(range(len(perms[0])))

Check warning on line 50 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L50

Added line #L50 was not covered by tests
for p in perms:
on = _apply_transpose(p, on)
return on

Check warning on line 53 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L52-L53

Added lines #L52 - L53 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
xadupre marked this conversation as resolved.
Show resolved Hide resolved
first = list(range(len(perm1.value)))
last = _apply_transposes([perm1.value, perm2.value])

Check warning on line 58 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L57-L58

Added lines #L57 - L58 were not covered by tests
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)

Check warning on line 61 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L60-L61

Added lines #L60 - L61 were not covered by tests


transpose_identity_rule = pattern.RewriteRule(
transpose_identity, transpose_identity_rewrite, transpose_identity_check
)
transpose_transpose_rule = pattern.RewriteRule(
transpose_transpose, transpose_transpose_rewrite, transpose_transpose_check
)


def llama_p0_rule_set(verbose: int = 0) -> orp.RewriteRuleSet:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.

Args:
verbose: verbosity
Returns:
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
RewriteRuleSet
"""
return orp.RewriteRuleSet(

Check warning on line 82 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L82

Added line #L82 was not covered by tests
[
no_op.mul_by_1_rule,
no_op.add_0_rule,
no_op.add_0_rule,
no_op.div_by_1_rule,
transpose_identity_rule,
transpose_transpose_rule,
]
)
128 changes: 128 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_tests.py
xadupre marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations
Fixed Show fixed Hide fixed

import unittest
Fixed Show fixed Hide fixed
from typing import Any

import numpy as np
import onnx
import onnx.reference

import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
from onnxscript import ir

FLOAT = onnx.TensorProto.FLOAT


class LlamaRuleSetsTest(unittest.TestCase):
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}

Check warning on line 18 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L18

Added line #L18 was not covered by tests
for i in model.graph.input:
shape = tuple(d + 2 for d in range(len(i.type.tensor_type.shape.dim)))
if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
feeds[i.name] = np.random.randn(*shape).astype(np.float32)

Check warning on line 22 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L22

Added line #L22 was not covered by tests
else:
raise AssertionError(f"Not implemented for input {i}")
return feeds

Check warning on line 25 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L24-L25

Added lines #L24 - L25 were not covered by tests

def _check_model(
self,
model: onnx.ModelProto,
optimized_model: onnx.ModelProto,
feeds: dict[str, Any] | None = None,
atol: float = 0.0,
rtol: float = 1e-7,
):
if not feeds:
feeds = self._get_random_inputs(model)
ref = onnx.reference.ReferenceEvaluator(model)
opt = onnx.reference.ReferenceEvaluator(optimized_model)
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
self.assertEqual(len(expected), len(got))

Check warning on line 41 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L36-L41

Added lines #L36 - L41 were not covered by tests
for a, b in zip(expected, got):
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)

Check warning on line 43 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L43

Added line #L43 was not covered by tests

def _identity_models(self):
models = [

Check warning on line 46 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L46

Added line #L46 was not covered by tests
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
[
onnx.numpy_helper.from_array(
np.array([1], dtype=np.float32), name="one"
)
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

Check warning on line 87 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L87

Added line #L87 was not covered by tests

def test_llama_p0_rule_set_identity(self):
for model in self._identity_models():
ir_model = ir.serde.deserialize_model(model)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Check warning on line 94 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L91-L94

Added lines #L91 - L94 were not covered by tests
xadupre marked this conversation as resolved.
Show resolved Hide resolved

self.assertEqual(["Identity"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model, rewritten_model)

Check warning on line 97 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L97

Added line #L97 was not covered by tests

def _transpose_transpose_models(self):
xadupre marked this conversation as resolved.
Show resolved Hide resolved
models = [

Check warning on line 100 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L100

Added line #L100 was not covered by tests
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

Check warning on line 114 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L114

Added line #L114 was not covered by tests

def test_llama_p0_rule_set_transpose_transpose(self):
for model in self._transpose_transpose_models():
xadupre marked this conversation as resolved.
Show resolved Hide resolved
ir_model = ir.serde.deserialize_model(model)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Check warning on line 121 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L118-L121

Added lines #L118 - L121 were not covered by tests

self.assertEqual(["Transpose"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model, rewritten_model)

Check warning on line 124 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L124

Added line #L124 was not covered by tests


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 128 in onnxscript/rewriter/llama_rule_sets_tests.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_tests.py#L128

Added line #L128 was not covered by tests
Loading