Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opt LN_sharded and SMX_sharded #4147

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_perf_virtual_machine(
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"expected_inference_time, expected_compile_time, inference_iterations",
([0.0375, 10, 10],),
([0.0364, 10, 10],),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To tighten inference time tolerance, please run the model 3 times on CI and average the inference time and consider 15% tolerance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice, I have ran it for 3 times and passed with fps ~338

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @farbabi, if there is no other things needs change/test, could I get an approval? thanks

)
def test_perf_bare_metal(
use_program_cache,
Expand Down
145 changes: 82 additions & 63 deletions tt_eager/tt_dnn/op_library/layernorm/kernels/compute/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "compute_kernel_api/layernorm.h"
#include "compute_kernel_api/tile_move_copy.h"


// SPLIT REDUCE across Cores
namespace NAMESPACE {
void MAIN {
Expand All @@ -29,6 +30,9 @@ void MAIN {
constexpr uint32_t subblock_w = get_compile_time_arg_val(6);
constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(7);
const bool is_allgather_worker = get_compile_time_arg_val(8) == 1;
constexpr uint32_t num_tiles_per_allgather_worker = get_compile_time_arg_val(9);
constexpr uint32_t num_tiles_per_block = get_compile_time_arg_val(10);


binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_intermed0);

Expand All @@ -51,9 +55,9 @@ void MAIN {
constexpr uint32_t cb_ex2 = tt::CB::dataflow4; // E[(x-E[x])^2] global reduce
constexpr uint32_t cb_ex_external2 = tt::CB::dataflow5;
constexpr uint32_t cb_ex_global = tt::CB::dataflow7; // E[x] global reduce
constexpr uint32_t cb_xmm2 = tt::CB::c_intermed2; // xmm^2
constexpr uint32_t cb_xmm2 = cb_x; // xmm^2
constexpr uint32_t cb_ex2pe = tt::CB::c_intermed3; // E[(x-E[x])^2]+eps
constexpr uint32_t cb_fusion = tt::CB::c_intermed4; // stream gamma/beta
constexpr uint32_t cb_fusion = cb_xmm; // stream gamma/beta
constexpr uint32_t cb_out = tt::CB::c_out0;

int index_subblock_w_offset = 0;
Expand All @@ -65,15 +69,15 @@ void MAIN {
#else
constexpr int cb_in = cb_in0;
#endif
constexpr int cb_im = (do_gamma | do_beta) ? cb_fusion : cb_out;
constexpr int cb_im = (do_gamma | do_beta) ? cb_x : cb_out;
constexpr int cb_outgamma = do_beta ? cb_fusion : cb_out;

// pre-add x + y
#ifdef FUSE_PRE_ADD
unpack_reconfig_data_format(tt::CB::c_in0, tt::CB::c_in0);
pack_reconfig_data_format(tt::CB::c_intermed0);
add_tiles_init();
for (uint32_t i = 0; i < block_h; i++) {
// pre-add x + y
#ifdef FUSE_PRE_ADD
unpack_reconfig_data_format(tt::CB::c_in0, tt::CB::c_in0);
pack_reconfig_data_format(tt::CB::c_intermed0);
add_tiles_init();
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
Expand All @@ -91,14 +95,18 @@ void MAIN {
cb_push_back(cb_in, subblock_w);
index_subblock_w_offset += subblock_w;
}
cb_wait_front(cb_in, block_w+index_h_offset);
unpack_reconfig_data_format(tt::CB::c_intermed0, tt::CB::c_intermed0);
#endif
index_h_offset += block_w;
}
unpack_reconfig_data_format(tt::CB::c_intermed0, tt::CB::c_intermed0);
cb_wait_front(cb_in, num_tiles_per_block);
#endif

// E[x],
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
cb_wait_front(cb_scaler, 1);
cb_reserve_back(cb_ex_partial, 1);
// E[x],
index_h_offset = 0;
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
cb_wait_front(cb_scaler, 1);
cb_reserve_back(cb_ex_partial, block_h);
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
for (uint32_t w = 0; w < block_w; w++) {
reduce_tile(REDUCE_OP, REDUCE_DIM, cb_in, cb_scaler, w+index_h_offset, scaler0, dst0);
Expand All @@ -107,10 +115,10 @@ void MAIN {
tile_regs_wait();
pack_tile(dst0, cb_ex_partial);
tile_regs_release();
reduce_revert_delta();
cb_push_back(cb_ex_partial, 1);
index_h_offset += block_w;
}
reduce_revert_delta();
cb_push_back(cb_ex_partial, block_h);

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr(is_allgather_worker) {
Expand Down Expand Up @@ -138,10 +146,10 @@ void MAIN {
}
cb_wait_front(cb_ex_global, block_h);

// x - E[x]
index_h_offset = 0;
sub_bcast_cols_init_short();
for (uint32_t i = 0; i < block_h; i++) {
// x - E[x]
sub_bcast_cols_init_short();
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
Expand All @@ -161,10 +169,13 @@ void MAIN {
}
cb_pop_front(cb_ex_global, 1);
cb_pop_front(cb_in, block_w);
cb_wait_front(cb_xmm, block_w+index_h_offset);
}
cb_wait_front(cb_xmm, num_tiles_per_block);

// (x - E[x])^2, cb_mm2 <-- cb_xmm
mul_tiles_init();
// (x - E[x])^2, cb_mm2 <-- cb_xmm
mul_tiles_init();
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
Expand All @@ -182,25 +193,28 @@ void MAIN {
cb_push_back(cb_xmm2, subblock_w);
index_subblock_w_offset += subblock_w;
}
index_h_offset += block_w;
}
cb_wait_front(cb_xmm2, num_tiles_per_block);

// Var(x)
cb_reserve_back(cb_ex_partial2, 1);
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
// Var(x)
cb_reserve_back(cb_ex_partial2, block_h);
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
cb_wait_front(cb_xmm2, block_w);
for (uint32_t w = 0; w < block_w; w++) {
reduce_tile(REDUCE_OP, REDUCE_DIM, cb_xmm2, cb_scaler, w, scaler0, dst0);
reduce_tile(REDUCE_OP, REDUCE_DIM, cb_xmm2, cb_scaler, w+index_h_offset, scaler0, dst0);
}
tile_regs_commit();
tile_regs_wait();
pack_tile(dst0, cb_ex_partial2);
tile_regs_release();
reduce_revert_delta();
cb_push_back(cb_ex_partial2, 1);
cb_pop_front(cb_xmm2, block_w);
cb_wait_front(cb_ex_partial2, 1);
index_h_offset += block_w;
}
reduce_revert_delta();
cb_pop_front(cb_xmm2, num_tiles_per_block);
cb_push_back(cb_ex_partial2, block_h);

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr(is_allgather_worker) {
Expand Down Expand Up @@ -246,28 +260,25 @@ void MAIN {
pack_tile(dst0, cb_ex2pe);
cb_push_back(cb_ex2pe, 1);
tile_regs_release();
cb_wait_front(cb_ex2pe, 1);
}

}
cb_wait_front(cb_ex_global, block_h);


if constexpr(do_gamma == 0 && do_beta == 0) {
pack_reconfig_data_format(cb_out);
}
// (x - Ex) * 1/[sqrt(Var + eps)]
mul_bcast_cols_init_short();
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
// (x - Ex) * 1/[sqrt(Var + eps)]
if constexpr(do_gamma == 0 && do_beta == 0) {
pack_reconfig_data_format(cb_out);
} else {
pack_reconfig_data_format(tt::CB::c_intermed0);
}
cb_wait_front(cb_xmm, block_w);
mul_bcast_cols_init_short();
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
for (uint32_t w = 0; w < subblock_w; w++) {
index = w + index_subblock_w_offset;
mul_tiles_bcast_cols(cb_xmm, cb_ex_global, index, 0, w);
index = w + index_subblock_w_offset + index_h_offset;
mul_tiles_bcast_cols(cb_xmm, cb_ex_global, index, i, w);
}
tile_regs_commit();
cb_reserve_back(cb_im, subblock_w);
Expand All @@ -279,23 +290,27 @@ void MAIN {
cb_push_back(cb_im, subblock_w);
index_subblock_w_offset += subblock_w;
}
cb_pop_front(cb_ex_global, 1);
cb_pop_front(cb_xmm, block_w);
cb_wait_front(cb_im, block_w);

if constexpr(do_gamma) {
if constexpr(do_beta == 0) {
pack_reconfig_data_format(cb_out);
}
cb_wait_front(cb_im, block_w);
cb_wait_front(cb_gamma, block_w);
mul_bcast_rows_init_short();
index_h_offset += block_w;
}
cb_pop_front(cb_ex_global, block_h);
cb_pop_front(cb_xmm, num_tiles_per_block);
cb_wait_front(cb_im, num_tiles_per_block);

if constexpr(do_gamma) {
if constexpr(do_beta == 0) {
pack_reconfig_data_format(cb_out);
}
mul_bcast_rows_init_short();
cb_wait_front(cb_gamma, block_w);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
for (uint32_t w = 0; w < subblock_w; w++) {
index = w + index_subblock_w_offset;
mul_tiles_bcast_rows(cb_im, cb_gamma, index, index, w);
mul_tiles_bcast_rows(cb_im, cb_gamma, index+index_h_offset, index, w);
}
tile_regs_commit();
cb_reserve_back(cb_outgamma, subblock_w);
Expand All @@ -307,20 +322,24 @@ void MAIN {
cb_push_back(cb_outgamma, subblock_w);
index_subblock_w_offset += subblock_w;
}
cb_pop_front(cb_im, block_w);
index_h_offset += block_w;
}
cb_pop_front(cb_im, num_tiles_per_block);
cb_wait_front(cb_outgamma, num_tiles_per_block);
}

if constexpr(do_beta) {
pack_reconfig_data_format(cb_out);
cb_wait_front(cb_beta, block_w);
cb_wait_front(cb_fusion, block_w);
add_bcast_rows_init_short();
if constexpr(do_beta) {
pack_reconfig_data_format(cb_out);
add_bcast_rows_init_short();
cb_wait_front(cb_beta, block_w);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
tile_regs_acquire();
for (uint32_t w = 0; w < subblock_w; w++) {
index = w + index_subblock_w_offset;
add_tiles_bcast_rows(cb_fusion, cb_beta, index, index, w);
add_tiles_bcast_rows(cb_fusion, cb_beta, index + index_h_offset, index, w);
}
tile_regs_commit();
cb_reserve_back(cb_out, subblock_w);
Expand All @@ -332,10 +351,10 @@ void MAIN {
cb_push_back(cb_out, subblock_w);
index_subblock_w_offset += subblock_w;
}
cb_pop_front(cb_fusion, block_w);
cb_wait_front(cb_out, block_w+index_h_offset);
index_h_offset += block_w;
}
index_h_offset += block_w;
cb_pop_front(cb_fusion, num_tiles_per_block);
cb_wait_front(cb_out, num_tiles_per_block);
}

}
Expand Down
22 changes: 8 additions & 14 deletions tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,8 @@ operation::ProgramWithCallbacks layernorm_sharded_(
subblock_wt,
num_subblocks_w,
1,
num_rows_per_all_to_all_worker,
block_ht * block_wt
};
std::vector<uint32_t> all_to_all_except_top_compute_compile_time_args = {
0,
Expand All @@ -780,7 +782,9 @@ operation::ProgramWithCallbacks layernorm_sharded_(
block_wt,
subblock_wt,
num_subblocks_w,
1
1,
num_rows_per_all_to_all_worker,
block_ht * block_wt
};
std::vector<uint32_t> not_all_to_all_compute_compile_time_args = {
0,
Expand All @@ -791,7 +795,9 @@ operation::ProgramWithCallbacks layernorm_sharded_(
block_wt,
subblock_wt,
num_subblocks_w,
0
0,
num_rows_per_all_to_all_worker,
block_ht * block_wt
};
// compute kernel
bool fp32_dest_acc_en = false;
Expand Down Expand Up @@ -904,24 +910,12 @@ operation::ProgramWithCallbacks layernorm_sharded_(
tt_metal::CircularBufferConfig ex_global_cb_config = tt_metal::CircularBufferConfig(ex_global_CB_size, {{ex_global_cb_index, cb_data_format}})
.set_page_size(ex_global_cb_index, single_tile_size);
auto cb_ex_global = tt_metal::CreateCircularBuffer(program, all_cores, ex_global_cb_config);
// xmm2
uint32_t xmm2_cb_index;
xmm2_cb_index = CB::c_intermed2;
tt_metal::CircularBufferConfig xmm2_cb_config = tt_metal::CircularBufferConfig(xmm2_CB_size, {{xmm2_cb_index, cb_data_format}})
.set_page_size(xmm2_cb_index, single_tile_size);
auto cb_xmm2 = tt_metal::CreateCircularBuffer(program, all_cores, xmm2_cb_config);
// ex2pe
uint32_t cb_ex2pe_index;
cb_ex2pe_index = CB::c_intermed3;
tt_metal::CircularBufferConfig ex2pe_cb_config = tt_metal::CircularBufferConfig(ex2pe_CB_size, {{cb_ex2pe_index, cb_data_format}})
.set_page_size(cb_ex2pe_index, single_tile_size);
auto cb_ex2pe = tt_metal::CreateCircularBuffer(program, all_cores, ex2pe_cb_config);
// fusion
uint32_t cb_fusion_index;
cb_fusion_index = CB::c_intermed4;
tt_metal::CircularBufferConfig fusion_cb_config = tt_metal::CircularBufferConfig(fusion_CB_size, {{cb_fusion_index, cb_data_format}})
.set_page_size(cb_fusion_index, single_tile_size);
auto cb_fusion = tt_metal::CreateCircularBuffer(program, all_cores, fusion_cb_config);
// out
uint32_t output_cb_index = CB::c_out0; // output operands start at index 16
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, {{output_cb_index, out_data_format}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void MAIN {
unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn);

// fused attn
cb_wait_front(cb_scale_mask, block_w);
// cb_wait_front(cb_scale_mask, block_w);
cb_wait_front(cb_fused_attn, block_w);
index_subblock_w_offset = 0;
for (uint32_t j = 0; j < num_subblocks_w; j++) {
Expand Down Expand Up @@ -111,7 +111,7 @@ void MAIN {
// sum(exp(x))
ACQ();
reduce_init_delta<false>(REDUCE_OP, REDUCE_DIM);
cb_wait_front(cb_exps, block_w);
// cb_wait_front(cb_exps, block_w);
cb_wait_front(cb_bcast_scaler, 1);
cb_reserve_back(cb_recipsumexps, 1);
for (uint32_t w = 0; w < block_w; w++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void MAIN {

cb_push_back(cb_out1, 1);
cb_pop_front(cb_im0, 1);
cb_wait_front(cb_out1, 1);


}
Expand Down