Skip to content

Commit

Permalink
add uint16 to serialization
Browse files Browse the repository at this point in the history
Differential Revision: D65915213

Pull Request resolved: pytorch#6942
  • Loading branch information
JacobSzwejbka authored Nov 19, 2024
1 parent 1de96f8 commit 54feeef
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 4 deletions.
8 changes: 7 additions & 1 deletion devtools/bundled_program/schema/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ enum ScalarType : byte {
FLOAT = 6,
DOUBLE = 7,
BOOL = 11,
// TODO(jakeszwe): Verify these are unused and then remove support
QINT8 = 12,
QUINT8 = 13,
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
FLOAT8E5M2 = 23,
FLOAT8E4M3FN = 24,
FLOAT8E5M2FNUZ = 25,
FLOAT8E4M3FNUZ = 26,
UINT16 = 27,
UINT32 = 28,
UINT64 = 29,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
Expand Down
2 changes: 2 additions & 0 deletions devtools/etdump/etdump_flatcc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type(
return executorch_flatbuffer_ScalarType_BOOL;
case exec_aten::ScalarType::Bits16:
return executorch_flatbuffer_ScalarType_BITS16;
case exec_aten::ScalarType::UInt16:
return executorch_flatbuffer_ScalarType_UINT16;
default:
ET_CHECK_MSG(
0,
Expand Down
8 changes: 7 additions & 1 deletion devtools/etdump/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ enum ScalarType : byte {
FLOAT = 6,
DOUBLE = 7,
BOOL = 11,
// TODO(jakeszwe): Verify these are unused and then remove support
QINT8 = 12,
QUINT8 = 13,
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
FLOAT8E5M2 = 23,
FLOAT8E4M3FN = 24,
FLOAT8E5M2FNUZ = 25,
FLOAT8E4M3FNUZ = 26,
UINT16 = 27,
UINT32 = 28,
UINT64 = 29,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
Expand Down
7 changes: 7 additions & 0 deletions exir/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ class ScalarType(IntEnum):
QUINT4x2 = 16
QUINT2x4 = 17
BITS16 = 22
FLOAT8E5M2 = 23
FLOAT8E4M3FN = 24
FLOAT8E5M2FNUZ = 25
FLOAT8E4M3FNUZ = 26
UINT16 = 27
UINT32 = 28
UINT64 = 29
2 changes: 1 addition & 1 deletion exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
torch.qint32: ScalarType.QINT32,
torch.bfloat16: ScalarType.BFLOAT16,
torch.quint4x2: ScalarType.QUINT4x2,
torch.uint16: ScalarType.BITS16,
torch.uint16: ScalarType.UINT16,
}


Expand Down
8 changes: 7 additions & 1 deletion schema/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ enum ScalarType : byte {
FLOAT = 6,
DOUBLE = 7,
BOOL = 11,
// TODO(jakeszwe): Verify these are unused and then remove support
QINT8 = 12,
QUINT8 = 13,
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
FLOAT8E5M2 = 23,
FLOAT8E4M3FN = 24,
FLOAT8E5M2FNUZ = 25,
FLOAT8E4M3FNUZ = 26,
UINT16 = 27,
UINT32 = 28,
UINT64 = 29,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
Expand Down

0 comments on commit 54feeef

Please sign in to comment.