Skip to content

Commit

Permalink
#0: Enable mixed precision for attn matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jan 19, 2024
1 parent c3f800e commit 20cbefb
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
33 changes: 15 additions & 18 deletions tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import tt_lib as ttl
Expand All @@ -27,20 +28,18 @@ def generate_input_shapes():
yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len]


@skip_for_wormhole_b0()
def test_attn_matmul(device):
@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
def test_attn_matmul(in0_dtype, in1_dtype, out_dtype, device):
torch.manual_seed(0)

for input_shape_a, input_shape_b in generate_input_shapes():
input_tensor_a = torch.randn(input_shape_a).bfloat16()
input_tensor_b = torch.randn(input_shape_b).bfloat16()

tt_input_tensor_a = (
ttl.tensor.Tensor(input_tensor_a, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_input_tensor_b = (
ttl.tensor.Tensor(input_tensor_b, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device)
tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device)

compute_grid_size = device.compute_with_storage_grid_size()

Expand All @@ -51,7 +50,7 @@ def test_attn_matmul(device):
output_mem_config=ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
),
output_dtype=ttl.tensor.DataType.BFLOAT16,
output_dtype=out_dtype,
)
tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

Expand All @@ -61,20 +60,18 @@ def test_attn_matmul(device):
assert allclose, f"FAILED: {output}"


@skip_for_wormhole_b0()
def test_attn_matmul_with_program_cache(device, use_program_cache):
@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
def test_attn_matmul_with_program_cache(in0_dtype, in1_dtype, out_dtype, device, use_program_cache):
torch.manual_seed(0)

for input_shape_a, input_shape_b in generate_input_shapes():
input_tensor_a = torch.randn(input_shape_a).bfloat16()
input_tensor_b = torch.randn(input_shape_b).bfloat16()

tt_input_tensor_a = (
ttl.tensor.Tensor(input_tensor_a, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_input_tensor_b = (
ttl.tensor.Tensor(input_tensor_b, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device)
tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device)

compute_grid_size = device.compute_with_storage_grid_size()

Expand All @@ -85,7 +82,7 @@ def test_attn_matmul_with_program_cache(device, use_program_cache):
output_mem_config=ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
),
output_dtype=ttl.tensor.DataType.BFLOAT16,
output_dtype=out_dtype,
)
tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,43 @@ void MAIN {
uint32_t Kt = get_arg_val<uint32_t>(2);
uint32_t Nt = get_arg_val<uint32_t>(3);

constexpr uint32_t cb_in0 = 0;
constexpr uint32_t cb_in1 = 1;
constexpr uint32_t cb_intermed0 = 24;
constexpr uint32_t cb_intermed1 = 25;
constexpr uint32_t cb_intermed2 = 26;
constexpr uint32_t out_cb_id = 16;

constexpr uint32_t num_rows_in_one_tile = 32;

mm_init(tt::CB::c_in0, tt::CB::c_in1, out_cb_id, transpose_hw);
mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);

for (uint32_t nb = 0; nb < batch; nb++)
for (uint32_t nb = 0; nb < batch; ++nb)
for (uint32_t mt_C = 0; mt_C < Mt; ++mt_C) // output tile of C
for (uint32_t nt_C = 0; nt_C < Nt; ++nt_C) // output tile index of C
{
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) {
acquire_dst(tt::DstMode::Half);
for (uint32_t kt = 0; kt < Kt; kt++) {
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; ++tile_row_id) {
tile_regs_acquire();
for (uint32_t kt = 0; kt < Kt; ++kt) {
if (tile_row_id == 0) {
cb_wait_front(tt::CB::c_in0, kt+1);
cb_wait_front(cb_in0, kt+1);
}
cb_wait_front(tt::CB::c_in1, onetile);
cb_wait_front(cb_in1, onetile);

matmul_tiles(tt::CB::c_in0, tt::CB::c_in1, kt, 0, 0, transpose_hw);
matmul_tiles(cb_in0, cb_in1, kt, 0, 0, transpose_hw);

cb_pop_front(tt::CB::c_in1, onetile);
cb_pop_front(cb_in1, onetile);
}
tile_regs_commit();

cb_reserve_back(cb_intermed0, onetile);
tile_regs_wait();
pack_tile(0, cb_intermed0);
release_dst(tt::DstMode::Half);
tile_regs_release();
cb_push_back(cb_intermed0, onetile);

// untilize tile and write to CB::c_intermed1
unpack_reconfig_data_format_srca(cb_in1, cb_intermed0);
cb_wait_front(cb_intermed0, onetile);
untilize_init_short(cb_intermed0);
cb_reserve_back(cb_intermed1, 1);
Expand All @@ -64,13 +69,16 @@ void MAIN {
cb_pop_front(cb_intermed0, 1);
untilize_uninit(cb_intermed0);

unpack_reconfig_data_format_srca(cb_intermed0, cb_in1);
mm_init_short(transpose_hw);
}
cb_pop_front(tt::CB::c_in0, Kt);
cb_pop_front(cb_in0, Kt);

// cb_intermed2 comes from reader; untilized row-major tile
unpack_reconfig_data_format_srca(cb_in1, cb_intermed2);
pack_reconfig_data_format(cb_intermed1, out_cb_id);
cb_wait_front(cb_intermed2, 1);
cb_reserve_back(tt::CB::c_out0, onetile);
cb_reserve_back(out_cb_id, onetile);

// tilize CB::intermed2 and write to CB::c_out0
tilize_init_short(cb_intermed2, 1);
Expand All @@ -80,6 +88,9 @@ void MAIN {
cb_pop_front(cb_intermed2, 1);
tilize_uninit();

// Hangs when in0 is BFLOAT8_B if we don't force the reconfig
unpack_reconfig_data_format_srca(cb_in1);
pack_reconfig_data_format(out_cb_id, cb_intermed0);
mm_init_short(transpose_hw);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ namespace primary {
namespace transformers {


operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Tensor &b, Tensor& output, std::optional<const uint32_t> num_tokens, std::optional<const bool> transpose_hw, CoreCoord compute_with_storage_grid_size, DataType output_dtype) {
operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Tensor &b, Tensor& output, std::optional<const uint32_t> num_tokens, std::optional<const bool> transpose_hw, CoreCoord compute_with_storage_grid_size) {

tt_metal::Program program{};

const auto& ashape = a.shape(), bshape = b.shape();

tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
tt::DataFormat interm_data_format = DataFormat::Float16_b;
tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format);
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
uint32_t interm_single_tile_size = tt_metal::detail::TileSize(interm_data_format);
uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format);
MathFidelity math_fidelity = MathFidelity::LoFi;

Expand Down Expand Up @@ -74,34 +76,34 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
uint32_t src1_addr = src1_buffer->address();
uint32_t dst_addr = dst_buffer->address();

uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = CB::c_in0;
uint32_t cb0_num_input_tiles = Kt * 2;
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = CB::c_in1;
uint32_t cb1_num_input_tiles = 2;
tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(cb1_num_input_tiles * in1_single_tile_size, {{src1_cb_index, output_data_format}})
tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(cb1_num_input_tiles * in1_single_tile_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config);

uint32_t cb_intermed0_index = 24;
tt_metal::CircularBufferConfig cb_interm0_config = tt_metal::CircularBufferConfig(1 * output_single_tile_size, {{cb_intermed0_index, output_data_format}})
.set_page_size(cb_intermed0_index, output_single_tile_size);
auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm0_config);
uint32_t cb_intermed0_index = CB::c_intermed0;
tt_metal::CircularBufferConfig cb_interm0_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed0_index, interm_data_format}})
.set_page_size(cb_intermed0_index, interm_single_tile_size);
auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm0_config);

uint32_t cb_intermed1_index = 25;
tt_metal::CircularBufferConfig cb_interm1_config = tt_metal::CircularBufferConfig(1 * output_single_tile_size, {{cb_intermed1_index, output_data_format}})
.set_page_size(cb_intermed1_index, output_single_tile_size);
auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm1_config);
uint32_t cb_intermed1_index = CB::c_intermed1;
tt_metal::CircularBufferConfig cb_interm1_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed1_index, interm_data_format}})
.set_page_size(cb_intermed1_index, interm_single_tile_size);
auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm1_config);

uint32_t cb_intermed2_index = 26;
tt_metal::CircularBufferConfig cb_interm2_config = tt_metal::CircularBufferConfig(1 * output_single_tile_size, {{cb_intermed2_index, output_data_format}})
.set_page_size(cb_intermed2_index, output_single_tile_size);
auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config);
uint32_t cb_intermed2_index = CB::c_intermed2;
tt_metal::CircularBufferConfig cb_interm2_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed2_index, interm_data_format}})
.set_page_size(cb_intermed2_index, interm_single_tile_size);
auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm2_config);

uint32_t output_cb_index = 16; // output operands start at index 16
uint32_t output_cb_index = CB::c_out0; // output operands start at index 16
uint32_t num_output_tiles = 2;
tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * output_single_tile_size, {{output_cb_index, output_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ void AttnMatmul::validate(const std::vector<Tensor>& input_tensors) const {
TT_FATAL((input_tensor_a.layout() == Layout::TILE && input_tensor_b.layout() == Layout::TILE), "Inputs to matmul must be tilized");

// TODO: Uplift to support BFLOAT8_B and mixed precision
TT_FATAL(input_tensor_a.dtype() == tt::tt_metal::DataType::BFLOAT16, "Unsupported data format");
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE, "Operands to matmul need to be on device!");
TT_FATAL(input_tensor_a.device() == input_tensor_b.device(), "Operands to matmul need to be on the same device!");
TT_FATAL(input_tensor_a.buffer() != nullptr and input_tensor_b.buffer() != nullptr, "Operands to matmul need to be allocated in buffers on device!");
Expand Down Expand Up @@ -197,7 +196,7 @@ std::vector<Shape> AttnMatmul::compute_output_shapes(const std::vector<Tensor>&

std::vector<Tensor> AttnMatmul::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.dtype(), Layout::TILE, this->output_mem_config);
return operation::generic_create_output_tensors(*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}

operation::ProgramWithCallbacks AttnMatmul::create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const {
Expand All @@ -208,7 +207,7 @@ operation::ProgramWithCallbacks AttnMatmul::create_program(const std::vector<Ten
auto device_compute_with_storage_grid_size = input_tensor_a.device()->compute_with_storage_grid_size();
TT_ASSERT((this->compute_with_storage_grid_size.x <= device_compute_with_storage_grid_size.x && this->compute_with_storage_grid_size.y <= device_compute_with_storage_grid_size.y), "Unsupported grid shape");

return multi_core_attn_matmul(input_tensor_a, input_tensor_b, output_tensor, this->num_tokens, this->transpose_hw, this->compute_with_storage_grid_size, output_dtype);
return multi_core_attn_matmul(input_tensor_a, input_tensor_b, output_tensor, this->num_tokens, this->transpose_hw, this->compute_with_storage_grid_size);
}

tt::stl::reflection::Attributes AttnMatmul::attributes() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace transformers {
operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads(const Tensor &input_tensor, std::vector<Tensor> &output, CoreCoord compute_with_storage_grid_size);
operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads_sharded(const Tensor &input_tensor, std::vector<Tensor> &output, CoreCoord compute_with_storage_grid_size);
operation::ProgramWithCallbacks multi_core_concat_heads(const Tensor &input_tensor, Tensor &output_tensor, CoreCoord compute_with_storage_grid_size);
operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional<const uint32_t> num_tokens, std::optional<const bool> transpose_hw, CoreCoord compute_with_storage_grid_size, DataType output_dtype);
operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional<const uint32_t> num_tokens, std::optional<const bool> transpose_hw, CoreCoord compute_with_storage_grid_size);

struct SplitFusedQKVAndSplitHeads {
CoreCoord compute_with_storage_grid_size;
Expand Down

0 comments on commit 20cbefb

Please sign in to comment.