From f26f4d785b66a30339e53ebb452303b6af55c7c4 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 7 Jan 2025 14:30:19 -0800 Subject: [PATCH] Use --opset-version provided by iree-import-onnx. (#776) This option was added in https://github.com/iree-org/iree/commit/d4975713a0aa3ba872807fc16585fb4b5a04e41d. --- shortfin/tests/invocation/conftest.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py index e62373eb5..97bcadaa8 100644 --- a/shortfin/tests/invocation/conftest.py +++ b/shortfin/tests/invocation/conftest.py @@ -8,14 +8,6 @@ import urllib.request -def upgrade_onnx(original_path, converted_path): - import onnx - - original_model = onnx.load_model(original_path) - converted_model = onnx.version_converter.convert_version(original_model, 17) - onnx.save(converted_model, converted_path) - - @pytest.fixture(scope="session") def mobilenet_onnx_path(tmp_path_factory): try: @@ -23,16 +15,14 @@ def mobilenet_onnx_path(tmp_path_factory): except ModuleNotFoundError: raise pytest.skip("onnx python package not available") parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") - orig_onnx_path = parent_dir / "mobilenet_orig.onnx" - upgraded_onnx_path = parent_dir / "mobilenet.onnx" - if not upgraded_onnx_path.exists(): + onnx_path = parent_dir / "mobilenet.onnx" + if not onnx_path.exists(): print("Downloading mobilenet.onnx") urllib.request.urlretrieve( "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - orig_onnx_path, + onnx_path, ) - upgrade_onnx(orig_onnx_path, upgraded_onnx_path) - return upgraded_onnx_path + return onnx_path @pytest.fixture(scope="session") @@ -47,7 +37,7 @@ def mobilenet_compiled_path(mobilenet_onnx_path, compile_flags): if not vmfb_path.exists(): print("Compiling mobilenet") args = import_onnx.parse_arguments( - ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ["-o", str(mlir_path), str(mobilenet_onnx_path), "--opset-version", "17"] ) import_onnx.main(args) tools.compile_file(