diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 54aa412ff..4fac129ef 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -425,11 +425,19 @@ def eval_function( # type: ignore[override] return self._graph.add_function_call(function, inputs, attributes) -@runtime_typing.checked def _add_attribute_to_torchscript_node( node: torch.Node, key: str, - value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor], + value: Union[ + float, + int, + str, + bytes, + Sequence[float], + Sequence[int], + torch.Tensor, + ir.TensorProtocol, + ], ): """Initializes the right attribute based on type of value.""" if isinstance(value, float): diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index 50d0f568f..f46756055 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -84,6 +84,10 @@ def test_tensor_proto_tensor(self, _: str, dtype: int): self.skipTest("numpy<1.25 does not support bool dtype in from_dlpack") np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) + @unittest.skipIf( + version_utils.onnx_older_than("1.17"), + "numpy_helper.to_array was not correctly implemented in onnx<1.17", + ) def test_tensor_proto_tensor_bfloat16(self): expected_array = np.array( [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]], dtype=ml_dtypes.bfloat16 @@ -95,7 +99,7 @@ def test_tensor_proto_tensor_bfloat16(self): np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]]), ) tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.bfloat16), expected_array) + np.testing.assert_array_equal(tensor.numpy(), expected_array) raw_data = tensor.tobytes() tensor_proto_from_raw_data = onnx.TensorProto( dims=tensor_proto.dims, @@ -103,7 +107,9 @@ def test_tensor_proto_tensor_bfloat16(self): raw_data=raw_data, ) array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) + np.testing.assert_array_equal( + array_from_raw_data.view(ml_dtypes.bfloat16), expected_array + ) # Test dlpack with self.assertRaises(BufferError): # NumPy does not support bfloat16 in from_dlpack