Skip to content

Commit

Permalink
#4014: added support for uint16 datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Nov 27, 2023
1 parent bd90c9b commit 4872bf0
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 43 deletions.
31 changes: 24 additions & 7 deletions tests/tt_eager/python_api_testing/unit_testing/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


tt_dtype_to_torch_dtype = {
ttl.tensor.DataType.UINT16: torch.int16,
ttl.tensor.DataType.UINT32: torch.int32,
ttl.tensor.DataType.FLOAT32: torch.float,
ttl.tensor.DataType.BFLOAT16: torch.bfloat16,
Expand All @@ -31,7 +32,7 @@
def test_tensor_conversion_between_torch_and_tt(shape, tt_dtype, device):
dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype == torch.int32:
if dtype in {torch.int16, torch.int32}:
torch_tensor = torch.randint(0, 1024, shape, dtype=dtype)
else:
torch_tensor = torch.rand(shape, dtype=dtype)
Expand All @@ -54,6 +55,7 @@ def test_tensor_conversion_between_torch_and_tt(shape, tt_dtype, device):
@pytest.mark.parametrize(
"tt_dtype",
[
ttl.tensor.DataType.UINT16,
ttl.tensor.DataType.UINT32,
ttl.tensor.DataType.FLOAT32,
ttl.tensor.DataType.BFLOAT16,
Expand All @@ -65,7 +67,7 @@ def test_serialization(tmp_path, shape, tt_dtype):

dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype == torch.int32:
if dtype in {torch.int16, torch.int32}:
torch_tensor = torch.randint(0, 1024, shape, dtype=dtype)
else:
torch_tensor = torch.rand(shape, dtype=dtype)
Expand All @@ -91,6 +93,7 @@ def test_serialization(tmp_path, shape, tt_dtype):
@pytest.mark.parametrize(
"tt_dtype",
[
ttl.tensor.DataType.UINT16,
ttl.tensor.DataType.UINT32,
ttl.tensor.DataType.FLOAT32,
ttl.tensor.DataType.BFLOAT16,
Expand All @@ -101,17 +104,31 @@ def test_print(shape, tt_dtype):

dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype == torch.int32:
if dtype in {torch.int16, torch.int32}:
torch_tensor = torch.randint(0, 1024, shape, dtype=dtype)
else:
torch_tensor = torch.rand(shape, dtype=dtype)

tt_tensor = ttl.tensor.Tensor(torch_tensor, tt_dtype)
if tt_dtype == ttl.tensor.DataType.UINT32:
assert str(tt_tensor) == "Tensor([ [[[684, 559, 629, 192],\n [835, 763, 707, 359],\n [9, 723, 277, 754]],\n\n [[804, 599, 70, 472],\n [600, 396, 314, 705],\n [486, 551, 87, 174]]]], dtype=uint32 )\n"
if tt_dtype == ttl.tensor.DataType.UINT16:
assert (
str(tt_tensor)
== "Tensor([ [[[684, 559, 629, 192],\n [835, 763, 707, 359],\n [9, 723, 277, 754]],\n\n [[804, 599, 70, 472],\n [600, 396, 314, 705],\n [486, 551, 87, 174]]]], dtype=uint16 )\n"
)
elif tt_dtype == ttl.tensor.DataType.UINT32:
assert (
str(tt_tensor)
== "Tensor([ [[[684, 559, 629, 192],\n [835, 763, 707, 359],\n [9, 723, 277, 754]],\n\n [[804, 599, 70, 472],\n [600, 396, 314, 705],\n [486, 551, 87, 174]]]], dtype=uint32 )\n"
)
elif tt_dtype == ttl.tensor.DataType.FLOAT32:
assert str(tt_tensor) == "Tensor([ [[[0.496257, 0.768222, 0.0884774, 0.13203],\n [0.307423, 0.634079, 0.490093, 0.896445],\n [0.455628, 0.632306, 0.348893, 0.401717]],\n\n [[0.0223258, 0.168859, 0.293888, 0.518522],\n [0.697668, 0.800011, 0.161029, 0.282269],\n [0.681609, 0.915194, 0.3971, 0.874156]]]], dtype=float32 )\n"
assert (
str(tt_tensor)
== "Tensor([ [[[0.496257, 0.768222, 0.0884774, 0.13203],\n [0.307423, 0.634079, 0.490093, 0.896445],\n [0.455628, 0.632306, 0.348893, 0.401717]],\n\n [[0.0223258, 0.168859, 0.293888, 0.518522],\n [0.697668, 0.800011, 0.161029, 0.282269],\n [0.681609, 0.915194, 0.3971, 0.874156]]]], dtype=float32 )\n"
)
elif tt_dtype == ttl.tensor.DataType.BFLOAT16:
assert str(tt_tensor) == "Tensor([ [[[0.671875, 0.183594, 0.457031, 0.75],\n [0.261719, 0.980469, 0.761719, 0.402344],\n [0.0351562, 0.824219, 0.0820312, 0.945312]],\n\n [[0.140625, 0.339844, 0.273438, 0.84375],\n [0.34375, 0.546875, 0.226562, 0.753906],\n [0.898438, 0.152344, 0.339844, 0.679688]]]], dtype=bfloat16 )\n"
assert (
str(tt_tensor)
== "Tensor([ [[[0.671875, 0.183594, 0.457031, 0.75],\n [0.261719, 0.980469, 0.761719, 0.402344],\n [0.0351562, 0.824219, 0.0820312, 0.945312]],\n\n [[0.140625, 0.339844, 0.273438, 0.84375],\n [0.34375, 0.546875, 0.226562, 0.753906],\n [0.898438, 0.152344, 0.339844, 0.679688]]]], dtype=bfloat16 )\n"
)
else:
raise ValueError(f"Unsupported dtype: {tt_dtype}")
12 changes: 6 additions & 6 deletions tt_eager/tensor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ OwnedStorage load_owned_storage(ifstream& input_stream, DataType data_type) {
if (data_type == DataType::UINT32 or data_type == DataType::BFLOAT8_B) {
using T = std::uint32_t;
return load_owned_storage<T>(input_stream);
}
else if (data_type == DataType::FLOAT32) {
} else if (data_type == DataType::UINT16) {
using T = std::uint16_t;
return load_owned_storage<T>(input_stream);
} else if (data_type == DataType::FLOAT32) {
using T = float;
return load_owned_storage<T>(input_stream);
}
else if (data_type == DataType::BFLOAT16) {
} else if (data_type == DataType::BFLOAT16) {
using T = bfloat16;
return load_owned_storage<T>(input_stream);
}
else {
} else {
TT_THROW("Unsupported DataType");
}
}
Expand Down
3 changes: 2 additions & 1 deletion tt_eager/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ namespace tensor_impl {

std::ostream& operator<<(std::ostream& os, const DataType& dtype) {
switch (dtype) {
case DataType::BFLOAT8_B: os << "bfloat8_b"; break;
case DataType::BFLOAT16: os << "bfloat16"; break;
case DataType::FLOAT32: os << "float32"; break;
case DataType::UINT16: os << "uint16"; break;
case DataType::UINT32: os << "uint32"; break;
case DataType::BFLOAT8_B: os << "bfloat8_b"; break;
default: throw std::invalid_argument("Unknown data type");
}
return os;
Expand Down
26 changes: 18 additions & 8 deletions tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ template <typename DataType, template<typename> typename BufferType>
constexpr inline std::vector<uint32_t> pack_vec_into_uint32_vec(const BufferType<DataType>& data_to_pack) {
if constexpr (std::is_same_v<DataType, uint32_t>) {
return std::vector(std::begin(data_to_pack), std::end(data_to_pack));
}
else if constexpr (std::is_same_v<DataType, bfloat16>) {
} else if constexpr (std::is_same_v<DataType, uint16_t>) {
std::vector<uint32_t> output;
for (auto index = 0; index < data_to_pack.size(); index += 2) {
auto value = data_to_pack[index] << 16 + data_to_pack[index + 1];
output.push_back(value);
}
return output;
} else if constexpr (std::is_same_v<DataType, bfloat16>) {
auto bfloat16_vec = std::vector(std::begin(data_to_pack), std::end(data_to_pack));
return pack_bfloat16_vec_into_uint32_vec(bfloat16_vec);
}
else if constexpr (std::is_same_v<DataType, float>) {
} else if constexpr (std::is_same_v<DataType, float>) {
std::vector<uint32_t> uint32_data;
assert(data_to_pack.size() % 2 == 0);
for (auto i = 0; i < data_to_pack.size(); i += 2) {
Expand All @@ -79,11 +84,16 @@ template <typename DataType>
constexpr inline std::vector<DataType> unpack_uint32_vec(std::vector<uint32_t> &data_to_unpack) {
if constexpr (std::is_same_v<DataType, uint32_t>) {
return data_to_unpack;
}
else if constexpr (std::is_same_v<DataType, bfloat16>) {
} else if constexpr (std::is_same_v<DataType, uint16_t>) {
std::vector<DataType> output;
for (auto index = 0; index < data_to_unpack.size(); index++) {
output.push_back(data_to_unpack[index] >> 16 & 0xFF);
output.push_back(data_to_unpack[index] & 0xFF);
}
return output;
} else if constexpr (std::is_same_v<DataType, bfloat16>) {
return unpack_uint32_vec_into_bfloat16_vec(data_to_unpack);
}
else if constexpr (std::is_same_v<DataType, float>) {
} else if constexpr (std::is_same_v<DataType, float>) {
std::vector<float> float_data;
for (auto i = 0; i < data_to_unpack.size(); i++) {
auto unpacked = unpack_two_bfloat16_from_uint32(data_to_unpack[i]);
Expand Down
42 changes: 26 additions & 16 deletions tt_eager/tensor/tensor_impl_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ uint32_t element_size_bytes_wrapper(DataType dtype) {
const static std::map<DataType, std::function<uint32_t()>> element_size_bytes_map = {
{DataType::BFLOAT16, &element_size_bytes<bfloat16>},
{DataType::FLOAT32, &element_size_bytes<float>},
{DataType::UINT32, &element_size_bytes<uint32_t>}
{DataType::UINT32, &element_size_bytes<uint32_t>},
{DataType::UINT16, &element_size_bytes<uint16_t>},
};
return element_size_bytes_map.at(dtype)();
}
Expand All @@ -26,7 +27,8 @@ uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpack
{DataType::BFLOAT16, &packed_buffer_size_bytes<bfloat16>},
{DataType::FLOAT32, &packed_buffer_size_bytes<float>},
{DataType::UINT32, &packed_buffer_size_bytes<uint32_t>},
{DataType::BFLOAT8_B, &packed_buffer_size_bytes<uint32_t>}
{DataType::BFLOAT8_B, &packed_buffer_size_bytes<uint32_t>},
{DataType::UINT16, &packed_buffer_size_bytes<uint16_t>},
};
return packed_buffer_size_bytes_map.at(dtype)(volume_unpacked_data);
}
Expand All @@ -36,18 +38,21 @@ Tensor to_host_wrapper(const Tensor &tensor) {
{DataType::BFLOAT16, &to_host<bfloat16>},
{DataType::FLOAT32, &to_host<float>},
{DataType::UINT32, &to_host<uint32_t>},
{DataType::BFLOAT8_B, &to_host<uint32_t>}
{DataType::BFLOAT8_B, &to_host<uint32_t>},
{DataType::UINT16, &to_host<uint16_t>},
};
return to_host_map.at(tensor.dtype())(tensor);
}

Tensor to_device_wrapper(const Tensor &tensor, Device *target_device, const MemoryConfig &mem_config) {
const static std::map<DataType, std::function<Tensor(const Tensor &, Device *, const MemoryConfig &)>> to_device_map = {
{DataType::BFLOAT16, &to_device<bfloat16>},
{DataType::FLOAT32, &to_device<float>},
{DataType::UINT32, &to_device<uint32_t>},
{DataType::BFLOAT8_B, &to_device<uint32_t>}
};
const static std::map<DataType, std::function<Tensor(const Tensor &, Device *, const MemoryConfig &)>>
to_device_map = {
{DataType::BFLOAT16, &to_device<bfloat16>},
{DataType::FLOAT32, &to_device<float>},
{DataType::UINT32, &to_device<uint32_t>},
{DataType::BFLOAT8_B, &to_device<uint32_t>},
{DataType::UINT16, &to_device<uint16_t>},
};
return to_device_map.at(tensor.dtype())(tensor, target_device, mem_config);
}

Expand All @@ -57,17 +62,20 @@ Tensor to_layout_wrapper(const Tensor &tensor, Layout target_layout) {
{DataType::FLOAT32, &to_layout<float>},
{DataType::UINT32, &to_layout<uint32_t>},
{DataType::BFLOAT8_B, &to_layout_bfloat8_b},
{DataType::UINT16, &to_layout<uint16_t>},
};
return to_layout_map.at(tensor.dtype())(tensor, target_layout);
}

Tensor pad_wrapper(const Tensor &tensor, const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) {
const static std::map<DataType, std::function<Tensor(const Tensor &, const Shape &, const Shape &, float)>> pad_map = {
{DataType::BFLOAT16, &pad<bfloat16>},
{DataType::FLOAT32, &pad<float>},
{DataType::UINT32, &pad<uint32_t>},
{DataType::BFLOAT8_B, &pad_bfloat8_b},
};
const static std::map<DataType, std::function<Tensor(const Tensor &, const Shape &, const Shape &, float)>>
pad_map = {
{DataType::BFLOAT16, &pad<bfloat16>},
{DataType::FLOAT32, &pad<float>},
{DataType::UINT32, &pad<uint32_t>},
{DataType::BFLOAT8_B, &pad_bfloat8_b},
{DataType::UINT16, &pad<uint16_t>},
};
return pad_map.at(tensor.dtype())(tensor, output_tensor_shape, input_tensor_start, pad_value);
}

Expand All @@ -77,6 +85,7 @@ Tensor unpad_wrapper(const Tensor &tensor, const Shape &output_tensor_start, con
{DataType::FLOAT32, &unpad<float>},
{DataType::UINT32, &unpad<uint32_t>},
{DataType::BFLOAT8_B, &unpad_bfloat8_b},
{DataType::UINT16, &unpad<uint16_t>},
};
return unpad_map.at(tensor.dtype())(tensor, output_tensor_start, output_tensor_end);
}
Expand All @@ -86,7 +95,8 @@ std::string to_string_wrapper(const Tensor &tensor, Layout print_layout, bool pr
{DataType::BFLOAT16, &to_string<bfloat16>},
{DataType::FLOAT32, &to_string<float>},
{DataType::UINT32, &to_string<uint32_t>},
{DataType::BFLOAT8_B, &to_string<uint32_t>}
{DataType::BFLOAT8_B, &to_string<uint32_t>},
{DataType::UINT16, &to_string<uint16_t>},
};
return to_string_map.at(tensor.dtype())(tensor, print_layout, pretty_print);
}
Expand Down
11 changes: 6 additions & 5 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ enum class DataType {
BFLOAT16 = 0,
FLOAT32 = 1,
UINT32 = 2,
BFLOAT8_B = 3
BFLOAT8_B = 3,
UINT16 = 4,
};

enum class StorageType {
Expand Down Expand Up @@ -145,10 +146,10 @@ bool operator==(const MemoryConfig& config_a, const MemoryConfig& config_b);
bool operator!=(const MemoryConfig& config_a, const MemoryConfig& config_b);

using OwnedBuffer = std::variant<
owned_buffer::Buffer<uint16_t>,
owned_buffer::Buffer<uint32_t>,
owned_buffer::Buffer<float>,
owned_buffer::Buffer<bfloat16>
>;
owned_buffer::Buffer<bfloat16>>;
struct OwnedStorage {
OwnedBuffer buffer;

Expand All @@ -167,10 +168,10 @@ struct DeviceStorage {
};

using BorrowedBuffer = std::variant<
borrowed_buffer::Buffer<uint16_t>,
borrowed_buffer::Buffer<uint32_t>,
borrowed_buffer::Buffer<float>,
borrowed_buffer::Buffer<bfloat16>
>;
borrowed_buffer::Buffer<bfloat16>>;
struct BorrowedStorage {
BorrowedBuffer buffer;
std::function<void()> on_creation_callback = []{};
Expand Down
9 changes: 9 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ void TensorModule(py::module &m_tensor) {
.def(py::self == py::self)
.def(py::self != py::self);

auto py_owned_buffer_for_uint16_t =
py::class_<owned_buffer::Buffer<uint16_t>>(m_tensor, "owned_buffer_for_uint16_t", py::buffer_protocol());
detail::implement_buffer_protocol<owned_buffer::Buffer<uint16_t>, uint16_t>(py_owned_buffer_for_uint16_t);

auto py_owned_buffer_for_uint32_t = py::class_<owned_buffer::Buffer<uint32_t>>(m_tensor, "owned_buffer_for_uint32_t", py::buffer_protocol());
detail::implement_buffer_protocol<owned_buffer::Buffer<uint32_t>, uint32_t>(py_owned_buffer_for_uint32_t);

Expand All @@ -196,6 +200,11 @@ void TensorModule(py::module &m_tensor) {
auto py_owned_buffer_for_bfloat16_t = py::class_<owned_buffer::Buffer<bfloat16>>(m_tensor, "owned_buffer_for_bfloat16_t", py::buffer_protocol());
detail::implement_buffer_protocol<owned_buffer::Buffer<bfloat16>, bfloat16>(py_owned_buffer_for_bfloat16_t);

auto py_borrowed_buffer_for_uint16_t = py::class_<borrowed_buffer::Buffer<std::uint16_t>>(
m_tensor, "borrowed_buffer_for_uint16_t", py::buffer_protocol());
detail::implement_buffer_protocol<borrowed_buffer::Buffer<std::uint16_t>, std::uint16_t>(
py_borrowed_buffer_for_uint16_t);

auto py_borrowed_buffer_for_uint32_t = py::class_<borrowed_buffer::Buffer<std::uint32_t>>(m_tensor, "borrowed_buffer_for_uint32_t", py::buffer_protocol());
detail::implement_buffer_protocol<borrowed_buffer::Buffer<std::uint32_t>, std::uint32_t>(py_borrowed_buffer_for_uint32_t);

Expand Down
23 changes: 23 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ Tensor convert_torch_tensor_to_tt_tensor(
}

switch (data_type) {
case DataType::UINT16: {
if (not torch_dtype.equal(torch.attr("int32"))) {
borrow_storage = false;
contiguous_torch_tensor = contiguous_torch_tensor.attr("to")(torch.attr("int16"));
}
break;
}
case DataType::UINT32: {
if (not torch_dtype.equal(torch.attr("int32"))) {
borrow_storage = false;
Expand Down Expand Up @@ -87,6 +94,21 @@ Tensor convert_torch_tensor_to_tt_tensor(
auto on_destruction_callback = [tensor = contiguous_torch_tensor] { tensor.dec_ref(); };

switch (data_type) {
case DataType::UINT16: {
auto data_ptr =
reinterpret_cast<uint16_t *>(py::cast<std::size_t>(contiguous_torch_tensor.attr("data_ptr")()));
auto num_elements = py::cast<std::size_t>(contiguous_torch_tensor.attr("numel")());
if (borrow_storage) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR);
} else {
std::vector<uint16_t> uint16_t_vector(data_ptr, data_ptr + num_elements);
auto buffer = owned_buffer::create<uint16_t>(std::move(uint16_t_vector));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR);
}
}
case DataType::UINT32: {
auto data_ptr =
reinterpret_cast<uint32_t *>(py::cast<std::size_t>(contiguous_torch_tensor.attr("data_ptr")()));
Expand Down Expand Up @@ -208,6 +230,7 @@ Tensor convert_torch_tensor_to_tt_tensor(
}

const auto tt_dtype_to_torch_dtype = std::map<DataType, py::object> {
{DataType::UINT16, torch.attr("int16")}, // TODO(arakhmati): add DataType::INT16
{DataType::UINT32, torch.attr("int32")}, // TODO(arakhmati): add DataType::INT32
{DataType::FLOAT32, torch.attr("float32")},
{DataType::BFLOAT16, torch.attr("bfloat16")},
Expand Down

0 comments on commit 4872bf0

Please sign in to comment.