Skip to content

Commit

Permalink
Fix opset import in GemmFloat8 python unit tests (microsoft#18489)
Browse files Browse the repository at this point in the history
### Description
The unit test are failing if a development version of onnx is used. The
opset are set to 19.
  • Loading branch information
xadupre authored and kleiti committed Mar 22, 2024
1 parent 257c72b commit 3c1c007
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numpy.testing import assert_allclose
from onnx import TensorProto
from onnx.checker import check_model
from onnx.defs import onnx_opset_version
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
from onnx.numpy_helper import from_array

Expand Down Expand Up @@ -91,7 +92,10 @@ def get_model_gemm(
]
nodes = [n for n in nodes if n is not None]
graph = make_graph(nodes, "gemm", inputs, [d], inits)
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9)
opset_imports = [make_opsetid("", onnx_opset_version() - 1)]
if domain == "com.microsoft":
opset_imports.append(make_opsetid("com.microsoft", 1))
onnx_model = make_model(graph, opset_imports=opset_imports, ir_version=9)
if domain != "com.microsoft":
check_model(onnx_model)
return onnx_model
Expand Down Expand Up @@ -268,7 +272,8 @@ def test_combinations(self, shapeA, shapeB, transA, transB):
make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
],
[make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])],
)
),
opset_imports=[make_opsetid("", 19), make_opsetid("com.microsoft", 1)],
)

sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
Expand Down

0 comments on commit 3c1c007

Please sign in to comment.