Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torchlib] torch.where(x) - default overload - is actually not implem…
…ented (#1971) As the title said the overload is not implemented in [`aten_where`](https://github.com/microsoft/onnxscript/blob/99cf79fd4ab150e3726b36fb3e9104304e203200/onnxscript/function_libs/torch_lib/ops/core.py#L8871C1-L8874C44). It should be decomposed into `nonzero` function by pytorch. Now it throws error as there is not enough parameters. Minimal reproducible example: ```python import torch class Model(torch.nn.Module): def forward(self, x): return torch.where(x) torch.onnx.export(Model(), (torch.tensor([0, 1, 2, 0, 3]),), dynamo=True) ``` ``` <class 'ValueError'>: Required parameter 'self' is not provided. Signature: pkg.onnxscript.torch_lib::aten_where(condition: T_condition, self: TTensor, other: TTensor) -> (TTensor) where T_condition=BOOL, TTensor=INT8 | FLOAT16 | INT16 | INT32 | UINT8 | FLOAT | BOOL | COMPLEX128 | BFLOAT16 | COMPLEX64 | DOUBLE | INT64. Args: (SymbolicTensor('x', type=Tensor(INT64), shape=[5], producer=None, index=None),). Kwargs: {}. ``` As for the tests I would have thought it is handled by the [ops_test.py](https://github.com/microsoft/onnxscript/tree/main/tests/function_libs/torch_lib) but apparently it is not. --- As a side note, the `pylint` is somehow broken for this file (at least). Co-authored-by: Ti-Tai Wang <[email protected]>
- Loading branch information