Skip to content

Commit

Permalink
#0: scalar implementations for unpack bfp8/bfp4
Browse files Browse the repository at this point in the history
Goal is to build on non X86_64 platforms.
  • Loading branch information
joelsmithTT committed Nov 4, 2024
1 parent 7b6c0a6 commit 6b0bca3
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 1 deletion.
111 changes: 111 additions & 0 deletions tt_metal/common/bfloat4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
#include <iostream>
#include <random>
#include <vector>

#if defined(__x86_64__)
#include <immintrin.h>
#endif

#include "tt_metal/common/assert.hpp"
#include "tt_metal/common/tt_backend_api_types.hpp"
Expand All @@ -17,6 +20,38 @@
// TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor
struct bfloat4_b {};

// Slow; used for architectures that haven't had a vectorized implementation written yet.
// bfp4[2:0] = mantissa
// bfp4[3] = sign
// bfp4[7:4] = ignored
inline float convert_bfp4_to_float(uint8_t bfp4, uint32_t exp, bool is_exp_a) {
uint32_t rebias_offset = is_exp_a ? -112 : 0;
uint32_t sign = bfp4 >> 3;
uint32_t mantissa = bfp4 & 0x7;
uint32_t shifted_mantissa = mantissa;
bool mantissa_is_zero = (mantissa == 0);
uint32_t shift_count = 0;

for (size_t shift_val = 0; shift_val < 3; ++shift_val) {
if (shifted_mantissa < 0x4) {
shifted_mantissa <<= 1;
shift_count = shift_val + 1;
}
}

shifted_mantissa = (shifted_mantissa << 1) & 0x7;

if (!mantissa_is_zero) {
mantissa = shifted_mantissa;
exp = exp - (rebias_offset + shift_count);
} else {
exp = 0;
}

uint32_t result = (sign << 31) | (exp << 23) | (mantissa << 20);
return *reinterpret_cast<float*>(&result);
}

inline std::vector<uint32_t> pack_fp32_vec_as_bfp4_tiles(const std::vector<float> &fp32_vec, bool row_major_input, bool is_exp_a) {
return pack_fp32_vec_as_bfp_tiles<tt::DataFormat::Bfp4_b>(fp32_vec, row_major_input, is_exp_a);
}
Expand All @@ -27,6 +62,7 @@ constexpr int log2(int n) {
return log;
}

#if defined(__x86_64__)
inline std::vector<float> unpack_bfp4_tiles_into_float_vec(const std::vector<uint32_t> &bfp_tiles, bool row_major_output, bool is_exp_a) {
ZoneScoped;

Expand Down Expand Up @@ -159,6 +195,81 @@ inline std::vector<float> unpack_bfp4_tiles_into_float_vec(const std::vector<uin
}
return float_vec;
}
#else
inline std::vector<float> unpack_bfp4_tiles_into_float_vec(const std::vector<uint32_t> &bfp_tiles, bool row_major_output, bool is_exp_a) {
ZoneScoped;

constexpr int num_elements_in_dword = 8;
constexpr int data_dwords_per_exp = 16 / num_elements_in_dword;
constexpr int num_exps_in_dword = 4;
constexpr int data_dwords_per_exp_dword_log2 = log2(data_dwords_per_exp * num_exps_in_dword);
constexpr int data_dwords_per_exp_log2 = log2(data_dwords_per_exp);

uint32_t size_bytes = bfp_tiles.size() * 4;
uint32_t single_bfp_tile_size = tile_size(tt::DataFormat::Bfp4_b);
TT_ASSERT(size_bytes % single_bfp_tile_size == 0);
uint32_t num_tiles = size_bytes / single_bfp_tile_size;

int data_index;
int subtile_r;
int subtile_c;
uint32_t rebias_offset = (is_exp_a ? -112 : 0);
uint32_t exp_word, sub_word_index;

constexpr int subtiles_in_tile_row = 2;
constexpr int subtiles_in_tile_col = 2;
constexpr int subtile_rows = 16;
constexpr int subtile_cols = 16;
constexpr uint32_t num_float_in_tile = subtiles_in_tile_row * subtiles_in_tile_col * subtile_rows * subtile_cols;
uint32_t fp32_element_index = 0;

constexpr int num_bfp_dwords_in_tile = 128 + 16;
constexpr int num_dwords_per_row = subtile_cols / num_elements_in_dword;

std::vector<float> float_vec;
float_vec.resize(num_tiles * num_float_in_tile);
for (int tile_index = 0; tile_index < num_tiles; ++tile_index) {
for (int tr = 0; tr < subtiles_in_tile_row; ++tr) {
for (int tc = 0; tc < subtiles_in_tile_col; ++tc) {
for (int i = 0; i < subtile_rows; ++i) {
subtile_r = tr * subtile_rows + i;
for (int j = 0; j < subtile_cols; j += 2*num_elements_in_dword) {
subtile_c = tc * subtile_cols + j;
data_index = (tr*64 + tc*32 + i*num_dwords_per_row + j/num_elements_in_dword); // Each uint32_t contains 8 BFP4 values. Divide data index by 8
int tile_and_data_index = data_index + (num_bfp_dwords_in_tile * tile_index);

int exponent_index = (data_index >> data_dwords_per_exp_dword_log2) + (num_bfp_dwords_in_tile * tile_index);
exp_word = bfp_tiles.at(exponent_index); // Extract the uint32_t value that stores the shared exponent for this set of data. Each 32 bit word is shared amongst 64 datums

sub_word_index = (tile_and_data_index >> data_dwords_per_exp_log2) & 0x3; // Extract the byte in which the shared exponent is stored. Each byte is shared amongst 16 datums.
auto exp = get_byte(exp_word, sub_word_index);

uint32_t float_data_index;
if (row_major_output) {
float_data_index = subtile_c + (32 * subtile_r) + (tile_index * num_float_in_tile);
} else {
float_data_index = fp32_element_index;
fp32_element_index += 2*num_elements_in_dword;
}

// sixteen bfp4 values packed into eight bytes
const uint8_t *bfp4_x16 = reinterpret_cast<const uint8_t*>(&bfp_tiles[16 + tile_and_data_index]);
for (int k = 0; k < 8; ++k) {
uint8_t bfp4_0 = (bfp4_x16[k] >> 0) & 0xF;
uint8_t bfp4_1 = (bfp4_x16[k] >> 4) & 0xF;
float float_0 = convert_bfp4_to_float(bfp4_0, exp, is_exp_a);
float float_1 = convert_bfp4_to_float(bfp4_1, exp, is_exp_a);
float_vec.at(float_data_index + (2 * k) + 0) = float_0;
float_vec.at(float_data_index + (2 * k) + 1) = float_1;
}
}
}
}
}
}
return float_vec;
}
#endif

inline std::vector<uint32_t> create_random_vector_of_bfp4(uint32_t num_bytes, bool is_exp_a, int rand_max_float, int seed, float offset = 0.0f) {
uint32_t single_bfp4_tile_size = tile_size(tt::DataFormat::Bfp4_b);
Expand Down
96 changes: 96 additions & 0 deletions tt_metal/common/bfloat8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
#include <iostream>
#include <random>
#include <vector>

#if defined(__x86_64__)
#include <immintrin.h>
#endif

#include "tt_metal/common/assert.hpp"
#include "tt_metal/common/tt_backend_api_types.hpp"
Expand All @@ -18,6 +21,35 @@
// TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor
struct bfloat8_b {};

// Slow; used for architectures that haven't had a vectorized implementation written yet.
inline float convert_bfp8_to_float(uint8_t bfp8, uint32_t exp, bool is_exp_a) {
uint32_t rebias_offset = is_exp_a ? -112 : 0;
uint32_t sign = bfp8 >> 7;
uint32_t mantissa = bfp8 & 0x7F;
uint32_t shifted_mantissa = mantissa;
bool mantissa_is_zero = (mantissa == 0);
uint32_t shift_count = 0;

for (size_t shift_val = 0; shift_val < 7; ++shift_val) {
if (shifted_mantissa < 0x40) {
shifted_mantissa <<= 1;
shift_count = shift_val + 1;
}
}

shifted_mantissa = (shifted_mantissa << 1) & 0x7F;

if (!mantissa_is_zero) {
mantissa = shifted_mantissa;
exp = exp - (rebias_offset + shift_count);
} else {
exp = 0;
}

uint32_t result = (sign << 31) | (exp << 23) | (mantissa << 16);
return *reinterpret_cast<float*>(&result);
}

template <bool truncate_bfp_mantissa=false>
inline uint8_t convert_u32_to_bfp8(uint32_t input, uint32_t shared_exp, bool is_exp_a) {
//check for both +/- 0.0
Expand Down Expand Up @@ -170,6 +202,7 @@ inline std::vector<uint32_t> pack_fp32_vec_as_bfp8_tiles(const std::vector<float
return packed_result;
}

#if defined(__x86_64__)
inline std::vector<float> unpack_bfp8_tiles_into_float_vec(const std::vector<uint32_t> &bfp8_tiles, bool row_major_output, bool is_exp_a) {
ZoneScoped;

Expand Down Expand Up @@ -261,6 +294,69 @@ inline std::vector<float> unpack_bfp8_tiles_into_float_vec(const std::vector<uin
}
return float_vec;
}
#else
// This is a scalar (non X86 SIMD) implementation of the above function.
inline std::vector<float> unpack_bfp8_tiles_into_float_vec(const std::vector<uint32_t> &bfp8_tiles, bool row_major_output, bool is_exp_a) {
ZoneScoped;

int num_elements_in_dword = 4;
uint32_t size_bytes = bfp8_tiles.size() * num_elements_in_dword; // each uint32_t contains 4 BFP8 values
uint32_t single_bfp8_tile_size = tile_size(tt::DataFormat::Bfp8_b);
TT_ASSERT(size_bytes % single_bfp8_tile_size == 0);
uint32_t num_tiles = size_bytes / single_bfp8_tile_size;

int data_index;
int subtile_r;
int subtile_c;
uint32_t rebias_offset = (is_exp_a ? -112 : 0);
uint32_t exp_word, sub_word_index;

int subtiles_in_tile_row = 2;
int subtiles_in_tile_col = 2;
int subtile_rows = 16;
int subtile_cols = 16;
uint32_t num_bfp8_in_tile = 256 + 16;
uint32_t num_float_in_tile = subtiles_in_tile_row * subtiles_in_tile_col * subtile_rows * subtile_cols;
uint32_t fp32_element_index = 0;
std::vector<float> float_vec;
float_vec.resize(num_tiles * num_float_in_tile);
for (int tile_index = 0; tile_index < num_tiles; ++tile_index) {
for (int tr = 0; tr < subtiles_in_tile_row; ++tr) {
for (int tc = 0; tc < subtiles_in_tile_col; ++tc) {
for (int i = 0; i < subtile_rows; ++i) {
subtile_r = tr * 16 + i;
for (int j = 0; j < subtile_cols; j += 8) {
subtile_c = tc * 16 + j;
data_index = (tr*128 + tc*64 + i*4 + j/4); // Each uint32_t contains 4 BFP8 values. Divide data index by 4.
int tile_and_data_index = data_index + (num_bfp8_in_tile * tile_index);

int exponent_index = (data_index >> 4) + (num_bfp8_in_tile * tile_index);
exp_word = bfp8_tiles.at(exponent_index); // Extract the uint32_t value that stores the shared exponent for this set of data. Each 32 bit word is shared amongst 64 datums

sub_word_index = (tile_and_data_index >> 2) & 0x3; // Extract the byte in which the shared exponent is stored. Each byte is shared amongst 16 datums.
uint32_t exp = get_byte(exp_word, sub_word_index);
const uint8_t *bfp8_x8 = reinterpret_cast<const uint8_t*>(&bfp8_tiles[16 + tile_and_data_index]); // eight bpf8 values

uint32_t float_data_index;
if (row_major_output) {
float_data_index = subtile_c + (32 * subtile_r) + (tile_index * num_float_in_tile);
} else {
float_data_index = fp32_element_index;
fp32_element_index += 8;
}

for (int k = 0; k < 8; ++k) {
float float_num = convert_bfp8_to_float(bfp8_x8[k], exp, is_exp_a);
float_vec.at(float_data_index + k) = float_num;
}
}
}
}
}
}
return float_vec;
}
#endif

inline std::vector<uint32_t> create_random_vector_of_bfp8(uint32_t num_bytes, bool is_exp_a, int rand_max_float, int seed, float offset = 0.0f) {
uint32_t single_bfp8_tile_size = tile_size(tt::DataFormat::Bfp8_b);
Expand Down
1 change: 0 additions & 1 deletion tt_metal/common/blockfloat_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <iostream>
#include <random>
#include <vector>
#include <immintrin.h>

#include "tt_metal/common/assert.hpp"
#include "tt_metal/common/tt_backend_api_types.hpp"
Expand Down

0 comments on commit 6b0bca3

Please sign in to comment.