Skip to content

Commit

Permalink
#4014: Add testing for uint16 and uint32 on device
Browse files Browse the repository at this point in the history
- Fix bug with unpack in slow dispatch path for uint16
  • Loading branch information
TT-BrianLiu committed Dec 1, 2023
1 parent caa4859 commit 5dd537a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"tt_dtype",
[
ttl.tensor.DataType.UINT32,
ttl.tensor.DataType.UINT16,
ttl.tensor.DataType.FLOAT32,
ttl.tensor.DataType.BFLOAT16,
],
Expand All @@ -38,7 +39,12 @@ def test_tensor_conversion_between_torch_and_tt(shape, tt_dtype, device):
torch_tensor = torch.rand(shape, dtype=dtype)

tt_tensor = ttl.tensor.Tensor(torch_tensor, tt_dtype)
if tt_dtype in {ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B}:
if tt_dtype in {
ttl.tensor.DataType.BFLOAT16,
ttl.tensor.DataType.BFLOAT8_B,
ttl.tensor.DataType.UINT32,
ttl.tensor.DataType.UINT16,
}:
tt_tensor = tt_tensor.to(device)
tt_tensor = tt_tensor.cpu()

Expand Down
11 changes: 6 additions & 5 deletions tt_eager/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes,
page_size = constants::TILE_HW * size_of_element;
}
break;
case DataType::UINT32: {
case DataType::UINT32:
case DataType::UINT16: {
uint32_t size_of_element = element_size_bytes_wrapper(dtype);
page_size = constants::TILE_HW * size_of_element;
}
Expand Down Expand Up @@ -94,21 +95,21 @@ void validate_on_device_dtype_and_layout(Device *device, DataType dtype, Layout
// TODO: Get supported layout and dtypes from device
auto supported_dtype = [&dtype]() {
TT_ASSERT(
(dtype == DataType::BFLOAT16 || dtype == DataType::BFLOAT8_B || dtype == DataType::UINT32) &&
"Only BFLOAT16 , BFLOAT8_B or UINT32 is supported on device!"
(dtype == DataType::BFLOAT16 || dtype == DataType::BFLOAT8_B || dtype == DataType::UINT32 || dtype == DataType::UINT16) &&
"Only BFLOAT16, BFLOAT8_B, UINT32, or UINT16 is supported on device!"
);
};
auto supported_layout = [&dtype, &layout]() {
switch (dtype) {
case DataType::UINT32:
break;
case DataType::UINT16:
case DataType::BFLOAT16:
break;
case DataType::BFLOAT8_B:
TT_ASSERT(layout == Layout::TILE && "Only TILE layout is supported for BFLOAT8_B dtype!");
break;
default:
TT_ASSERT(false && "Only BFLOAT16 or BFLOAT8_B is supported on device!");
TT_ASSERT(false && "Only BFLOAT16, BFLOAT8_B, UINT32, or UINT16 is supported on device!");
break;
}
};
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ constexpr inline std::vector<DataType> unpack_uint32_vec(std::vector<uint32_t> &
} else if constexpr (std::is_same_v<DataType, uint16_t>) {
std::vector<DataType> output;
for (auto index = 0; index < data_to_unpack.size(); index++) {
output.push_back(data_to_unpack[index] >> 16 & 0xFF);
output.push_back(data_to_unpack[index] & 0xFF);
output.push_back(data_to_unpack[index] >> 16);
output.push_back(data_to_unpack[index] & 0xFFFF);
}
return output;
} else if constexpr (std::is_same_v<DataType, bfloat16>) {
Expand Down

0 comments on commit 5dd537a

Please sign in to comment.