From f0769c3bc86810f3c025bf72cae4e2710de0af64 Mon Sep 17 00:00:00 2001 From: Bludator Date: Tue, 17 Dec 2024 01:28:22 +0100 Subject: [PATCH] [torchlib] torch.where(x) - default overload - is actually not implemented (#1971) As the title said the overload is not implemented in [`aten_where`](https://github.com/microsoft/onnxscript/blob/99cf79fd4ab150e3726b36fb3e9104304e203200/onnxscript/function_libs/torch_lib/ops/core.py#L8871C1-L8874C44). It should be decomposed into `nonzero` function by pytorch. Now it throws error as there is not enough parameters. Minimal reproducible example: ```python import torch class Model(torch.nn.Module): def forward(self, x): return torch.where(x) torch.onnx.export(Model(), (torch.tensor([0, 1, 2, 0, 3]),), dynamo=True) ``` ``` : Required parameter 'self' is not provided. Signature: pkg.onnxscript.torch_lib::aten_where(condition: T_condition, self: TTensor, other: TTensor) -> (TTensor) where T_condition=BOOL, TTensor=INT8 | FLOAT16 | INT16 | INT32 | UINT8 | FLOAT | BOOL | COMPLEX128 | BFLOAT16 | COMPLEX64 | DOUBLE | INT64. Args: (SymbolicTensor('x', type=Tensor(INT64), shape=[5], producer=None, index=None),). Kwargs: {}. ``` As for the tests I would have thought it is handled by the [ops_test.py](https://github.com/microsoft/onnxscript/tree/main/tests/function_libs/torch_lib) but apparently it is not. --- As a side note, the `pylint` is somehow broken for this file (at least). Co-authored-by: Ti-Tai Wang --- onnxscript/function_libs/torch_lib/ops/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 63f692954..9de7b170f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8861,7 +8861,6 @@ def reshape_to_2d(tensor): @torch_op( ( - "aten::where", "aten::where.Scalar", "aten::where.ScalarSelf", "aten::where.ScalarOther",