Skip to content

Commit

Permalink
complete patterns
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jun 19, 2024
1 parent a52695d commit d926354
Show file tree
Hide file tree
Showing 2 changed files with 935 additions and 0 deletions.
333 changes: 333 additions & 0 deletions onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,196 @@
op = orp.onnxop


class _CombineBinary(orp.RewriteRuleAsClass):
@classmethod
def _same_shape(
cls, sh1: tuple[int, ...], sh2: tuple[int, ...], broadcast: bool = False
) -> bool:
if broadcast:
if len(sh1) != len(sh2):
rk = max(len(sh1), len(sh2))
sh1 = (1,) * (rk - len(sh1)) + sh1
sh2 = (1,) * (rk - len(sh2)) + sh2

Check warning on line 22 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L20-L22

Added lines #L20 - L22 were not covered by tests
allow_one1 = True
allow_one2 = True
for a, b in zip(sh1, sh2):
if a == b:
if a != 1:
allow_one1 = False
if b != 1:
allow_one2 = False
continue
if a == 1 and allow_one1:
allow_one2 = False
continue

Check warning on line 34 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L33-L34

Added lines #L33 - L34 were not covered by tests
if b == 1 and allow_one2:
allow_one1 = False
continue
return False

Check warning on line 38 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L36-L38

Added lines #L36 - L38 were not covered by tests
return True
return sh1 == sh2

Check warning on line 40 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L40

Added line #L40 was not covered by tests

@classmethod
def check(cls, context, x, y, z) -> bool:
if x.shape is None or y.shape is None or z.shape is None:
return False

Check warning on line 45 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L45

Added line #L45 was not covered by tests
return cls._same_shape(x.shape, y.shape, broadcast=True) and cls._same_shape(
y.shape, z.shape, broadcast=True
)


class CombinedAddAdd1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(x, op.Add(y, z))

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddAdd(x, y, z, domain="ai.onnx.contrib")


class CombinedAddAdd2(CombinedAddAdd1):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(op.Add(x, y), z)


class CombinedMulMul1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(x, op.Mul(y, z))

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulMul(x, y, z, domain="ai.onnx.contrib")


class CombinedMulMul2(CombinedMulMul1):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(op.Mul(x, y), z)


class CombinedAddMul1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(op.Add(x, y), z)

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddMul(x, y, z, domain="ai.onnx.contrib")


class CombinedAddMul2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(x, op.Add(y, z))

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddMul(y, z, x, domain="ai.onnx.contrib")


class CombinedMulAdd1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(op.Mul(x, y), z)

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulAdd(x, y, z, domain="ai.onnx.contrib")


class CombinedMulAdd2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(x, op.Mul(y, z))

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulAdd(y, z, x, domain="ai.onnx.contrib")


class AddSharedInput1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(x, y), op.Add(x, z)

Check warning on line 126 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L126

Added line #L126 was not covered by tests

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddSharedInput(x, y, z, domain="ai.onnx.contrib")

Check warning on line 130 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L130

Added line #L130 was not covered by tests


class AddSharedInput2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Add(y, x), op.Add(x, z)

Check warning on line 136 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L136

Added line #L136 was not covered by tests

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddSharedInput(x, y, z, domain="ai.onnx.contrib")

Check warning on line 140 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L140

Added line #L140 was not covered by tests


class MulSharedInput1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(x, y), op.Mul(x, z)

Check warning on line 146 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L146

Added line #L146 was not covered by tests

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulSharedInput(x, y, z, domain="ai.onnx.contrib")

Check warning on line 150 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L150

Added line #L150 was not covered by tests


class MulSharedInput2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(y, x), op.Mul(x, z)

Check warning on line 156 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L156

Added line #L156 was not covered by tests

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulSharedInput(x, y, z, domain="ai.onnx.contrib")

Check warning on line 160 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L160

Added line #L160 was not covered by tests


class CombinedSubMul1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(op.Sub(x, y), z)

@classmethod
def rewrite(cls, op, x, y, z):
return op.SubMul(x, y, z, negative=0, domain="ai.onnx.contrib")


class CombinedSubMul2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Mul(op.Sub(y, x), z)

@classmethod
def rewrite(cls, op, x, y, z):
return op.SubMul(x, y, z, negative=1, domain="ai.onnx.contrib")

Check warning on line 180 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L180

Added line #L180 was not covered by tests


class CombinedMulSub1(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Sub(op.Mul(x, y), z)

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulSub(x, y, z, negative=0, domain="ai.onnx.contrib")


class CombinedMulSub2(_CombineBinary):
@classmethod
def pattern(cls, op, x, y, z):
return op.Sub(z, op.Mul(x, y))

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulSub(x, y, z, negative=1, domain="ai.onnx.contrib")


class MaskedScatterNDOfShape(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, shape, indices, updates, tensor, masked, zero, reduction):
Expand Down Expand Up @@ -42,6 +232,97 @@ def rewrite(cls, op, shape, indices, updates, tensor, masked, zero, reduction):
)


class MulSigmoid(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x):
return op.Mul(x, op.Sigmoid(x))

@classmethod
def check(cls, context, x) -> bool:
return True

@classmethod
def rewrite(cls, op, x):
return op.MulSigmoid(x, domain="ai.onnx.contrib")


class NegXPlus1(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, cst, x):
return op.Sub(cst, x)

@classmethod
def check(cls, context, cst, x) -> bool:
if cst.const_value is None:
return False
if cst.shape != (1,):
return False

Check warning on line 259 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L259

Added line #L259 was not covered by tests
value = float(cst.const_value.numpy().reshape((1,))[0])
return value == 1

@classmethod
def rewrite(cls, op, cst, x):
return op.NegXplus1(x, domain="ai.onnx.contrib")


class ReplaceZero1(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y):
return op.Where(op.Cast(x, to=onnx.TensorProto.BOOL), x, y)

@classmethod
def check(cls, context, x, y) -> bool:
return True

Check warning on line 275 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L275

Added line #L275 was not covered by tests

@classmethod
def rewrite(cls, op, x, y):
return op.ReplaceZero(x, y, equal=1, domain="ai.onnx.contrib")

Check warning on line 279 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L279

Added line #L279 was not covered by tests


class ReplaceZero2(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y):
return op.Where(op.Cast(x, to=onnx.TensorProto.BOOL), y, x)

@classmethod
def check(cls, context, x, y) -> bool:
return True

@classmethod
def rewrite(cls, op, x, y):
return op.ReplaceZero(x, y, equal=0, domain="ai.onnx.contrib")


class Rotary1(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y):
x1, x2 = op.Split(x, num_outputs=2, axis=-1, outputs=2)
return op.Concat(op.Neg(x2), x1, axis=-1)

@classmethod
def check(cls, context, x) -> bool:
return True

Check warning on line 304 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L304

Added line #L304 was not covered by tests

@classmethod
def rewrite(cls, op, x):
return op.Rotary(x, side="left", domain="ai.onnx.contrib")

Check warning on line 308 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L308

Added line #L308 was not covered by tests


class Rotary2(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y):
x1, x2 = op.Split(x, num_outputs=2, axis=-1, outputs=2)
return op.Concat(x2, op.Neg(x1), axis=-1)

@classmethod
def check(cls, context, x) -> bool:
return True

Check warning on line 319 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L319

Added line #L319 was not covered by tests

@classmethod
def rewrite(cls, op, x):
return op.Rotary(x, side="right", domain="ai.onnx.contrib")

Check warning on line 323 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L323

Added line #L323 was not covered by tests


class TransposeCast1(orp.RewriteRuleAsClass):
"""Replaces ``Cast + Transpose(. perm=[1, 0])`` by ``TransposeCast2D``."""

Expand Down Expand Up @@ -90,6 +371,34 @@ def rewrite(cls, op, x, perm, to):
return op.Transpose2DCastFP16(x, domain="ai.onnx.contrib")


class MulAddTranspose(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y, z):
return op.Transpose(op.MulAdd(x, y, z, domain="ai.onnx.contrib"), perm=[0, 2, 1, 3])

@classmethod
def check(cls, context, x, y, z) -> bool:
return True

@classmethod
def rewrite(cls, op, x, y, z):
return op.MulAdd(x, y, z, transposeMiddle=1, domain="ai.onnx.contrib")


class AddMulTranspose(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, y, z):
return op.Transpose(op.AddMul(x, y, z, domain="ai.onnx.contrib"), perm=[0, 2, 1, 3])

@classmethod
def check(cls, context, x, y, z) -> bool:
return True

@classmethod
def rewrite(cls, op, x, y, z):
return op.AddMul(x, y, z, transposeMiddle=1, domain="ai.onnx.contrib")


def llm_rule_set_cuda() -> orp.RewriteRuleSet:
"""Returns a set of rules fusing nodes into custom kernels.
Expand All @@ -98,7 +407,31 @@ def llm_rule_set_cuda() -> orp.RewriteRuleSet:
"""
return orp.RewriteRuleSet(
[
# orp.make_rewrite_rule_from_class(AddSharedInput1, True),
# orp.make_rewrite_rule_from_class(AddSharedInput2, True),
# orp.make_rewrite_rule_from_class(MulSharedInput1, True),
# orp.make_rewrite_rule_from_class(MulSharedInput2, True),
orp.make_rewrite_rule_from_class(AddMulTranspose),
orp.make_rewrite_rule_from_class(CombinedAddAdd1),
orp.make_rewrite_rule_from_class(CombinedAddAdd2),
orp.make_rewrite_rule_from_class(CombinedAddMul1),
orp.make_rewrite_rule_from_class(CombinedAddMul2),
orp.make_rewrite_rule_from_class(CombinedMulAdd1),
orp.make_rewrite_rule_from_class(CombinedMulAdd2),
orp.make_rewrite_rule_from_class(CombinedMulMul1),
orp.make_rewrite_rule_from_class(CombinedMulMul2),
orp.make_rewrite_rule_from_class(CombinedMulSub1),
orp.make_rewrite_rule_from_class(CombinedMulSub2),
orp.make_rewrite_rule_from_class(CombinedSubMul1),
orp.make_rewrite_rule_from_class(CombinedSubMul2),
orp.make_rewrite_rule_from_class(MaskedScatterNDOfShape),
orp.make_rewrite_rule_from_class(MulAddTranspose),
orp.make_rewrite_rule_from_class(MulSigmoid),
orp.make_rewrite_rule_from_class(NegXPlus1),
orp.make_rewrite_rule_from_class(ReplaceZero1),
orp.make_rewrite_rule_from_class(ReplaceZero2),
orp.make_rewrite_rule_from_class(Rotary1),
orp.make_rewrite_rule_from_class(Rotary2),
orp.make_rewrite_rule_from_class(TransposeCast1),
orp.make_rewrite_rule_from_class(TransposeCast2),
]
Expand Down
Loading

0 comments on commit d926354

Please sign in to comment.