diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 50d0f568f..9a9fef2d3 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -103,7 +103,7 @@ def test_tensor_proto_tensor_bfloat16(self): raw_data=raw_data, ) array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) + np.testing.assert_array_equal(array_from_raw_data.view(ml_dtypes.bfloat16), expected_array) # Test dlpack with self.assertRaises(BufferError): # NumPy does not support bfloat16 in from_dlpack