Skip to content

Commit

Permalink
#3219: Added host functions which tilize and untilize bfloat16 vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
DrJessop committed Nov 28, 2023
1 parent f3dddca commit 8ac12b6
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include "tests/tt_metal/tt_metal/unit_tests/common/basic_fixture.hpp"
#include "tt_metal/common/tilize_untilize.hpp"

template <bool tilize_first, typename T>
void tilize_untilize_helper(uint max_num_batches, uint max_num_row_tiles, uint max_num_col_tiles, uint TILE_HEIGHT, uint TILE_WIDTH) {
for (uint i = 1; i <= max_num_batches; i++) {
for (uint nrows = TILE_HEIGHT; nrows <= max_num_row_tiles * TILE_HEIGHT; nrows += TILE_HEIGHT) {
for (uint ncols = TILE_WIDTH; ncols <= max_num_col_tiles * TILE_WIDTH; ncols += TILE_WIDTH) {
// Create bfloat16 arange
vector<T> data;
for (float datum = 0; datum < i * nrows * ncols; datum++) {
data.push_back(datum);
}

vector<T> target = data;
if constexpr (tilize_first) {
tilize(data, nrows, ncols);
ASSERT_FALSE(data == target);
untilize(data, nrows, ncols);
} else {
untilize(data, nrows, ncols);
ASSERT_FALSE(data == target);
tilize(data, nrows, ncols);
}
ASSERT_TRUE(data == target);
}
}
}
}

// The following run the tilize/untilize APIs and their inverses
TEST_F(BasicFixture, TestTilizeAndThenUntilizeBfloat16) {
uint max_num_batches = 8;
uint max_num_row_tiles = 8;
uint max_num_col_tiles = 8;
uint TILE_HEIGHT = 32;
uint TILE_WIDTH = 32;

tilize_untilize_helper<true, bfloat16>(max_num_batches, max_num_row_tiles, max_num_col_tiles, TILE_HEIGHT, TILE_WIDTH);
}

TEST_F(BasicFixture, TestTilizeThrowErrorForNonBfloat16DataType) {
vector<float> vec(1024, 0);
EXPECT_ANY_THROW(tilize(vec, 32, 32));
}

TEST_F(BasicFixture, TestTilizeThrowErrorForInvalidTileMandN) {
// m and n are not divisible by tile size
vector<bfloat16> vec(16, 0);
EXPECT_ANY_THROW(tilize(vec, 4, 4)); // m and n not divisible by 32
EXPECT_ANY_THROW(tilize(vec, 0, 4)); // Cannot have 0 shapes
EXPECT_ANY_THROW(tilize(vec, 4, 0));
EXPECT_ANY_THROW(tilize(vec, 0, 0));
}

TEST_F(BasicFixture, TestTilizeThrowErrorForInvalidVectorShape) {
vector<bfloat16> vec(16, 0); // Size not divisible by 1024
EXPECT_ANY_THROW(tilize(vec, 32, 32)); // m and n not divisible by 32
vec = {}; // Cannot have a zero vector either
EXPECT_ANY_THROW(tilize(vec, 32, 32)); // m and n not divisible by 32
}

TEST_F(BasicFixture, TestUntilizeThrowErrorForNonBfloat16DataType) {
vector<float> vec(1024, 0);
EXPECT_ANY_THROW(untilize(vec, 32, 32));
}

TEST_F(BasicFixture, TestUntilizeThrowErrorForInvalidTileMandN) {
// m and n are not divisible by tile side lengths
vector<bfloat16> vec(16, 0);
EXPECT_ANY_THROW(untilize(vec, 4, 4));
EXPECT_ANY_THROW(untilize(vec, 0, 4));
EXPECT_ANY_THROW(untilize(vec, 4, 0));
EXPECT_ANY_THROW(untilize(vec, 0, 0));
}

TEST_F(BasicFixture, TestUntilizeThrowErrorForInvalidVectorShape) {
vector<bfloat16> vec(16, 0); // Size not divisible by 1024
EXPECT_ANY_THROW(untilize(vec, 32, 32)); // m and n not divisible by 32
vec = {}; // Cannot have a zero vector either
EXPECT_ANY_THROW(untilize(vec, 32, 32)); // m and n not divisible by 32
}

TEST_F(BasicFixture, TestUntilizeAndThenTilizeBfloat16) {
uint max_num_batches = 8;
uint max_num_row_tiles = 8;
uint max_num_col_tiles = 8;
uint TILE_HEIGHT = 32;
uint TILE_WIDTH = 32;

tilize_untilize_helper<false, bfloat16>(max_num_batches, max_num_row_tiles, max_num_col_tiles, TILE_HEIGHT, TILE_WIDTH);
}
133 changes: 133 additions & 0 deletions tt_metal/common/tilize_untilize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <vector>

#include "bfloat16.hpp"

template <typename T>
void tilize(std::vector<T>& input, uint32_t m, uint32_t n) {
TT_ASSERT(input.size() > 0 and m > 0 and n > 0, "None of the input size, m, nor n can be 0");
TT_ASSERT((input.size() % (m * n)) == 0, "Input size must be divisible by m and n");

std::vector<T> tilized_input;
tilized_input.reserve(input.size());

uint32_t block_num_elements = m * n;
uint32_t num_blocks = input.size() / block_num_elements;

const auto write_face =
[](vector<T>& tilized_input, const vector<T>& input, uint32_t face_height, uint32_t face_width, uint32_t face_idx, uint32_t n)
-> void {
for (uint32_t i = 0; i < face_height; i++) {
for (uint32_t j = 0; j < face_width; j++) {
tilized_input.push_back(input[face_idx + j]);
}
face_idx += n;
}
};

if constexpr (std::is_same<T, bfloat16>()) {
uint32_t TILE_HEIGHT = 32;
uint32_t TILE_WIDTH = 32;
uint32_t FACE_HEIGHT = 16;
uint32_t FACE_WIDTH = 16;
uint32_t row_tiles = m / TILE_HEIGHT;
uint32_t col_tiles = n / TILE_WIDTH;
uint32_t row_of_tiles_num_elements = TILE_HEIGHT * n;
TT_ASSERT((m % TILE_HEIGHT == 0) and (n % TILE_WIDTH == 0), "m and n must be divisible by 32");
uint32_t block_start = 0;
for (size_t i = 0; i < num_blocks; i++) {
uint32_t tile_start = block_start;
for (uint32_t row_tile = 0; row_tile < row_tiles; row_tile++) {
uint32_t row_tile_start = tile_start;
for (uint32_t col_tile = 0; col_tile < col_tiles; col_tile++) {
uint32_t face0_id = row_tile_start;
uint32_t face1_id = face0_id + FACE_WIDTH;
uint32_t face2_id = face0_id + n * FACE_HEIGHT;
uint32_t face3_id = face2_id + FACE_WIDTH;

write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face0_id, n);
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face1_id, n);
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face2_id, n);
write_face(tilized_input, input, FACE_HEIGHT, FACE_WIDTH, face3_id, n);
row_tile_start += TILE_WIDTH;
}
tile_start += row_of_tiles_num_elements;
}
block_start += block_num_elements;
}
} else {
TT_THROW("Invalid type passed into tilize");
}

input = std::move(tilized_input);
}

template <typename T>
void untilize(std::vector<T>& input, uint32_t m, uint32_t n) {
TT_ASSERT(input.size() > 0 and m > 0 and n > 0, "None of the input size, m, nor n can be 0");
TT_ASSERT((input.size() % (m * n)) == 0, "Input size must be divisible by m and n");

std::vector<T> untilized_input;
untilized_input.reserve(input.size());

uint32_t block_num_elements = m * n;
uint32_t num_blocks = input.size() / block_num_elements;

const auto untilize_row = [](vector<T>& untilized_input,
const vector<T>& input,
uint32_t face_height,
uint32_t face_width,
uint32_t tile_idx,
uint32_t TILE_WIDTH,
uint32_t n) -> void {
uint32_t face_num_elements = face_height * face_width;
uint32_t face_start = tile_idx;
for (uint32_t m = 0; m < 2; m++) {
for (uint32_t i = 0; i < face_height; i++) {
uint32_t row_start = face_start + i * face_width;
for (uint32_t j = 0; j < n / TILE_WIDTH; j++) { // Iterates over all the column tiles
// Grab 16 elements from tile j, face 0/2
for (uint32_t k = 0; k < face_width; k++) {
untilized_input.push_back(input[row_start + k]);
}

// Grab 16 elements from tile j, face 1/3
row_start += face_height * face_width;
for (uint32_t k = 0; k < face_width; k++) {
untilized_input.push_back(input[row_start + k]);
}
row_start += face_height * face_width * 3; // If on face 1, need to get to face 0 of next tile, and
// if on face 3, need to get to face 2 of next tile
}
}
face_start += face_height * face_width * 2; // Get to face 2 of current tile
}
};

if constexpr (std::is_same<T, bfloat16>()) {
uint32_t TILE_HEIGHT = 32;
uint32_t TILE_WIDTH = 32;
uint32_t FACE_HEIGHT = 16;
uint32_t FACE_WIDTH = 16;
uint32_t row_tiles = m / TILE_HEIGHT;
uint32_t col_tiles = n / TILE_WIDTH;
uint32_t row_of_tiles_num_elements = TILE_HEIGHT * n;
TT_ASSERT((m % TILE_HEIGHT == 0) and (n % TILE_WIDTH == 0), "m and n must be divisible by 32");
uint32_t block_start = 0;
for (size_t i = 0; i < num_blocks; i++) {
uint32_t row_tile_start = block_start;
for (uint32_t row_tile = 0; row_tile < row_tiles; row_tile++) {
untilize_row(untilized_input, input, FACE_HEIGHT, FACE_WIDTH, row_tile_start, TILE_WIDTH, n);
row_tile_start += row_of_tiles_num_elements;
}
block_start += block_num_elements;
}
} else {
TT_THROW("Invalid type passed into untilize");
}

input = std::move(untilized_input);
}

0 comments on commit 8ac12b6

Please sign in to comment.