From a981b8add9a4c7e67ad0d28622b23d4c6a55a76e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Oct 2023 20:12:22 -0700 Subject: [PATCH] Fix randint dtypes | test(torchlib) (#1088) Fix test input types for randint by replacing float input tests to integer inputs. --- .../tests/function_libs/torch_lib/extra_opinfo.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index c274d95f9..87575cea7 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1308,21 +1308,21 @@ def sample_inputs_scaled_dot_product_flash_attention( opinfo_core.OpInfo( "ops.aten.randint", aten_name="randint", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint.low", aten_name="randint.low", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_low, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint_like", aten_name="randint_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like, supports_out=False, ), @@ -1330,14 +1330,14 @@ def sample_inputs_scaled_dot_product_flash_attention( "ops.aten.randint_like__dtype", op=torch.ops.aten.randint_like, aten_name="randint_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_dtype, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint_like.low_dtype", aten_name="randint_like.low_dtype", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_low_dtype, supports_out=False, ), @@ -1345,7 +1345,7 @@ def sample_inputs_scaled_dot_product_flash_attention( "ops.aten.randint_like.low_dtype__dtype", op=torch.ops.aten.randint_like.low_dtype, aten_name="randint_like.low_dtype", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, supports_out=False, ),