From a2d4f5749216f928d1e675658bcb8bc2306c2022 Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 1 Dec 2023 21:24:22 +0000 Subject: [PATCH 1/5] #0: rebase to main --- .../layernorm/kernels/compute/layernorm.cpp | 144 ++++++++++-------- .../op_library/layernorm/layernorm_op.cpp | 22 +-- .../softmax/kernels/compute/softmax.cpp | 4 +- 3 files changed, 92 insertions(+), 78 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp index d55d4c4c178..66733778886 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp @@ -16,6 +16,7 @@ #include "compute_kernel_api/layernorm.h" #include "compute_kernel_api/tile_move_copy.h" + // SPLIT REDUCE across Cores namespace NAMESPACE { void MAIN { @@ -29,6 +30,9 @@ void MAIN { constexpr uint32_t subblock_w = get_compile_time_arg_val(6); constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(7); const bool is_allgather_worker = get_compile_time_arg_val(8) == 1; + constexpr uint32_t num_tiles_per_allgather_worker = get_compile_time_arg_val(9); + constexpr uint32_t num_tiles_per_block = get_compile_time_arg_val(10); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_intermed0); @@ -51,9 +55,9 @@ void MAIN { constexpr uint32_t cb_ex2 = tt::CB::dataflow4; // E[(x-E[x])^2] global reduce constexpr uint32_t cb_ex_external2 = tt::CB::dataflow5; constexpr uint32_t cb_ex_global = tt::CB::dataflow7; // E[x] global reduce - constexpr uint32_t cb_xmm2 = tt::CB::c_intermed2; // xmm^2 + constexpr uint32_t cb_xmm2 = cb_x; // xmm^2 constexpr uint32_t cb_ex2pe = tt::CB::c_intermed3; // E[(x-E[x])^2]+eps - constexpr uint32_t cb_fusion = tt::CB::c_intermed4; // stream gamma/beta + constexpr uint32_t cb_fusion = cb_xmm; // stream gamma/beta constexpr uint32_t cb_out = tt::CB::c_out0; int index_subblock_w_offset = 0; @@ -65,15 +69,15 @@ void MAIN { #else constexpr int cb_in = cb_in0; #endif - constexpr int cb_im = (do_gamma | do_beta) ? cb_fusion : cb_out; + constexpr int cb_im = (do_gamma | do_beta) ? cb_x : cb_out; constexpr int cb_outgamma = do_beta ? cb_fusion : cb_out; + // pre-add x + y + #ifdef FUSE_PRE_ADD + unpack_reconfig_data_format(tt::CB::c_in0, tt::CB::c_in0); + pack_reconfig_data_format(tt::CB::c_intermed0); + add_tiles_init(); for (uint32_t i = 0; i < block_h; i++) { - // pre-add x + y - #ifdef FUSE_PRE_ADD - unpack_reconfig_data_format(tt::CB::c_in0, tt::CB::c_in0); - pack_reconfig_data_format(tt::CB::c_intermed0); - add_tiles_init(); index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); @@ -91,14 +95,20 @@ void MAIN { cb_push_back(cb_in, subblock_w); index_subblock_w_offset += subblock_w; } - cb_wait_front(cb_in, block_w+index_h_offset); - unpack_reconfig_data_format(tt::CB::c_intermed0, tt::CB::c_intermed0); - #endif + index_h_offset += block_w; + } + unpack_reconfig_data_format(tt::CB::c_intermed0, tt::CB::c_intermed0); + cb_wait_front(cb_in, num_tiles_per_block); - // E[x], - reduce_init_delta(REDUCE_OP, REDUCE_DIM); - cb_wait_front(cb_scaler, 1); - cb_reserve_back(cb_ex_partial, 1); + // UNPACK(( DPRINT << TSLICE(cb_in, 0, SliceRange::h0_w0_32()) << ENDL() )); + #endif + + // E[x], + index_h_offset = 0; + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_wait_front(cb_scaler, 1); + cb_reserve_back(cb_ex_partial, block_h); + for (uint32_t i = 0; i < block_h; i++) { tile_regs_acquire(); for (uint32_t w = 0; w < block_w; w++) { reduce_tile(REDUCE_OP, REDUCE_DIM, cb_in, cb_scaler, w+index_h_offset, scaler0, dst0); @@ -107,10 +117,10 @@ void MAIN { tile_regs_wait(); pack_tile(dst0, cb_ex_partial); tile_regs_release(); - reduce_revert_delta(); - cb_push_back(cb_ex_partial, 1); index_h_offset += block_w; } + reduce_revert_delta(); + cb_push_back(cb_ex_partial, block_h); // global reduce, cb_ex <-- cb_ex_external, cb_ex_partial if constexpr(is_allgather_worker) { @@ -138,10 +148,10 @@ void MAIN { } cb_wait_front(cb_ex_global, block_h); + // x - E[x] index_h_offset = 0; + sub_bcast_cols_init_short(); for (uint32_t i = 0; i < block_h; i++) { - // x - E[x] - sub_bcast_cols_init_short(); index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); @@ -161,10 +171,12 @@ void MAIN { } cb_pop_front(cb_ex_global, 1); cb_pop_front(cb_in, block_w); - cb_wait_front(cb_xmm, block_w+index_h_offset); + } + cb_wait_front(cb_xmm, num_tiles_per_block); - // (x - E[x])^2, cb_mm2 <-- cb_xmm - mul_tiles_init(); + // (x - E[x])^2, cb_mm2 <-- cb_xmm + mul_tiles_init(); + for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); @@ -182,25 +194,27 @@ void MAIN { cb_push_back(cb_xmm2, subblock_w); index_subblock_w_offset += subblock_w; } + } + cb_wait_front(cb_xmm2, num_tiles_per_block); - // Var(x) - cb_reserve_back(cb_ex_partial2, 1); - reduce_init_delta(REDUCE_OP, REDUCE_DIM); + // Var(x) + cb_reserve_back(cb_ex_partial2, block_h); + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { tile_regs_acquire(); - cb_wait_front(cb_xmm2, block_w); for (uint32_t w = 0; w < block_w; w++) { - reduce_tile(REDUCE_OP, REDUCE_DIM, cb_xmm2, cb_scaler, w, scaler0, dst0); + reduce_tile(REDUCE_OP, REDUCE_DIM, cb_xmm2, cb_scaler, w+index_h_offset, scaler0, dst0); } tile_regs_commit(); tile_regs_wait(); pack_tile(dst0, cb_ex_partial2); tile_regs_release(); - reduce_revert_delta(); - cb_push_back(cb_ex_partial2, 1); - cb_pop_front(cb_xmm2, block_w); - cb_wait_front(cb_ex_partial2, 1); index_h_offset += block_w; } + reduce_revert_delta(); + cb_pop_front(cb_xmm2, num_tiles_per_block); + cb_push_back(cb_ex_partial2, block_h); // global reduce, cb_ex <-- cb_ex_external, cb_ex_partial if constexpr(is_allgather_worker) { @@ -252,22 +266,20 @@ void MAIN { } cb_wait_front(cb_ex_global, block_h); + + if constexpr(do_gamma == 0 && do_beta == 0) { + pack_reconfig_data_format(cb_out); + } + // (x - Ex) * 1/[sqrt(Var + eps)] + mul_bcast_cols_init_short(); index_h_offset = 0; for (uint32_t i = 0; i < block_h; i++) { - // (x - Ex) * 1/[sqrt(Var + eps)] - if constexpr(do_gamma == 0 && do_beta == 0) { - pack_reconfig_data_format(cb_out); - } else { - pack_reconfig_data_format(tt::CB::c_intermed0); - } - cb_wait_front(cb_xmm, block_w); - mul_bcast_cols_init_short(); index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); for (uint32_t w = 0; w < subblock_w; w++) { - index = w + index_subblock_w_offset; - mul_tiles_bcast_cols(cb_xmm, cb_ex_global, index, 0, w); + index = w + index_subblock_w_offset + index_h_offset; + mul_tiles_bcast_cols(cb_xmm, cb_ex_global, index, i, w); } tile_regs_commit(); cb_reserve_back(cb_im, subblock_w); @@ -279,23 +291,27 @@ void MAIN { cb_push_back(cb_im, subblock_w); index_subblock_w_offset += subblock_w; } - cb_pop_front(cb_ex_global, 1); - cb_pop_front(cb_xmm, block_w); - cb_wait_front(cb_im, block_w); - if constexpr(do_gamma) { - if constexpr(do_beta == 0) { - pack_reconfig_data_format(cb_out); - } - cb_wait_front(cb_im, block_w); - cb_wait_front(cb_gamma, block_w); - mul_bcast_rows_init_short(); + index_h_offset += block_w; + } + cb_pop_front(cb_ex_global, block_h); + cb_pop_front(cb_xmm, num_tiles_per_block); + cb_wait_front(cb_im, num_tiles_per_block); + + if constexpr(do_gamma) { + if constexpr(do_beta == 0) { + pack_reconfig_data_format(cb_out); + } + mul_bcast_rows_init_short(); + cb_wait_front(cb_gamma, block_w); + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); for (uint32_t w = 0; w < subblock_w; w++) { index = w + index_subblock_w_offset; - mul_tiles_bcast_rows(cb_im, cb_gamma, index, index, w); + mul_tiles_bcast_rows(cb_im, cb_gamma, index+index_h_offset, index, w); } tile_regs_commit(); cb_reserve_back(cb_outgamma, subblock_w); @@ -307,20 +323,24 @@ void MAIN { cb_push_back(cb_outgamma, subblock_w); index_subblock_w_offset += subblock_w; } - cb_pop_front(cb_im, block_w); + index_h_offset += block_w; } + cb_pop_front(cb_im, num_tiles_per_block); + cb_wait_front(cb_outgamma, num_tiles_per_block); + } - if constexpr(do_beta) { - pack_reconfig_data_format(cb_out); - cb_wait_front(cb_beta, block_w); - cb_wait_front(cb_fusion, block_w); - add_bcast_rows_init_short(); + if constexpr(do_beta) { + pack_reconfig_data_format(cb_out); + add_bcast_rows_init_short(); + cb_wait_front(cb_beta, block_w); + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { tile_regs_acquire(); for (uint32_t w = 0; w < subblock_w; w++) { index = w + index_subblock_w_offset; - add_tiles_bcast_rows(cb_fusion, cb_beta, index, index, w); + add_tiles_bcast_rows(cb_fusion, cb_beta, index + index_h_offset, index, w); } tile_regs_commit(); cb_reserve_back(cb_out, subblock_w); @@ -332,10 +352,10 @@ void MAIN { cb_push_back(cb_out, subblock_w); index_subblock_w_offset += subblock_w; } - cb_pop_front(cb_fusion, block_w); - cb_wait_front(cb_out, block_w+index_h_offset); + index_h_offset += block_w; } - index_h_offset += block_w; + cb_pop_front(cb_fusion, num_tiles_per_block); + cb_wait_front(cb_out, num_tiles_per_block); } } diff --git a/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp b/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp index 84b27e9e734..50ca928e2db 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp @@ -770,6 +770,8 @@ operation::ProgramWithCallbacks layernorm_sharded_( subblock_wt, num_subblocks_w, 1, + num_rows_per_all_to_all_worker, + block_ht * block_wt }; std::vector all_to_all_except_top_compute_compile_time_args = { 0, @@ -780,7 +782,9 @@ operation::ProgramWithCallbacks layernorm_sharded_( block_wt, subblock_wt, num_subblocks_w, - 1 + 1, + num_rows_per_all_to_all_worker, + block_ht * block_wt }; std::vector not_all_to_all_compute_compile_time_args = { 0, @@ -791,7 +795,9 @@ operation::ProgramWithCallbacks layernorm_sharded_( block_wt, subblock_wt, num_subblocks_w, - 0 + 0, + num_rows_per_all_to_all_worker, + block_ht * block_wt }; // compute kernel bool fp32_dest_acc_en = false; @@ -904,24 +910,12 @@ operation::ProgramWithCallbacks layernorm_sharded_( tt_metal::CircularBufferConfig ex_global_cb_config = tt_metal::CircularBufferConfig(ex_global_CB_size, {{ex_global_cb_index, cb_data_format}}) .set_page_size(ex_global_cb_index, single_tile_size); auto cb_ex_global = tt_metal::CreateCircularBuffer(program, all_cores, ex_global_cb_config); - // xmm2 - uint32_t xmm2_cb_index; - xmm2_cb_index = CB::c_intermed2; - tt_metal::CircularBufferConfig xmm2_cb_config = tt_metal::CircularBufferConfig(xmm2_CB_size, {{xmm2_cb_index, cb_data_format}}) - .set_page_size(xmm2_cb_index, single_tile_size); - auto cb_xmm2 = tt_metal::CreateCircularBuffer(program, all_cores, xmm2_cb_config); // ex2pe uint32_t cb_ex2pe_index; cb_ex2pe_index = CB::c_intermed3; tt_metal::CircularBufferConfig ex2pe_cb_config = tt_metal::CircularBufferConfig(ex2pe_CB_size, {{cb_ex2pe_index, cb_data_format}}) .set_page_size(cb_ex2pe_index, single_tile_size); auto cb_ex2pe = tt_metal::CreateCircularBuffer(program, all_cores, ex2pe_cb_config); - // fusion - uint32_t cb_fusion_index; - cb_fusion_index = CB::c_intermed4; - tt_metal::CircularBufferConfig fusion_cb_config = tt_metal::CircularBufferConfig(fusion_CB_size, {{cb_fusion_index, cb_data_format}}) - .set_page_size(cb_fusion_index, single_tile_size); - auto cb_fusion = tt_metal::CreateCircularBuffer(program, all_cores, fusion_cb_config); // out uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, {{output_cb_index, out_data_format}}) diff --git a/tt_eager/tt_dnn/op_library/softmax/kernels/compute/softmax.cpp b/tt_eager/tt_dnn/op_library/softmax/kernels/compute/softmax.cpp index 3a8cd14ec9f..74a191cdf40 100644 --- a/tt_eager/tt_dnn/op_library/softmax/kernels/compute/softmax.cpp +++ b/tt_eager/tt_dnn/op_library/softmax/kernels/compute/softmax.cpp @@ -64,7 +64,7 @@ void MAIN { unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn); // fused attn - cb_wait_front(cb_scale_mask, block_w); + // cb_wait_front(cb_scale_mask, block_w); cb_wait_front(cb_fused_attn, block_w); index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { @@ -111,7 +111,7 @@ void MAIN { // sum(exp(x)) ACQ(); reduce_init_delta(REDUCE_OP, REDUCE_DIM); - cb_wait_front(cb_exps, block_w); + // cb_wait_front(cb_exps, block_w); cb_wait_front(cb_bcast_scaler, 1); cb_reserve_back(cb_recipsumexps, 1); for (uint32_t w = 0; w < block_w; w++) { From 66657dd35e129be3bc69ee73f3a614b69997a86e Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 1 Dec 2023 21:07:31 +0000 Subject: [PATCH 2/5] #3629: fix bug in LN, bert pcc back to normal --- .../op_library/layernorm/kernels/compute/layernorm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp index 66733778886..7118f988c56 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp @@ -99,8 +99,6 @@ void MAIN { } unpack_reconfig_data_format(tt::CB::c_intermed0, tt::CB::c_intermed0); cb_wait_front(cb_in, num_tiles_per_block); - - // UNPACK(( DPRINT << TSLICE(cb_in, 0, SliceRange::h0_w0_32()) << ENDL() )); #endif // E[x], @@ -176,6 +174,7 @@ void MAIN { // (x - E[x])^2, cb_mm2 <-- cb_xmm mul_tiles_init(); + index_h_offset = 0; for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { @@ -194,6 +193,7 @@ void MAIN { cb_push_back(cb_xmm2, subblock_w); index_subblock_w_offset += subblock_w; } + index_h_offset += block_w; } cb_wait_front(cb_xmm2, num_tiles_per_block); @@ -260,7 +260,7 @@ void MAIN { pack_tile(dst0, cb_ex2pe); cb_push_back(cb_ex2pe, 1); tile_regs_release(); - cb_wait_front(cb_ex2pe, 1); + cb_wait_front(cb_ex2pe, 1+i); } } From bfb5b4e330d46afa2fb9c1b3d542de88733e2a4e Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 1 Dec 2023 21:11:35 +0000 Subject: [PATCH 3/5] #3629: remove redundant wait_front --- .../tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp index 7118f988c56..7e0dbd7b387 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp @@ -260,7 +260,6 @@ void MAIN { pack_tile(dst0, cb_ex2pe); cb_push_back(cb_ex2pe, 1); tile_regs_release(); - cb_wait_front(cb_ex2pe, 1+i); } } From aa374ea5e900d69c2e892a5ead50a0f47d05d4e9 Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 1 Dec 2023 21:17:29 +0000 Subject: [PATCH 4/5] #0: remove extra wait_front --- .../op_library/transformer_tms/compute/transpose_wh_sharded.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/compute/transpose_wh_sharded.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/compute/transpose_wh_sharded.cpp index e3fb92f1eda..4436c1b6a5a 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/compute/transpose_wh_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/compute/transpose_wh_sharded.cpp @@ -31,7 +31,6 @@ void MAIN { cb_push_back(cb_out1, 1); cb_pop_front(cb_im0, 1); - cb_wait_front(cb_out1, 1); } From cbe97c2feed87d9637a40b9a600b58d07bc3067f Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 1 Dec 2023 22:46:00 +0000 Subject: [PATCH 5/5] #0: set target fps to 330 --- models/demos/metal_BERT_large_11/tests/test_perf_bert11.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py b/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py index a2fb59ddcb8..d315f7bde69 100644 --- a/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py +++ b/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py @@ -171,7 +171,7 @@ def test_perf_virtual_machine( @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "expected_inference_time, expected_compile_time, inference_iterations", - ([0.0375, 10, 10],), + ([0.0364, 10, 10],), ) def test_perf_bare_metal( use_program_cache,