From 8a5cc63fe7637ceacb2398bb6cbb1165c08962cb Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Fri, 19 Jan 2024 21:57:06 +0000 Subject: [PATCH] #4681: Add new group_attn_matmul (uplift + optimizations of attn_matmul) - Same as attn_matmul but we can have kv_heads > 1 - kv_heads is mcasted by 32 cores to all q_head cores - Supports interleaved and height sharded (row or col) for any mix of in0, in1, or output - This op is fully dynamic across input shape, similar to eltwise_binary - Add unit testing for group_attn_matmul and program caching --- .../unit_testing/test_attn_matmul.py | 186 +++++++- .../writer_transformer_group_attn_matmul.cpp | 46 ++ tt_eager/tt_dnn/module.mk | 1 + .../compute/transformer_group_attn_matmul.cpp | 101 ++++ ...er_mcast_transformer_group_attn_matmul.cpp | 295 ++++++++++++ .../multi_core_group_attn_matmul.cpp | 448 ++++++++++++++++++ .../transformer_tms/transformer_tms.cpp | 169 ++++++- .../transformer_tms/transformer_tms.hpp | 35 ++ .../primary/transformers/module.hpp | 6 + 9 files changed, 1280 insertions(+), 7 deletions(-) create mode 100644 tt_eager/tt_dnn/kernels/dataflow/writer_transformer_group_attn_matmul.cpp create mode 100644 tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp create mode 100644 tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_mcast_transformer_group_attn_matmul.cpp create mode 100644 tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py index 4ac117ff0c0..9bd9387eacb 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py @@ -4,10 +4,10 @@ import pytest import torch +import pytest import tt_lib as ttl -from models.utility_functions import print_diff_argmax, comp_pcc -from models.utility_functions import skip_for_wormhole_b0 +from models.utility_functions import comp_pcc, skip_for_wormhole_b0 def generate_input_shapes(): @@ -90,3 +90,185 @@ def test_attn_matmul_with_program_cache(in0_dtype, in1_dtype, out_dtype, device, allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize( + "shard_orientation", + (ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.ShardOrientation.COL_MAJOR), +) +@pytest.mark.parametrize( + "output_sharded", + (False, True), +) +@pytest.mark.parametrize( + "in1_sharded", + (False, True), +) +@pytest.mark.parametrize( + "in0_sharded", + (False, True), +) +@pytest.mark.parametrize( + "batch, K, seq_len, q_heads, kv_heads", + ((32, 64, 128, 16, 1), (32, 64, 128, 32, 2)), +) +def test_group_attn_matmul( + batch, K, seq_len, q_heads, kv_heads, in0_sharded, in1_sharded, output_sharded, shard_orientation, device +): + # NOTE: For interleaved kv_heads, batch 64, 96, etc... should be supported + torch.manual_seed(0) + + compute_grid_size = device.compute_with_storage_grid_size() + + interleaved_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ) + + # NOTE: Mixed precision is supported as well + in0_dtype = ttl.tensor.DataType.BFLOAT16 + in1_dtype = ttl.tensor.DataType.BFLOAT16 + output_dtype = ttl.tensor.DataType.BFLOAT16 + + q_len = 1 + input_shape_a = [q_len, q_heads, batch, K] + input_shape_b = [batch, kv_heads, K, seq_len] + + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ( + ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + tt_input_tensor_b = ( + ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + + if in0_sharded: + tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_a, + compute_grid_size, + [q_len * batch, K], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + if in1_sharded: + tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_b, + compute_grid_size, + [kv_heads * K, seq_len], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + if output_sharded: + output_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + else: + output_mem_config = interleaved_mem_config + + tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=compute_grid_size, + output_mem_config=output_mem_config, + output_dtype=output_dtype, + ) + if output_sharded: + tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( + tt_output_tensor_on_device, interleaved_mem_config + ) + + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + input_tensor_a = input_tensor_a.to(torch.float) + input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize("sharded", [False, True]) +@pytest.mark.parametrize("output_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +def test_group_attn_matmul_with_program_cache(in0_dtype, in1_dtype, output_dtype, sharded, device, use_program_cache): + torch.manual_seed(0) + + compute_grid_size = device.compute_with_storage_grid_size() + + interleaved_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ) + + shard_orientation = ttl.tensor.ShardOrientation.COL_MAJOR # Only used if sharded + + q_len = 1 + batch = 32 if sharded else 64 + num_cache_entries = 0 # Only track cache entries of group_attn_matmul + for K, seq_len, q_heads, kv_heads in ((96, 64, 10, 2), (64, 128, 50, 5)): + input_shape_a = [q_len, q_heads, batch, K] + input_shape_b = [batch, kv_heads, K, seq_len] + + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ( + ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + tt_input_tensor_b = ( + ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + + if sharded: + tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_a, + compute_grid_size, + [q_len * batch, K], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_b, + compute_grid_size, + [kv_heads * K, seq_len], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + output_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + else: + output_mem_config = interleaved_mem_config + + num_cache_entries_start = ttl.program_cache.num_entries() + tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=compute_grid_size, + output_mem_config=output_mem_config, + output_dtype=output_dtype, + ) + num_cache_entries += ttl.program_cache.num_entries() - num_cache_entries_start + + if sharded: + tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( + tt_output_tensor_on_device, interleaved_mem_config + ) + + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + input_tensor_a = input_tensor_a.to(torch.float) + input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + assert num_cache_entries == 1 diff --git a/tt_eager/tt_dnn/kernels/dataflow/writer_transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/kernels/dataflow/writer_transformer_group_attn_matmul.cpp new file mode 100644 index 00000000000..8f82e5c496f --- /dev/null +++ b/tt_eager/tt_dnn/kernels/dataflow/writer_transformer_group_attn_matmul.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t has_work = get_arg_val(0); + if (has_work == 0) return; + uint32_t dst_addr = get_arg_val(1); + uint32_t num_tiles = get_arg_val(2); + uint32_t start_id = get_arg_val(3); + + constexpr uint32_t cb_id_out = get_compile_time_arg_val(0); + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + + #ifdef OUT_SHARDED + cb_wait_front(cb_id_out, num_tiles); + #else + + // single-tile ublocks + constexpr uint32_t onetile = 1; + const uint32_t tile_bytes = get_tile_size(cb_id_out); + const DataFormat data_format = get_dataformat(cb_id_out); + + const InterleavedAddrGenFast s = { + .bank_base_address = dst_addr, + .page_size = tile_bytes, + .data_format = data_format + }; + + #ifdef BACKWARDS + uint32_t end_id = start_id - num_tiles; + for (uint32_t i = start_id; i != end_id; -- i) { + #else + uint32_t end_id = start_id + num_tiles; + for (uint32_t i = start_id; i < end_id; ++ i) { + #endif + cb_wait_front(cb_id_out, onetile); + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write_tile(i, s, l1_read_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_out, onetile); + } + #endif +} diff --git a/tt_eager/tt_dnn/module.mk b/tt_eager/tt_dnn/module.mk index 00bb5b9ab90..bfe67abb6a7 100644 --- a/tt_eager/tt_dnn/module.mk +++ b/tt_eager/tt_dnn/module.mk @@ -120,6 +120,7 @@ TT_DNN_SRCS = \ tt_eager/tt_dnn/op_library/transformer_tms/multi_core_split_query_key_value_and_split_heads/multi_core_split_query_key_value_and_split_heads.cpp \ tt_eager/tt_dnn/op_library/transformer_tms/multi_core_concatenate_heads/multi_core_concatenate_heads.cpp \ tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp \ + tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp \ tt_eager/tt_dnn/op_library/run_operation.cpp \ tt_eager/tt_dnn/op_library/split/split_tiled.cpp \ tt_eager/tt_dnn/op_library/split/split_last_dim_two_chunks_tiled.cpp \ diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp new file mode 100644 index 00000000000..de278457ccc --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/matmul.h" +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/untilize.h" + +using std::uint32_t; + +// matmul C=A*B using dims MK*KN = MN (row major order) +// +namespace NAMESPACE { +void MAIN { + + constexpr uint32_t onetile = 1; + + constexpr uint32_t transpose_hw = get_compile_time_arg_val(0); + + uint32_t has_work = get_arg_val(0); + if (has_work == 0) return; + uint32_t batch = get_arg_val(1); + uint32_t Mt = get_arg_val(2); + uint32_t Kt = get_arg_val(3); + uint32_t Nt = get_arg_val(4); + + constexpr uint32_t cb_in0 = 0; + constexpr uint32_t cb_in1 = 1; + constexpr uint32_t cb_intermed0 = 24; + constexpr uint32_t cb_intermed1 = 25; + constexpr uint32_t cb_intermed2 = 26; + constexpr uint32_t out_cb_id = 16; + + constexpr uint32_t num_rows_in_one_tile = 32; + + mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw); + + for (uint32_t nb = 0; nb < batch; ++nb) { + for (uint32_t mt_C = 0; mt_C < Mt; ++mt_C) { // output tile of C + cb_wait_front(cb_in0, Kt); + for (uint32_t nt_C = 0; nt_C < Nt; ++nt_C) { // output tile index of C + for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; ++tile_row_id) { + tile_regs_acquire(); + for (uint32_t kt = 0; kt < Kt; ++kt) { + cb_wait_front(cb_in1, onetile); + + matmul_tiles(cb_in0, cb_in1, kt, 0, 0, transpose_hw); + + cb_pop_front(cb_in1, onetile); + } + tile_regs_commit(); + + cb_reserve_back(cb_intermed0, onetile); + tile_regs_wait(); + pack_tile(0, cb_intermed0); + tile_regs_release(); + cb_push_back(cb_intermed0, onetile); + + // untilize tile and write to CB::c_intermed1 + unpack_reconfig_data_format_srca(cb_in1, cb_intermed0); + cb_wait_front(cb_intermed0, onetile); + untilize_init_short(cb_intermed0); + cb_reserve_back(cb_intermed1, 1); + untilize_block(cb_intermed0, 1, cb_intermed1); + cb_push_back(cb_intermed1, 1); + + cb_pop_front(cb_intermed0, 1); + untilize_uninit(cb_intermed0); + + unpack_reconfig_data_format_srca(cb_intermed0, cb_in1); + mm_init_short(transpose_hw); + } + + // cb_intermed2 comes from reader; untilized row-major tile + unpack_reconfig_data_format_srca(cb_in1, cb_intermed2); + pack_reconfig_data_format(cb_intermed1, out_cb_id); + cb_wait_front(cb_intermed2, 1); + cb_reserve_back(out_cb_id, onetile); + + // tilize CB::intermed2 and write to CB::c_out0 + tilize_init_short(cb_intermed2, 1); + tilize_block(cb_intermed2, 1, out_cb_id); + cb_push_back(out_cb_id, 1); + + cb_pop_front(cb_intermed2, 1); + tilize_uninit(); + + // Hangs when in0 is BFLOAT8_B if we don't force the reconfig + unpack_reconfig_data_format_srca(cb_in1); + pack_reconfig_data_format(out_cb_id, cb_intermed0); + mm_init_short(transpose_hw); + } // Nt + + cb_pop_front(cb_in0, Kt); + } // Mt + } // batch + +} +} // NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_mcast_transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_mcast_transformer_group_attn_matmul.cpp new file mode 100644 index 00000000000..43b6f498edc --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_mcast_transformer_group_attn_matmul.cpp @@ -0,0 +1,295 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + uint32_t i = 0; + + uint32_t has_work = get_arg_val(i++); + const bool has_work_bool = has_work == 1; + + uint32_t src0_addr = get_arg_val(i++); + uint32_t src1_addr = get_arg_val(i++); + uint32_t Mt = get_arg_val(i++); + uint32_t Kt = get_arg_val(i++); + uint32_t Nt = get_arg_val(i++); + uint32_t MtKt = get_arg_val(i++); + uint32_t num_kv_heads = get_arg_val(i++); // in1[1] (ie. in1 C) + uint32_t in1_KtNt = get_arg_val(i++); + uint32_t in1_CKtNt_skip = get_arg_val(i++); // 0 if in0 and in1 Kt are the same + uint32_t in1_CKtNt_mul_32 = get_arg_val(i++); + uint32_t blocks = get_arg_val(i++); + uint32_t in0_start_id = get_arg_val(i++); + uint32_t in1_start_id = get_arg_val(i++); + uint32_t kv_heads_addr_offset = get_arg_val(i++); + + uint32_t in1_mcast_dest_noc_start_x = get_arg_val(i++); + uint32_t in1_mcast_dest_noc_start_y = get_arg_val(i++); + uint32_t in1_mcast_dest_noc_end_x = get_arg_val(i++); + uint32_t in1_mcast_dest_noc_end_y = get_arg_val(i++); + uint32_t in1_mcast_num_dests = get_arg_val(i++); + uint32_t in1_mcast_num_cores = get_arg_val(i++); + uint32_t in1_mcast_sender_semaphore_addr = get_arg_val(i++); + uint32_t in1_mcast_receiver_semaphore_addr = get_arg_val(i++); + + uint32_t in1_mcast_sender_size_bytes = get_arg_val(i++); + uint32_t in1_mcast_sender_id = get_arg_val(i++); + uint32_t in1_mcast_sender_num_x = get_arg_val(i++); + uint32_t in1_mcast_sender_num_y = get_arg_val(i++); + volatile tt_l1_ptr uint32_t *in1_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(i)); i+=in1_mcast_sender_num_x; + volatile tt_l1_ptr uint32_t *in1_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(i)); i+=in1_mcast_sender_num_y; + + + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1; + #define transpose_hw_bool get_compile_time_arg_val(2) == 1 + constexpr bool row_major = (bool) get_compile_time_arg_val(3); + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; // copy single KV heads for Q heads + constexpr uint32_t cb_id_in2 = 2; // mcast receiver + constexpr uint32_t cb_id_in3 = 3; // all interleaved or sharded KV heads for one user batch + constexpr uint32_t cb_id_intermed0 = 24; + constexpr uint32_t cb_id_intermed1 = 25; + constexpr uint32_t cb_id_intermed2 = 26; + + constexpr uint32_t onetile = 1; + constexpr uint32_t num_rows_in_one_tile = 32; + const uint32_t in1_tile_bytes = get_tile_size(cb_id_in1); + + #ifdef IN0_SHARDED + if (has_work_bool) { + cb_reserve_back(cb_id_in0, blocks * MtKt); + cb_push_back(cb_id_in0, blocks * MtKt); + } + #else + const uint32_t in0_tile_bytes = get_tile_size(cb_id_in0); + const DataFormat in0_data_format = get_dataformat(cb_id_in0); + const InterleavedAddrGenFast s0 = { + .bank_base_address = src0_addr, + .page_size = in0_tile_bytes, + .data_format = in0_data_format + }; + #endif + + #ifndef IN1_SHARDED + const DataFormat in1_data_format = get_dataformat(cb_id_in1); + const InterleavedAddrGenFast s1 = { + .bank_base_address = src1_addr, + .page_size = in1_tile_bytes, + .data_format = in1_data_format + }; + #endif + + // Mcast setup + // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted + volatile tt_l1_ptr uint32_t* in1_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(in1_mcast_receiver_semaphore_addr); + noc_semaphore_set(in1_mcast_receiver_semaphore_addr_ptr, VALID); + // local address that will be atomically incremented by mcast receivers, to know when all receivers are ready + // to receive the mcast + volatile tt_l1_ptr uint32_t* in1_mcast_sender_semaphore_addr_ptr = reinterpret_cast(in1_mcast_sender_semaphore_addr); + + uint64_t in1_mcast_sender_semaphore_noc_addr_vec[num_rows_in_one_tile]; + if constexpr(row_major) { + uint32_t x = 0, y = 0; + for (uint32_t i = 0; i < num_rows_in_one_tile; ++i) { + in1_mcast_sender_semaphore_noc_addr_vec[i] = get_noc_addr(in1_mcast_sender_noc_x[x], in1_mcast_sender_noc_y[y], in1_mcast_sender_semaphore_addr); + ++x; + if (x == in1_mcast_sender_num_x) { + x = 0; + ++y; + } + } + } else { + uint32_t x = 0, y = 0; + for (uint32_t i = 0; i < num_rows_in_one_tile; ++i) { + in1_mcast_sender_semaphore_noc_addr_vec[i] = get_noc_addr(in1_mcast_sender_noc_x[x], in1_mcast_sender_noc_y[y], in1_mcast_sender_semaphore_addr); + ++y; + if (y == in1_mcast_sender_num_y) { + y = 0; + ++x; + } + } + } + + uint64_t in1_multicast_noc_addr = get_noc_multicast_addr( + in1_mcast_dest_noc_start_x, + in1_mcast_dest_noc_start_y, + in1_mcast_dest_noc_end_x, + in1_mcast_dest_noc_end_y, + 0 + ); + + uint64_t in1_mcast_receiver_semaphore_noc_addr = in1_multicast_noc_addr | in1_mcast_receiver_semaphore_addr; + + + // CB write ptr; no pop/push for cb 2 and 3 so write/read ptr's never change + uint32_t l1_write_addr_in2 = get_write_ptr(cb_id_in2); + uint32_t l1_write_addr_in3 = get_write_ptr(cb_id_in3); + uint64_t in1_multicast_data_addr = in1_multicast_noc_addr | l1_write_addr_in2; + uint64_t noc_l1_read_addr_for_kv_heads = get_noc_addr(l1_write_addr_in2 + kv_heads_addr_offset); + + // TODO: Clean this up; don't think this will work if we double buffer intermed 1/2 + uint32_t cb_intermed1_addr_initial = get_read_ptr(cb_id_intermed1); + uint32_t cb_intermed2_addr_initial = get_write_ptr(cb_id_intermed2); + uint32_t cb_intermed1_addr; + uint32_t cb_intermed2_addr; + constexpr uint32_t bfloat16_row_bytes = 64; + + // Only used for interleaved + uint32_t in0_batch = in0_start_id; + uint32_t in1_batch; + uint32_t in0_Mt; + uint32_t in1_Nt; + uint32_t in0_tensor_id; + uint32_t in1_tensor_id; + + // Only used for sharded + // Don't need to track batch because user batch must be 32 (ie. Mt must be 1) + uint64_t in1_sharded_cb_noc_addr_Nt = get_noc_addr(l1_write_addr_in3); // Read/write ptr should be the same + uint64_t in1_sharded_cb_noc_addr; + uint32_t Nt_bytes = Nt * in1_tile_bytes; + uint32_t in1_KtNt_bytes = in1_KtNt * in1_tile_bytes; + uint32_t in1_CKtNt_skip_bytes = in1_CKtNt_skip * in1_tile_bytes; + for (uint32_t b = 0; b < blocks; b++) { + in0_Mt = in0_batch; + in1_batch = in1_start_id; + + for (uint32_t m = 0; m < Mt; m++) { + in1_Nt = in1_batch; + + #ifndef IN0_SHARDED + if (has_work_bool) { + in0_tensor_id = in0_Mt; + cb_reserve_back(cb_id_in0, Kt); + for (uint32_t kt = 0; kt < Kt; kt++) { + // Read in0 tile at (mt, kt) + uint32_t l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(in0_tensor_id, s0, l1_write_addr_in0); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, onetile); + + in0_tensor_id++; // in0 is MK + } + } + #endif + + for (uint32_t n = 0; n < Nt; n++) { + cb_intermed1_addr = cb_intermed1_addr_initial; + cb_intermed2_addr = cb_intermed2_addr_initial; + in1_tensor_id = in1_Nt; + in1_sharded_cb_noc_addr = in1_sharded_cb_noc_addr_Nt; + + if (has_work_bool) { + cb_reserve_back(cb_id_intermed2, 1); + } + for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) { + for (uint32_t kt = 0; kt < Kt; kt++) { + // Read in1 tile at (kt, nt) + if (tile_row_id == in1_mcast_sender_id) { + // MCAST SENDER: send all kv_heads in one user batch + #ifdef IN1_SHARDED + // Copy to cb_id_in2 to mcast + uint64_t in1_sharded_cb_current_noc_addr = in1_sharded_cb_noc_addr; + uint32_t in2_current_l1_write_addr = l1_write_addr_in2; + for (uint32_t kv_heads_id = 0; kv_heads_id < num_kv_heads; kv_heads_id++) { + noc_async_read(in1_sharded_cb_current_noc_addr, in2_current_l1_write_addr, in1_tile_bytes); + in1_sharded_cb_current_noc_addr += in1_KtNt_bytes; // Increment by Nt to get to next kv_heads + in2_current_l1_write_addr += in1_tile_bytes; + } + // These indices are local to each core, so don't modify when looping num_rows_in_one_tile + in1_sharded_cb_noc_addr += Nt_bytes; // Kt is in in1[2], so stride is Nt + noc_async_read_barrier(); + #else + uint32_t in1_tensor_current_id = in1_tensor_id; + uint32_t in2_current_l1_write_addr = l1_write_addr_in2; + for (uint32_t kv_heads_id = 0; kv_heads_id < num_kv_heads; kv_heads_id++) { + noc_async_read_tile(in1_tensor_current_id, s1, in2_current_l1_write_addr); + + in1_tensor_current_id += in1_KtNt; // Increment by KtNt to get to next kv_heads + in2_current_l1_write_addr += in1_tile_bytes; + } + noc_async_read_barrier(); + #endif + + // wait until all in1 mcast destinations have atomically incremented the in1 semaphore_addr (i.e. its value should be in1_mcast_num_dests), then reset + // the semaphore_addr value back to zero for the next block + noc_semaphore_wait(in1_mcast_sender_semaphore_addr_ptr, in1_mcast_num_dests); + noc_semaphore_set(in1_mcast_sender_semaphore_addr_ptr, 0); + + // Now we have the block in the CB address, we can mcast to dests! + // num_dests will source, since we are copying to a different local CB as well + noc_async_write_multicast_loopback_src(l1_write_addr_in2, in1_multicast_data_addr, in1_mcast_sender_size_bytes, in1_mcast_num_cores + 1); + + // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf + // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). + + // We should also multicast VALID flag to destinations for receiver semaphore + noc_semaphore_set_multicast(in1_mcast_receiver_semaphore_addr, in1_mcast_receiver_semaphore_noc_addr, in1_mcast_num_cores); + + // Write barrier needed since we mcast to self, and also needed to finish sending mcast flag before we modify locally + noc_async_write_barrier(); + } else { + // MCAST RECEIVER: receive all kv_heads in one user batch + // Set in1 semaphore value to INVALID + noc_semaphore_set(in1_mcast_receiver_semaphore_addr_ptr, INVALID); + + // Atomic increment source core counter + uint64_t in1_mcast_sender_semaphore_noc_addr = in1_mcast_sender_semaphore_noc_addr_vec[tile_row_id]; + noc_semaphore_inc(in1_mcast_sender_semaphore_noc_addr, 1); + + // wait on in1 semaphore value to become VALID (set by mcast sender after it multicasts data) + noc_semaphore_wait(in1_mcast_receiver_semaphore_addr_ptr, VALID); + } + if (has_work_bool) { + // Choose matching kv_heads for q_heads + cb_reserve_back(cb_id_in1, onetile); + noc_async_read(noc_l1_read_addr_for_kv_heads, get_write_ptr(cb_id_in1), in1_tile_bytes); + noc_async_read_barrier(); + cb_push_back(cb_id_in1, onetile); + } + + #if (transpose_hw_bool) + in1_tensor_id++; // Kt is in in1[3], so it is contiguous in memory + #else + in1_tensor_id += Nt; // Kt is in in1[2], so stride is Nt + #endif + } // Kt loop + + if (has_work_bool) { + // Read 32 untilized tiles and select correct rows to reconstruct single correct tile + cb_wait_front(cb_id_intermed1, 1); + noc_async_read(get_noc_addr(cb_intermed1_addr), cb_intermed2_addr, bfloat16_row_bytes); + noc_async_read_barrier(); + cb_pop_front(cb_id_intermed1, 1); + cb_intermed1_addr += bfloat16_row_bytes; + cb_intermed2_addr += bfloat16_row_bytes; + } + + in1_tensor_id += in1_CKtNt_skip; // different depending on transpose_hw + } // 32 tiles loop + + if (has_work_bool) { + cb_push_back(cb_id_intermed2, 1); + } + + // Next tile in Nt + #if (transpose_hw_bool) + in1_Nt += Kt; // next tile in Nt is in in1[2], so stride is Kt + #else + in1_Nt++; + #endif + + in1_sharded_cb_noc_addr_Nt += in1_tile_bytes; + } // Nt loop + + in0_Mt += Kt; + // here, KtNt is the stride of the full in1 tensor (ie. max cache length is incorporated in one of Kt or Nt depending on transpose_hw) + in1_batch += in1_CKtNt_mul_32; // different depending on transpose_hw + } // Mt loop + in0_batch += MtKt; + } // B loop +} diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp new file mode 100644 index 00000000000..b669698683c --- /dev/null +++ b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp @@ -0,0 +1,448 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/transformer_tms/transformer_tms.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" + +using namespace tt::constants; +using namespace tt; + +namespace tt { +namespace operations { +namespace primary { +namespace transformers { + + +operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, const Tensor &b, Tensor& output, std::optional num_tokens, std::optional transpose_hw, CoreCoord compute_with_storage_grid_size, const bool row_major) { + + tt_metal::Program program{}; + + const auto& ashape = a.shape(), bshape = b.shape(); + + tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype()); + tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype()); + tt::DataFormat interm_data_format = DataFormat::Float16_b; + tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype()); + uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); + uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); + uint32_t interm_single_tile_size = tt_metal::detail::TileSize(interm_data_format); + uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); + MathFidelity math_fidelity = MathFidelity::LoFi; + + tt_metal::Buffer *src0_buffer = a.buffer(); + tt_metal::Buffer *src1_buffer = b.buffer(); + tt_metal::Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + // This should allocate a DRAM buffer on the device + tt_metal::Device *device = a.device(); + + // Load kernels on all device cores, because we use cached program for input shapes with changing shapes + CoreCoord device_compute_with_storage_grid = device->compute_with_storage_grid_size(); + auto all_device_cores = CoreRange({0, 0}, {device_compute_with_storage_grid.x - 1, device_compute_with_storage_grid.y - 1}); + + // See set_runtime_args for how input shapes are used; these are the variables needed for setting up kernels and CBs + const bool transpose_hw_bool = transpose_hw.value_or(false); + uint32_t KV_HEADS = bshape[1]; // bshape[0] is user batch + uint32_t Kt = ashape[3]/TILE_WIDTH; + + // Mcast args + auto in1_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_device_cores, INVALID); + auto in1_mcast_receiver_semaphore = tt_metal::CreateSemaphore(program, all_device_cores, INVALID); + + // Only first 32 of cores mcast KV heads to match num_rows_in_one_tile in reader kernel, so these coordinates are static if we cache on compute_with_storage_grid_size + // TODO: If this is not the case, then we should set reader_runtime_args to max possible size and update sender noc coordinates based on input + CoreCoord mcast_sender_grid = ((CoreRangeSet) num_cores_to_corerange_set(TILE_HEIGHT, compute_with_storage_grid_size, row_major)).bounding_box().grid_size(); + std::vector in1_mcast_sender_noc_x(mcast_sender_grid.x); + std::vector in1_mcast_sender_noc_y(mcast_sender_grid.y); + for(uint32_t core_idx_x = 0; core_idx_x < mcast_sender_grid.x; ++core_idx_x) { + in1_mcast_sender_noc_x[core_idx_x] = device->worker_core_from_logical_core({core_idx_x, 0}).x; + } + for(uint32_t core_idx_y = 0; core_idx_y < mcast_sender_grid.y; ++core_idx_y) { + in1_mcast_sender_noc_y[core_idx_y] = device->worker_core_from_logical_core({0, core_idx_y}).y; + } + + // Set up CBs + const bool in0_is_sharded = a.is_sharded(); + const bool in1_is_sharded = b.is_sharded(); + const bool output_is_sharded = output.is_sharded(); + + // CB for in0 (ie. q_heads) + uint32_t src0_cb_index = CB::c_in0; + CBHandle cb_src0; + if (in0_is_sharded) { + uint32_t cb0_num_input_tiles = a.shard_spec().value().numel() / TILE_HW; // Should be full MtKt and C should be 1 + tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}}) + .set_page_size(src0_cb_index, in0_single_tile_size).set_globally_allocated_address(*src0_buffer); + cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + } else { + uint32_t cb0_num_input_tiles = Kt * 2; + tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}}) + .set_page_size(src0_cb_index, in0_single_tile_size); + cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + } + + // CB for in1 (ie. one kv_heads matching q_heads); for MQA, can probably optimize away unnecessary copies from cb2 to cb1 + uint32_t src1_cb_index = CB::c_in1; + uint32_t cb1_num_input_tiles = 2; + tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(cb1_num_input_tiles * in1_single_tile_size, {{src1_cb_index, in1_data_format}}) + .set_page_size(src1_cb_index, in1_single_tile_size); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); + + // CB for interleaved/sharded KV heads for mcasting; mcasts to same CB + uint32_t src2_cb_index = CB::c_in2; + uint32_t cb2_num_input_tiles = KV_HEADS; + tt_metal::CircularBufferConfig cb_src2_config = tt_metal::CircularBufferConfig(cb2_num_input_tiles * in1_single_tile_size, {{src2_cb_index, in1_data_format}}) + .set_page_size(src2_cb_index, in1_single_tile_size); + auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src2_config); + + // CB for sharded KV heads + CBHandle cb_src3 = 0; // unused if KV heads is interleaved + if (in1_is_sharded) { + uint32_t src3_cb_index = CB::c_in3; + uint32_t cb3_num_input_tiles = b.shard_spec().value().numel() / TILE_HW; // Should be full CKtNt and batch must be 32 + tt_metal::CircularBufferConfig cb_src3_config = tt_metal::CircularBufferConfig(cb3_num_input_tiles * in1_single_tile_size, {{src3_cb_index, in1_data_format}}) + .set_page_size(src3_cb_index, in1_single_tile_size).set_globally_allocated_address(*src1_buffer); + cb_src3 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src3_config); + } + + // Intermediate CBs for handling untilizing, copying rows, and tilizing to output CB + uint32_t cb_intermed0_index = CB::c_intermed0; + tt_metal::CircularBufferConfig cb_interm0_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed0_index, interm_data_format}}) + .set_page_size(cb_intermed0_index, interm_single_tile_size); + auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm0_config); + + uint32_t cb_intermed1_index = CB::c_intermed1; + tt_metal::CircularBufferConfig cb_interm1_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed1_index, interm_data_format}}) + .set_page_size(cb_intermed1_index, interm_single_tile_size); + auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm1_config); + + uint32_t cb_intermed2_index = CB::c_intermed2; + tt_metal::CircularBufferConfig cb_interm2_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed2_index, interm_data_format}}) + .set_page_size(cb_intermed2_index, interm_single_tile_size); + auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config); + + // CB for output (if sharded, full num tiles per core) + uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 + CBHandle cb_output; + if (output_is_sharded) { + uint32_t num_output_tiles = output.shard_spec().value().numel() / TILE_HW; // Should be full MtNt and C should be 1 + tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * output_single_tile_size, {{output_cb_index, output_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size).set_globally_allocated_address(*dst_buffer); + cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_output_config); + } else { + uint32_t num_output_tiles = 2; + tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * output_single_tile_size, {{output_cb_index, output_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); + cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_output_config); + } + + const bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + const bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + (uint32_t) src0_is_dram, + (uint32_t) src1_is_dram, + (uint32_t) transpose_hw_bool, + (uint32_t) row_major, + }; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = { + (std::uint32_t) output_cb_index, + (std::uint32_t) dst_is_dram + }; + + std::map reader_kernel_defines; + std::map writer_kernel_defines; + if (in0_is_sharded) { + reader_kernel_defines["IN0_SHARDED"] = "1"; + } + if (in1_is_sharded) { + reader_kernel_defines["IN1_SHARDED"] = "1"; + } + if (output_is_sharded) { + writer_kernel_defines["OUT_SHARDED"] = "1"; + } + + tt_metal::NOC reader_noc = tt_metal::detail::GetPreferredNOCForDRAMRead(tt::Cluster::instance().arch()); // Default is NOC_1 + const bool reader_noc_is_NOC_0 = reader_noc == tt_metal::NOC::NOC_0; + tt_metal::NOC writer_noc = reader_noc_is_NOC_0 ? tt_metal::NOC::NOC_1 : tt_metal::NOC::NOC_0; + auto reader_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_mcast_transformer_group_attn_matmul.cpp", + all_device_cores, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = reader_noc, + .compile_args = reader_compile_time_args, + .defines = reader_kernel_defines, + } + ); + + auto writer_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/writer_transformer_group_attn_matmul.cpp", + all_device_cores, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = writer_noc, + .compile_args = writer_compile_time_args, + .defines = writer_kernel_defines, + } + ); + + vector compute_args = { + (uint32_t) transpose_hw_bool, // transpose_hw for matmul_init + }; // bmm compute kernel the B, Mt, Nt are just 3 for loops that technically act as 1 large loop, so only set Nt for simplicity + + auto compute_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp", + all_device_cores, + tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_args} + ); + + // Reader runtime args that need updating for each core (after setting some defaults based on shape) + constexpr uint32_t HAS_WORK_VECTOR_IDX = 0; + constexpr uint32_t BLOCKS_VECTOR_IDX = 11; + constexpr uint32_t IN0_START_ID_VECTOR_IDX = 12; + constexpr uint32_t KV_HEADS_ADDR_OFFSET_VECTOR_IDX = 14; + constexpr uint32_t MCAST_SENDER_ID_VECTOR_IDX = 24; + constexpr uint32_t writer_runtime_args_size = 4; + constexpr uint32_t compute_runtime_args_size = 5; + // reader runtime args have 26 fixed args + mcast_sender coords for x and y (TODO: number of sender cores are currently hard-coded to be 32) + const uint32_t reader_runtime_args_size = 27 + in1_mcast_sender_noc_x.size() + in1_mcast_sender_noc_y.size(); + + auto set_runtime_args = [ + num_tokens, + transpose_hw, + row_major, + compute_with_storage_grid_size, + device_compute_with_storage_grid, + reader_id, + writer_id, + compute_kernel_id, + cb_src0, + cb_src2, + cb_src3, + cb_output, + in0_single_tile_size, + in1_single_tile_size, + output_single_tile_size, + in0_is_sharded, + in1_is_sharded, + output_is_sharded, + reader_noc_is_NOC_0, + in1_mcast_sender_semaphore, + in1_mcast_receiver_semaphore, + in1_mcast_sender_noc_x, + in1_mcast_sender_noc_y, + HAS_WORK_VECTOR_IDX, + BLOCKS_VECTOR_IDX, + IN0_START_ID_VECTOR_IDX, + KV_HEADS_ADDR_OFFSET_VECTOR_IDX, + MCAST_SENDER_ID_VECTOR_IDX, + reader_runtime_args_size, + writer_runtime_args_size, + compute_runtime_args_size + ] + ( + Program& program, + const Tensor& a, + const Tensor& b, + const Tensor& output + ) { + tt_metal::Buffer *src0_buffer = a.buffer(); + tt_metal::Buffer *src1_buffer = b.buffer(); + tt_metal::Buffer *dst_buffer = output.buffer(); + + const auto& ashape = a.shape(), bshape = b.shape(); + + tt_metal::Device *device = a.device(); + + // A block of work is one MtNt + uint32_t Q_HEADS = ashape[1]; // ashape[0] is q_len (always 1) and ashape[1] is Q num_heads; only parallelize on this + // Must always have at least 32 cores active since there are always 32 mcast cores for KV_HEADS + // TODO: Currently, we always mcast to at least 32 cores even when Q_HEADS < 32; we can optimize if we pass in proper mcast receiver grid based on Q_HEADS + // TODO: If batch > 32 (ie. 64), each core handles all batches; only supported for interleaved KV_heads + // TODO: For sharded KV_heads, user batch must be 32 due to how we shard + // TODO: To generalize to allow parallelizing/sharding across generic batch for KV_heads, we need to track sender cores across batch-number of rows instead of 32 + // TODO: Only support one block of work (ie. 1 Q head per core) because each core assumes only one KV_heads to use + auto [num_cores, all_cores, core_group_1, core_group_2, num_output_blocks_per_core_group_1, num_output_blocks_per_core_group_2] = split_work_to_cores(compute_with_storage_grid_size, std::max(Q_HEADS, TILE_HEIGHT), row_major); + TT_FATAL(num_output_blocks_per_core_group_1 == 1 and num_output_blocks_per_core_group_2 == 0, "Group attention matmul only supports one q_heads per core. Increase compute grid size to at least have as many cores as q_heads!"); + + // C = torch.matmul(A.transpose(0, 2) * B).transpose(0, 2) + // MN = MK*KN + // Note, in1 K may not be the same as in0 K. We will read up to in0 K from in1 K for matmul. + const bool transpose_hw_bool = transpose_hw.value_or(false); + const uint32_t num_tokens_val = num_tokens.value_or(0); // should not be nullopt if transpose_hw=true + + uint32_t KV_HEADS = bshape[1]; // bshape[0] is user batch + uint32_t Mt = ashape[2]/TILE_HEIGHT; + uint32_t Kt = ashape[3]/TILE_WIDTH; + // For transpose_hw=true, in1_Kt is same as in0_Kt but on bshape[3] + // For transpose_hw=false, in1_Kt is on bshape[2] but represents the max cache length to read from (ie. may not equal in0_Kt) + uint32_t in1_Kt = transpose_hw_bool ? Kt : bshape[2]/TILE_HEIGHT; + uint32_t Nt = transpose_hw_bool ? num_tokens_val/TILE_HEIGHT : bshape[3]/TILE_WIDTH; + uint32_t MtKt = Mt * Kt; + uint32_t MtNt = Mt * Nt; + // For transpose_hw=true, in1_Kt is max cache length + // For transpose_hw=false, bshape[2] is max cache length + uint32_t in1_KtNt = transpose_hw_bool ? bshape[2]/TILE_HEIGHT * in1_Kt : in1_Kt * Nt; + uint32_t in1_CKtNt = KV_HEADS * in1_KtNt; + uint32_t in1_CKtNt_skip = in1_CKtNt - (transpose_hw_bool ? in1_Kt : Kt * Nt); // Decrement by how much we increment while iterating through Kt + + // Mcast receiver args + CoreRange all_cores_bounding_box = all_cores.bounding_box(); + uint32_t mcast_num_cores = all_cores_bounding_box.size(); + CoreCoord mcast_receiver_grid = all_cores_bounding_box.grid_size(); + CoreCoord top_left_core = all_cores_bounding_box.start; + CoreCoord bottom_right_core = all_cores_bounding_box.end; + CoreCoord top_left_core_physical = device->worker_core_from_logical_core(top_left_core); + CoreCoord bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); + + // Default reader runtime args + std::vector reader_runtime_args = { + 0, // has_work + src0_buffer->address(), + src1_buffer->address(), + Mt, + Kt, + Nt, + MtKt, + KV_HEADS, + in1_KtNt, + in1_CKtNt_skip, // Skip to get next batch for in1 after reading in0 Kt + in1_CKtNt * TILE_HEIGHT, // in1 stride; skips 32 * KtNt in bshape[0] for one block of MtNt + 0, // blocks of work + 0, // in0_start_id + 0, // in1_start_id; always start at 0 for each block of work and let kernels handle id tracking; for sharded, this isn't used + 0, // kv_heads_addr_offset; l1 offset in bytes to identify which kv_heads each q_heads is mapped to + + // mcast args + (uint32_t) (reader_noc_is_NOC_0 ? top_left_core_physical.x : bottom_right_core_physical.x), // in1_mcast_dest_noc_start_x + (uint32_t) (reader_noc_is_NOC_0 ? top_left_core_physical.y : bottom_right_core_physical.y), // in1_mcast_dest_noc_start_y + (uint32_t) (reader_noc_is_NOC_0 ? bottom_right_core_physical.x : top_left_core_physical.x), // in1_mcast_dest_noc_end_x + (uint32_t) (reader_noc_is_NOC_0 ? bottom_right_core_physical.y : top_left_core_physical.y), // in1_mcast_dest_noc_end_y + num_cores - 1, // in1_mcast_num_dests + mcast_num_cores - 1, // in1_mcast_num_cores + in1_mcast_sender_semaphore, + in1_mcast_receiver_semaphore, + KV_HEADS * in1_single_tile_size, // in1_mcast_sender_size_bytes + 0, // in1_mcast_sender_id + (uint32_t) in1_mcast_sender_noc_x.size(), // in1_mcast_sender_num_x + (uint32_t) in1_mcast_sender_noc_y.size(), // in1_mcast_sender_num_y + }; + // TODO: Length of these variables should be static in length since we hard-code 32 mcast sender cores and cache on compute_with_storage_grid_size + reader_runtime_args.insert(reader_runtime_args.end(), in1_mcast_sender_noc_x.begin(), in1_mcast_sender_noc_x.end()); + reader_runtime_args.insert(reader_runtime_args.end(), in1_mcast_sender_noc_y.begin(), in1_mcast_sender_noc_y.end()); + + std::vector cores = grid_to_cores_with_noop( + all_cores_bounding_box.end.x, + all_cores_bounding_box.end.y, + device_compute_with_storage_grid.x, + device_compute_with_storage_grid.y, + row_major + ); + uint32_t g1_numcores = core_group_1.num_cores(); + uint32_t g2_numcores = core_group_2.num_cores(); + + std::vector> all_reader_runtime_args = { cores.size(), reader_runtime_args }; + std::vector> all_writer_runtime_args = { cores.size(), std::vector(writer_runtime_args_size) }; + std::vector> all_compute_runtime_args = { cores.size(), std::vector(compute_runtime_args_size) }; + + // Set runtime args + uint32_t num_output_blocks_per_core; + for (uint32_t i = 0, num_blocks_written = 0; i < num_cores; i++){ + const CoreCoord &core = cores.at(i); + + if (i < g1_numcores) { + num_output_blocks_per_core = num_output_blocks_per_core_group_1; + } else { + num_output_blocks_per_core = num_output_blocks_per_core_group_2; + } + + uint32_t kv_heads_id = i / (Q_HEADS / KV_HEADS); + // Runtime method of turning off kernels/code blocks + // Needed because some cores only have partial readers for reading kv_heads + uint32_t has_work = i < Q_HEADS; + + // Update core dependent runtime args + reader_runtime_args[HAS_WORK_VECTOR_IDX] = has_work; + reader_runtime_args[BLOCKS_VECTOR_IDX] = num_output_blocks_per_core; + reader_runtime_args[IN0_START_ID_VECTOR_IDX] = num_blocks_written * MtKt; + reader_runtime_args[KV_HEADS_ADDR_OFFSET_VECTOR_IDX] = kv_heads_id * in1_single_tile_size; + reader_runtime_args[MCAST_SENDER_ID_VECTOR_IDX] = i; + + // Update runtime_args vectors + all_reader_runtime_args[i] = reader_runtime_args; + all_writer_runtime_args[i] = { + has_work, + dst_buffer->address(), + num_output_blocks_per_core * MtNt, // num_tiles + num_blocks_written * MtNt, // start_id + }; + all_compute_runtime_args[i] = { + has_work, + num_output_blocks_per_core, // B + Mt, // Mt + Kt, // Kt + Nt, // Nt + }; + + num_blocks_written += num_output_blocks_per_core; + } + + SetRuntimeArgs(program, reader_id, cores, all_reader_runtime_args); + SetRuntimeArgs(program, writer_id, cores, all_writer_runtime_args); + SetRuntimeArgs(program, compute_kernel_id, cores, all_compute_runtime_args); + + // Update dynamic CBs + uint32_t cb2_num_input_tiles = KV_HEADS; + UpdateCircularBufferTotalSize(program, cb_src2, cb2_num_input_tiles * in1_single_tile_size); + + if (in0_is_sharded) { + uint32_t cb0_num_input_tiles = a.shard_spec().value().numel() / TILE_HW; // Should be full MtKt and C should be 1 + UpdateDynamicCircularBufferAddress(program, cb_src0, *src0_buffer); + UpdateCircularBufferTotalSize(program, cb_src0, cb0_num_input_tiles * in0_single_tile_size); + } + if (in1_is_sharded) { + uint32_t cb3_num_input_tiles = b.shard_spec().value().numel() / TILE_HW; // Should be full CKtNt and batch must be 32 + UpdateDynamicCircularBufferAddress(program, cb_src3, *src1_buffer); + UpdateCircularBufferTotalSize(program, cb_src3, cb3_num_input_tiles * in1_single_tile_size); + } + if (output_is_sharded) { + uint32_t num_output_tiles = output.shard_spec().value().numel() / TILE_HW; // Should be full MtNt and C should be 1 + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + UpdateCircularBufferTotalSize(program, cb_output, num_output_tiles * output_single_tile_size); + } + }; + + set_runtime_args(program, a, b, output); + + auto override_runtime_arguments_callback = [ + set_runtime_args + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors + ) { + const auto& output_tensor = output_tensors.size() == 1 ? output_tensors.at(0) : input_tensors.at(0); + + set_runtime_args(program, input_tensors.at(0), input_tensors.at(1), output_tensor); + }; + + return {std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; +} + +} // namespace transformers +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 537c578465d..cb898492469 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -4,6 +4,7 @@ #include "tt_dnn/op_library/transformer_tms/transformer_tms.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" +#include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" @@ -53,12 +54,10 @@ std::vector SplitFusedQKVAndSplitHeads::create_output_tensors(const std: // shard spec uint32_t per_core_M_qv = (num_heads / num_cores_y) * M; // 768 uint32_t per_core_N_qv = K; // 64 - ShardSpec shard_spec_qv = ShardSpec{.shard_grid=all_cores, - .shard_shape={per_core_M_qv, per_core_N_qv}, .shard_orientation=ShardOrientation::COL_MAJOR}; + ShardSpec shard_spec_qv = ShardSpec{.grid=all_cores, .shape={per_core_M_qv, per_core_N_qv}, .orientation=ShardOrientation::COL_MAJOR}; uint32_t per_core_M_k = (num_heads / num_cores_y) * K; // 128 uint32_t per_core_N_k = M; // 384 - ShardSpec shard_spec_k = ShardSpec{.shard_grid=all_cores, - .shard_shape={per_core_M_k, per_core_N_k}, .shard_orientation=ShardOrientation::COL_MAJOR}; + ShardSpec shard_spec_k = ShardSpec{.grid=all_cores, .shape={per_core_M_k, per_core_N_k}, .orientation=ShardOrientation::COL_MAJOR}; // create sharded tensors auto mem_config_qv = this->output_mem_config; mem_config_qv.shard_spec = shard_spec_qv; @@ -224,7 +223,6 @@ tt::stl::reflection::Attributes AttnMatmul::attributes() const { }; } - const operation::Hash AttnMatmul::compute_program_hash(const std::vector &input_tensors) const { return operation::hash_operation( this->transpose_hw, @@ -236,6 +234,167 @@ const operation::Hash AttnMatmul::compute_program_hash(const std::vector input_tensors.at(1).dtype()); } + +void GroupAttnMatmul::validate(const std::vector& input_tensors) const { + // input_a: [q_len, q_heads, batch, head_dim] + // input_b: [batch, kv_heads, head_dim, kv_len] + // intermediate: [q_heads, batch, batch, kv_len] + // output: [q_len, q_heads, batch, kv_len] + + TT_FATAL(input_tensors.size() == 2); + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + TT_FATAL((input_tensor_a.layout() == Layout::TILE && input_tensor_b.layout() == Layout::TILE), "Inputs to matmul must be tilized"); + + // TODO: Uplift to support BFLOAT8_B and mixed precision + TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE, "Operands to matmul need to be on device!"); + TT_FATAL(input_tensor_a.device() == input_tensor_b.device(), "Operands to matmul need to be on the same device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr and input_tensor_b.buffer() != nullptr, "Operands to matmul need to be allocated in buffers on device!"); + + const auto ashape = input_tensor_a.shape(); + const auto bshape = input_tensor_b.shape(); + TT_FATAL((ashape[0] == 1), "Input q_len must be 1!"); + TT_FATAL((ashape[1] % bshape[1] == 0), "Number of q_heads must be divisible by kv_heads!"); + TT_FATAL((ashape[2] == bshape[0]), "Num of users must match!"); + + const auto num_cores_used = std::max(ashape[1], TILE_HEIGHT); // Need at least 32 cores for mcasting KV heads + TT_FATAL((num_cores_used <= this->compute_with_storage_grid_size.x * this->compute_with_storage_grid_size.y), "Compute grid size is too small for group attention matmul! For now, we require at most 1 q_heads per core."); + + + // Any sharded memory configs must be HEIGHT_SHARDED and have the same orientation + ShardOrientation shard_orientation = row_major ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR; + if (input_tensor_a.is_sharded()) { + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + TT_FATAL(input_tensor_a.shard_spec().value().orientation == shard_orientation, "Any sharded memory configs must have the same shard orientation as one another!"); + auto shard_shape = input_tensor_a.shard_spec().value().shape; + TT_FATAL(shard_shape[0] == input_tensor_a.shape()[2]); + TT_FATAL(shard_shape[1] == input_tensor_a.shape()[3]); + } + if (input_tensor_b.is_sharded()) { + TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + TT_FATAL(input_tensor_b.shard_spec().value().orientation == shard_orientation, "Any sharded memory configs must have the same shard orientation as one another!"); + auto shard_shape = input_tensor_b.shard_spec().value().shape; + TT_FATAL(shard_shape[0] == input_tensor_b.shape()[1] * input_tensor_b.shape()[2]); + TT_FATAL(shard_shape[1] == input_tensor_b.shape()[3]); + TT_FATAL(input_tensor_b.shape()[0] == 32, "Only batch 32 is supported for KV sharded!"); + } + if (this->output_mem_config.is_sharded()) { + TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + + // If user passes in output_mem_config with shard_spec, assert that it is the same as the one calculated in GroupAttnMatmul::create_output_tensors + if (this->output_mem_config.shard_spec.has_value()) { + const Shape output_shape = this->compute_output_shapes(input_tensors).at(0); + const uint32_t num_cores = output_shape[1]; + CoreRangeSet all_cores = num_cores_to_corerange_set(num_cores, this->compute_with_storage_grid_size, row_major); + + auto shard_shape = this->output_mem_config.shard_spec.value().shape; + TT_FATAL(this->output_mem_config.shard_spec.value().grid == all_cores, "Shard spec in output mem config must match shard spec calculated in GroupAttnMatmul::create_output_tensors!"); + TT_FATAL(this->output_mem_config.shard_spec.value().orientation == shard_orientation, "Any sharded memory configs must have the same shard orientation as one another!"); + TT_FATAL(shard_shape[0] == output_shape[2]); + TT_FATAL(shard_shape[1] == output_shape[3]); + } + } + + bool read_from_kv_cache = false; + if (this->num_tokens.has_value() or this->transpose_hw.has_value()) { + TT_FATAL((this->num_tokens.has_value() and this->transpose_hw.has_value()), "Must provide num_tokens and transpose_hw flag if we are reading from cache for in1!"); + TT_FATAL(this->num_tokens.value() % 32 == 0, "Number of tokens must be divisble by 32!"); + read_from_kv_cache = true; + } + + if (read_from_kv_cache) { + if (this->transpose_hw.value()) { + TT_FATAL(ashape[3] == bshape[3] && "For pre-attention matmul, dimension K for B is in B.shape[3], so A.shape[3] must match B.shape[3]"); // A.K == B.K + } else { + TT_FATAL(ashape[3] == this->num_tokens && "For post-attention matmul, dimension K (A.shape[3]) is the kv_seq_len in this case and must match the length of the cache we read"); // A.K == B.K + } + } else { + TT_FATAL(ashape[3] == bshape[2] && "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in attn_matmul op"); // A.K == B.K + } +} + +std::vector GroupAttnMatmul::compute_output_shapes(const std::vector& input_tensors) const { + // input_a: [q_len, q_heads, batch, head_dim] + // input_b: [batch, kv_heads, head_dim, kv_len] + // intermediate: [q_heads, batch, batch, kv_len] + // output: [q_len, q_heads, batch, kv_len] + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + const auto ashape = input_tensor_a.shape(); + const auto bshape = input_tensor_b.shape(); + + uint32_t N = bshape[3]; + if (this->transpose_hw.value_or(false)) { + N = this->num_tokens.value(); + } + + return {Shape{1, ashape[1], ashape[2], N}}; +} + +std::vector GroupAttnMatmul::create_output_tensors(const std::vector& input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + if (this->output_mem_config.is_sharded()) { + auto output_mem_config = this->output_mem_config; + if (this->output_mem_config.shard_spec.has_value()) { + output_mem_config.shard_spec = this->output_mem_config.shard_spec.value(); + } else { + const Shape output_shape = this->compute_output_shapes(input_tensors).at(0); + const uint32_t num_cores = output_shape[1]; + CoreRangeSet all_cores = num_cores_to_corerange_set(num_cores, this->compute_with_storage_grid_size, row_major); + + ShardOrientation shard_orientation = row_major ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR; + ShardSpec shard_spec = ShardSpec{.grid=all_cores, .shape={output_shape[2], output_shape[3]}, .orientation=shard_orientation}; + output_mem_config.shard_spec = shard_spec; + } + return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), this->output_dtype, Layout::TILE, input_tensor_a.device(), output_mem_config)}; + } else { + return operation::generic_create_output_tensors(*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); + } +} + +operation::ProgramWithCallbacks GroupAttnMatmul::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + auto& output_tensor = output_tensors.at(0); + + auto device_compute_with_storage_grid_size = input_tensor_a.device()->compute_with_storage_grid_size(); + TT_ASSERT((this->compute_with_storage_grid_size.x <= device_compute_with_storage_grid_size.x && this->compute_with_storage_grid_size.y <= device_compute_with_storage_grid_size.y), "Unsupported grid shape"); + + return multi_core_group_attn_matmul(input_tensor_a, input_tensor_b, output_tensor, this->num_tokens, this->transpose_hw, this->compute_with_storage_grid_size, this->row_major); +} + +tt::stl::reflection::Attributes GroupAttnMatmul::attributes() const { + return { + {"transpose_hw", this->transpose_hw}, + {"compute_with_storage_grid_size", this->compute_with_storage_grid_size.str()}, + {"output_mem_config", this->output_mem_config}, + {"output_dtype", this->output_dtype}, + {"row_major", this->row_major}, + }; +} + + +const operation::Hash GroupAttnMatmul::compute_program_hash(const std::vector &input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + + return operation::hash_operation( + this->transpose_hw, + this->compute_with_storage_grid_size.str(), + this->output_mem_config.memory_layout, + this->output_mem_config.buffer_type, + this->output_dtype, + this->row_major, + input_tensor_a.memory_config().memory_layout, + input_tensor_a.memory_config().buffer_type, + input_tensor_a.dtype(), + input_tensor_b.memory_config().memory_layout, + input_tensor_b.memory_config().buffer_type, + input_tensor_b.dtype() + ); +} + } // namespace transformers } // namespace primary } // namespace operations diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp index 4e092f765db..33620cb15cc 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp @@ -21,7 +21,9 @@ namespace transformers { operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads(const Tensor &input_tensor, std::vector &output, CoreCoord compute_with_storage_grid_size); operation::ProgramWithCallbacks multi_core_split_query_key_value_and_split_heads_sharded(const Tensor &input_tensor, std::vector &output, CoreCoord compute_with_storage_grid_size); operation::ProgramWithCallbacks multi_core_concat_heads(const Tensor &input_tensor, Tensor &output_tensor, CoreCoord compute_with_storage_grid_size); +// TODO: Group attention matmul will support sharding, mcasting, and should be faster; we should make attn_matmul (ie. KV heads = 1) a special case of group_attn_matmul and run the same op operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional num_tokens, std::optional transpose_hw, CoreCoord compute_with_storage_grid_size); +operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor &output_tensor, std::optional num_tokens, std::optional transpose_hw, CoreCoord compute_with_storage_grid_size, const bool row_major); struct SplitFusedQKVAndSplitHeads { CoreCoord compute_with_storage_grid_size; @@ -81,6 +83,39 @@ inline Tensor attn_matmul_from_cache(const Tensor &input_tensor_a, const Tensor return operation::run(AttnMatmul{num_tokens_rounded_up_to_32, transpose_hw, compute_with_storage_grid_size, mem_config, output_dtype.value_or(input_tensor_a.dtype())}, {input_tensor_a, input_tensor_b}).at(0); } +// TODO: Should we support option to read directly from cache (with optional transpose_hw)? +struct GroupAttnMatmul { + std::optional num_tokens; + std::optional transpose_hw; + CoreCoord compute_with_storage_grid_size; + MemoryConfig output_mem_config; + DataType output_dtype; + const bool row_major; // Specifies how work is distributed across cores + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; + tt::stl::reflection::Attributes attributes() const; + const operation::Hash compute_program_hash(const std::vector &input_tensors) const; +}; + +inline Tensor group_attn_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const CoreCoord& compute_with_storage_grid_size, const MemoryConfig& mem_config, std::optional output_dtype=std::nullopt) { + bool row_major = false; + // GroupAttnMatmul::validate will check that any sharded memory configs have same orientation + if (input_tensor_a.is_sharded()) { + row_major = input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR; + } else if (input_tensor_b.is_sharded()) { + row_major = input_tensor_b.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR; + } else if (mem_config.is_sharded()) { + if (mem_config.shard_spec.has_value()) { + row_major = mem_config.shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; + } + } + + return operation::run(GroupAttnMatmul{std::nullopt, std::nullopt, compute_with_storage_grid_size, mem_config, output_dtype.value_or(input_tensor_a.dtype()), row_major}, {input_tensor_a, input_tensor_b}).at(0); +} + } // namespace transformers } // namespace primary diff --git a/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp index 2e04335239b..f3be5a30b29 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/transformers/module.hpp @@ -41,6 +41,12 @@ void py_module(py::module& m_transformers) { py::arg().noconvert(), py::arg().noconvert(), py::arg("num_tokens").noconvert(), py::arg("transpose_hw").noconvert(), py::arg("compute_with_storage_grid_size").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_dtype").noconvert() = std::nullopt, R"doc( Performs the same matmul as attn_matmul, but fuses additional functionality for reading in in1. For in1, read num_tokens (rounded up to 32) from full cache along in1.shape()[2] (num_tokens must be > 0 and <= max_cache_len). For example, 64 tokens will be read for 32 < token_idx <= 64. Additional option to apply transpose_hw to in1 for pre-attention matmul with transpose_hw=true. For post-attention matmul, transpose_hw should be false. )doc"); + m_transformers.def("group_attn_matmul", &group_attn_matmul, + py::arg().noconvert(), py::arg().noconvert(), py::arg("compute_with_storage_grid_size").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_dtype").noconvert() = std::nullopt, R"doc( + Performs a special pre-softmax matmul with [q_len, q_heads, batch, head_dim] and [batch, kv_heads, head_dim, kv_len]. q_len and q_heads must be divisible by kv_heads. If kv_heads is sharded, then batch must be 32; otherwise, batch can any multiple of 32. An intermediate value of [q_heads, batch, batch, kv_len] is produced (only on device cores). Batch dim from Z and Y is combined by taking the 1st, 2nd, ..., and 32nd row of Y from the batches in Z. Final output tensor is [1, q_heads, batch, kv_len]. In PyTorch, this is equivalent to: + B = torch.repeat_interleave(B, q_heads // kv_heads, dim=1) + torch.matmul(A.transpose(0, 2), B).transpose(0, 2). Similar concept for post-softmax matmul. + )doc"); py::class_(m_transformers, "SoftmaxDefaultProgramConfig") .def(py::init<>());