Skip to content

Commit

Permalink
Fix implementation of index_put | fix(torchlib) (#1298)
Browse files Browse the repository at this point in the history
Continuation of #1277 by @xadupre 

The onnx implementation of index_put is different in torch script
exporter

(https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212).
The PR replaces the current implementation failing on one corner case by
the one from torch script.

---------

Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: xadupre <[email protected]>
Co-authored-by: Xavier Dupré <[email protected]>
  • Loading branch information
justinchuby and xadupre authored Mar 21, 2024
1 parent ad6faf2 commit e6ea34f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
28 changes: 11 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,27 +4058,21 @@ def aten_index_put(
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""

# TODO(justinchuby): Handle when indicies has more than one element
index = op.SequenceAt(indices, 0)
new_index = op.Unsqueeze(index, [-1])

if op.Cast(accumulate, to=BOOL.dtype):
# put values into zeros array first, then add to input
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
result = op.Add(result, self)
result = op.ScatterND(self, new_index, values, reduction="add")
else:
result = op.ScatterElements(self, new_ind_t, values)
result = op.ScatterND(self, new_index, values)

return result


Expand Down
27 changes: 27 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,26 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

data = torch_testing.make_tensor(
(10, 3),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
indices = (torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4)),)
values = torch_testing.make_tensor(
(2, 4, 3),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
yield opinfo_core.SampleInput(data, indices, values)


def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
del op_info # unused
del kwargs
Expand Down Expand Up @@ -1936,6 +1956,13 @@ def __init__(self):
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.index_put",
aten_name="index_put",
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_index_put,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
14 changes: 11 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,15 +853,23 @@ def _where_input_wrangler(
core_ops.aten_index_put_bool,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
reason="this Aten overload only support tensor(bool) as args",
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
).skip(
)
.skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.18"),
dtypes=(torch.float16,),
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
),
TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put),
TorchLibOpInfo("index_select", core_ops.aten_index_select),
TorchLibOpInfo("isclose", core_ops.aten_isclose),
TorchLibOpInfo("isfinite", core_ops.aten_isfinite),
Expand Down

0 comments on commit e6ea34f

Please sign in to comment.