From 5dd537a214156b8a778f7ac63c8dcbb8960fd497 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Thu, 30 Nov 2023 16:58:08 +0000 Subject: [PATCH] #4014: Add testing for uint16 and uint32 on device - Fix bug with unpack in slow dispatch path for uint16 --- .../python_api_testing/unit_testing/test_tensor.py | 8 +++++++- tt_eager/tensor/tensor_impl.cpp | 11 ++++++----- tt_eager/tensor/tensor_impl.hpp | 4 ++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_tensor.py b/tests/tt_eager/python_api_testing/unit_testing/test_tensor.py index 8503ef767f9..aadb67fd545 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_tensor.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_tensor.py @@ -25,6 +25,7 @@ "tt_dtype", [ ttl.tensor.DataType.UINT32, + ttl.tensor.DataType.UINT16, ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT16, ], @@ -38,7 +39,12 @@ def test_tensor_conversion_between_torch_and_tt(shape, tt_dtype, device): torch_tensor = torch.rand(shape, dtype=dtype) tt_tensor = ttl.tensor.Tensor(torch_tensor, tt_dtype) - if tt_dtype in {ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B}: + if tt_dtype in { + ttl.tensor.DataType.BFLOAT16, + ttl.tensor.DataType.BFLOAT8_B, + ttl.tensor.DataType.UINT32, + ttl.tensor.DataType.UINT16, + }: tt_tensor = tt_tensor.to(device) tt_tensor = tt_tensor.cpu() diff --git a/tt_eager/tensor/tensor_impl.cpp b/tt_eager/tensor/tensor_impl.cpp index 7355b9c1371..0fa637f5528 100644 --- a/tt_eager/tensor/tensor_impl.cpp +++ b/tt_eager/tensor/tensor_impl.cpp @@ -42,7 +42,8 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, page_size = constants::TILE_HW * size_of_element; } break; - case DataType::UINT32: { + case DataType::UINT32: + case DataType::UINT16: { uint32_t size_of_element = element_size_bytes_wrapper(dtype); page_size = constants::TILE_HW * size_of_element; } @@ -94,21 +95,21 @@ void validate_on_device_dtype_and_layout(Device *device, DataType dtype, Layout // TODO: Get supported layout and dtypes from device auto supported_dtype = [&dtype]() { TT_ASSERT( - (dtype == DataType::BFLOAT16 || dtype == DataType::BFLOAT8_B || dtype == DataType::UINT32) && - "Only BFLOAT16 , BFLOAT8_B or UINT32 is supported on device!" + (dtype == DataType::BFLOAT16 || dtype == DataType::BFLOAT8_B || dtype == DataType::UINT32 || dtype == DataType::UINT16) && + "Only BFLOAT16, BFLOAT8_B, UINT32, or UINT16 is supported on device!" ); }; auto supported_layout = [&dtype, &layout]() { switch (dtype) { case DataType::UINT32: - break; + case DataType::UINT16: case DataType::BFLOAT16: break; case DataType::BFLOAT8_B: TT_ASSERT(layout == Layout::TILE && "Only TILE layout is supported for BFLOAT8_B dtype!"); break; default: - TT_ASSERT(false && "Only BFLOAT16 or BFLOAT8_B is supported on device!"); + TT_ASSERT(false && "Only BFLOAT16, BFLOAT8_B, UINT32, or UINT16 is supported on device!"); break; } }; diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 2c07ddf87a3..0782e653399 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -87,8 +87,8 @@ constexpr inline std::vector unpack_uint32_vec(std::vector & } else if constexpr (std::is_same_v) { std::vector output; for (auto index = 0; index < data_to_unpack.size(); index++) { - output.push_back(data_to_unpack[index] >> 16 & 0xFF); - output.push_back(data_to_unpack[index] & 0xFF); + output.push_back(data_to_unpack[index] >> 16); + output.push_back(data_to_unpack[index] & 0xFFFF); } return output; } else if constexpr (std::is_same_v) {