Skip to content

Commit

Permalink
string and external
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed May 4, 2024
1 parent e9e452e commit fa4239a
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,44 @@ def __array__(self, dtype: Any = None) -> np.ndarray:
def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
This is an improved version of onnx.numpy_helper.to_array.
It first reads the data using the dtype corresponding to the tensor
proto data field, then converts it to the correct dtype and shape.
Special cases are bfloat16, complex and int4 where we need to
reinterpret the data. Other types can simply be casted.
When the data type is not supported by numpy, the value is the bit representation
of the dtype:
- ``int8`` for int4, with the sign bit extended to 8 bits.
- ``uint8`` for uint4.
- ``uint8`` for 8-bit data types like float8.
- ``uint16`` for bfloat16.
When the data type is a string, this method returns a numpy array
of bytes instead of a numpy array of strings, to follow the ONNX
specification.
External tensors are not supported by this class. Use
:class:`onnxscript.ir.ExternalTensor` instead.
Raises:
ValueError: If the data type is UNDEFINED.
"""
# This is an improved version of onnx.numpy_helper.to_array.
# It first reads the data using the dtype corresponding to the tensor
# proto data field, then converts it to the correct dtype and shape.
# Special cases are bfloat16, complex and int4 where we need to
# reinterpret the data. Other types can simply be casted.
dtype = self.dtype
if dtype == _enums.DataType.UNDEFINED:
raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")

if self._proto.HasField("raw_data"):
array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
This assignment to 'array' is unnecessary as it is
redefined
before this value is used.
# Cannot return now, because we may need to unpack 4bit tensors
if dtype == _enums.DataType.STRING:
return np.array(self._proto.string_data).reshape(self._proto.dims)
elif self._proto.int32_data:
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
if dtype == _enums.DataType.FLOAT16:
Expand Down Expand Up @@ -180,7 +201,19 @@ def numpy(self) -> np.ndarray:
return array.astype(dtype.numpy()).reshape(self._proto.dims)

def tobytes(self) -> bytes:
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian."""
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian.
Raises:
ValueError: If the tensor is a string tensor or an external tensor.
ValueError: If the tensor is of UNDEFINED data type.
"""
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
raise ValueError("Cannot convert external tensor to bytes.")
if self.dtype == _enums.DataType.STRING:
raise ValueError("Cannot convert string tensor to bytes.")
if self.dtype == _enums.DataType.UNDEFINED:
raise ValueError("Cannot convert UNDEFINED tensor to bytes.")

if self._proto.HasField("raw_data"):
return self._proto.raw_data
if self._proto.float_data:
Expand Down

0 comments on commit fa4239a

Please sign in to comment.