diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6c65f56b5..ff13c6a49 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4035,13 +4035,6 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy return op.Transpose(self, perm=perm) -@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: - new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0]) - new_indices = op.Squeeze(new_indices, axes=[1]) - return op.Gather(self, new_indices, axis=0) - - def aten_index_add( self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1 ) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 123cc58fa..87b1a886d 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -692,31 +692,6 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) -def _index_variable_bool(shape, max_indices, device): - if not isinstance(shape, tuple): - shape = (shape,) - index = ( - torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool() - ) - return index - - -def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del kwargs # Unused - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad - ) - s = 5 - index_bool = _index_variable_bool(s, s, device=device) - test_args = [ - ([index_bool],), - ] - - for args in test_args: - yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) - - def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused del kwargs # Unused @@ -1961,15 +1936,6 @@ def __init__(self): ), sample_inputs_func=sample_inputs_index, ), - opinfo_core.OpInfo( - "ops.aten.index.Tensor.bool", - aten_name="index.Tensor", - dtypes=common_dtype.all_types_and_complex_and( - torch.bool, torch.float16, torch.bfloat16, torch.chalf - ), - sample_inputs_func=sample_inputs_index_bool, - op=torch.ops.aten.index.Tensor, - ), opinfo_core.OpInfo( "ops.aten.layer_norm", aten_name="layer_norm", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 4b8196b5f..62b3b345f 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -848,7 +848,6 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool,