From 492096e23b458c3a40f96bea1e8b8ac2519cfeb2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 22 Jul 2024 21:02:05 +0000 Subject: [PATCH 1/3] Fix weekly CI pipeline errors 1. Fix type annotation for _add_attribute_to_torchscript_node and removed runtime type checking because int like inputs are not correctly recognized by beartype 2. Fix IR test errors with latest onnx-weekly --- .../graph_building/_graph_building_torch.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 54aa412ff..4fac129ef 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -425,11 +425,19 @@ def eval_function( # type: ignore[override] return self._graph.add_function_call(function, inputs, attributes) -@runtime_typing.checked def _add_attribute_to_torchscript_node( node: torch.Node, key: str, - value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor], + value: Union[ + float, + int, + str, + bytes, + Sequence[float], + Sequence[int], + torch.Tensor, + ir.TensorProtocol, + ], ): """Initializes the right attribute based on type of value.""" if isinstance(value, float): From 88a62b1d0cb6ce179b2980be2e88ea699c5a7060 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 22 Jul 2024 21:04:19 +0000 Subject: [PATCH 2/3] bfloat16 --- onnxscript/ir/serde_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 50d0f568f..9a9fef2d3 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -103,7 +103,7 @@ def test_tensor_proto_tensor_bfloat16(self): raw_data=raw_data, ) array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) + np.testing.assert_array_equal(array_from_raw_data.view(ml_dtypes.bfloat16), expected_array) # Test dlpack with self.assertRaises(BufferError): # NumPy does not support bfloat16 in from_dlpack From 321b62cc192158a5ba93cca6b57d8906bb8f53bd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 22 Jul 2024 21:10:57 +0000 Subject: [PATCH 3/3] test_tensor_proto_tensor_bfloat16 --- onnxscript/ir/serde_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 9a9fef2d3..f46756055 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -84,6 +84,10 @@ def test_tensor_proto_tensor(self, _: str, dtype: int): self.skipTest("numpy<1.25 does not support bool dtype in from_dlpack") np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) + @unittest.skipIf( + version_utils.onnx_older_than("1.17"), + "numpy_helper.to_array was not correctly implemented in onnx<1.17", + ) def test_tensor_proto_tensor_bfloat16(self): expected_array = np.array( [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]], dtype=ml_dtypes.bfloat16 @@ -95,7 +99,7 @@ def test_tensor_proto_tensor_bfloat16(self): np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]]), ) tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.bfloat16), expected_array) + np.testing.assert_array_equal(tensor.numpy(), expected_array) raw_data = tensor.tobytes() tensor_proto_from_raw_data = onnx.TensorProto( dims=tensor_proto.dims, @@ -103,7 +107,9 @@ def test_tensor_proto_tensor_bfloat16(self): raw_data=raw_data, ) array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data.view(ml_dtypes.bfloat16), expected_array) + np.testing.assert_array_equal( + array_from_raw_data.view(ml_dtypes.bfloat16), expected_array + ) # Test dlpack with self.assertRaises(BufferError): # NumPy does not support bfloat16 in from_dlpack