From be865931938067821ab73f7a9fe3cf1dfa2f8934 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 10:41:26 -0700 Subject: [PATCH 1/2] [torchlib] Improve aten::fill --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1fc122966..32dcf770e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3622,10 +3622,6 @@ def aten_full( if dtype != -1: fill_value = op.Cast(fill_value, to=dtype) - if isinstance(size, list) and size == []: - # TODO(justinchuby): Handle empty list better than using isinstance - # size can be empty, meaning a scalar - return fill_value size = op.Cast(size, to=INT64.dtype) return op.Expand(fill_value, size) From 32439416b2016feab75175e07d5f484975be743e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 19:47:11 +0000 Subject: [PATCH 2/2] lint --- onnxscript/tools/benchmark/benchmark_helpers.py | 2 +- onnxscript/tools/transformers_models/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index e796a8808..3a874fa46 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -287,7 +287,7 @@ def common_export( if exporter == "script": torch.onnx.export( model, - inputs, + inputs, # type: ignore[arg-type] filename, do_constant_folding=False, input_names=[f"input{i}" for i in range(len(inputs))], diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index fd7a5807a..43dc81e9b 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -41,6 +41,7 @@ def export_to_onnx( prog = torch.onnx.export(model, args, dynamo=True) # pylint: disable=no-value-for-parameter else: prog = torch.onnx.dynamo_export(model, *args) + assert prog is not None model_proto = prog.model_proto if optimize: model_proto = onnxscript.optimizer.optimize(