Skip to content

Commit

Permalink
Revert "Add support for aten:index op when index is boolean | feat(to…
Browse files Browse the repository at this point in the history
…rchlib)" (#1307)

Reverting because this causes the dispatcher in PyTorch to choose the
wrong overload. Reverts #1285
  • Loading branch information
justinchuby authored Mar 19, 2024
1 parent 9b1f2c6 commit ad6faf2
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 42 deletions.
7 changes: 0 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,13 +4035,6 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
return op.Transpose(self, perm=perm)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0])
new_indices = op.Squeeze(new_indices, axes=[1])
return op.Gather(self, new_indices, axis=0)


def aten_index_add(
self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
) -> TensorType:
Expand Down
34 changes: 0 additions & 34 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,31 +692,6 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
)


def _index_variable_bool(shape, max_indices, device):
if not isinstance(shape, tuple):
shape = (shape,)
index = (
torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool()
)
return index


def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
s = 5
index_bool = _index_variable_bool(s, s, device=device)
test_args = [
([index_bool],),
]

for args in test_args:
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
Expand Down Expand Up @@ -1961,15 +1936,6 @@ def __init__(self):
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.index.Tensor.bool",
aten_name="index.Tensor",
dtypes=common_dtype.all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16, torch.chalf
),
sample_inputs_func=sample_inputs_index_bool,
op=torch.ops.aten.index.Tensor,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
1 change: 0 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,6 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
Expand Down

0 comments on commit ad6faf2

Please sign in to comment.