From 8d994716bafb8457e0e7d92b376fe427606a65ec Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Tue, 14 May 2024 21:11:32 +0000 Subject: [PATCH] #7530: Uplift unique rt args to update directly into the cmd queue for fast dispatch. TODO: Fix/uplift common runtime args update to remove looping in assemble_runtime_args --- .../multi_core/all_gather_op_multi_core.cpp | 22 ++++++------ .../multi_core/concat_op_multi_core.cpp | 2 +- .../single_core/concat_op_single_core.cpp | 2 +- .../op_library/moreh_adam/moreh_adam.cpp | 6 ++-- .../op_library/moreh_adamw/moreh_adamw.cpp | 6 ++-- .../moreh_arange/moreh_arange_op.cpp | 3 +- .../moreh_clip_grad_norm_step1.cpp | 9 ++--- .../moreh_clip_grad_norm_step2.cpp | 9 ++--- .../moreh_clip_grad_norm_step3.cpp | 6 ++-- .../moreh_cumsum_nc/moreh_cumsum_nc.cpp | 6 ++-- .../single_core/moreh_dot_op_single_core.cpp | 9 ++--- .../moreh_dot_backward_op_single_core.cpp | 9 ++--- .../moreh_getitem_rm/moreh_getitem_rm.cpp | 6 ++-- .../moreh_getitem_tilized.cpp | 12 +++---- .../moreh_groupnorm/moreh_groupnorm.cpp | 6 ++-- ...reh_groupnorm_backward_gamma_beta_grad.cpp | 6 ++-- .../moreh_groupnorm_backward_input_grad.cpp | 6 ++-- .../moreh_layernorm/moreh_layernorm_op.cpp | 6 ++-- ...reh_layernorm_backward_gamma_beta_grad.cpp | 6 ++-- .../moreh_layernorm_backward_input_grad.cpp | 6 ++-- .../moreh_mean_nc/moreh_mean_nc.cpp | 6 ++-- .../moreh_mean_backward.cpp | 6 ++-- .../moreh_nll_loss_step1.cpp | 6 ++-- .../moreh_nll_loss_step2.cpp | 6 ++-- .../moreh_norm/moreh_norm_h/moreh_norm_h.cpp | 9 ++--- .../moreh_norm_other/moreh_norm_other.cpp | 9 ++--- .../moreh_norm/moreh_norm_w/moreh_norm_w.cpp | 9 ++--- .../moreh_norm_backward.cpp | 9 ++--- .../tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp | 2 +- .../softmax_c_large/softmax_c_large.cpp | 6 ++-- .../softmax_h_large/softmax_h_large.cpp | 6 ++-- .../softmax_h_small/softmax_h_small.cpp | 6 ++-- .../softmax_w_large/softmax_w_large.cpp | 6 ++-- .../softmax_w_small/softmax_w_small.cpp | 6 ++-- .../softmax_backward_c_large.cpp | 6 ++-- .../softmax_backward_h_large.cpp | 6 ++-- .../softmax_backward_h_small.cpp | 6 ++-- .../softmax_backward_w_large.cpp | 6 ++-- .../softmax_backward_w_small.cpp | 6 ++-- .../non_zero_indices_op_single_core.cpp | 3 +- .../op_library/prod/prod_nc/prod_nc.cpp | 6 ++-- .../multi_core/sharded_op_multi_core.cpp | 6 ++-- .../unpad/multi_core/unpad_op_multi_core.cpp | 16 ++++----- .../single_core/unpad_op_single_core.cpp | 12 +++---- tt_metal/host_api.hpp | 10 +++--- tt_metal/impl/dispatch/command_queue.cpp | 34 ++++++++---------- tt_metal/impl/kernels/kernel.cpp | 27 ++++++++++++-- tt_metal/impl/kernels/kernel.hpp | 36 +++++++++++++++++++ tt_metal/tt_metal.cpp | 8 ++--- 49 files changed, 197 insertions(+), 226 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index d80f18bee02..9690f1f52f4 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -1157,42 +1157,42 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& const std::vector>& optional_input_tensors, const std::vector& output_tensors ) { - bool is_sharded = input_tensors.at(0).is_sharded(); - const auto& input = input_tensors.at(0); - const auto& output = output_tensors.at(0); + bool is_sharded = input_tensors[0].is_sharded(); + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; for (uint32_t i = 0; i < total_worker_core_pairs_used; ++i) { if (is_sharded) { auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_reader_sender_runtime_args.at(7) = input.buffer()->address(); + worker_reader_sender_runtime_args[7] = input.buffer()->address(); uint32_t num_dest_cores = worker_reader_sender_runtime_args.at(12); - worker_reader_sender_runtime_args.at(12 + num_dest_cores + 4) = output.buffer()->address(); + worker_reader_sender_runtime_args[12 + num_dest_cores + 4] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_reader_sender_runtime_args:"); for (uint32_t j = 0; j < worker_reader_sender_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_reader_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_writer_sender_runtime_args.at(12) = output.buffer()->address(); + worker_writer_sender_runtime_args[12] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_writer_sender_runtime_args:"); for (uint32_t j = 0; j < worker_writer_sender_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_writer_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i)); - worker_writer_receiver_runtime_args.at(10) = output.buffer()->address(); + worker_writer_receiver_runtime_args[10] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_writer_receiver_runtime_args:"); for (uint32_t j = 0; j < worker_writer_receiver_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_writer_receiver_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } } else { auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_reader_sender_runtime_args.at(0) = input.buffer()->address(); - worker_reader_sender_runtime_args.at(1) = output.buffer()->address(); + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + worker_reader_sender_runtime_args[1] = output.buffer()->address(); auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_writer_sender_runtime_args.at(0) = output.buffer()->address(); + worker_writer_sender_runtime_args[0] = output.buffer()->address(); auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i)); - worker_writer_receiver_runtime_args.at(0) = output.buffer()->address(); + worker_writer_receiver_runtime_args[0] = output.buffer()->address(); } } }; diff --git a/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp index 25b50079de8..9f0adca240c 100644 --- a/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp @@ -729,7 +729,7 @@ operation::ProgramWithCallbacks concat_multi_core( for (const auto &core : cores) { { auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3); + std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3); } { diff --git a/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp b/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp index 08fee4b2da8..ca19f74f8f4 100644 --- a/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp @@ -179,7 +179,7 @@ operation::ProgramWithCallbacks concat_single_core( { auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3); + std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3); } { diff --git a/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp b/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp index 294a1a6ff78..2feeb064f42 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp @@ -213,7 +213,7 @@ operation::ProgramWithCallbacks moreh_adam_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -221,11 +221,10 @@ operation::ProgramWithCallbacks moreh_adam_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -233,7 +232,6 @@ operation::ProgramWithCallbacks moreh_adam_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp index 98b985efa16..c17cc662a59 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp @@ -213,7 +213,7 @@ operation::ProgramWithCallbacks moreh_adamw_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -221,11 +221,10 @@ operation::ProgramWithCallbacks moreh_adamw_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -233,7 +232,6 @@ operation::ProgramWithCallbacks moreh_adamw_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp b/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp index 2ec1e46c6b5..bc720e4918c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp @@ -114,9 +114,8 @@ operation::ProgramWithCallbacks moreh_arange_( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp index 2bc2e23ff67..6b2a5608c7d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp @@ -196,23 +196,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_tensors.at(i).buffer()->address(); runtime_args[3] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_address; - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, core); runtime_args[1] = p; runtime_args[2] = static_cast(p_is_negative); - SetRuntimeArgs(program, compute_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp index 841df274fc9..1eb2da1cb5f 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp @@ -133,23 +133,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl( const auto output_address = input_tensors.at(1).buffer()->address(); { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core); runtime_args[0] = input_address; runtime_args[3] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, single_core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core); runtime_args[0] = output_address; - SetRuntimeArgs(program, writer_kernels_id, single_core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core); runtime_args[1] = p; runtime_args[2] = static_cast(p_is_negative); - SetRuntimeArgs(program, compute_kernels_id, single_core, runtime_args); } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp index 2a735f8c277..e2298da64a5 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp @@ -140,16 +140,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffers.at(i)->address(); runtime_args[2] = clip_coef_clamped_address; - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_buffers.at(i)->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp b/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp index c45af58e5c2..111e53bd6dc 100644 --- a/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp @@ -183,15 +183,13 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp b/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp index 6765aab01a6..366bec441cd 100644 --- a/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp @@ -118,24 +118,21 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer_a->address(); runtime_args[1] = src_buffer_b->address(); runtime_args[2] = num_tiles; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[0] = num_tiles; - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); runtime_args[1] = 1; - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } }; return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; diff --git a/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp b/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp index f0748e9ac8c..9e316fe6933 100644 --- a/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp @@ -179,29 +179,26 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core( CoreCoord core = {0, 0}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; runtime_args[2] = src0_buffer->address(); runtime_args[3] = src1_buffer->address(); runtime_args[4] = src2_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; runtime_args[2] = dst0_address; runtime_args[3] = dst1_address; - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } }; return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp index 56d797cf685..9e4805f7e02 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp @@ -230,19 +230,17 @@ operation::ProgramWithCallbacks moreh_getitem_rm( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp index 9c1c6cb0f8c..f17c1da531d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp @@ -293,19 +293,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; @@ -545,19 +543,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp index e0f656e357b..be86f98b6dc 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp @@ -336,7 +336,7 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); if (gamma_buffer != nullptr) { runtime_args[2] = gamma_buffer->address(); @@ -344,11 +344,10 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( if (beta_buffer != nullptr) { runtime_args[5] = beta_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = ouput_buffer->address(); if (mean_buffer != nullptr) { runtime_args[2] = mean_buffer->address(); @@ -356,7 +355,6 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( if (rstd_buffer != nullptr) { runtime_args[5] = rstd_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp index b526b63aef8..44cf94e8d90 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp @@ -280,23 +280,21 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_gamma_beta_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[2] = input_buffer->address(); runtime_args[4] = mean_buffer->address(); runtime_args[6] = rstd_buffer->address(); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); if (gamma_grad_buffer != nullptr) { runtime_args[0] = gamma_grad_buffer->address(); } if (beta_grad_buffer != nullptr) { runtime_args[3] = beta_grad_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp index 073a3157970..b990eedff0e 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp @@ -292,7 +292,7 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_input_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[2] = input_buffer->address(); runtime_args[4] = mean_buffer->address(); @@ -300,13 +300,11 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_input_grad_impl( if (gamma_buffer != nullptr) { runtime_args[8] = gamma_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp index f379a4f462b..0645bd41cf2 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp @@ -386,7 +386,7 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); if (gamma_buffer != nullptr) { runtime_args[6] = gamma_buffer->address(); @@ -394,11 +394,10 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( if (beta_buffer != nullptr) { runtime_args[7] = beta_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = ouput_buffer->address(); if (mean_buffer != nullptr) { runtime_args[1] = mean_buffer->address(); @@ -406,7 +405,6 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( if (rstd_buffer != nullptr) { runtime_args[2] = rstd_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp index dae24fe6fcb..d25ec17a1ad 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp @@ -278,23 +278,21 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_gamma_beta_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[1] = input_buffer->address(); runtime_args[2] = mean_buffer->address(); runtime_args[3] = rstd_buffer->address(); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); if (gamma_grad_buffer != nullptr) { runtime_args[0] = gamma_grad_buffer->address(); } if (beta_grad_buffer != nullptr) { runtime_args[1] = beta_grad_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp index 7d4575802a2..ea191cb71b4 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp @@ -324,7 +324,7 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_input_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[1] = input_buffer->address(); runtime_args[2] = mean_buffer->address(); @@ -332,13 +332,11 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_input_grad_impl( if (gamma_buffer != nullptr) { runtime_args[4] = gamma_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp b/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp index 51e20000bf1..9114b8a58ff 100644 --- a/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp @@ -175,15 +175,13 @@ operation::ProgramWithCallbacks moreh_mean_nc(const Tensor &input, const Tensor for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp index f49a2b862a6..bb2e2b6a45d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp @@ -200,15 +200,13 @@ operation::ProgramWithCallbacks moreh_mean_backward_program(const Tensor &output for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_grad_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp index 32e9e223a6a..d0d0ef84e48 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp @@ -168,15 +168,13 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl(const Tensor &input, c CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp index 5e21e0595c6..d687a6d5f91 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp @@ -176,15 +176,13 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl(const Tensor &input, c CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp index 6b93956f7de..51a17f0cef1 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp @@ -221,17 +221,15 @@ operation::ProgramWithCallbacks moreh_norm_h_impl(const Tensor &input, float p, CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -243,12 +241,11 @@ operation::ProgramWithCallbacks moreh_norm_h_impl(const Tensor &input, float p, } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[3] = floored_p; runtime_args[4] = static_cast(p_is_negative); runtime_args[5] = floored_recip_p; runtime_args[6] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp index 73704e360b6..5d9b734fa80 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp @@ -225,17 +225,15 @@ operation::ProgramWithCallbacks moreh_norm_other_impl(const Tensor &input, float CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -247,12 +245,11 @@ operation::ProgramWithCallbacks moreh_norm_other_impl(const Tensor &input, float } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[2] = floored_p; runtime_args[3] = static_cast(p_is_negative); runtime_args[4] = floored_recip_p; runtime_args[5] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp index 894e11f95a9..c6c3e17ba98 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp @@ -220,17 +220,15 @@ operation::ProgramWithCallbacks moreh_norm_w_impl(const Tensor &input, float p, CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -242,12 +240,11 @@ operation::ProgramWithCallbacks moreh_norm_w_impl(const Tensor &input, float p, } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[3] = floored_p; runtime_args[4] = static_cast(p_is_negative); runtime_args[5] = floored_recip_p; runtime_args[6] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp index 8fcabd0748f..33b9ebd164f 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp @@ -274,18 +274,16 @@ operation::ProgramWithCallbacks moreh_norm_backward_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = output_buffer->address(); runtime_args[4] = output_grad_buffer->address(); runtime_args[6] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -297,12 +295,11 @@ operation::ProgramWithCallbacks moreh_norm_backward_( } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[5] = floored_p; runtime_args[6] = static_cast(p_is_negative); runtime_args[7] = floored_p_minus_one; runtime_args[8] = static_cast(p_minus_one_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp b/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp index e6025ce4452..7567cc94100 100644 --- a/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp @@ -219,7 +219,7 @@ operation::ProgramWithCallbacks moreh_sgd_( } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_ids, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_ids, core); runtime_args[0] = param_out_address; if (has_momentum_buffer) { diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp index f3921b28a30..79e0b62a5e0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp @@ -150,15 +150,13 @@ operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp index 3f878ba0a1f..abcb6b19409 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp @@ -144,15 +144,13 @@ operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp index 4ab132eca6a..77523098d4c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp @@ -166,15 +166,13 @@ operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp index cede75c98a8..f1ae31c7dd6 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp @@ -145,15 +145,13 @@ operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp index c37530a6cdf..bf90b8d47b0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp @@ -166,15 +166,13 @@ operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp index 90682747295..5752781a893 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -157,16 +157,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp index 423f7709696..859867d17f0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -152,16 +152,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp index bdb1a941271..44df2758698 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -174,16 +174,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp index 4f316e4d921..78ce4ceecfa 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -152,16 +152,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp index cf29f8ce9d9..a834f5e4acd 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -175,16 +175,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp b/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp index 0f4451f54ed..65d72e96387 100644 --- a/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp @@ -106,14 +106,13 @@ operation::ProgramWithCallbacks non_zero_indices_single_core(const Tensor &input uint32_t alignment_base = 32/input.element_size(); uint32_t aligned_elements = div_up(input.get_legacy_shape()[-1] , alignment_base) * alignment_base; uint32_t actual_elements = input.get_legacy_shape()[-1]; - auto runtime_args = tt_metal::GetRuntimeArgs(program, kernel_id, core); + auto& runtime_args = tt_metal::GetRuntimeArgs(program, kernel_id, core); runtime_args[0] = input.buffer()->address(); runtime_args[1] = output_0.buffer()->address(); runtime_args[2] = output_1.buffer()->address(); runtime_args[3] = aligned_elements; runtime_args[4] = actual_elements; runtime_args[5] = input.element_size(); - tt_metal::SetRuntimeArgs(program, kernel_id, core, runtime_args); }; return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; } diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp index 78781395077..88d499a8460 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp @@ -180,15 +180,13 @@ operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index 7e6267c814d..cbe78eafd92 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -877,12 +877,10 @@ operation::ProgramWithCallbacks reshard_multi_core(const Tensor& input, Tensor& auto all_cores = output_shard_spec.grid; auto cores = corerange_to_cores(all_cores, std::nullopt, output_shard_spec.orientation == ShardOrientation::ROW_MAJOR); for (auto core: cores) { - auto runtime_args_0 = GetRuntimeArgs(program, kernel_id_0, core); - auto runtime_args_1 = GetRuntimeArgs(program, kernel_id_1, core); + auto &runtime_args_0 = GetRuntimeArgs(program, kernel_id_0, core); + auto &runtime_args_1 = GetRuntimeArgs(program, kernel_id_1, core); runtime_args_0[grid.x + grid.y] = input.buffer()->address(); runtime_args_1[grid.x + grid.y] = input.buffer()->address(); - tt_metal::SetRuntimeArgs(program, kernel_id_0, core, runtime_args_0); - tt_metal::SetRuntimeArgs(program, kernel_id_1, core, runtime_args_1); } UpdateDynamicCircularBufferAddress(program, cb_dst0, *output.buffer()); }; diff --git a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp index c9724c9950c..1387791d805 100644 --- a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp @@ -268,9 +268,9 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; const auto set_common_reader_args = [&]( - std::vector & reader_common_args, - uint32_t * num_unpadded_tiles_per_dim, - uint32_t * num_padded_tiles_per_dim) __attribute__((always_inline)) { + uint32_t* reader_common_args, + uint32_t* num_unpadded_tiles_per_dim, + uint32_t* num_padded_tiles_per_dim) __attribute__((always_inline)) { reader_common_args[0] = input_buffer->address(); num_unpadded_tiles_per_dim[0] = num_unpadded_Xt; num_unpadded_tiles_per_dim[1] = num_unpadded_Yt; @@ -289,7 +289,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( }; const auto set_reader_rt_args = [&]( - std::vector & reader_rt_args, + uint32_t* reader_rt_args, const uint32_t* num_unpadded_tiles_per_dim, const uint32_t* num_padded_tiles_per_dim, const uint32_t& num_tiles_per_core, @@ -311,14 +311,14 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( std::vector reader_common_args(1 + num_dims * 2); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); @@ -352,7 +352,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( if constexpr (initialize_args) { std::vector reader_kernel_args(2 + num_dims); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_tiles_per_core, @@ -362,7 +362,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( } else { auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_tiles_per_core, diff --git a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp index 8cceba36fa7..c4a32a76518 100644 --- a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp @@ -185,7 +185,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; const auto set_common_reader_args = [&]( - std::vector & reader_common_args, + uint32_t * reader_common_args, uint32_t * num_unpadded_tiles_per_dim, uint32_t * num_padded_tiles_per_dim) __attribute__((always_inline)) { reader_common_args[0] = input_buffer->address(); @@ -206,7 +206,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( }; const auto set_reader_rt_args = [&]( - std::vector & reader_rt_args, + uint32_t * reader_rt_args, const uint32_t* num_unpadded_tiles_per_dim, const uint32_t* num_padded_tiles_per_dim, const uint32_t& num_tiles_per_core, @@ -228,14 +228,14 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( std::vector reader_common_args(1 + num_dims * 2); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); @@ -243,7 +243,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( if constexpr (initialize_args) { std::vector reader_kernel_args(2 + num_dims); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_unpadded_tiles, @@ -253,7 +253,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( } else { auto& reader_kernel_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_unpadded_tiles, diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index 0ba0cb9d70f..def697295cb 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -334,7 +334,7 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const /** * Get the runtime args for a kernel. * - * Return value: std::vector & + * Return value: uint32_t * * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| @@ -342,19 +342,19 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | * | logical_core | The location of the Tensix core where the runtime args will be written | const CoreCoord & | Any logical Tensix core coordinate | Yes | */ -std::vector& GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core); +RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core); /** * Get the runtime args for a kernel. * - * Return value: std::vector< std::vector< std::vector> > & + * Return value: std::vector< std::vector< RuntimeArgsData > > & * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | */ -std::vector< std::vector< std::vector> > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id); +std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id); /** * Get the common runtime args for a kernel. @@ -366,7 +366,7 @@ std::vector< std::vector< std::vector> > & GetRuntimeArgs(const Progra * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | */ -std::vector& GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); +std::vector & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); /** * Update specific entries of the runtime args vector for a kernel using the command queue. This API must be used when Asynchronous Command Queue Mode is enabled. diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 913eca885fe..dc074b035ec 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -312,6 +312,7 @@ void generate_dispatch_write_packed( const std::vector& sub_cmds, const std::vector>& rt_data_and_sizes, const uint32_t& max_runtime_args_len, + std::vector>& rt_args_data, const uint32_t max_prefetch_command_size, const uint32_t id) { static_assert( @@ -352,8 +353,8 @@ void generate_dispatch_write_packed( uint32_t num_packed_cmds = std::min(num_packed_cmds_in_seq, max_packed_cmds); uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast); uint32_t cmd_sequence_sizeB = align(sizeof(CQPrefetchCmd) + rt_payload_sizeB, PCIE_ALIGNMENT); - HostMemDeviceCommand command_sequence(cmd_sequence_sizeB); - command_sequence.add_dispatch_write_packed( + runtime_args_command_sequences.emplace_back(cmd_sequence_sizeB); + runtime_args_command_sequences.back().add_dispatch_write_packed( num_packed_cmds, l1_arg_base_addr, max_runtime_args_len * sizeof(uint32_t), @@ -361,12 +362,12 @@ void generate_dispatch_write_packed( sub_cmds, rt_data_and_sizes, offset_idx); - runtime_args_command_sequences.emplace_back(command_sequence); uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast); const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) { cmd_mapping[(uint64_t)id << 32 | sub_cmds[i].noc_xy_addr] = { runtime_args_command_sequences.size() - 1, data_offset}; + rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset); data_offset += data_inc; } num_packed_cmds_in_seq -= num_packed_cmds; @@ -419,6 +420,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { uint32_t num_dsts = unique_processor_to_l1_arg_base_addr.size(); std::vector> unique_sub_cmds(num_dsts); std::vector>> unique_rt_data_and_sizes(num_dsts); + std::vector>> unique_rt_args_data(num_dsts); std::vector unique_max_runtime_args_len(num_dsts, 0); std::vector>> common_sub_cmds(program.num_kernels()); std::vector>> common_rt_data_and_sizes(program.num_kernels()); + std::vector>> common_rt_args_data(program.num_kernels()); std::vector common_max_runtime_args_len(program.num_kernels(), 0); std::vector common_processor_to_l1_arg_base_addr(program.num_kernels()); @@ -445,12 +448,14 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { unique_processors.insert(processor_idx); unique_sub_cmds[processor_idx].reserve(kernel->cores_with_runtime_args().size()); unique_rt_data_and_sizes[processor_idx].reserve(kernel->cores_with_runtime_args().size()); + unique_rt_args_data[processor_idx].reserve(kernel->cores_with_runtime_args().size()); for (const auto& core_coord : kernel->cores_with_runtime_args()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type()); uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); const auto& runtime_args_data = kernel->runtime_args(core_coord); + unique_rt_args_data[processor_idx].emplace_back(kernel->runtime_args_data(core_coord)); // 2, 17, could be differnet len here unique_sub_cmds[processor_idx].emplace_back( @@ -475,6 +480,8 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { auto& unicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); unicast_sub_cmd.reserve(kernel->logical_cores().size()); + common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size()); + common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size()); for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); @@ -483,6 +490,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()); } } else { vector> dst_noc_multicast_info = @@ -493,11 +501,13 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { auto& multicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); + common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size()); for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()); } } common_max_runtime_args_len[kernel_id] = @@ -521,6 +531,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { unique_sub_cmds[processor_idx], unique_rt_data_and_sizes[processor_idx], unique_max_runtime_args_len[processor_idx], + unique_rt_args_data[processor_idx], max_prefetch_command_size, processor_idx); } @@ -534,6 +545,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { sub_cmds, common_rt_data_and_sizes[kernel_id], common_max_runtime_args_len[kernel_id], + common_rt_args_data[kernel_id], max_prefetch_command_size, kernel_id); }, @@ -542,22 +554,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { } else { for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) { auto kernel = detail::GetKernel(program, kernel_id); - - if (!kernel->cores_with_runtime_args().empty()) { - uint32_t processor_idx = - static_cast::type>(kernel->processor()); - for (const auto& core_coord : kernel->cores_with_runtime_args()) { - uint32_t noc_xy_encoding = get_noc_unicast_encoding( - this->device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type())); - const auto& data_loc = - program.command_indices - .processor_to_cmd_mapping[(uint64_t)processor_idx << 32 | noc_xy_encoding]; - const auto& runtime_args_data = kernel->runtime_args(core_coord); - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); - } - } - const auto& common_rt_args = kernel->common_runtime_args(); if (common_rt_args.size() > 0) { if (kernel->get_kernel_core_type() == CoreType::ETH) { diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index f7b9ac5dfc5..46146783031 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -41,6 +41,7 @@ Kernel::Kernel(const std::string &kernel_path_file_name, const CoreRangeSet &cor } } this->core_to_runtime_args_ = { max_x+1, std::vector< std::vector > (max_y+1, std::vector() ) }; + this->core_to_runtime_args_data_ = { max_x+1, std::vector< RuntimeArgsData > (max_y+1, RuntimeArgsData{} ) }; } std::string Kernel::name() const { @@ -157,14 +158,28 @@ std::vector& Kernel::runtime_args(const CoreCoord &logical_core) { return this->core_to_runtime_args_[logical_core.x][logical_core.y]; } +RuntimeArgsData &Kernel::runtime_args_data(const CoreCoord &logical_core) { + // TODO (abhullar): Should this check only be enabled in debug mode? + TT_FATAL( logical_core.x < this->core_to_runtime_args_.size() && logical_core.y < this->core_to_runtime_args_[logical_core.x].size(), "Cannot get runtime args for kernel {} that is not placed on core {}", this->name(), logical_core.str()); + return this->core_to_runtime_args_data_[logical_core.x][logical_core.y]; +} + std::vector< std::vector< std::vector> > & Kernel::runtime_args() { return this->core_to_runtime_args_; } +std::vector< std::vector > & Kernel::runtime_args_data() { + return this->core_to_runtime_args_data_; +} + std::vector& Kernel::common_runtime_args() { return this->common_runtime_args_; } +RuntimeArgsData & Kernel::common_runtime_args_data() { + return this->common_runtime_args_data_; +} + std::pair DataMovementKernel::get_runtime_args_range() const { std::pair arg_base_to_result_base; switch (this->config_.processor) { @@ -227,9 +242,16 @@ void Kernel::set_runtime_args(const CoreCoord &logical_core, const std::vectorvalidate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core); auto &set_rt_args = this->core_to_runtime_args_[logical_core.x][logical_core.y]; + // TODO: Only allow setting once TT_FATAL(set_rt_args.empty() or set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!"); - set_rt_args = runtime_args; - this->core_with_runtime_args_.insert( logical_core ); + if (set_rt_args.empty()) { + set_rt_args = runtime_args; + this->core_to_runtime_args_data_[logical_core.x][logical_core.y] = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; + this->core_with_runtime_args_.insert( logical_core ); + } else { + std::memcpy(this->core_to_runtime_args_data_[logical_core.x][logical_core.y].rt_args_data, runtime_args.data(), runtime_args.size() * sizeof(uint32_t)); + } + } void Kernel::set_common_runtime_args(const std::vector &common_runtime_args) { @@ -242,6 +264,7 @@ void Kernel::set_common_runtime_args(const std::vector &common_runtime auto &set_rt_args = this->common_runtime_args_; TT_FATAL(set_rt_args.empty() or set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); set_rt_args = common_runtime_args; + this->common_runtime_args_data_ = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; } void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const { diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index 4bf9df3e3aa..ede1686e538 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -23,6 +23,37 @@ namespace tt_metal { using Config = std::variant; +struct RuntimeArgsData { + uint32_t * rt_args_data; + size_t rt_args_size; + + inline uint32_t & operator[](size_t index) { + TT_ASSERT(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline const uint32_t& operator[](size_t index) const { + TT_ASSERT(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline uint32_t & at(size_t index) { + TT_FATAL(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline const uint32_t& at(size_t index) const { + TT_FATAL(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline uint32_t * data() noexcept { + return rt_args_data; + } + inline const uint32_t * data() const noexcept { + return rt_args_data; + } + inline size_t size() const noexcept{ + return rt_args_size; + } +}; + class Kernel : public JitBuildSettings { public: Kernel(const std::string &kernel_path_file_name, const CoreRangeSet &core_range_set, const std::vector &compile_args, const std::map&defines); @@ -50,8 +81,11 @@ class Kernel : public JitBuildSettings { void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value); std::vector & runtime_args(const CoreCoord &logical_core); + RuntimeArgsData& runtime_args_data(const CoreCoord &logical_core); std::vector< std::vector< std::vector> > & runtime_args(); + std::vector< std::vector< RuntimeArgsData > > & runtime_args_data(); std::vector & common_runtime_args(); + RuntimeArgsData& common_runtime_args_data(); std::map defines() const { return defines_; } @@ -96,7 +130,9 @@ class Kernel : public JitBuildSettings { uint16_t binary_size16_; std::vector compile_time_args_; std::vector< std::vector< std::vector> > core_to_runtime_args_; + std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_; std::vector common_runtime_args_; + RuntimeArgsData common_runtime_args_data_; std::set core_with_runtime_args_; std::size_t max_runtime_args_per_core_; // For validation CoreCoord core_with_max_runtime_args_; // For validation diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index ad6b686b9c7..fe19f14dd98 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -951,14 +951,14 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const } -std::vector & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { +RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->runtime_args(logical_core); + return detail::GetKernel(program, kernel_id)->runtime_args_data(logical_core); } -std::vector< std::vector< std::vector> > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { +std::vector< std::vector< RuntimeArgsData> >& GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->runtime_args(); + return detail::GetKernel(program, kernel_id)->runtime_args_data(); } std::vector & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) {