Skip to content

Commit

Permalink
test_tensor_proto_tensor_bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 22, 2024
1 parent 88a62b1 commit 321b62c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions onnxscript/ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def test_tensor_proto_tensor(self, _: str, dtype: int):
self.skipTest("numpy<1.25 does not support bool dtype in from_dlpack")
np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy())

@unittest.skipIf(
version_utils.onnx_older_than("1.17"),
"numpy_helper.to_array was not correctly implemented in onnx<1.17",
)
def test_tensor_proto_tensor_bfloat16(self):
expected_array = np.array(
[[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]], dtype=ml_dtypes.bfloat16
Expand All @@ -95,15 +99,17 @@ def test_tensor_proto_tensor_bfloat16(self):
np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]]),
)
tensor = serde.TensorProtoTensor(tensor_proto)
np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.bfloat16), expected_array)
np.testing.assert_array_equal(tensor.numpy(), expected_array)
raw_data = tensor.tobytes()
tensor_proto_from_raw_data = onnx.TensorProto(
dims=tensor_proto.dims,
data_type=tensor_proto.data_type,
raw_data=raw_data,
)
array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
np.testing.assert_array_equal(array_from_raw_data.view(ml_dtypes.bfloat16), expected_array)
np.testing.assert_array_equal(
array_from_raw_data.view(ml_dtypes.bfloat16), expected_array
)
# Test dlpack
with self.assertRaises(BufferError):
# NumPy does not support bfloat16 in from_dlpack
Expand Down

0 comments on commit 321b62c

Please sign in to comment.