diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 0e611562a..b9c04d711 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -132,6 +132,12 @@ def __array__(self, dtype: Any = None) -> np.ndarray: def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. + This is an improved version of onnx.numpy_helper.to_array. + It first reads the data using the dtype corresponding to the tensor + proto data field, then converts it to the correct dtype and shape. + Special cases are bfloat16, complex and int4 where we need to + reinterpret the data. Other types can simply be casted. + When the data type is not supported by numpy, the value is the bit representation of the dtype: @@ -139,6 +145,16 @@ def numpy(self) -> np.ndarray: - ``uint8`` for uint4. - ``uint8`` for 8-bit data types like float8. - ``uint16`` for bfloat16. + + When the data type is a string, this method returns a numpy array + of bytes instead of a numpy array of strings, to follow the ONNX + specification. + + External tensors are not supported by this class. Use + :class:`onnxscript.ir.ExternalTensor` instead. + + Raises: + ValueError: If the data type is UNDEFINED. """ # This is an improved version of onnx.numpy_helper.to_array. # It first reads the data using the dtype corresponding to the tensor @@ -146,9 +162,14 @@ def numpy(self) -> np.ndarray: # Special cases are bfloat16, complex and int4 where we need to # reinterpret the data. Other types can simply be casted. dtype = self.dtype + if dtype == _enums.DataType.UNDEFINED: + raise ValueError("Cannot convert UNDEFINED tensor to numpy array.") + if self._proto.HasField("raw_data"): array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")) # Cannot return now, because we may need to unpack 4bit tensors + if dtype == _enums.DataType.STRING: + return np.array(self._proto.string_data).reshape(self._proto.dims) elif self._proto.int32_data: array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32)) if dtype == _enums.DataType.FLOAT16: @@ -180,7 +201,19 @@ def numpy(self) -> np.ndarray: return array.astype(dtype.numpy()).reshape(self._proto.dims) def tobytes(self) -> bytes: - """Return the tensor as a byte string conformed to the ONNX specification, in little endian.""" + """Return the tensor as a byte string conformed to the ONNX specification, in little endian. + + Raises: + ValueError: If the tensor is a string tensor or an external tensor. + ValueError: If the tensor is of UNDEFINED data type. + """ + if self._proto.data_location == onnx.TensorProto.EXTERNAL: + raise ValueError("Cannot convert external tensor to bytes.") + if self.dtype == _enums.DataType.STRING: + raise ValueError("Cannot convert string tensor to bytes.") + if self.dtype == _enums.DataType.UNDEFINED: + raise ValueError("Cannot convert UNDEFINED tensor to bytes.") + if self._proto.HasField("raw_data"): return self._proto.raw_data if self._proto.float_data: