From e85e46876d0818ab787b1290884be743fbf2366e Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 9 May 2024 16:28:03 +0000 Subject: [PATCH] #8189: optimize layernorm sharded and eltwise_binary_multicore override runtime args (cherry picked from commit cb5301802e71a03da61d4a3077bf56cf39973fbd) #8189: eltwise_binary set_runtime_args from 35us -> 8us, more room for improvement. (cherry picked from commit 86d926479b816182d2d7025b10e8e0d62468a95b) #8524: Fix eltwise_binary override runtime args for case where num of cores changes across cached hits. --- .../eltwise_binary_op_multi_core.cpp | 442 +++++++++++------- .../multi_core/layernorm_op_multi_core.cpp | 35 +- 2 files changed, 290 insertions(+), 187 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp index 8cccc026c78..b79fa274583 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp @@ -17,6 +17,233 @@ using namespace tt::constants; namespace tt { namespace tt_metal { +template +inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( + Program& program, + const Tensor& a, + const Tensor& b, + const Tensor& output, + const KernelHandle binary_reader_kernel_id, + const KernelHandle unary_writer_kernel_id, + const KernelHandle eltwise_binary_kernel_id, + const CBHandle cb_src0, + const CBHandle cb_src1, + const CBHandle cb_output, + const CoreCoord compute_with_storage_grid_size, + const uint32_t src0_single_tile_size, + const uint32_t src1_single_tile_size, + const uint32_t dst_single_tile_size){ + + auto src_buffer_a = a.buffer(); + auto src_buffer_b = b.buffer(); + auto dst_buffer = output.buffer(); + + CoreRangeSet all_cores({}), core_group_1({}), core_group_2({}); + + std::optional shard_spec = std::nullopt; + bool src0_sharded = a.memory_config().is_sharded(); + bool src1_sharded = b.memory_config().is_sharded(); + bool out_sharded = output.memory_config().is_sharded(); + + bool block_sharded = false; + if (src0_sharded) { + shard_spec = a.shard_spec().value(); + block_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (src1_sharded) { + shard_spec = b.shard_spec().value(); + block_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (out_sharded) { + shard_spec = output.shard_spec().value(); + block_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } + + uint32_t num_tiles = a.volume() / TILE_HW; + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2; + uint32_t num_cores_total = num_cores_x * num_cores_y; + + uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; + + uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; + + bool row_major; + uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, last_unpadded_block_width = 0; + CoreCoord end_core; + vector cores; + + if (shard_spec.has_value()) { + all_cores = shard_spec.value().grid; + num_cores = all_cores.num_cores(); + core_group_1 = all_cores; + core_group_2 = CoreRangeSet({}); + num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_core_group_2 = 0; + block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); + max_block_size = block_size_per_core_group_1; + + block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; + row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; + if (block_sharded) { + block_height = shard_spec.value().shape[0] / TILE_HEIGHT; + block_width = shard_spec.value().shape[1] / TILE_WIDTH; + block_size = block_width * block_height; + end_core = (*shard_spec.value().grid.ranges().begin()).end; + output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; + last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); + last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); + } + auto bbox = core_group_1.bounding_box(); + cores = grid_to_cores_with_noop(bbox.end.x, bbox.end.y, num_cores_x, num_cores_y, row_major); + } else { + row_major = true; + std::tie(num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); + block_cnt_per_core_group_1 = num_tiles_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2; + cores = grid_to_cores(num_cores_x * num_cores_y, num_cores_x, num_cores_y, row_major); + } + + uint32_t g1_numcores = core_group_1.num_cores(); + uint32_t g2_numcores = core_group_2.num_cores(); + + + std::vector< std::vector > binary_reader_args; + std::vector< std::vector > eltwise_binary_args; + std::vector< std::vector > unary_writer_args; + if constexpr(initialize_args) { + binary_reader_args = { cores.size(), std::vector(4) }; + eltwise_binary_args = { cores.size(), std::vector(2) }; + if (block_sharded and not out_sharded) + unary_writer_args = { cores.size(), std::vector(7) }; + else + unary_writer_args = { cores.size(), std::vector(3) }; + } + + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i){ + const CoreCoord &core = cores.at(i); + uint32_t num_tiles_per_core = 0; + uint32_t block_cnt_per_core = 0; + uint32_t block_size_per_core = 0; + if (i < g1_numcores) { + num_tiles_per_core = num_tiles_per_core_group_1; + block_cnt_per_core = block_cnt_per_core_group_1; + block_size_per_core = block_size_per_core_group_1; + } else if (i < num_cores) { + num_tiles_per_core = num_tiles_per_core_group_2; + block_cnt_per_core = block_cnt_per_core_group_2; + block_size_per_core = block_size_per_core_group_2; + } else { + // Zero out non-working cores RT args. Only necessary in override + // since initialization pushes zero vectors to unused cores. + if constexpr (!initialize_args) { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[2] = 0; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = 0; + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[1] = 0; + } + continue; + } + if constexpr(initialize_args) { + binary_reader_args[i] = {src_buffer_a->address(), src_buffer_b->address(), num_tiles_per_core, num_tiles_read}; + eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; + } else { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[0] = src_buffer_a->address(); + reader_args[1] = src_buffer_b->address(); + reader_args[2] = num_tiles_per_core; + reader_args[3] = num_tiles_read; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = block_cnt_per_core; + eltwise_args[1] = block_size_per_core; + } + if (block_sharded and not out_sharded) { + uint32_t block_start_width_offset; + uint32_t block_start_height_offset; + uint32_t unpadded_block_height = block_height; + uint32_t unpadded_block_width = block_width; + if (row_major) { + block_start_width_offset = core.x * block_width; + block_start_height_offset = core.y * block_height; + if (core.x == end_core.x) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.y == end_core.y) { + unpadded_block_height = last_unpadded_block_height; + } + } else { + block_start_width_offset = core.y * block_width; + block_start_height_offset = core.x * block_height; + if (core.y == end_core.y) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.x == end_core.x) { + unpadded_block_height = last_unpadded_block_height; + } + } + if constexpr(initialize_args) { + unary_writer_args[i] = { dst_buffer->address(), + block_height, + block_width, + unpadded_block_height, + unpadded_block_width, + output_width, + block_size, + block_start_height_offset * output_width + block_start_width_offset, + 0 }; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = block_height; + writer_args[2] = block_width; + writer_args[3] = unpadded_block_height; + writer_args[4] = unpadded_block_width; + writer_args[5] = output_width; + writer_args[6] = block_size; + writer_args[7] = block_start_height_offset * output_width + block_start_width_offset; + writer_args[8] = 0; + } + } else { + if constexpr(initialize_args) { + unary_writer_args[i] = { dst_buffer->address(), num_tiles_per_core, num_tiles_read }; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = num_tiles_per_core; + writer_args[2] = num_tiles_read; + } + } + num_tiles_read += num_tiles_per_core; + } + + if constexpr(initialize_args) { + SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); + SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); + SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); + } + + if (src0_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer_a); + UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * src0_single_tile_size); + } + if (src1_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src1, *src_buffer_b); + UpdateCircularBufferTotalSize(program, cb_src1, num_tiles_per_core_group_1 * src1_single_tile_size); + } + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); + } + +} operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const Tensor &b, const Tensor& output, BinaryOpType op_type, const std::optional> fused_activations) { @@ -151,184 +378,35 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const tt_metal::ComputeConfig{.defines = eltwise_defines} ); - auto set_runtime_args = [ - binary_reader_kernel_id, - unary_writer_kernel_id, - eltwise_binary_kernel_id, - cb_src0, - cb_src1, - cb_output, - compute_with_storage_grid_size, - src0_single_tile_size, - src1_single_tile_size, - dst_single_tile_size - ] - ( - Program& program, - const Tensor& a, - const Tensor& b, - const Tensor& output - ) { - auto src_buffer_a = a.buffer(); - auto src_buffer_b = b.buffer(); - auto dst_buffer = output.buffer(); - - CoreRangeSet all_cores({}), core_group_1({}), core_group_2({}); - - std::optional shard_spec = std::nullopt; - bool src0_sharded = a.memory_config().is_sharded(); - bool src1_sharded = b.memory_config().is_sharded(); - bool out_sharded = output.memory_config().is_sharded(); - - bool block_sharded = false; - if (src0_sharded) { - shard_spec = a.shard_spec().value(); - block_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - } else if (src1_sharded) { - shard_spec = b.shard_spec().value(); - block_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - } else if (out_sharded) { - shard_spec = output.shard_spec().value(); - block_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - } - - uint32_t num_tiles = a.volume() / TILE_HW; - - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2; - - uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; - - uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; - - bool row_major; - uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, last_unpadded_block_width = 0; - CoreCoord end_core; - vector cores; - - if (shard_spec.has_value()) { - all_cores = shard_spec.value().grid; - num_cores = all_cores.num_cores(); - core_group_1 = all_cores; - core_group_2 = CoreRangeSet({}); - num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; - num_tiles_per_core_group_2 = 0; - block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); - max_block_size = block_size_per_core_group_1; - - block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; - row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; - if (block_sharded) { - block_height = shard_spec.value().shape[0] / TILE_HEIGHT; - block_width = shard_spec.value().shape[1] / TILE_WIDTH; - block_size = block_width * block_height; - end_core = (*shard_spec.value().grid.ranges().begin()).end; - output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; - last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); - last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); - } - auto bbox = core_group_1.bounding_box(); - cores = grid_to_cores_with_noop(bbox.end.x, bbox.end.y, num_cores_x, num_cores_y, row_major); - } else { - row_major = true; - std::tie(num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); - block_cnt_per_core_group_1 = num_tiles_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2; - cores = grid_to_cores(num_cores_x * num_cores_y, num_cores_x, num_cores_y, row_major); - } - - uint32_t g1_numcores = core_group_1.num_cores(); - uint32_t g2_numcores = core_group_2.num_cores(); - - std::vector< std::vector > binary_reader_args = { cores.size(), std::vector(4) }; - std::vector< std::vector > eltwise_binary_args = { cores.size(), std::vector(2) }; - std::vector< std::vector > unary_writer_args; - if (block_sharded and not out_sharded) - unary_writer_args = { cores.size(), std::vector(7) }; - else - unary_writer_args = { cores.size(), std::vector(3) }; - - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores; ++i){ - const CoreCoord &core = cores.at(i); - uint32_t num_tiles_per_core = 0; - uint32_t block_cnt_per_core = 0; - uint32_t block_size_per_core = 0; - if (i < g1_numcores) { - num_tiles_per_core = num_tiles_per_core_group_1; - block_cnt_per_core = block_cnt_per_core_group_1; - block_size_per_core = block_size_per_core_group_1; - } else { - num_tiles_per_core = num_tiles_per_core_group_2; - block_cnt_per_core = block_cnt_per_core_group_2; - block_size_per_core = block_size_per_core_group_2; - } - binary_reader_args[i] = {src_buffer_a->address(), src_buffer_b->address(), num_tiles_per_core, num_tiles_read}; - eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; - if (block_sharded and not out_sharded) { - uint32_t block_start_width_offset; - uint32_t block_start_height_offset; - uint32_t unpadded_block_height = block_height; - uint32_t unpadded_block_width = block_width; - if (row_major) { - block_start_width_offset = core.x * block_width; - block_start_height_offset = core.y * block_height; - if (core.x == end_core.x) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.y == end_core.y) { - unpadded_block_height = last_unpadded_block_height; - } - } else { - block_start_width_offset = core.y * block_width; - block_start_height_offset = core.x * block_height; - if (core.y == end_core.y) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.x == end_core.x) { - unpadded_block_height = last_unpadded_block_height; - } - } - unary_writer_args[i] = { dst_buffer->address(), - block_height, - block_width, - unpadded_block_height, - unpadded_block_width, - output_width, - block_size, - block_start_height_offset * output_width + block_start_width_offset, - 0 }; - } else { - unary_writer_args[i] = { dst_buffer->address(), num_tiles_per_core, num_tiles_read }; - } - num_tiles_read += num_tiles_per_core; - } - - SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); - SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); - SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); - - if (src0_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer_a); - UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * src0_single_tile_size); - } - if (src1_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_src1, *src_buffer_b); - UpdateCircularBufferTotalSize(program, cb_src1, num_tiles_per_core_group_1 * src1_single_tile_size); - } - if (out_sharded) { - UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); - UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); - } - }; + set_eltwise_binary_runtime_args( + program, + a, + b, + output, + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size); - set_runtime_args(program, a, b, output); auto override_runtime_arguments_callback = [ - set_runtime_args + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size ] ( const void* operation, @@ -342,7 +420,21 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const auto src_buffer_b = input_tensors.at(1).buffer(); const auto& output_tensor = output_tensors.size() == 1 ? output_tensors.at(0) : input_tensors.at(0); - set_runtime_args(program, input_tensors.at(0), input_tensors.at(1), output_tensor); + set_eltwise_binary_runtime_args( + program, + input_tensors.at(0), + input_tensors.at(1), + output_tensor, + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size); }; return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; } diff --git a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp index ddccda7eb68..669865bf413 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp @@ -1168,6 +1168,8 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( auto override_runtime_args_callback = [ writer_kernel_ids, + writer_mcast_sender_kernels_id, + writer_mcast_receiver_kernels_id, cb_in0, cb_in1, cb_output, @@ -1180,11 +1182,11 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( const std::vector>& optional_input_tensors, const std::vector& output_tensors ) { - auto src_buffer_a = input_tensors.at(0).buffer(); - auto b_tensor = optional_input_tensors.at(0); - auto gamma_tensor = optional_input_tensors.at(1); - auto beta_tensor = optional_input_tensors.at(2); - auto dst_buffer = output_tensors.at(0).buffer(); + const auto src_buffer_a = input_tensors.at(0).buffer(); + const auto b_tensor = optional_input_tensors.at(0); + const auto gamma_tensor = optional_input_tensors.at(1); + const auto beta_tensor = optional_input_tensors.at(2); + const auto dst_buffer = output_tensors.at(0).buffer(); UpdateDynamicCircularBufferAddress(program, cb_in0, *src_buffer_a); @@ -1194,18 +1196,27 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + auto& writer_sender_args_by_core = GetRuntimeArgs(program, writer_mcast_sender_kernels_id); + auto& writer_receiver_args_by_core = GetRuntimeArgs(program, writer_mcast_receiver_kernels_id); + + const auto gamma_address = gamma_tensor.has_value() ? gamma_tensor.value().buffer()->address() : 0; + const auto beta_address = beta_tensor.has_value() ? beta_tensor.value().buffer()->address() : 0; + + for (uint32_t i = 0; i < cores.size(); ++i) { const CoreCoord& core = cores[i]; - auto writer_kernel_id = writer_kernel_ids.at(i); + const auto writer_kernel_id = writer_kernel_ids.at(i); - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + if (writer_kernel_id == writer_mcast_sender_kernels_id) { + auto& runtime_args = writer_sender_args_by_core[core.x][core.y]; + runtime_args[3] = gamma_address; + runtime_args[4] = beta_address; - if (gamma_tensor.has_value()) { - runtime_args[3] = gamma_tensor.value().buffer()->address(); - } - if (beta_tensor.has_value()) { - runtime_args[4] = beta_tensor.value().buffer()->address(); + } else if (writer_kernel_id == writer_mcast_receiver_kernels_id) { + auto& runtime_args = writer_receiver_args_by_core[core.x][core.y]; + runtime_args[3] = gamma_address; + runtime_args[4] = beta_address; } } };