Skip to content

Commit

Permalink
Use IR types to define onnx_types (#1924)
Browse files Browse the repository at this point in the history
- Use IR types to define onnx_types so that it is not dependent on onnx
package version.
- Also add INT4 and UINT4 types.
- Make some helper functions private.
  • Loading branch information
justinchuby authored Nov 7, 2024
1 parent ec3b140 commit 3a7d6fd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 41 deletions.
82 changes: 44 additions & 38 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,27 @@
import onnx
import onnx.helper

DType = onnx.TensorProto.DataType
import onnxscript.ir

DimType = Union[int, str, type(None)]
_DType = onnxscript.ir.DataType
_DimType = Union[int, str, type(None)]
_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]

_tensor_type_shape_cache: dict[_DType, TensorType] = {}
tensor_type_registry: dict[_DType, TensorType] = {}

def check_dim(dim):

def _check_dim(dim):
if not isinstance(dim, (int, str, type(None))):
raise TypeError(f"Invalid dimension {dim}")


ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]


def check_shape(shape):
def _check_shape(shape):
if isinstance(shape, tuple):
for dim in shape:
check_dim(dim)
_check_dim(dim)
elif shape != Ellipsis:
check_dim(shape)


tensor_type_registry: dict[DType, TensorType] = {}
_tensor_type_shape_cache: dict[DType, TensorType] = {}
_check_dim(shape)


class TensorType(abc.ABC):
Expand All @@ -58,13 +56,13 @@ class TensorType(abc.ABC):
tensor: FLOAT[128, 1024]
"""

dtype: ClassVar[DType]
shape: ClassVar[Optional[ShapeType]]
dtype: ClassVar[_DType]
shape: ClassVar[Optional[_ShapeType]]

def __new__(cls):
raise NotImplementedError("TensorTypes cannot be instantiated")

def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
cls.dtype = dtype
cls.shape = shape
if shape is None:
Expand All @@ -76,9 +74,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
)
tensor_type_registry[dtype] = cls
else:
check_shape(shape)
_check_shape(shape)

def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
if cls.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
Expand Down Expand Up @@ -108,83 +106,91 @@ def to_string(cls) -> str:
return f"tensor({cls.__name__.lower()})"


class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
class FLOAT(TensorType, dtype=onnxscript.ir.DataType.FLOAT):
pass


class UINT8(TensorType, dtype=onnxscript.ir.DataType.UINT8):
pass


class INT8(TensorType, dtype=onnxscript.ir.DataType.INT8):
pass


class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
class UINT16(TensorType, dtype=onnxscript.ir.DataType.UINT16):
pass


class INT8(TensorType, dtype=onnx.TensorProto.INT8):
class INT16(TensorType, dtype=onnxscript.ir.DataType.INT16):
pass


class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
class INT32(TensorType, dtype=onnxscript.ir.DataType.INT32):
pass


class INT16(TensorType, dtype=onnx.TensorProto.INT16):
class INT64(TensorType, dtype=onnxscript.ir.DataType.INT64):
pass


class INT32(TensorType, dtype=onnx.TensorProto.INT32):
class STRING(TensorType, dtype=onnxscript.ir.DataType.STRING):
pass


class INT64(TensorType, dtype=onnx.TensorProto.INT64):
class BOOL(TensorType, dtype=onnxscript.ir.DataType.BOOL):
pass


class STRING(TensorType, dtype=onnx.TensorProto.STRING):
class FLOAT16(TensorType, dtype=onnxscript.ir.DataType.FLOAT16):
pass


class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
class DOUBLE(TensorType, dtype=onnxscript.ir.DataType.DOUBLE):
pass


class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
class UINT32(TensorType, dtype=onnxscript.ir.DataType.UINT32):
pass


class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
class UINT64(TensorType, dtype=onnxscript.ir.DataType.UINT64):
pass


class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
class COMPLEX64(TensorType, dtype=onnxscript.ir.DataType.COMPLEX64):
pass


class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
class COMPLEX128(TensorType, dtype=onnxscript.ir.DataType.COMPLEX128):
pass


class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
class BFLOAT16(TensorType, dtype=onnxscript.ir.DataType.BFLOAT16):
pass


class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
class FLOAT8E4M3FN(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FN):
pass


class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
class FLOAT8E4M3FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FNUZ):
pass


class FLOAT8E4M3FN(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FN):
class FLOAT8E5M2(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2):
pass


class FLOAT8E4M3FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FNUZ):
class FLOAT8E5M2FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2FNUZ):
pass


class FLOAT8E5M2(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2):
class INT4(TensorType, dtype=onnxscript.ir.DataType.INT4):
pass


class FLOAT8E5M2FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2FNUZ):
class UINT4(TensorType, dtype=onnxscript.ir.DataType.UINT4):
pass


Expand Down
6 changes: 3 additions & 3 deletions tests/onnx_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from parameterized import parameterized

from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry
from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry


class TestOnnxTypes(unittest.TestCase):
Expand All @@ -26,7 +26,7 @@ def test_instantiation(self):
FLOAT[...]()

@parameterized.expand(tensor_type_registry.items())
def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]):
def test_type_properties(self, dtype: int, tensor_type: type[TensorType]):
self.assertEqual(tensor_type.dtype, dtype)
self.assertIsNone(tensor_type.shape)
self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index]
Expand All @@ -35,7 +35,7 @@ def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]):
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index]

@parameterized.expand([(dtype,) for dtype in tensor_type_registry])
def test_dtype_bound_to_subclass(self, dtype: DType):
def test_dtype_bound_to_subclass(self, dtype: int):
with self.assertRaises(ValueError):
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)

Expand Down

0 comments on commit 3a7d6fd

Please sign in to comment.