diff --git a/onnxscript/rewriter/custom_ops/llm_rule_sets.py b/onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py similarity index 60% rename from onnxscript/rewriter/custom_ops/llm_rule_sets.py rename to onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py index 964137709..a9e3bd056 100644 --- a/onnxscript/rewriter/custom_ops/llm_rule_sets.py +++ b/onnxscript/rewriter/custom_ops/llm_rule_sets_cuda.py @@ -10,6 +10,38 @@ op = orp.onnxop +class MaskedScatterNDOfShape(orp.RewriteRuleAsClass): + @classmethod + def pattern(cls, op, shape, indices, updates, tensor, masked, zero, reduction): + cst = op.ConstantOfShape(shape, value=tensor) + masked_indices = op.Equal(indices, masked) + masked_updates = op.Where(masked_indices, zero, updates) + return op.ScatterND(cst, indices, masked_updates, reduction=reduction) + + @classmethod + def check(cls, context, shape, indices, updates, tensor, masked, zero, reduction) -> bool: + if reduction.value != "add": + return False + if tensor.value.numpy().reshape((1,)).tolist() != [0]: + return False + if zero.const_value is None or zero.const_value.numpy().reshape((1,)).tolist() != [0]: + return False + if masked.const_value is None or masked.const_value.numpy().size != 1: + return False + return True + + @classmethod + def rewrite(cls, op, shape, indices, updates, tensor, masked, zero, reduction): + return op.MaskedScatterNDOfShape( + shape, + indices, + updates, + maskedValue=int(masked.const_value.numpy().reshape((1,))[0]), + reduction=reduction.value, + domain="ai.onnx.contrib", + ) + + class TransposeCast1(orp.RewriteRuleAsClass): """Replaces ``Cast + Transpose(. perm=[1, 0])`` by ``TransposeCast2D``.""" @@ -58,7 +90,7 @@ def rewrite(cls, op, x, perm, to): return op.Transpose2DCastFP16(x, domain="ai.onnx.contrib") -def llm_rule_set() -> orp.RewriteRuleSet: +def llm_rule_set_cuda() -> orp.RewriteRuleSet: """Returns a set of rules fusing nodes into custom kernels. Returns: @@ -66,6 +98,7 @@ def llm_rule_set() -> orp.RewriteRuleSet: """ return orp.RewriteRuleSet( [ + orp.make_rewrite_rule_from_class(MaskedScatterNDOfShape), orp.make_rewrite_rule_from_class(TransposeCast1), orp.make_rewrite_rule_from_class(TransposeCast2), ] diff --git a/onnxscript/rewriter/custom_ops/llm_rule_sets_test.py b/onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py similarity index 51% rename from onnxscript/rewriter/custom_ops/llm_rule_sets_test.py rename to onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py index dc00f60cd..db49c9ccc 100644 --- a/onnxscript/rewriter/custom_ops/llm_rule_sets_test.py +++ b/onnxscript/rewriter/custom_ops/llm_rule_sets_cuda_test.py @@ -8,17 +8,38 @@ import numpy as np import onnx import onnx.helper as oh +import onnx.numpy_helper as onh import onnx.reference import onnx.reference.op_run as op_run +import onnx.reference.ops.op_scatternd as op_scat import parameterized -import onnxscript.rewriter.custom_ops.llm_rule_sets as llm_rule_sets +import onnxscript.rewriter.custom_ops.llm_rule_sets_cuda as llm_rule_sets_cuda from onnxscript import ir TFLOAT = onnx.TensorProto.FLOAT TFLOAT16 = onnx.TensorProto.FLOAT16 +class MaskedScatterNDOfShape(op_run.OpRun): + op_domain = "ai.onnx.contrib" + + def _run(self, shape, indices, updates, reduction=None, maskedValue=None): + data = np.zeros(shape, dtype=updates.dtype) + new_updates = np.where(indices == maskedValue, 0, updates) + y = op_scat._scatter_nd_impl(data, indices, new_updates, reduction=reduction) + return (y,) + + +class ScatterNDOfShape(op_run.OpRun): + op_domain = "ai.onnx.contrib" + + def _run(self, shape, indices, updates, reduction=None, strategy=None): + data = np.zeros(shape, dtype=updates.dtype) + y = op_scat._scatter_nd_impl(data, indices, updates, reduction=reduction) + return (y,) + + class Transpose2DCastFP16(op_run.OpRun): op_domain = "ai.onnx.contrib" @@ -48,6 +69,12 @@ def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: feeds[i.name] = np.random.randn(*shape).astype(np.float32) elif i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT16: feeds[i.name] = np.random.randn(*shape).astype(np.float16) + elif i.type.tensor_type.elem_type == onnx.TensorProto.INT64: + if tuple(shape) == (2,): + feeds[i.name] = np.array([7, 5], dtype=np.int64) + else: + feeds[i.name] = np.zeros(tuple(shape)).astype(np.int64) + feeds[i.name][::2] = 1 else: raise AssertionError(f"Not implemented for input {i}") return feeds @@ -64,7 +91,13 @@ def _check_model( feeds = self._get_random_inputs(model) ref = onnx.reference.ReferenceEvaluator(model) opt = onnx.reference.ReferenceEvaluator( - optimized_model, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32] + optimized_model, + new_ops=[ + ScatterNDOfShape, + MaskedScatterNDOfShape, + Transpose2DCastFP16, + Transpose2DCastFP32, + ], ) expected = ref.run(None, feeds) got = opt.run(None, feeds) @@ -110,7 +143,7 @@ def _transpose_cast(cls, in_type, cast_before): def test_llm_transpose_cast(self, in_type, cast_before): model_proto = self._transpose_cast(in_type, cast_before) ir_model = ir.serde.deserialize_model(model_proto) - rule_set = llm_rule_sets.llm_rule_set() + rule_set = llm_rule_sets_cuda.llm_rule_set_cuda() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) @@ -120,6 +153,61 @@ def test_llm_transpose_cast(self, in_type, cast_before): ) self._check_model(model_proto, rewritten_model) + @classmethod + def _masked_scatternd_of_shape(cls, reduction, itype): + dtype = np.float32 if itype == onnx.TensorProto.FLOAT else np.float16 + + return oh.make_model( + oh.make_graph( + [ + oh.make_node( + "ConstantOfShape", + ["shape"], + ["data"], + value=onh.from_array(np.array([0], dtype=np.float32)), + ), + oh.make_node("Equal", ["indices", "mone"], ["masked_indices"]), + oh.make_node( + "Where", + ["masked_indices", "zero", "updates"], + ["masked_updates"], + ), + oh.make_node( + "ScatterND", + inputs=["data", "indices", "masked_updates"], + outputs=["y"], + reduction=reduction, + ), + ], + "nd", + [ + oh.make_tensor_value_info("shape", onnx.TensorProto.INT64, [2]), + oh.make_tensor_value_info("indices", onnx.TensorProto.INT64, [5, 3, 1]), + oh.make_tensor_value_info("updates", itype, [5, 3, 5]), + ], + [oh.make_tensor_value_info("y", itype, [None, None])], + [ + onh.from_array(np.array([-1], dtype=np.int64), name="mone"), + onh.from_array(np.array([0], dtype=dtype), name="zero"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + + @parameterized.parameterized.expand([("add", TFLOAT), ("add", TFLOAT16)]) + def test_llm_masked_scatter(self, reduction, itype): + model_proto = self._masked_scatternd_of_shape(reduction, itype) + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = llm_rule_sets_cuda.llm_rule_set_cuda() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + + self.assertEqual( + ["MaskedScatterNDOfShape"], [n.op_type for n in rewritten_model.graph.node] + ) + self._check_model(model_proto, rewritten_model, atol=1e-2) + if __name__ == "__main__": unittest.main(verbosity=2)