From 1ee4ee1bc926bed94570720432c0e68db7d57f36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 17:00:23 -0700 Subject: [PATCH] Implement rand* ops | feat(torchilb) (#1035) Implement rand* ops. I split all ops that can take a dtype into two overloads to avoid trace_only logic. Replaces https://github.com/microsoft/onnxscript/pull/875 --- .../function_libs/torch_lib/ops/core.py | 129 ++++++++-- .../function_libs/torch_lib/extra_opinfo.py | 238 ++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 37 ++- 3 files changed, 372 insertions(+), 32 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4317ee257..7513cf7bc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6232,47 +6232,138 @@ def aten_rad2deg(self: TensorType) -> TensorType: @torch_op("aten::rand") -def aten_rand(size: Sequence[int], dtype: int = 1) -> TReal: +def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal: """rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.RandomUniform(shape=size, dtype=dtype) + shaper = op.ConstantOfShape(size) + return op.RandomUniformLike(shaper, dtype=dtype) -def aten_rand_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +@torch_op("aten::rand_like") +def aten_rand_like(self: TFloat) -> TFloat: """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - raise NotImplementedError() + return op.RandomUniformLike(self) + +@torch_op("aten::rand_like") +def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType: + """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" -def aten_randint(high: int, size: INT64) -> TensorType: - """randint(int high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + return op.RandomUniformLike(self, dtype=dtype) - raise NotImplementedError() +@torch_op("aten::randint") +def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType: + """randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" -def aten_randint_like( - self: TensorType, high: int, memory_format: Optional[str] = None + shaper = op.ConstantOfShape(size) + rand = op.RandomUniformLike(shaper) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint.low") +def aten_randint_low( + low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype ) -> TensorType: - """randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + """randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + shaper = op.ConstantOfShape(size) + rand = op.RandomUniformLike(shaper) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint_like") +def aten_randint_like(self: TensorType, high: INT64) -> IntType: + """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.CastLike(rand_int, self) + + +@torch_op("aten::randint_like") +def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType: + """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint_like.low_dtype") +def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType: + """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + + This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None. + """ + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.CastLike(rand_int, self) + + +@torch_op("aten::randint_like.low_dtype") +def aten_randint_like_low_dtype_dtype( + self: TensorType, low: INT64, high: INT64, dtype: int +) -> TensorType: + """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.Cast(rand_int, to=dtype) @torch_op("aten::randn") -def aten_randn( - size: Sequence[int], - dtype: int = 1, - requires_grad: bool = False, # pylint: disable=unused-argument -) -> TReal: +def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal: """randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.RandomNormal(dtype=dtype, shape=size) + shaper = op.ConstantOfShape(size) + return op.RandomNormalLike(shaper, dtype=dtype) -def aten_randn_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +@torch_op("aten::randn_like") +def aten_randn_like(self: TFloat) -> TFloat: """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - raise NotImplementedError() + return op.RandomNormalLike(self) + + +@torch_op("aten::randn_like") +def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType: + """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + return op.RandomNormalLike(self, dtype=dtype) def aten_randperm(n: int) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 2228a079a..c274d95f9 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -516,6 +516,156 @@ def sample_inputs_native_dropout( yield opinfo_core.SampleInput(make_arg(case), p=p, train=training) +def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del device # Unused + del requires_grad # Unused + del kwargs # Unused + + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(shape, kwargs=dict(dtype=dtype)) + + +def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape)) + + +def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=torch.float32, + requires_grad=requires_grad, + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) + + +def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + del self # Unused + + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {}), + ] + for shape, kwargs in inputs: + t = torch_testing.make_tensor( + shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad + ) + yield opinfo_core.SampleInput(t, **kwargs) + + +def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + del self # Unused + + inputs = [ + ((S,), {"dtype": dtype}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), {"dtype": torch.double}), + ] + for shape, kwargs in inputs: + t = torch_testing.make_tensor( + shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad + ) + yield opinfo_core.SampleInput(t, **kwargs) + + +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput( + low, high, sample.input.shape, *sample.args, **sample.kwargs + ) + + +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + del op # Unused + del device # Unused + del requires_grad # Unused + del kwargs # Unused + + shapes = ((M,), (S, S)) + + for shape in shapes: + yield opinfo_core.SampleInput(input=shape, kwargs=dict(dtype=dtype)) + + def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1133,6 +1283,94 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_max_pool3d_with_indices, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.rand", + aten_name="rand", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.rand_like", + aten_name="rand_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand_like, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.rand_like__dtype", + op=torch.ops.aten.rand_like, + aten_name="rand_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand_like_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint", + aten_name="randint", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint.low", + aten_name="randint.low", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_low, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like", + aten_name="randint_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like__dtype", + op=torch.ops.aten.randint_like, + aten_name="randint_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like.low_dtype", + aten_name="randint_like.low_dtype", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_low_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like.low_dtype__dtype", + op=torch.ops.aten.randint_like.low_dtype, + aten_name="randint_like.low_dtype", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randn", + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randn, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randn_like", + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_like_fns, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randn_like_dtype", + op=torch.ops.aten.randn_like, + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_like_fns_dtype, + supports_out=False, + ), # NOTE: torch.STFT has pre-padding and it's not supported by aten::stft # This custom OpInfo uses aten::stft directly. opinfo_core.OpInfo( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index f2e64cce9..2d00771b5 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -333,14 +333,6 @@ def _nonzero_input_wrangler( return args, kwargs -def _randn_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Make the size argument as attribute list[int] - kwargs["size"] = args.pop(0).tolist() - return args, kwargs - - def _permute_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1170,16 +1162,35 @@ def _where_input_wrangler( trace_only=True, ), TorchLibOpInfo("pow", core_ops.aten_pow), - # TorchLibOpInfo("rand", core_ops.aten_rand), # no test case in OPS_DB + TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), + TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), + TorchLibOpInfo( + "ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True + ), + TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True), + TorchLibOpInfo("ops.aten.randint.low", core_ops.aten_randint_low, nondeterministic=True), + TorchLibOpInfo("ops.aten.randint_like", core_ops.aten_randint_like, nondeterministic=True), + TorchLibOpInfo( + "ops.aten.randint_like__dtype", core_ops.aten_randint_like_dtype, nondeterministic=True + ), TorchLibOpInfo( - "randn", - core_ops.aten_randn, - input_wrangler=_randn_input_wrangler, + "ops.aten.randint_like.low_dtype", + core_ops.aten_randint_like_low_dtype, nondeterministic=True, - ).xfail( + ), + TorchLibOpInfo( + "ops.aten.randint_like.low_dtype__dtype", + core_ops.aten_randint_like_low_dtype_dtype, + nondeterministic=True, + ), + TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( dtypes=(torch.float16,), reason="fixme: Shape inference error", ), + TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), + TorchLibOpInfo( + "ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True + ), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), TorchLibOpInfo( "remainder",