From 937558f6155075224d80f0bc1bc83f91294029bd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 14:35:59 -0700 Subject: [PATCH] [torchlib] Improve aten::fill (#1754) I updated torch-onnx to handle empty `[]` inputs, so the isinstance check is not needed. --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- onnxscript/tools/benchmark/benchmark_helpers.py | 2 +- onnxscript/tools/transformers_models/__init__.py | 1 + 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1fc122966..32dcf770e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3622,10 +3622,6 @@ def aten_full( if dtype != -1: fill_value = op.Cast(fill_value, to=dtype) - if isinstance(size, list) and size == []: - # TODO(justinchuby): Handle empty list better than using isinstance - # size can be empty, meaning a scalar - return fill_value size = op.Cast(size, to=INT64.dtype) return op.Expand(fill_value, size) diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index e796a8808..3a874fa46 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -287,7 +287,7 @@ def common_export( if exporter == "script": torch.onnx.export( model, - inputs, + inputs, # type: ignore[arg-type] filename, do_constant_folding=False, input_names=[f"input{i}" for i in range(len(inputs))], diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index fd7a5807a..43dc81e9b 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -41,6 +41,7 @@ def export_to_onnx( prog = torch.onnx.export(model, args, dynamo=True) # pylint: disable=no-value-for-parameter else: prog = torch.onnx.dynamo_export(model, *args) + assert prog is not None model_proto = prog.model_proto if optimize: model_proto = onnxscript.optimizer.optimize(