Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jun 19, 2024
1 parent 9b5617e commit a52695d
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,38 @@
op = orp.onnxop


class MaskedScatterNDOfShape(orp.RewriteRuleAsClass):
@classmethod
def pattern(cls, op, shape, indices, updates, tensor, masked, zero, reduction):
cst = op.ConstantOfShape(shape, value=tensor)
masked_indices = op.Equal(indices, masked)
masked_updates = op.Where(masked_indices, zero, updates)
return op.ScatterND(cst, indices, masked_updates, reduction=reduction)

@classmethod
def check(cls, context, shape, indices, updates, tensor, masked, zero, reduction) -> bool:
if reduction.value != "add":
return False

Check warning on line 24 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L24

Added line #L24 was not covered by tests
if tensor.value.numpy().reshape((1,)).tolist() != [0]:
return False

Check warning on line 26 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L26

Added line #L26 was not covered by tests
if zero.const_value is None or zero.const_value.numpy().reshape((1,)).tolist() != [0]:
return False

Check warning on line 28 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L28

Added line #L28 was not covered by tests
if masked.const_value is None or masked.const_value.numpy().size != 1:
return False

Check warning on line 30 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py#L30

Added line #L30 was not covered by tests
return True

@classmethod
def rewrite(cls, op, shape, indices, updates, tensor, masked, zero, reduction):
return op.MaskedScatterNDOfShape(
shape,
indices,
updates,
maskedValue=int(masked.const_value.numpy().reshape((1,))[0]),
reduction=reduction.value,
domain="ai.onnx.contrib",
)


class TransposeCast1(orp.RewriteRuleAsClass):
"""Replaces ``Cast + Transpose(. perm=[1, 0])`` by ``TransposeCast2D``."""

Expand Down Expand Up @@ -58,14 +90,15 @@ def rewrite(cls, op, x, perm, to):
return op.Transpose2DCastFP16(x, domain="ai.onnx.contrib")


def llm_rule_set() -> orp.RewriteRuleSet:
def llm_rule_set_cuda() -> orp.RewriteRuleSet:
"""Returns a set of rules fusing nodes into custom kernels.
Returns:
RewriteRuleSet
"""
return orp.RewriteRuleSet(
[
orp.make_rewrite_rule_from_class(MaskedScatterNDOfShape),
orp.make_rewrite_rule_from_class(TransposeCast1),
orp.make_rewrite_rule_from_class(TransposeCast2),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,38 @@
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import onnx.reference
import onnx.reference.op_run as op_run
import onnx.reference.ops.op_scatternd as op_scat
import parameterized

import onnxscript.rewriter.custom_ops.llm_rule_sets as llm_rule_sets
import onnxscript.rewriter.custom_ops.llm_rule_sets_cuda as llm_rule_sets_cuda
from onnxscript import ir

TFLOAT = onnx.TensorProto.FLOAT
TFLOAT16 = onnx.TensorProto.FLOAT16


class MaskedScatterNDOfShape(op_run.OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, shape, indices, updates, reduction=None, maskedValue=None):
data = np.zeros(shape, dtype=updates.dtype)
new_updates = np.where(indices == maskedValue, 0, updates)
y = op_scat._scatter_nd_impl(data, indices, new_updates, reduction=reduction)
return (y,)


class ScatterNDOfShape(op_run.OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, shape, indices, updates, reduction=None, strategy=None):
data = np.zeros(shape, dtype=updates.dtype)
y = op_scat._scatter_nd_impl(data, indices, updates, reduction=reduction)
return (y,)

Check warning on line 40 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py#L38-L40

Added lines #L38 - L40 were not covered by tests


class Transpose2DCastFP16(op_run.OpRun):
op_domain = "ai.onnx.contrib"

Expand Down Expand Up @@ -48,6 +69,12 @@ def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
feeds[i.name] = np.random.randn(*shape).astype(np.float32)
elif i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT16:
feeds[i.name] = np.random.randn(*shape).astype(np.float16)
elif i.type.tensor_type.elem_type == onnx.TensorProto.INT64:
if tuple(shape) == (2,):
feeds[i.name] = np.array([7, 5], dtype=np.int64)
else:
feeds[i.name] = np.zeros(tuple(shape)).astype(np.int64)
feeds[i.name][::2] = 1
else:
raise AssertionError(f"Not implemented for input {i}")

Check warning on line 79 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py#L79

Added line #L79 was not covered by tests
return feeds
Expand All @@ -64,7 +91,13 @@ def _check_model(
feeds = self._get_random_inputs(model)
ref = onnx.reference.ReferenceEvaluator(model)
opt = onnx.reference.ReferenceEvaluator(
optimized_model, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32]
optimized_model,
new_ops=[
ScatterNDOfShape,
MaskedScatterNDOfShape,
Transpose2DCastFP16,
Transpose2DCastFP32,
],
)
expected = ref.run(None, feeds)
got = opt.run(None, feeds)
Expand Down Expand Up @@ -110,7 +143,7 @@ def _transpose_cast(cls, in_type, cast_before):
def test_llm_transpose_cast(self, in_type, cast_before):
model_proto = self._transpose_cast(in_type, cast_before)
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llm_rule_sets.llm_rule_set()
rule_set = llm_rule_sets_cuda.llm_rule_set_cuda()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Expand All @@ -120,6 +153,61 @@ def test_llm_transpose_cast(self, in_type, cast_before):
)
self._check_model(model_proto, rewritten_model)

@classmethod
def _masked_scatternd_of_shape(cls, reduction, itype):
dtype = np.float32 if itype == onnx.TensorProto.FLOAT else np.float16

return oh.make_model(
oh.make_graph(
[
oh.make_node(
"ConstantOfShape",
["shape"],
["data"],
value=onh.from_array(np.array([0], dtype=np.float32)),
),
oh.make_node("Equal", ["indices", "mone"], ["masked_indices"]),
oh.make_node(
"Where",
["masked_indices", "zero", "updates"],
["masked_updates"],
),
oh.make_node(
"ScatterND",
inputs=["data", "indices", "masked_updates"],
outputs=["y"],
reduction=reduction,
),
],
"nd",
[
oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, [2]),
oh.make_tensor_value_info("indices", onnx.TensorProto.INT64, [5, 3, 1]),
oh.make_tensor_value_info("updates", itype, [5, 3, 5]),
],
[oh.make_tensor_value_info("y", itype, [None, None])],
[
onh.from_array(np.array([-1], dtype=np.int64), name="mone"),
onh.from_array(np.array([0], dtype=dtype), name="zero"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)

@parameterized.parameterized.expand([("add", TFLOAT), ("add", TFLOAT16)])
def test_llm_masked_scatter(self, reduction, itype):
model_proto = self._masked_scatternd_of_shape(reduction, itype)
ir_model = ir.serde.deserialize_model(model_proto)
rule_set = llm_rule_sets_cuda.llm_rule_set_cuda()
rule_set.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(
["MaskedScatterNDOfShape"], [n.op_type for n in rewritten_model.graph.node]
)
self._check_model(model_proto, rewritten_model, atol=1e-2)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 213 in onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py#L213

Added line #L213 was not covered by tests

0 comments on commit a52695d

Please sign in to comment.