Skip to content

Commit

Permalink
Add provider argument to bfloat16 converter test (#1513)
Browse files Browse the repository at this point in the history
In test, adding missed provider information into ORT inference session.
  • Loading branch information
titaiwangms authored May 7, 2024
1 parent c2d1de1 commit ea1eda9
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def test_bfloat16_converted_model_runtime(self):
onnx.checker.check_model(model_proto_filled_shape_type, full_check=True)
try:
ort_session = onnxruntime.InferenceSession(
model_proto_filled_shape_type.SerializeToString()
model_proto_filled_shape_type.SerializeToString(),
providers=["CPUExecutionProvider"],
)
v0 = np.random.randn(2, 3, 4).astype(np.float16)
v1 = np.random.randn(2, 3, 4).astype(np.float16)
Expand Down

0 comments on commit ea1eda9

Please sign in to comment.