diff --git a/devtools/bundled_program/schema/scalar_type.fbs b/devtools/bundled_program/schema/scalar_type.fbs index fc299ac691..e9c830b972 100644 --- a/devtools/bundled_program/schema/scalar_type.fbs +++ b/devtools/bundled_program/schema/scalar_type.fbs @@ -18,13 +18,19 @@ enum ScalarType : byte { FLOAT = 6, DOUBLE = 7, BOOL = 11, - // TODO(jakeszwe): Verify these are unused and then remove support QINT8 = 12, QUINT8 = 13, QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, BITS16 = 22, + FLOAT8E5M2 = 23, + FLOAT8E4M3FN = 24, + FLOAT8E5M2FNUZ = 25, + FLOAT8E4M3FNUZ = 26, + UINT16 = 27, + UINT32 = 28, + UINT64 = 29, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index cfd1d2ae14..f34b28e650 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -57,6 +57,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type( return executorch_flatbuffer_ScalarType_BOOL; case exec_aten::ScalarType::Bits16: return executorch_flatbuffer_ScalarType_BITS16; + case exec_aten::ScalarType::UInt16: + return executorch_flatbuffer_ScalarType_UINT16; default: ET_CHECK_MSG( 0, diff --git a/devtools/etdump/scalar_type.fbs b/devtools/etdump/scalar_type.fbs index fc299ac691..e9c830b972 100644 --- a/devtools/etdump/scalar_type.fbs +++ b/devtools/etdump/scalar_type.fbs @@ -18,13 +18,19 @@ enum ScalarType : byte { FLOAT = 6, DOUBLE = 7, BOOL = 11, - // TODO(jakeszwe): Verify these are unused and then remove support QINT8 = 12, QUINT8 = 13, QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, BITS16 = 22, + FLOAT8E5M2 = 23, + FLOAT8E4M3FN = 24, + FLOAT8E5M2FNUZ = 25, + FLOAT8E4M3FNUZ = 26, + UINT16 = 27, + UINT32 = 28, + UINT64 = 29, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, diff --git a/exir/scalar_type.py b/exir/scalar_type.py index 5d41038610..4b0a4f412f 100644 --- a/exir/scalar_type.py +++ b/exir/scalar_type.py @@ -29,3 +29,10 @@ class ScalarType(IntEnum): QUINT4x2 = 16 QUINT2x4 = 17 BITS16 = 22 + FLOAT8E5M2 = 23 + FLOAT8E4M3FN = 24 + FLOAT8E5M2FNUZ = 25 + FLOAT8E4M3FNUZ = 26 + UINT16 = 27 + UINT32 = 28 + UINT64 = 29 diff --git a/exir/tensor.py b/exir/tensor.py index a40bef4e5e..0c5218bb59 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, - torch.uint16: ScalarType.BITS16, + torch.uint16: ScalarType.UINT16, } diff --git a/schema/scalar_type.fbs b/schema/scalar_type.fbs index fc299ac691..e9c830b972 100644 --- a/schema/scalar_type.fbs +++ b/schema/scalar_type.fbs @@ -18,13 +18,19 @@ enum ScalarType : byte { FLOAT = 6, DOUBLE = 7, BOOL = 11, - // TODO(jakeszwe): Verify these are unused and then remove support QINT8 = 12, QUINT8 = 13, QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, BITS16 = 22, + FLOAT8E5M2 = 23, + FLOAT8E4M3FN = 24, + FLOAT8E5M2FNUZ = 25, + FLOAT8E4M3FNUZ = 26, + UINT16 = 27, + UINT32 = 28, + UINT64 = 29, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9,