diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py index 7dffad8f84c83..482a334b12b85 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -14,6 +14,7 @@ from numpy.testing import assert_allclose from onnx import TensorProto from onnx.checker import check_model +from onnx.defs import onnx_opset_version from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info from onnx.numpy_helper import from_array @@ -91,7 +92,10 @@ def get_model_gemm( ] nodes = [n for n in nodes if n is not None] graph = make_graph(nodes, "gemm", inputs, [d], inits) - onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + opset_imports = [make_opsetid("", onnx_opset_version() - 1)] + if domain == "com.microsoft": + opset_imports.append(make_opsetid("com.microsoft", 1)) + onnx_model = make_model(graph, opset_imports=opset_imports, ir_version=9) if domain != "com.microsoft": check_model(onnx_model) return onnx_model @@ -268,7 +272,8 @@ def test_combinations(self, shapeA, shapeB, transA, transB): make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), ], [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], - ) + ), + opset_imports=[make_opsetid("", 19), make_opsetid("com.microsoft", 1)], ) sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"])