From b120613a9ecb4098e9e3d68fa055afaa949848e6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Aug 2023 19:49:54 -0700 Subject: [PATCH 1/7] Rand ops (#1033) --- .../function_libs/torch_lib/ops/core.py | 129 +++++++++++++++--- .../function_libs/torch_lib/ops_test_data.py | 29 +++- 2 files changed, 134 insertions(+), 24 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index afdeb56c1..85c2927fc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5702,47 +5702,138 @@ def aten_rad2deg(self: TensorType) -> TensorType: @torch_op("aten::rand") -def aten_rand(size: Sequence[int], dtype: int = 1) -> TReal: +def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal: """rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.RandomUniform(shape=size, dtype=dtype) + shaper = op.ConstantOfShape(size) + return op.RandomUniformLike(shaper, dtype=dtype) -def aten_rand_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +@torch_op("aten::rand_like") +def aten_rand_like(self: TFloat) -> TFloat: """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - raise NotImplementedError() + return op.RandomUniformLike(self) + +@torch_op("aten::rand_like") +def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType: + """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" -def aten_randint(high: int, size: INT64) -> TensorType: - """randint(int high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + return op.RandomUniformLike(self, dtype=dtype) - raise NotImplementedError() +@torch_op("aten::randint") +def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType: + """randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" -def aten_randint_like( - self: TensorType, high: int, memory_format: Optional[str] = None + shaper = op.ConstantOfShape(size) + rand = op.RandomUniformLike(shaper) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint.low") +def aten_randint_low( + low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype ) -> TensorType: - """randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + """randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + shaper = op.ConstantOfShape(size) + rand = op.RandomUniformLike(shaper) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint_like") +def aten_randint_like(self: TensorType, high: INT64) -> IntType: + """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.CastLike(rand_int, self) + + +@torch_op("aten::randint_like") +def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType: + """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Scale to [0, high] first + rand_scaled = op.Mul(rand, op.CastLike(high, rand)) + # Round to ints + rand_int = op.Floor(rand_scaled) + return op.Cast(rand_int, to=dtype) + + +@torch_op("aten::randint_like.low_dtype") +def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType: + """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + + This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None. + """ + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.CastLike(rand_int, self) + + +@torch_op("aten::randint_like.low_dtype") +def aten_randint_like_low_dtype_dtype( + self: TensorType, low: INT64, high: INT64, dtype: int +) -> TensorType: + """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + self_float = op.Cast(self, to=FLOAT.dtype) + rand = op.RandomUniformLike(self_float) + # Translate to [low, high] first + high = op.Cast(high, to=FLOAT.dtype) + low = op.Cast(low, to=FLOAT.dtype) + rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) + # Round to ints + rand_int = op.Floor(rand_translated) + return op.Cast(rand_int, to=dtype) @torch_op("aten::randn") -def aten_randn( - size: Sequence[int], - dtype: int = 1, - requires_grad: bool = False, # pylint: disable=unused-argument -) -> TReal: +def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal: """randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - return op.RandomNormal(dtype=dtype, shape=size) + shaper = op.ConstantOfShape(size) + return op.RandomNormalLike(shaper, dtype=dtype) -def aten_randn_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +@torch_op("aten::randn_like") +def aten_randn_like(self: TFloat) -> TFloat: """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - raise NotImplementedError() + return op.RandomNormalLike(self) + + +@torch_op("aten::randn_like") +def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType: + """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + + return op.RandomNormalLike(self, dtype=dtype) def aten_randperm(n: int) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 207b8459e..929e8597d 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1137,16 +1137,35 @@ def _where_input_wrangler( trace_only=True, ), TorchLibOpInfo("pow", core_ops.aten_pow), - # TorchLibOpInfo("rand", core_ops.aten_rand), # no test case in OPS_DB + TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand), + TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), TorchLibOpInfo( - "randn", - core_ops.aten_randn, - input_wrangler=_randn_input_wrangler, + "ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True + ), + TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True), + TorchLibOpInfo("ops.aten.randint.low", core_ops.aten_randint_low, nondeterministic=True), + TorchLibOpInfo("ops.aten.randint_like", core_ops.aten_randint_like, nondeterministic=True), + TorchLibOpInfo( + "ops.aten.randint_like__dtype", core_ops.aten_randint_like_dtype, nondeterministic=True + ), + TorchLibOpInfo( + "ops.aten.randint_like.low_dtype", + core_ops.aten_randint_like_low_dtype, nondeterministic=True, - ).xfail( + ), + TorchLibOpInfo( + "ops.aten.randint_like.low_dtype__dtype", + core_ops.aten_randint_like_low_dtype_dtype, + nondeterministic=True, + ), + TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( dtypes=(torch.float16,), reason="fixme: Shape inference error", ), + TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn, nondeterministic=True), + TorchLibOpInfo( + "ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True + ), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), TorchLibOpInfo( "remainder", From 2130301fcd2ffbd080f3b85bf222e70be4b0b0e3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 15:33:40 +0000 Subject: [PATCH 2/7] Snap --- .../function_libs/torch_lib/extra_opinfo.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 2228a079a..49cd4ca0a 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -516,6 +516,109 @@ def sample_inputs_native_dropout( yield opinfo_core.SampleInput(make_arg(case), p=p, train=training) +def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): + def op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) + + +def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): + def op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape)) + + +def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): + def op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=torch.float32, requires_grad=requires_grad + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) + + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) + + +def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {'dtype': dtype}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), {'dtype': torch.double}), + ((S,), ), + + ] + for shape, kwargs in inputs: + t = torch_testing.make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield opinfo_core.SampleInput(t, **kwargs) + + +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + +def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield SampleInput( + sample.input, + high, + *sample.args, + **sample.kwargs) + # With low and high + yield SampleInput( + get_independent_tensor(sample.input), + low, + high, + *sample.args, + **sample.kwargs) + + def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs From 6ee3f246c1294475e3d9396972b05801b9079d1f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 18:49:11 +0000 Subject: [PATCH 3/7] Tests --- .../function_libs/torch_lib/extra_opinfo.py | 162 +++++++++++++++--- .../function_libs/torch_lib/ops_test_data.py | 8 - 2 files changed, 137 insertions(+), 33 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 49cd4ca0a..1a7ee5688 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -517,7 +517,7 @@ def sample_inputs_native_dropout( def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): - def op_info # Unused + del op_info # Unused make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -533,7 +533,7 @@ def op_info # Unused def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): - def op_info # Unused + del op_info # Unused make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -549,10 +549,13 @@ def op_info # Unused def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): - def op_info # Unused + del op_info # Unused make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=torch.float32, requires_grad=requires_grad + torch_testing.make_tensor, + device=device, + dtype=torch.float32, + requires_grad=requires_grad, ) shapes = ( (M,), @@ -565,24 +568,36 @@ def op_info # Unused def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + del self # Unused + inputs = [ ((), {}), ((S, S), {}), ((0, S, 0), {}), - ((S,), {'dtype': dtype}), + ((S,),), + ] + for shape, kwargs in inputs: + t = torch_testing.make_tensor( + shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad + ) + yield opinfo_core.SampleInput(t, **kwargs) + + +def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + del self # Unused + + inputs = [ + ((S,), {"dtype": dtype}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), {'dtype': torch.double}), - ((S,), ), - + ((S,), {"dtype": torch.double}), ] for shape, kwargs in inputs: - t = torch_testing.make_tensor(shape, dtype=dtype, device=device, - low=None, high=None, - requires_grad=requires_grad) + t = torch_testing.make_tensor( + shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad + ) yield opinfo_core.SampleInput(t, **kwargs) - def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): high = 10 @@ -590,33 +605,64 @@ def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): # With high yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs): low = 2 high = 10 for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): # With low and high - yield opinfo_core.SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) + yield opinfo_core.SampleInput( + low, high, sample.input.shape, *sample.args, **sample.kwargs + ) def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): - low = 2 high = 10 for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): # With high - yield SampleInput( - sample.input, - high, - *sample.args, - **sample.kwargs) + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): # With low and high - yield SampleInput( - get_independent_tensor(sample.input), - low, - high, - *sample.args, - **sample.kwargs) + yield opinfo_core.SampleInput( + sample.input, high, *sample.args, **sample.kwargs + ) + + +def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) + + +def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput( + sample.input, low, high, *sample.args, **sample.kwargs + ) + + +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + del op # Unused + del requires_grad # Unused + + shapes = ([M], [S, S]) + + for shape in shapes: + yield opinfo_core.SampleInput(input=shape, kwargs=dict(dtype=dtype)) def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): @@ -1236,6 +1282,72 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_max_pool3d_with_indices, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.rand", + aten_name="rand", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.rand_like", + aten_name="rand_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand_like, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.rand_like__dtype", + op=torch.ops.aten.rand_like, + aten_name="rand_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_rand_like_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint", + aten_name="randint", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint.low", + aten_name="randint.low", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_low, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like", + aten_name="randint_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like__dtype", + op=torch.ops.aten.randint_like, + aten_name="randint_like", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like.low_dtype", + aten_name="randint_like.low_dtype", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_low_dtype, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randint_like.low_dtype__dtype", + op=torch.ops.aten.randint_like.low_dtype, + aten_name="randint_like.low_dtype", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, + supports_out=False, + ), # NOTE: torch.STFT has pre-padding and it's not supported by aten::stft # This custom OpInfo uses aten::stft directly. opinfo_core.OpInfo( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index daf1bc239..f3a10ebfa 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -333,14 +333,6 @@ def _nonzero_input_wrangler( return args, kwargs -def _randn_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Make the size argument as attribute list[int] - kwargs["size"] = args.pop(0).tolist() - return args, kwargs - - def _permute_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: From 47b65f5fb5d2f75c9a4e18a6e264dcfa6ca65ee2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 18:52:16 +0000 Subject: [PATCH 4/7] Tests --- .../function_libs/torch_lib/extra_opinfo.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 1a7ee5688..64059f6a4 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -598,6 +598,7 @@ def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): ) yield opinfo_core.SampleInput(t, **kwargs) + def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): high = 10 @@ -630,9 +631,7 @@ def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwarg for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): # With low and high - yield opinfo_core.SampleInput( - sample.input, high, *sample.args, **sample.kwargs - ) + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): @@ -650,9 +649,7 @@ def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_gra for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): # With low and high - yield opinfo_core.SampleInput( - sample.input, low, high, *sample.args, **sample.kwargs - ) + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): @@ -1348,6 +1345,28 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.randn", + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_randn, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randn_like", + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_like_fns, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.randn_like_dtype", + op=torch.ops.aten.randn_like, + aten_name="randn", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_like_fns_dtype, + supports_out=False, + ), # NOTE: torch.STFT has pre-padding and it's not supported by aten::stft # This custom OpInfo uses aten::stft directly. opinfo_core.OpInfo( From 9b69176e47b3bc8cd6b54b650dc8a5c035d093c5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 23:38:55 +0000 Subject: [PATCH 5/7] Fix all tests --- .../tests/function_libs/torch_lib/extra_opinfo.py | 10 ++++------ .../tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 64059f6a4..04819118b 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -519,9 +519,6 @@ def sample_inputs_native_dropout( def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) shapes = ( (M,), (S, S), @@ -529,7 +526,7 @@ def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): ) for shape in shapes: - yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) + yield opinfo_core.SampleInput(shape, kwargs=dict(dtype=dtype)) def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): @@ -574,7 +571,7 @@ def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): ((), {}), ((S, S), {}), ((0, S, 0), {}), - ((S,),), + ((S,), {}), ] for shape, kwargs in inputs: t = torch_testing.make_tensor( @@ -654,9 +651,10 @@ def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_gra def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): del op # Unused + del device # Unused del requires_grad # Unused - shapes = ([M], [S, S]) + shapes = ((M,), (S, S)) for shape in shapes: yield opinfo_core.SampleInput(input=shape, kwargs=dict(dtype=dtype)) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index f3a10ebfa..2d00771b5 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1162,7 +1162,7 @@ def _where_input_wrangler( trace_only=True, ), TorchLibOpInfo("pow", core_ops.aten_pow), - TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand), + TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), TorchLibOpInfo( "ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True @@ -1187,7 +1187,7 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Shape inference error", ), - TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn, nondeterministic=True), + TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo( "ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True ), From 286b7ec3a4fe8d2c114cf450d66ac4950e1685f1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 23:40:53 +0000 Subject: [PATCH 6/7] unused --- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 04819118b..43436a541 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -518,6 +518,9 @@ def sample_inputs_native_dropout( def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused + del device # Unused + del requires_grad # Unused + del kwargs # Unused shapes = ( (M,), @@ -531,6 +534,7 @@ def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused + del kwargs # Unused make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -547,6 +551,7 @@ def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused + del kwargs # Unused make_arg = functools.partial( torch_testing.make_tensor, From f7c1159046e1f90378eb7b071c817d906c78b2be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 23:41:42 +0000 Subject: [PATCH 7/7] del --- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 43436a541..c274d95f9 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -658,6 +658,7 @@ def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): del op # Unused del device # Unused del requires_grad # Unused + del kwargs # Unused shapes = ((M,), (S, S))