From 3a7d6fd0657ec4de4172d5dce2806a4dd82e1fa1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 7 Nov 2024 09:41:49 -0800 Subject: [PATCH] Use IR types to define onnx_types (#1924) - Use IR types to define onnx_types so that it is not dependent on onnx package version. - Also add INT4 and UINT4 types. - Make some helper functions private. --- onnxscript/onnx_types.py | 82 +++++++++++++++++++++------------------- tests/onnx_types_test.py | 6 +-- 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index d4ddb2fe8..5ddb2bbb1 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -9,29 +9,27 @@ import onnx import onnx.helper -DType = onnx.TensorProto.DataType +import onnxscript.ir -DimType = Union[int, str, type(None)] +_DType = onnxscript.ir.DataType +_DimType = Union[int, str, type(None)] +_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)] +_tensor_type_shape_cache: dict[_DType, TensorType] = {} +tensor_type_registry: dict[_DType, TensorType] = {} -def check_dim(dim): + +def _check_dim(dim): if not isinstance(dim, (int, str, type(None))): raise TypeError(f"Invalid dimension {dim}") -ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)] - - -def check_shape(shape): +def _check_shape(shape): if isinstance(shape, tuple): for dim in shape: - check_dim(dim) + _check_dim(dim) elif shape != Ellipsis: - check_dim(shape) - - -tensor_type_registry: dict[DType, TensorType] = {} -_tensor_type_shape_cache: dict[DType, TensorType] = {} + _check_dim(shape) class TensorType(abc.ABC): @@ -58,13 +56,13 @@ class TensorType(abc.ABC): tensor: FLOAT[128, 1024] """ - dtype: ClassVar[DType] - shape: ClassVar[Optional[ShapeType]] + dtype: ClassVar[_DType] + shape: ClassVar[Optional[_ShapeType]] def __new__(cls): raise NotImplementedError("TensorTypes cannot be instantiated") - def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): + def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None): cls.dtype = dtype cls.shape = shape if shape is None: @@ -76,9 +74,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): ) tensor_type_registry[dtype] = cls else: - check_shape(shape) + _check_shape(shape) - def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]: + def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]: if cls.shape is not None: raise ValueError("Invalid usage: shape already specified.") if shape is None: @@ -108,83 +106,91 @@ def to_string(cls) -> str: return f"tensor({cls.__name__.lower()})" -class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): +class FLOAT(TensorType, dtype=onnxscript.ir.DataType.FLOAT): + pass + + +class UINT8(TensorType, dtype=onnxscript.ir.DataType.UINT8): + pass + + +class INT8(TensorType, dtype=onnxscript.ir.DataType.INT8): pass -class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): +class UINT16(TensorType, dtype=onnxscript.ir.DataType.UINT16): pass -class INT8(TensorType, dtype=onnx.TensorProto.INT8): +class INT16(TensorType, dtype=onnxscript.ir.DataType.INT16): pass -class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): +class INT32(TensorType, dtype=onnxscript.ir.DataType.INT32): pass -class INT16(TensorType, dtype=onnx.TensorProto.INT16): +class INT64(TensorType, dtype=onnxscript.ir.DataType.INT64): pass -class INT32(TensorType, dtype=onnx.TensorProto.INT32): +class STRING(TensorType, dtype=onnxscript.ir.DataType.STRING): pass -class INT64(TensorType, dtype=onnx.TensorProto.INT64): +class BOOL(TensorType, dtype=onnxscript.ir.DataType.BOOL): pass -class STRING(TensorType, dtype=onnx.TensorProto.STRING): +class FLOAT16(TensorType, dtype=onnxscript.ir.DataType.FLOAT16): pass -class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): +class DOUBLE(TensorType, dtype=onnxscript.ir.DataType.DOUBLE): pass -class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): +class UINT32(TensorType, dtype=onnxscript.ir.DataType.UINT32): pass -class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): +class UINT64(TensorType, dtype=onnxscript.ir.DataType.UINT64): pass -class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): +class COMPLEX64(TensorType, dtype=onnxscript.ir.DataType.COMPLEX64): pass -class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): +class COMPLEX128(TensorType, dtype=onnxscript.ir.DataType.COMPLEX128): pass -class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): +class BFLOAT16(TensorType, dtype=onnxscript.ir.DataType.BFLOAT16): pass -class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): +class FLOAT8E4M3FN(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FN): pass -class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): +class FLOAT8E4M3FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FNUZ): pass -class FLOAT8E4M3FN(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FN): +class FLOAT8E5M2(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2): pass -class FLOAT8E4M3FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FNUZ): +class FLOAT8E5M2FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2FNUZ): pass -class FLOAT8E5M2(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2): +class INT4(TensorType, dtype=onnxscript.ir.DataType.INT4): pass -class FLOAT8E5M2FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2FNUZ): +class UINT4(TensorType, dtype=onnxscript.ir.DataType.UINT4): pass diff --git a/tests/onnx_types_test.py b/tests/onnx_types_test.py index 8e9a96eb5..1f7a98cc1 100644 --- a/tests/onnx_types_test.py +++ b/tests/onnx_types_test.py @@ -13,7 +13,7 @@ from parameterized import parameterized -from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry +from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry class TestOnnxTypes(unittest.TestCase): @@ -26,7 +26,7 @@ def test_instantiation(self): FLOAT[...]() @parameterized.expand(tensor_type_registry.items()) - def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): + def test_type_properties(self, dtype: int, tensor_type: type[TensorType]): self.assertEqual(tensor_type.dtype, dtype) self.assertIsNone(tensor_type.shape) self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index] @@ -35,7 +35,7 @@ def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index] @parameterized.expand([(dtype,) for dtype in tensor_type_registry]) - def test_dtype_bound_to_subclass(self, dtype: DType): + def test_dtype_bound_to_subclass(self, dtype: int): with self.assertRaises(ValueError): type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)