Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First series of P0 patterns to optimize llama #1490

Merged
merged 17 commits into from
May 29, 2024
95 changes: 95 additions & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations
Fixed Show fixed Hide fixed
xadupre marked this conversation as resolved.
Show resolved Hide resolved

from typing import Sequence

import onnxscript.ir as ir
import onnxscript.rewriter.no_op as no_op
import onnxscript.rewriter.pattern as orp

_op = orp.onnxop


def transpose_identity(x, perm):
return _op.Transpose(x, perm=perm)


def transpose_identity_check(x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(perm, ir.RefAttr):
return False

Check warning on line 18 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L18

Added line #L18 was not covered by tests
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False


def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)


def transpose_transpose(x, perm1, perm2):
return _op.Transpose(_op.Transpose(x, perm=perm1), perm=perm2)


def transpose_transpose_check(
x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False

Check warning on line 37 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L37

Added line #L37 was not covered by tests
return True


def _apply_transpose(perm: tuple[int, ...], on: Sequence[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
return res
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def _apply_transposes(
perms: Sequence[tuple[int, ...]], on: Sequence[int] | None = None
) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = _apply_transpose(p, on)
return on
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
xadupre marked this conversation as resolved.
Show resolved Hide resolved
first = list(range(len(perm1.value)))
last = _apply_transposes([perm1.value, perm2.value])
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)


transpose_identity_rule = orp.RewriteRule(
transpose_identity, transpose_identity_rewrite, transpose_identity_check
)
transpose_transpose_rule = orp.RewriteRule(
transpose_transpose, transpose_transpose_rewrite, transpose_transpose_check
)


def llama_p0_rule_set(verbose: int = 0) -> orp.RewriteRuleSet:
xadupre marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.

Args:
verbose: verbosity

Returns:
RewriteRuleSet
"""
return orp.RewriteRuleSet(
[
no_op.mul_by_1_rule,
no_op.add_0_rule,
no_op.add_0_rule,
no_op.div_by_1_rule,
transpose_identity_rule,
transpose_transpose_rule,
]
)
130 changes: 130 additions & 0 deletions onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import unittest
from typing import Any

import numpy as np
import onnx
import onnx.reference

import onnxscript.rewriter.llama_rule_sets as llama_rule_sets
from onnxscript import ir

FLOAT = onnx.TensorProto.FLOAT


class LlamaRuleSetsTest(unittest.TestCase):
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}
for i in model.graph.input:
shape = tuple(d + 2 for d in range(len(i.type.tensor_type.shape.dim)))
if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT:
feeds[i.name] = np.random.randn(*shape).astype(np.float32)
else:
raise AssertionError(f"Not implemented for input {i}")

Check warning on line 24 in onnxscript/rewriter/llama_rule_sets_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_test.py#L24

Added line #L24 was not covered by tests
return feeds

def _check_model(
self,
model: onnx.ModelProto,
optimized_model: onnx.ModelProto,
feeds: dict[str, Any] | None = None,
atol: float = 0.0,
rtol: float = 1e-7,
):
if not feeds:
feeds = self._get_random_inputs(model)
ref = onnx.reference.ReferenceEvaluator(model)
opt = onnx.reference.ReferenceEvaluator(optimized_model)
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
self.assertEqual(len(expected), len(got))
for a, b in zip(expected, got):
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)

@classmethod
def _identity_models(cls):
models = [
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
[
onnx.numpy_helper.from_array(
np.array([1], dtype=np.float32), name="one"
)
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

def test_llama_p0_rule_set_identity(self):
for model_proto in self._identity_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Identity"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)

@classmethod
def _transpose_transpose_models(cls):
models = [
onnx.helper.make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]),
onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
]
return models

def test_llama_p0_rule_set_transpose_transpose(self):
for model_proto in self._transpose_transpose_models():
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llama_rule_sets.llama_p0_rule_set()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Transpose"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 130 in onnxscript/rewriter/llama_rule_sets_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_test.py#L130

Added line #L130 was not covered by tests
Loading