-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#3219: Added host functions which tilize and untilize bfloat16 vectors
- Loading branch information
Showing
2 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
97 changes: 97 additions & 0 deletions
97
tests/tt_metal/tt_metal/unit_tests/host_apis/test_tilize_untilize.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <gtest/gtest.h> | ||
#include "tests/tt_metal/tt_metal/unit_tests/common/basic_fixture.hpp" | ||
#include "tt_metal/common/tilize_untilize.hpp" | ||
|
||
template <bool tilize_first, typename T> | ||
void tilize_untilize_helper(uint max_num_batches, uint max_num_row_tiles, uint max_num_col_tiles, uint TILE_HEIGHT, uint TILE_WIDTH) { | ||
for (uint i = 1; i <= max_num_batches; i++) { | ||
for (uint nrows = TILE_HEIGHT; nrows <= max_num_row_tiles * TILE_HEIGHT; nrows += TILE_HEIGHT) { | ||
for (uint ncols = TILE_WIDTH; ncols <= max_num_col_tiles * TILE_WIDTH; ncols += TILE_WIDTH) { | ||
// Create bfloat16 arange | ||
vector<T> data; | ||
for (float datum = 0; datum < i * nrows * ncols; datum++) { | ||
data.push_back(datum); | ||
} | ||
|
||
vector<T> target = data; | ||
if constexpr (tilize_first) { | ||
tilize(data, nrows, ncols); | ||
ASSERT_FALSE(data == target); | ||
untilize(data, nrows, ncols); | ||
} else { | ||
untilize(data, nrows, ncols); | ||
ASSERT_FALSE(data == target); | ||
tilize(data, nrows, ncols); | ||
} | ||
ASSERT_TRUE(data == target); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// The following run the tilize/untilize APIs and their inverses | ||
TEST_F(BasicFixture, TestTilizeAndThenUntilizeBfloat16) { | ||
uint max_num_batches = 8; | ||
uint max_num_row_tiles = 8; | ||
uint max_num_col_tiles = 8; | ||
uint TILE_HEIGHT = 32; | ||
uint TILE_WIDTH = 32; | ||
|
||
tilize_untilize_helper<true, bfloat16>(max_num_batches, max_num_row_tiles, max_num_col_tiles, TILE_HEIGHT, TILE_WIDTH); | ||
} | ||
|
||
TEST_F(BasicFixture, TestTilizeThrowErrorForNonBfloat16DataType) { | ||
vector<float> vec(1024, 0); | ||
EXPECT_ANY_THROW(tilize(vec, 32, 32)); | ||
} | ||
|
||
TEST_F(BasicFixture, TestTilizeThrowErrorForInvalidTileMandN) { | ||
// m and n are not divisible by tile size | ||
vector<bfloat16> vec(16, 0); | ||
EXPECT_ANY_THROW(tilize(vec, 4, 4)); // m and n not divisible by 32 | ||
EXPECT_ANY_THROW(tilize(vec, 0, 4)); // Cannot have 0 shapes | ||
EXPECT_ANY_THROW(tilize(vec, 4, 0)); | ||
EXPECT_ANY_THROW(tilize(vec, 0, 0)); | ||
} | ||
|
||
TEST_F(BasicFixture, TestTilizeThrowErrorForInvalidVectorShape) { | ||
vector<bfloat16> vec(16, 0); // Size not divisible by 1024 | ||
EXPECT_ANY_THROW(tilize(vec, 32, 32)); // m and n not divisible by 32 | ||
vec = {}; // Cannot have a zero vector either | ||
EXPECT_ANY_THROW(tilize(vec, 32, 32)); // m and n not divisible by 32 | ||
} | ||
|
||
TEST_F(BasicFixture, TestUntilizeThrowErrorForNonBfloat16DataType) { | ||
vector<float> vec(1024, 0); | ||
EXPECT_ANY_THROW(untilize(vec, 32, 32)); | ||
} | ||
|
||
TEST_F(BasicFixture, TestUntilizeThrowErrorForInvalidTileMandN) { | ||
// m and n are not divisible by tile side lengths | ||
vector<bfloat16> vec(16, 0); | ||
EXPECT_ANY_THROW(untilize(vec, 4, 4)); | ||
EXPECT_ANY_THROW(untilize(vec, 0, 4)); | ||
EXPECT_ANY_THROW(untilize(vec, 4, 0)); | ||
EXPECT_ANY_THROW(untilize(vec, 0, 0)); | ||
} | ||
|
||
TEST_F(BasicFixture, TestUntilizeThrowErrorForInvalidVectorShape) { | ||
vector<bfloat16> vec(16, 0); // Size not divisible by 1024 | ||
EXPECT_ANY_THROW(untilize(vec, 32, 32)); // m and n not divisible by 32 | ||
vec = {}; // Cannot have a zero vector either | ||
EXPECT_ANY_THROW(untilize(vec, 32, 32)); // m and n not divisible by 32 | ||
} | ||
|
||
TEST_F(BasicFixture, TestUntilizeAndThenTilizeBfloat16) { | ||
uint max_num_batches = 8; | ||
uint max_num_row_tiles = 8; | ||
uint max_num_col_tiles = 8; | ||
uint TILE_HEIGHT = 32; | ||
uint TILE_WIDTH = 32; | ||
|
||
tilize_untilize_helper<false, bfloat16>(max_num_batches, max_num_row_tiles, max_num_col_tiles, TILE_HEIGHT, TILE_WIDTH); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <vector> | ||
|
||
#include "bfloat16.hpp" | ||
|
||
template <typename T> | ||
void tilize(std::vector<T>& input, uint32_t m, uint32_t n) { | ||
TT_ASSERT(input.size() > 0 and m > 0 and n > 0, "None of the input size, m, nor n can be 0"); | ||
TT_ASSERT((input.size() % (m * n)) == 0, "Input size must be divisible by m and n"); | ||
|
||
std::vector<T> tilized_input; | ||
tilized_input.reserve(input.size()); | ||
|
||
uint32_t block_num_elements = m * n; | ||
uint32_t num_blocks = input.size() / block_num_elements; | ||
|
||
const auto write_face = | ||
[](vector<T>& tilized_input, const vector<T>& input, uint32_t face_height, uint32_t face_width, uint32_t face_idx, uint32_t n) | ||
-> void { | ||
for (uint32_t i = 0; i < face_height; i++) { | ||
for (uint32_t j = 0; j < face_width; j++) { | ||
tilized_input.push_back(input[face_idx + j]); | ||
} | ||
face_idx += n; | ||
} | ||
}; | ||
|
||
if constexpr (std::is_same<T, bfloat16>()) { | ||
uint32_t TILE_HEIGHT = 32; | ||
uint32_t TILE_WIDTH = 32; | ||
uint32_t FACE_HEIGHT = 16; | ||
uint32_t FACE_WIDTH = 16; | ||
uint32_t row_tiles = m / TILE_HEIGHT; | ||
uint32_t col_tiles = n / TILE_WIDTH; | ||
uint32_t row_of_tiles_num_elements = TILE_HEIGHT * n; | ||
TT_ASSERT((m % TILE_HEIGHT == 0) and (n % TILE_WIDTH == 0), "m and n must be divisible by 32"); | ||
uint32_t block_start = 0; | ||
for (size_t i = 0; i < num_blocks; i++) { | ||
uint32_t tile_start = block_start; | ||
for (uint32_t row_tile = 0; row_tile < row_tiles; row_tile++) { | ||
uint32_t row_tile_start = tile_start; | ||
for (uint32_t col_tile = 0; col_tile < col_tiles; col_tile++) { | ||
uint32_t face0_id = row_tile_start; | ||
uint32_t face1_id = face0_id + FACE_WIDTH; | ||
uint32_t face2_id = face0_id + n * FACE_HEIGHT; | ||
uint32_t face3_id = face2_id + FACE_WIDTH; | ||
|
||
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face0_id, n); | ||
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face1_id, n); | ||
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face2_id, n); | ||
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face3_id, n); | ||
row_tile_start += TILE_WIDTH; | ||
} | ||
tile_start += row_of_tiles_num_elements; | ||
} | ||
block_start += block_num_elements; | ||
} | ||
} else { | ||
TT_THROW("Invalid type passed into tilize"); | ||
} | ||
|
||
input = std::move(tilized_input); | ||
} | ||
|
||
template <typename T> | ||
void untilize(std::vector<T>& input, uint32_t m, uint32_t n) { | ||
TT_ASSERT(input.size() > 0 and m > 0 and n > 0, "None of the input size, m, nor n can be 0"); | ||
TT_ASSERT((input.size() % (m * n)) == 0, "Input size must be divisible by m and n"); | ||
|
||
std::vector<T> untilized_input; | ||
untilized_input.reserve(input.size()); | ||
|
||
uint32_t block_num_elements = m * n; | ||
uint32_t num_blocks = input.size() / block_num_elements; | ||
|
||
const auto untilize_row = [](vector<T>& untilized_input, | ||
const vector<T>& input, | ||
uint32_t face_height, | ||
uint32_t face_width, | ||
uint32_t tile_idx, | ||
uint32_t TILE_WIDTH, | ||
uint32_t n) -> void { | ||
uint32_t face_num_elements = face_height * face_width; | ||
uint32_t face_start = tile_idx; | ||
for (uint32_t m = 0; m < 2; m++) { | ||
for (uint32_t i = 0; i < face_height; i++) { | ||
uint32_t row_start = face_start + i * face_width; | ||
for (uint32_t j = 0; j < n / TILE_WIDTH; j++) { // Iterates over all the column tiles | ||
// Grab 16 elements from tile j, face 0/2 | ||
for (uint32_t k = 0; k < face_width; k++) { | ||
untilized_input.push_back(input[row_start + k]); | ||
} | ||
|
||
// Grab 16 elements from tile j, face 1/3 | ||
row_start += face_height * face_width; | ||
for (uint32_t k = 0; k < face_width; k++) { | ||
untilized_input.push_back(input[row_start + k]); | ||
} | ||
row_start += face_height * face_width * 3; // If on face 1, need to get to face 0 of next tile, and | ||
// if on face 3, need to get to face 2 of next tile | ||
} | ||
} | ||
face_start += face_height * face_width * 2; // Get to face 2 of current tile | ||
} | ||
}; | ||
|
||
if constexpr (std::is_same<T, bfloat16>()) { | ||
uint32_t TILE_HEIGHT = 32; | ||
uint32_t TILE_WIDTH = 32; | ||
uint32_t FACE_HEIGHT = 16; | ||
uint32_t FACE_WIDTH = 16; | ||
uint32_t row_tiles = m / TILE_HEIGHT; | ||
uint32_t col_tiles = n / TILE_WIDTH; | ||
uint32_t row_of_tiles_num_elements = TILE_HEIGHT * n; | ||
TT_ASSERT((m % TILE_HEIGHT == 0) and (n % TILE_WIDTH == 0), "m and n must be divisible by 32"); | ||
uint32_t block_start = 0; | ||
for (size_t i = 0; i < num_blocks; i++) { | ||
uint32_t row_tile_start = block_start; | ||
for (uint32_t row_tile = 0; row_tile < row_tiles; row_tile++) { | ||
untilize_row(untilized_input, input, FACE_HEIGHT, FACE_WIDTH, row_tile_start, TILE_WIDTH, n); | ||
row_tile_start += row_of_tiles_num_elements; | ||
} | ||
block_start += block_num_elements; | ||
} | ||
} else { | ||
TT_THROW("Invalid type passed into untilize"); | ||
} | ||
|
||
input = std::move(untilized_input); | ||
} |