Skip to content

Commit

Permalink
#7504: add native transpose functionality into create_qkv_heads
Browse files Browse the repository at this point in the history
  - Also partially addresses #7519
  - TODO: add batch > cores support
  - TODO: add a separate q and kv tensor api (next step)
  - TODO: add cross attention support (sequence length of Q greater than KV)
  • Loading branch information
sjameelTT committed Apr 17, 2024
1 parent 3e77384 commit d755048
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_create_qkv_heads_test(
cores_h,
cores_w,
device,
transpose_k=False,
transpose_k,
in_mem_config=None,
out_mem_config=None,
):
Expand Down Expand Up @@ -73,12 +73,16 @@ def run_create_qkv_heads_test(
in0_t,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
transpose_k_heads=False,
transpose_k_heads=transpose_k,
output_mem_config=out_mem_config,
)

assert list(q.get_legacy_shape()) == [batch, num_q_heads, seq_len, head_dim]
assert list(k.get_legacy_shape()) == [batch, num_kv_heads, seq_len, head_dim]
if transpose_k:
assert list(k.get_legacy_shape()) == [batch, num_kv_heads, head_dim, seq_len]
else:
assert list(k.get_legacy_shape()) == [batch, num_kv_heads, seq_len, head_dim]

assert list(v.get_legacy_shape()) == [batch, num_kv_heads, seq_len, head_dim]

pyt_got_back_rm_q = tt2torch_tensor(q)
Expand All @@ -93,8 +97,15 @@ def run_create_qkv_heads_test(
ref_k = torch.reshape(ref_k, [batch, seq_len, num_kv_heads, head_dim]).transpose(-3, -2)
ref_v = torch.reshape(ref_v, [batch, seq_len, num_kv_heads, head_dim]).transpose(-3, -2)

if transpose_k:
ref_k = torch.transpose(ref_k, -2, -1)

if dtype == ttl.tensor.DataType.BFLOAT8_B:
pcc = 0.99
elif (
dtype == ttl.tensor.DataType.FLOAT32 and transpose_k
): # conversion from fp32 to tf32 when unpack writes to register for compute will decrease pcc in the transpose case
pcc = 0.9999999
else:
pcc = 1.0

Expand All @@ -120,6 +131,10 @@ def run_create_qkv_heads_test(
(ttl.tensor.DataType.BFLOAT8_B, ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.FLOAT32),
ids=["BFLOAT8_B", "BFLOAT16", "FLOAT32"],
)
@pytest.mark.parametrize(
"transpose_k",
(True, False),
)
@pytest.mark.parametrize(
"batch, seq_len, num_q_heads, num_kv_heads, head_dim, cores_h, cores_w",
(
Expand All @@ -138,6 +153,7 @@ def test_nlp_create_qkv_heads_test(
num_q_heads,
num_kv_heads,
head_dim,
transpose_k,
cores_h,
cores_w,
dtype,
Expand All @@ -146,4 +162,6 @@ def test_nlp_create_qkv_heads_test(
if is_grayskull() and dtype == ttl.tensor.DataType.FLOAT32:
pytest.skip("Skipping float32 tests on Grayskull")

run_create_qkv_heads_test(batch, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, cores_h, cores_w, device)
run_create_qkv_heads_test(
batch, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, cores_h, cores_w, device, transpose_k
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ void kernel_main() {
constexpr uint32_t v_size_per_group_t_bytes = get_compile_time_arg_val(14); // total size of all V heads (expecting 1) in a group

constexpr uint32_t cb_in0 = tt::CB::c_in0;
constexpr uint32_t cb_out0 = tt::CB::c_out0;
constexpr uint32_t cb_out1 = tt::CB::c_out1;
constexpr uint32_t cb_out2 = tt::CB::c_out2;

constexpr uint32_t cb_outq = tt::CB::c_out0;
#ifdef TRANSPOSE_K_HEADS
constexpr uint32_t cb_outk = tt::CB::c_intermed0;
#else
constexpr uint32_t cb_outk = tt::CB::c_out1;
#endif
constexpr uint32_t cb_outv = tt::CB::c_out2;


// copy one entire head_dim tile, then go to next sequence tile and do another head_dim.
Expand All @@ -49,9 +54,9 @@ void kernel_main() {

uint64_t src_noc_addr = get_noc_addr(get_read_ptr(cb_in0));
// re-order q
cb_reserve_back(cb_out0, q_num_tiles);
cb_reserve_back(cb_outq, q_num_tiles);

uint32_t q_write_addr = get_write_ptr(cb_out0);
uint32_t q_write_addr = get_write_ptr(cb_outq);
uint32_t src_noc_addr_offset_outer = 0;

uint32_t group_addr_offset = 0;
Expand All @@ -70,33 +75,45 @@ void kernel_main() {
group_addr_offset += group_t_size_bytes;
}
noc_async_read_barrier();
cb_push_back(cb_out0, q_num_tiles);
cb_push_back(cb_outq, q_num_tiles);

// re-order k
cb_reserve_back(cb_out1, k_num_tiles);
uint32_t k_write_addr = get_write_ptr(cb_out1);

cb_reserve_back(cb_outk, k_num_tiles);
uint32_t k_write_addr = get_write_ptr(cb_outk);
group_addr_offset = q_size_per_group_t_bytes;
for (uint32_t k = 0; k < groups_per_block; k++) { // number of kv heads inside the shard
uint32_t head_in_group_offset = 0;
for (uint32_t j = 0; j < k_heads_per_group; j++) { // go to next K heads in the group (expecting only 1 for K)
#ifdef TRANSPOSE_K_HEADS
for (uint32_t k_head_tile_offset = 0; k_head_tile_offset < k_head_size_bytes; k_head_tile_offset += single_tile_size_bytes) { // finish head after sequence length when transposing K
uint32_t seq_tile_offset = 0;
for (uint32_t i = 0; i < block_ht; i++) { // iterate across seq_len dimension tiles
uint64_t k_src_noc_addr = src_noc_addr + seq_tile_offset + head_in_group_offset + group_addr_offset + k_head_tile_offset;
noc_async_read(k_src_noc_addr, k_write_addr, single_tile_size_bytes); // read only one tile since we're transposing
k_write_addr += single_tile_size_bytes; // output address of next K head
seq_tile_offset += block_wt_size_bytes; // go to next tile in seq_len
}
}
#else
uint32_t seq_tile_offset = 0;
for (uint32_t i = 0; i < block_ht; i++) { // iterate across seq_len dimension tiles
uint64_t k_src_noc_addr = src_noc_addr + seq_tile_offset + head_in_group_offset + group_addr_offset;
noc_async_read(k_src_noc_addr, k_write_addr, k_head_size_bytes); // read one head worth of tiles
k_write_addr += k_head_size_bytes; // output address of next K head
seq_tile_offset += block_wt_size_bytes; // go to next tile in seq_len
}
#endif
head_in_group_offset += k_head_size_bytes;
}
group_addr_offset += group_t_size_bytes;
}
noc_async_read_barrier();
cb_push_back(cb_out1, k_num_tiles);

cb_push_back(cb_outk, k_num_tiles);

// re-order v
cb_reserve_back(cb_out2, v_num_tiles);
uint32_t v_write_addr = get_write_ptr(cb_out2);
cb_reserve_back(cb_outv, v_num_tiles);
uint32_t v_write_addr = get_write_ptr(cb_outv);
group_addr_offset = q_size_per_group_t_bytes + k_size_per_group_t_bytes;
for (uint32_t k = 0; k < groups_per_block; k++) { // number of kv heads inide the hard
uint32_t head_in_group_offset = 0;
Expand All @@ -113,5 +130,5 @@ void kernel_main() {
group_addr_offset += group_t_size_bytes;
}
noc_async_read_barrier();
cb_push_back(cb_out2, v_num_tiles);
cb_push_back(cb_outv, v_num_tiles);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace tt::constants;
using namespace tt;


static inline operation::ProgramWithCallbacks create_heads_combined_qkv_sharded(const Tensor &input_tensor, const vector<uint32_t> &&heads_per_group, const uint32_t head_dim, const uint32_t groups, std::vector<Tensor> &output) {
static inline operation::ProgramWithCallbacks create_heads_combined_qkv_sharded(const Tensor &input_tensor, const vector<uint32_t> &&heads_per_group, const uint32_t head_dim, const uint32_t groups, std::vector<Tensor> &output, bool transpose_k) {
// groups = kv_heads usually
// heads_per_group = [x 1 1] if qkv since q_heads >= kv_heads and k=v heads but this should be generic
TT_FATAL(head_dim % TILE_WIDTH == 0, fmt::format("head dim {} needs to be a multiple of tile width {}", head_dim, TILE_WIDTH));
Expand Down Expand Up @@ -84,11 +84,26 @@ static inline operation::ProgramWithCallbacks create_heads_combined_qkv_sharded(
(std::uint32_t) num_tiles_per_group[2] * single_tile_size, // size of V tiles in each group, in bytes
};

std::map<string, string> reader_defines;
if (transpose_k) {
reader_defines["TRANSPOSE_K_HEADS"] = "1";
}
auto reader_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/nlp_tms/kernels/dataflow/reader_create_qkv_heads_sharded.cpp",
all_cores,
tt_metal::ReaderDataMovementConfig(reader_compile_time_args));
tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines));

if (transpose_k) {
std::vector<uint32_t> compute_args = {
(std::uint32_t) block_ht*num_tiles_per_group[1]*groups_per_block, // number of K tiles
};
auto compute_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transpose_wh_sharded.cpp",
all_cores,
tt_metal::ComputeConfig{.compile_args = compute_args});
}

uint32_t input_size = block_ht*block_wt*single_tile_size;
uint32_t q_size = block_ht*num_tiles_per_group[0]*single_tile_size*groups_per_block;
Expand All @@ -112,6 +127,13 @@ static inline operation::ProgramWithCallbacks create_heads_combined_qkv_sharded(
.set_page_size(CB::c_out2, single_tile_size).set_globally_allocated_address(*output[2].buffer());
auto cb_out2_id = CreateCircularBuffer( program, all_cores, c_out2_config );

if (transpose_k) {
auto c_im0_config = CircularBufferConfig(k_size, {{CB::c_intermed0, data_format}})
.set_page_size(CB::c_intermed0, single_tile_size);
auto cb_im0_id = CreateCircularBuffer(program, all_cores, c_im0_config);
}


auto override_runtime_args_callback = [
cb_in0_id,
cb_out0_id,
Expand Down Expand Up @@ -172,7 +194,7 @@ namespace tt_metal {
*/
operation::ProgramWithCallbacks multi_core_create_qkv_heads_sharded(const Tensor &input_tensor_qkv, const uint32_t num_q_heads, const uint32_t num_kv_heads, const uint32_t head_dim, const bool transpose_k_heads, std::vector<Tensor>& output, CoreCoord compute_with_storage_grid_size) {
TT_FATAL(num_q_heads % num_kv_heads == 0, fmt::format("num q heads {} / num kv heads {} needs to be a whole number", num_q_heads, num_kv_heads));
return create_heads_combined_qkv_sharded(input_tensor_qkv, {num_q_heads/num_kv_heads, 1, 1}, head_dim, num_kv_heads, output);
return create_heads_combined_qkv_sharded(input_tensor_qkv, {num_q_heads/num_kv_heads, 1, 1}, head_dim, num_kv_heads, output, transpose_k_heads);
}

} // namespace tt_metal
Expand Down
14 changes: 8 additions & 6 deletions tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ void CreateQKVHeads::validate(const std::vector<Tensor> &input_tensors) const {
uint32_t num_h_cores = rm ? bbox.end.y + 1 : bbox.end.x + 1;
uint32_t num_w_cores = rm ? bbox.end.x + 1 : bbox.end.y + 1;

TT_FATAL(this->num_q_heads % this->num_kv_heads == 0, "Number of q heads {} must fit evenly into number of kv heads {}", this->num_q_heads, this->num_kv_heads);
TT_FATAL(this->num_q_heads % this->num_kv_heads == 0, fmt::format("Number of q heads {} must fit evenly into number of kv heads {}", this->num_q_heads, this->num_kv_heads));
TT_FATAL(input_shape[3] % (num_w_cores * TILE_WIDTH) == 0, fmt::format("Flattened hidden dimension {} must be a multiple of width cores {} * tile width {} to ensure that each core gets an even amount of tiles", input_shape[3], num_w_cores, TILE_WIDTH));

TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED);
Expand Down Expand Up @@ -332,19 +332,21 @@ std::vector<Tensor> CreateQKVHeads::create_output_tensors(const std::vector<Tens
//uint32_t num_q_heads_per_shard = q_shape[1] / num_w_cores;

uint32_t q_shard_h = q_shape[0] * q_shape[1] * q_shape[2] / num_cores; // want the API to work for different sequence lengths
uint32_t kv_shard_h = k_shape[0] * k_shape[1] * k_shape[2] / num_cores; // want the API to work for different sequence lengths
auto q_spec = ShardSpec(all_cores, {q_shard_h, head_dim}, shard_orientation);
auto kv_spec = ShardSpec(all_cores, {kv_shard_h, head_dim}, shard_orientation);
uint32_t k_shard_h = k_shape[0] * k_shape[1] * k_shape[2] / num_cores; // want the API to work for different sequence lengths
uint32_t v_shard_h = v_shape[0] * v_shape[1] * v_shape[2] / num_cores; // want the API to work for different sequence lengths

auto q_spec = ShardSpec(all_cores, {q_shard_h, q_shape[-1]}, shard_orientation);
auto k_spec = ShardSpec(all_cores, {k_shard_h, k_shape[-1]}, shard_orientation);
auto v_spec = ShardSpec(all_cores, {v_shard_h, v_shape[-1]}, shard_orientation);
// create sharded tensors
auto mem_config_q = this->output_mem_config;
mem_config_q.shard_spec = q_spec;

auto mem_config_k = this->output_mem_config;
mem_config_k.shard_spec = kv_spec;
mem_config_k.shard_spec = k_spec;

auto mem_config_v = this->output_mem_config;
mem_config_v.shard_spec = kv_spec;
mem_config_v.shard_spec = v_spec;

auto out_tensor_q = create_sharded_device_tensor(q_shape, input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config_q);
auto out_tensor_k = create_sharded_device_tensor(k_shape, input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config_k);
Expand Down
4 changes: 1 addition & 3 deletions ttnn/ttnn/operations/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,9 @@ def split_query_key_value_and_split_heads(
input_tensor,
num_q_heads=num_heads,
num_kv_heads=num_heads,
transpose_k_heads=False,
transpose_k_heads=transpose_key,
output_mem_config=memory_config,
)
if transpose_key:
key = ttnn.experimental.tensor.transpose(key, -2, -1, memory_config)

return query, key, value
else:
Expand Down

0 comments on commit d755048

Please sign in to comment.