diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py index e62373eb5..97bcadaa8 100644 --- a/shortfin/tests/invocation/conftest.py +++ b/shortfin/tests/invocation/conftest.py @@ -8,14 +8,6 @@ import urllib.request -def upgrade_onnx(original_path, converted_path): - import onnx - - original_model = onnx.load_model(original_path) - converted_model = onnx.version_converter.convert_version(original_model, 17) - onnx.save(converted_model, converted_path) - - @pytest.fixture(scope="session") def mobilenet_onnx_path(tmp_path_factory): try: @@ -23,16 +15,14 @@ def mobilenet_onnx_path(tmp_path_factory): except ModuleNotFoundError: raise pytest.skip("onnx python package not available") parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") - orig_onnx_path = parent_dir / "mobilenet_orig.onnx" - upgraded_onnx_path = parent_dir / "mobilenet.onnx" - if not upgraded_onnx_path.exists(): + onnx_path = parent_dir / "mobilenet.onnx" + if not onnx_path.exists(): print("Downloading mobilenet.onnx") urllib.request.urlretrieve( "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - orig_onnx_path, + onnx_path, ) - upgrade_onnx(orig_onnx_path, upgraded_onnx_path) - return upgraded_onnx_path + return onnx_path @pytest.fixture(scope="session") @@ -47,7 +37,7 @@ def mobilenet_compiled_path(mobilenet_onnx_path, compile_flags): if not vmfb_path.exists(): print("Compiling mobilenet") args = import_onnx.parse_arguments( - ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ["-o", str(mlir_path), str(mobilenet_onnx_path), "--opset-version", "17"] ) import_onnx.main(args) tools.compile_file(