Skip to content

Commit

Permalink
[torchlib] torch.where(x) - default overload - is actually not implem…
Browse files Browse the repository at this point in the history
…ented (#1971)

As the title said the overload is not implemented in
[`aten_where`](https://github.com/microsoft/onnxscript/blob/99cf79fd4ab150e3726b36fb3e9104304e203200/onnxscript/function_libs/torch_lib/ops/core.py#L8871C1-L8874C44).
It should be decomposed into `nonzero` function by pytorch.

Now it throws error as there is not enough parameters. Minimal
reproducible example:
```python
import torch

class Model(torch.nn.Module):
	def forward(self, x):
		return torch.where(x)

torch.onnx.export(Model(), (torch.tensor([0, 1, 2, 0, 3]),), dynamo=True)
```
```
<class 'ValueError'>: Required parameter 'self' is not provided. Signature: pkg.onnxscript.torch_lib::aten_where(condition: T_condition, self: TTensor, other: TTensor) -> (TTensor) where T_condition=BOOL, TTensor=INT8 | FLOAT16 | INT16 | INT32 | UINT8 | FLOAT | BOOL | COMPLEX128 | BFLOAT16 | COMPLEX64 | DOUBLE | INT64. Args: (SymbolicTensor('x', type=Tensor(INT64), shape=[5], producer=None, index=None),). Kwargs: {}.
```
As for the tests I would have thought it is handled by the
[ops_test.py](https://github.com/microsoft/onnxscript/tree/main/tests/function_libs/torch_lib)
but apparently it is not.

---
As a side note, the `pylint` is somehow broken for this file (at least).

Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
Bludator and titaiwangms authored Dec 17, 2024
1 parent 0aed232 commit f0769c3
Showing 1 changed file with 0 additions and 1 deletion.
1 change: 0 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8861,7 +8861,6 @@ def reshape_to_2d(tensor):

@torch_op(
(
"aten::where",
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
Expand Down

0 comments on commit f0769c3

Please sign in to comment.