diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index ed53d2f64..8effd0b28 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -75,7 +75,8 @@ def test_bfloat16_converted_model_runtime(self): onnx.checker.check_model(model_proto_filled_shape_type, full_check=True) try: ort_session = onnxruntime.InferenceSession( - model_proto_filled_shape_type.SerializeToString() + model_proto_filled_shape_type.SerializeToString(), + providers=["CPUExecutionProvider"], ) v0 = np.random.randn(2, 3, 4).astype(np.float16) v1 = np.random.randn(2, 3, 4).astype(np.float16)