Skip to content

Commit

Permalink
Fix randint dtypes | test(torchlib) (#1088)
Browse files Browse the repository at this point in the history
Fix test input types for randint by replacing float input tests to
integer inputs.
  • Loading branch information
justinchuby authored Oct 11, 2023
1 parent 1ee4ee1 commit a981b8a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,44 +1308,44 @@ def sample_inputs_scaled_dot_product_flash_attention(
opinfo_core.OpInfo(
"ops.aten.randint",
aten_name="randint",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.randint.low",
aten_name="randint.low",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint_low,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.randint_like",
aten_name="randint_like",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint_like,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.randint_like__dtype",
op=torch.ops.aten.randint_like,
aten_name="randint_like",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint_like_dtype,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.randint_like.low_dtype",
aten_name="randint_like.low_dtype",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint_like_low_dtype,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.randint_like.low_dtype__dtype",
op=torch.ops.aten.randint_like.low_dtype,
aten_name="randint_like.low_dtype",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypes=common_dtype.integral_types(),
sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype,
supports_out=False,
),
Expand Down

0 comments on commit a981b8a

Please sign in to comment.