Skip to content

Commit

Permalink
#4681: Add new group_attn_matmul (uplift + optimizations of attn_matmul)
Browse files Browse the repository at this point in the history
- Same as attn_matmul but we can have kv_heads > 1
- kv_heads is mcasted by 32 cores to all q_head cores
- Supports interleaved and height sharded (row or col) for any mix of in0, in1, or output
- This op is fully dynamic across input shape, similar to eltwise_binary
- Add unit testing for group_attn_matmul and program caching
  • Loading branch information
TT-BrianLiu committed Jan 25, 2024
1 parent 16d4df0 commit 8a5cc63
Show file tree
Hide file tree
Showing 9 changed files with 1,280 additions and 7 deletions.
186 changes: 184 additions & 2 deletions tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import pytest
import torch
import pytest

import tt_lib as ttl
from models.utility_functions import print_diff_argmax, comp_pcc
from models.utility_functions import skip_for_wormhole_b0
from models.utility_functions import comp_pcc, skip_for_wormhole_b0


def generate_input_shapes():
Expand Down Expand Up @@ -90,3 +90,185 @@ def test_attn_matmul_with_program_cache(in0_dtype, in1_dtype, out_dtype, device,

allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"


@pytest.mark.parametrize(
"shard_orientation",
(ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.ShardOrientation.COL_MAJOR),
)
@pytest.mark.parametrize(
"output_sharded",
(False, True),
)
@pytest.mark.parametrize(
"in1_sharded",
(False, True),
)
@pytest.mark.parametrize(
"in0_sharded",
(False, True),
)
@pytest.mark.parametrize(
"batch, K, seq_len, q_heads, kv_heads",
((32, 64, 128, 16, 1), (32, 64, 128, 32, 2)),
)
def test_group_attn_matmul(
batch, K, seq_len, q_heads, kv_heads, in0_sharded, in1_sharded, output_sharded, shard_orientation, device
):
# NOTE: For interleaved kv_heads, batch 64, 96, etc... should be supported
torch.manual_seed(0)

compute_grid_size = device.compute_with_storage_grid_size()

interleaved_mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
)

# NOTE: Mixed precision is supported as well
in0_dtype = ttl.tensor.DataType.BFLOAT16
in1_dtype = ttl.tensor.DataType.BFLOAT16
output_dtype = ttl.tensor.DataType.BFLOAT16

q_len = 1
input_shape_a = [q_len, q_heads, batch, K]
input_shape_b = [batch, kv_heads, K, seq_len]

input_tensor_a = torch.randn(input_shape_a).bfloat16()
input_tensor_b = torch.randn(input_shape_b).bfloat16()

tt_input_tensor_a = (
ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config)
)
tt_input_tensor_b = (
ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config)
)

if in0_sharded:
tt_input_tensor_a = ttl.tensor.interleaved_to_sharded(
tt_input_tensor_a,
compute_grid_size,
[q_len * batch, K],
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
shard_orientation,
)

if in1_sharded:
tt_input_tensor_b = ttl.tensor.interleaved_to_sharded(
tt_input_tensor_b,
compute_grid_size,
[kv_heads * K, seq_len],
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
shard_orientation,
)

if output_sharded:
output_mem_config = ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
buffer_type=ttl.tensor.BufferType.L1,
)
else:
output_mem_config = interleaved_mem_config

tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul(
tt_input_tensor_a,
tt_input_tensor_b,
compute_with_storage_grid_size=compute_grid_size,
output_mem_config=output_mem_config,
output_dtype=output_dtype,
)
if output_sharded:
tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved(
tt_output_tensor_on_device, interleaved_mem_config
)

tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

input_tensor_a = input_tensor_a.to(torch.float)
input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1)
golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2)

allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"


@pytest.mark.parametrize("sharded", [False, True])
@pytest.mark.parametrize("output_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
def test_group_attn_matmul_with_program_cache(in0_dtype, in1_dtype, output_dtype, sharded, device, use_program_cache):
torch.manual_seed(0)

compute_grid_size = device.compute_with_storage_grid_size()

interleaved_mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1
)

shard_orientation = ttl.tensor.ShardOrientation.COL_MAJOR # Only used if sharded

q_len = 1
batch = 32 if sharded else 64
num_cache_entries = 0 # Only track cache entries of group_attn_matmul
for K, seq_len, q_heads, kv_heads in ((96, 64, 10, 2), (64, 128, 50, 5)):
input_shape_a = [q_len, q_heads, batch, K]
input_shape_b = [batch, kv_heads, K, seq_len]

input_tensor_a = torch.randn(input_shape_a).bfloat16()
input_tensor_b = torch.randn(input_shape_b).bfloat16()

tt_input_tensor_a = (
ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config)
)
tt_input_tensor_b = (
ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config)
)

if sharded:
tt_input_tensor_a = ttl.tensor.interleaved_to_sharded(
tt_input_tensor_a,
compute_grid_size,
[q_len * batch, K],
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
shard_orientation,
)

tt_input_tensor_b = ttl.tensor.interleaved_to_sharded(
tt_input_tensor_b,
compute_grid_size,
[kv_heads * K, seq_len],
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
shard_orientation,
)

output_mem_config = ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
buffer_type=ttl.tensor.BufferType.L1,
)
else:
output_mem_config = interleaved_mem_config

num_cache_entries_start = ttl.program_cache.num_entries()
tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul(
tt_input_tensor_a,
tt_input_tensor_b,
compute_with_storage_grid_size=compute_grid_size,
output_mem_config=output_mem_config,
output_dtype=output_dtype,
)
num_cache_entries += ttl.program_cache.num_entries() - num_cache_entries_start

if sharded:
tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved(
tt_output_tensor_on_device, interleaved_mem_config
)

tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

input_tensor_a = input_tensor_a.to(torch.float)
input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1)
golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2)

allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"

assert num_cache_entries == 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"

void kernel_main() {
uint32_t has_work = get_arg_val<uint32_t>(0);
if (has_work == 0) return;
uint32_t dst_addr = get_arg_val<uint32_t>(1);
uint32_t num_tiles = get_arg_val<uint32_t>(2);
uint32_t start_id = get_arg_val<uint32_t>(3);

constexpr uint32_t cb_id_out = get_compile_time_arg_val(0);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;

#ifdef OUT_SHARDED
cb_wait_front(cb_id_out, num_tiles);
#else

// single-tile ublocks
constexpr uint32_t onetile = 1;
const uint32_t tile_bytes = get_tile_size(cb_id_out);
const DataFormat data_format = get_dataformat(cb_id_out);

const InterleavedAddrGenFast<dst_is_dram> s = {
.bank_base_address = dst_addr,
.page_size = tile_bytes,
.data_format = data_format
};

#ifdef BACKWARDS
uint32_t end_id = start_id - num_tiles;
for (uint32_t i = start_id; i != end_id; -- i) {
#else
uint32_t end_id = start_id + num_tiles;
for (uint32_t i = start_id; i < end_id; ++ i) {
#endif
cb_wait_front(cb_id_out, onetile);
uint32_t l1_read_addr = get_read_ptr(cb_id_out);
noc_async_write_tile(i, s, l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_out, onetile);
}
#endif
}
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/transformer_tms/multi_core_split_query_key_value_and_split_heads/multi_core_split_query_key_value_and_split_heads.cpp \
tt_eager/tt_dnn/op_library/transformer_tms/multi_core_concatenate_heads/multi_core_concatenate_heads.cpp \
tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp \
tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp \
tt_eager/tt_dnn/op_library/run_operation.cpp \
tt_eager/tt_dnn/op_library/split/split_tiled.cpp \
tt_eager/tt_dnn/op_library/split/split_last_dim_two_chunks_tiled.cpp \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/matmul.h"
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/untilize.h"

using std::uint32_t;

// matmul C=A*B using dims MK*KN = MN (row major order)
//
namespace NAMESPACE {
void MAIN {

constexpr uint32_t onetile = 1;

constexpr uint32_t transpose_hw = get_compile_time_arg_val(0);

uint32_t has_work = get_arg_val<uint32_t>(0);
if (has_work == 0) return;
uint32_t batch = get_arg_val<uint32_t>(1);
uint32_t Mt = get_arg_val<uint32_t>(2);
uint32_t Kt = get_arg_val<uint32_t>(3);
uint32_t Nt = get_arg_val<uint32_t>(4);

constexpr uint32_t cb_in0 = 0;
constexpr uint32_t cb_in1 = 1;
constexpr uint32_t cb_intermed0 = 24;
constexpr uint32_t cb_intermed1 = 25;
constexpr uint32_t cb_intermed2 = 26;
constexpr uint32_t out_cb_id = 16;

constexpr uint32_t num_rows_in_one_tile = 32;

mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);

for (uint32_t nb = 0; nb < batch; ++nb) {
for (uint32_t mt_C = 0; mt_C < Mt; ++mt_C) { // output tile of C
cb_wait_front(cb_in0, Kt);
for (uint32_t nt_C = 0; nt_C < Nt; ++nt_C) { // output tile index of C
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; ++tile_row_id) {
tile_regs_acquire();
for (uint32_t kt = 0; kt < Kt; ++kt) {
cb_wait_front(cb_in1, onetile);

matmul_tiles(cb_in0, cb_in1, kt, 0, 0, transpose_hw);

cb_pop_front(cb_in1, onetile);
}
tile_regs_commit();

cb_reserve_back(cb_intermed0, onetile);
tile_regs_wait();
pack_tile(0, cb_intermed0);
tile_regs_release();
cb_push_back(cb_intermed0, onetile);

// untilize tile and write to CB::c_intermed1
unpack_reconfig_data_format_srca(cb_in1, cb_intermed0);
cb_wait_front(cb_intermed0, onetile);
untilize_init_short(cb_intermed0);
cb_reserve_back(cb_intermed1, 1);
untilize_block(cb_intermed0, 1, cb_intermed1);
cb_push_back(cb_intermed1, 1);

cb_pop_front(cb_intermed0, 1);
untilize_uninit(cb_intermed0);

unpack_reconfig_data_format_srca(cb_intermed0, cb_in1);
mm_init_short(transpose_hw);
}

// cb_intermed2 comes from reader; untilized row-major tile
unpack_reconfig_data_format_srca(cb_in1, cb_intermed2);
pack_reconfig_data_format(cb_intermed1, out_cb_id);
cb_wait_front(cb_intermed2, 1);
cb_reserve_back(out_cb_id, onetile);

// tilize CB::intermed2 and write to CB::c_out0
tilize_init_short(cb_intermed2, 1);
tilize_block(cb_intermed2, 1, out_cb_id);
cb_push_back(out_cb_id, 1);

cb_pop_front(cb_intermed2, 1);
tilize_uninit();

// Hangs when in0 is BFLOAT8_B if we don't force the reconfig
unpack_reconfig_data_format_srca(cb_in1);
pack_reconfig_data_format(out_cb_id, cb_intermed0);
mm_init_short(transpose_hw);
} // Nt

cb_pop_front(cb_in0, Kt);
} // Mt
} // batch

}
} // NAMESPACE
Loading

0 comments on commit 8a5cc63

Please sign in to comment.