diff --git a/docs/aspell-dictionary.pws b/docs/aspell-dictionary.pws index b78450ae343..b1206699cb2 100644 --- a/docs/aspell-dictionary.pws +++ b/docs/aspell-dictionary.pws @@ -169,7 +169,6 @@ UnaryOpType UpdateCircularBufferPageSize UpdateCircularBufferTotalSize UpdateDynamicCircularBufferAddress -UpdateRuntimeArgs VC VCs WH diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst index 678e9ffc51d..661262e22dc 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst @@ -9,12 +9,12 @@ Runtime Arguments .. doxygenfunction:: SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const std::vector< CoreCoord > & core_spec, const std::vector> runtime_args) -.. doxygenfunction:: UpdateRuntimeArgs(Device* device, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args) - .. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) .. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id) .. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args) +.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) + .. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_HostAsyncCQ.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_HostAsyncCQ.cpp index 5575dfaa27e..c2b08908a21 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_HostAsyncCQ.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_HostAsyncCQ.cpp @@ -276,75 +276,6 @@ TEST_F(CommandQueueFixture, TestAsyncBufferRW) { command_queue.set_mode(current_mode); } -TEST_F(CommandQueueFixture, TestAsyncSetAndUpdateRuntimeArgs) { - // Test Asynchronous buffer allocation and SetRuntimeArgs API - auto& command_queue = this->device_->command_queue(); - auto current_mode = CommandQueue::default_mode(); - command_queue.set_mode(CommandQueue::CommandQueueMode::ASYNC); - - uint32_t buf_size = 4096; - uint32_t page_size = 4096; - CoreCoord core = {0, 0}; - // Initialize kernels in program - Program program; - auto reader = CreateKernel( - program, - "tt_metal/kernels/dataflow/reader_binary_diff_lengths.cpp", - core, - DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); - auto writer = CreateKernel( - program, - "tt_metal/kernels/dataflow/writer_unary.cpp", - core, - DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); - - // Asynchronously allocate buffers on device - auto src0 = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - auto src1 = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - auto dst = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - // Asynchronously set the runtime args based on (potentially unallocated) buffer addrs - std::shared_ptr writer_runtime_args = std::make_shared(); - *writer_runtime_args = {src0.get(), src1.get()}; - std::shared_ptr reader_runtime_args= std::make_shared(); - *reader_runtime_args = {dst.get()}; - SetRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_runtime_args); - SetRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_runtime_args); - Finish(this->device_->command_queue()); - - auto resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core); - auto resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core); - - EXPECT_EQ(resolved_writer_args.size(), 2); - EXPECT_EQ(resolved_reader_args.size(), 1); - EXPECT_EQ(resolved_writer_args[0], src0->address()); - EXPECT_EQ(resolved_writer_args[1], src1->address()); - EXPECT_EQ(resolved_reader_args[0], dst->address()); - - // Create new buffers and update the runtime args based on their address - auto src2 = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - auto src3 = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - auto dst1 = std::make_shared(this->device_, buf_size, page_size, BufferType::DRAM); - - *writer_runtime_args = {src2.get(), src3.get()}; - *reader_runtime_args = {dst1.get()}; - std::vector writer_update_idx = {0, 1}; - std::vector reader_update_idx = {0}; - // Asynchronously update the runtime args based on (potentially unallocated) buffer addrs - UpdateRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_update_idx, writer_runtime_args); - UpdateRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_update_idx, reader_runtime_args); - Finish(this->device_->command_queue()); - - resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core); - resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core); - - EXPECT_EQ(resolved_writer_args.size(), 2); - EXPECT_EQ(resolved_reader_args.size(), 1); - EXPECT_EQ(resolved_writer_args[0], src2->address()); - EXPECT_EQ(resolved_writer_args[1], src3->address()); - EXPECT_EQ(resolved_reader_args[0], dst1->address()); - command_queue.set_mode(current_mode); -} - TEST_F(CommandQueueFixture, TestAsyncCBAllocation) { // Test asynchronous allocation of buffers and their assignment to CBs auto& command_queue = this->device_->command_queue(); diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index d80f18bee02..9690f1f52f4 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -1157,42 +1157,42 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& const std::vector>& optional_input_tensors, const std::vector& output_tensors ) { - bool is_sharded = input_tensors.at(0).is_sharded(); - const auto& input = input_tensors.at(0); - const auto& output = output_tensors.at(0); + bool is_sharded = input_tensors[0].is_sharded(); + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; for (uint32_t i = 0; i < total_worker_core_pairs_used; ++i) { if (is_sharded) { auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_reader_sender_runtime_args.at(7) = input.buffer()->address(); + worker_reader_sender_runtime_args[7] = input.buffer()->address(); uint32_t num_dest_cores = worker_reader_sender_runtime_args.at(12); - worker_reader_sender_runtime_args.at(12 + num_dest_cores + 4) = output.buffer()->address(); + worker_reader_sender_runtime_args[12 + num_dest_cores + 4] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_reader_sender_runtime_args:"); for (uint32_t j = 0; j < worker_reader_sender_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_reader_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_writer_sender_runtime_args.at(12) = output.buffer()->address(); + worker_writer_sender_runtime_args[12] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_writer_sender_runtime_args:"); for (uint32_t j = 0; j < worker_writer_sender_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_writer_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i)); - worker_writer_receiver_runtime_args.at(10) = output.buffer()->address(); + worker_writer_receiver_runtime_args[10] = output.buffer()->address(); log_trace(tt::LogOp, "override worker_writer_receiver_runtime_args:"); for (uint32_t j = 0; j < worker_writer_receiver_runtime_args.size(); ++j) { log_trace(tt::LogOp, "\tworker_writer_receiver_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j)); } } else { auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_reader_sender_runtime_args.at(0) = input.buffer()->address(); - worker_reader_sender_runtime_args.at(1) = output.buffer()->address(); + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + worker_reader_sender_runtime_args[1] = output.buffer()->address(); auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i)); - worker_writer_sender_runtime_args.at(0) = output.buffer()->address(); + worker_writer_sender_runtime_args[0] = output.buffer()->address(); auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i)); - worker_writer_receiver_runtime_args.at(0) = output.buffer()->address(); + worker_writer_receiver_runtime_args[0] = output.buffer()->address(); } } }; diff --git a/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp index 25b50079de8..9f0adca240c 100644 --- a/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/concat/multi_core/concat_op_multi_core.cpp @@ -729,7 +729,7 @@ operation::ProgramWithCallbacks concat_multi_core( for (const auto &core : cores) { { auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3); + std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3); } { diff --git a/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp b/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp index 08fee4b2da8..ca19f74f8f4 100644 --- a/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/concat/single_core/concat_op_single_core.cpp @@ -179,7 +179,7 @@ operation::ProgramWithCallbacks concat_single_core( { auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3); + std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3); } { diff --git a/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp b/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp index 294a1a6ff78..2feeb064f42 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp @@ -213,7 +213,7 @@ operation::ProgramWithCallbacks moreh_adam_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -221,11 +221,10 @@ operation::ProgramWithCallbacks moreh_adam_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -233,7 +232,6 @@ operation::ProgramWithCallbacks moreh_adam_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp index 98b985efa16..c17cc662a59 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp @@ -213,7 +213,7 @@ operation::ProgramWithCallbacks moreh_adamw_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -221,11 +221,10 @@ operation::ProgramWithCallbacks moreh_adamw_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = param_buffer->address(); runtime_args[1] = grad_buffer->address(); runtime_args[2] = exp_avg_buffer->address(); @@ -233,7 +232,6 @@ operation::ProgramWithCallbacks moreh_adamw_( if (max_exp_avg_sq_buffer != nullptr) { runtime_args[4] = max_exp_avg_sq_buffer->address(); } - tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp b/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp index 2ec1e46c6b5..bc720e4918c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp @@ -114,9 +114,8 @@ operation::ProgramWithCallbacks moreh_arange_( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp index 2bc2e23ff67..6b2a5608c7d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp @@ -196,23 +196,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_tensors.at(i).buffer()->address(); runtime_args[3] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_address; - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, core); runtime_args[1] = p; runtime_args[2] = static_cast(p_is_negative); - SetRuntimeArgs(program, compute_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp index 841df274fc9..1eb2da1cb5f 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp @@ -133,23 +133,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl( const auto output_address = input_tensors.at(1).buffer()->address(); { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core); runtime_args[0] = input_address; runtime_args[3] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, single_core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core); runtime_args[0] = output_address; - SetRuntimeArgs(program, writer_kernels_id, single_core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core); runtime_args[1] = p; runtime_args[2] = static_cast(p_is_negative); - SetRuntimeArgs(program, compute_kernels_id, single_core, runtime_args); } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp index 2a735f8c277..e2298da64a5 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp @@ -140,16 +140,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffers.at(i)->address(); runtime_args[2] = clip_coef_clamped_address; - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_buffers.at(i)->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp b/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp index c45af58e5c2..111e53bd6dc 100644 --- a/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp @@ -183,15 +183,13 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp b/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp index 6765aab01a6..366bec441cd 100644 --- a/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp @@ -118,24 +118,21 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer_a->address(); runtime_args[1] = src_buffer_b->address(); runtime_args[2] = num_tiles; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[0] = num_tiles; - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); runtime_args[1] = 1; - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } }; return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; diff --git a/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp b/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp index f0748e9ac8c..9e316fe6933 100644 --- a/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_dot_backward/single_core/moreh_dot_backward_op_single_core.cpp @@ -179,29 +179,26 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core( CoreCoord core = {0, 0}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; runtime_args[2] = src0_buffer->address(); runtime_args[3] = src1_buffer->address(); runtime_args[4] = src2_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = (std::uint32_t)has_input_grad; runtime_args[1] = (std::uint32_t)has_input_grad; runtime_args[2] = dst0_address; runtime_args[3] = dst1_address; - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } }; return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp index 56d797cf685..9e4805f7e02 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp @@ -230,19 +230,17 @@ operation::ProgramWithCallbacks moreh_getitem_rm( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp index 9c1c6cb0f8c..f17c1da531d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp @@ -293,19 +293,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; @@ -545,19 +543,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); runtime_args[1] = index_info[0].address; runtime_args[2] = index_info[1].address; runtime_args[3] = index_info[2].address; runtime_args[4] = index_info[3].address; - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp index e0f656e357b..be86f98b6dc 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm.cpp @@ -336,7 +336,7 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); if (gamma_buffer != nullptr) { runtime_args[2] = gamma_buffer->address(); @@ -344,11 +344,10 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( if (beta_buffer != nullptr) { runtime_args[5] = beta_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = ouput_buffer->address(); if (mean_buffer != nullptr) { runtime_args[2] = mean_buffer->address(); @@ -356,7 +355,6 @@ operation::ProgramWithCallbacks moreh_groupnorm_impl( if (rstd_buffer != nullptr) { runtime_args[5] = rstd_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp index b526b63aef8..44cf94e8d90 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/gamma_beta_grad/moreh_groupnorm_backward_gamma_beta_grad.cpp @@ -280,23 +280,21 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_gamma_beta_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[2] = input_buffer->address(); runtime_args[4] = mean_buffer->address(); runtime_args[6] = rstd_buffer->address(); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); if (gamma_grad_buffer != nullptr) { runtime_args[0] = gamma_grad_buffer->address(); } if (beta_grad_buffer != nullptr) { runtime_args[3] = beta_grad_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp index 073a3157970..b990eedff0e 100644 --- a/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_groupnorm_backward/input_grad/moreh_groupnorm_backward_input_grad.cpp @@ -292,7 +292,7 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_input_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[2] = input_buffer->address(); runtime_args[4] = mean_buffer->address(); @@ -300,13 +300,11 @@ operation::ProgramWithCallbacks moreh_groupnorm_backward_input_grad_impl( if (gamma_buffer != nullptr) { runtime_args[8] = gamma_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp index f379a4f462b..0645bd41cf2 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp @@ -386,7 +386,7 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); if (gamma_buffer != nullptr) { runtime_args[6] = gamma_buffer->address(); @@ -394,11 +394,10 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( if (beta_buffer != nullptr) { runtime_args[7] = beta_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = ouput_buffer->address(); if (mean_buffer != nullptr) { runtime_args[1] = mean_buffer->address(); @@ -406,7 +405,6 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( if (rstd_buffer != nullptr) { runtime_args[2] = rstd_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp index dae24fe6fcb..d25ec17a1ad 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp @@ -278,23 +278,21 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_gamma_beta_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[1] = input_buffer->address(); runtime_args[2] = mean_buffer->address(); runtime_args[3] = rstd_buffer->address(); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); if (gamma_grad_buffer != nullptr) { runtime_args[0] = gamma_grad_buffer->address(); } if (beta_grad_buffer != nullptr) { runtime_args[1] = beta_grad_buffer->address(); } - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp index 7d4575802a2..ea191cb71b4 100644 --- a/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp @@ -324,7 +324,7 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_input_grad_impl( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = output_grad_buffer->address(); runtime_args[1] = input_buffer->address(); runtime_args[2] = mean_buffer->address(); @@ -332,13 +332,11 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_input_grad_impl( if (gamma_buffer != nullptr) { runtime_args[4] = gamma_buffer->address(); } - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp b/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp index 51e20000bf1..9114b8a58ff 100644 --- a/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_nc/moreh_mean_nc.cpp @@ -175,15 +175,13 @@ operation::ProgramWithCallbacks moreh_mean_nc(const Tensor &input, const Tensor for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp index f49a2b862a6..bb2e2b6a45d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward.cpp @@ -200,15 +200,13 @@ operation::ProgramWithCallbacks moreh_mean_backward_program(const Tensor &output for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_grad_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp index 32e9e223a6a..d0d0ef84e48 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp @@ -168,15 +168,13 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl(const Tensor &input, c CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp index 5e21e0595c6..d687a6d5f91 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp @@ -176,15 +176,13 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl(const Tensor &input, c CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp index 6b93956f7de..51a17f0cef1 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_h/moreh_norm_h.cpp @@ -221,17 +221,15 @@ operation::ProgramWithCallbacks moreh_norm_h_impl(const Tensor &input, float p, CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -243,12 +241,11 @@ operation::ProgramWithCallbacks moreh_norm_h_impl(const Tensor &input, float p, } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[3] = floored_p; runtime_args[4] = static_cast(p_is_negative); runtime_args[5] = floored_recip_p; runtime_args[6] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp index 73704e360b6..5d9b734fa80 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_other/moreh_norm_other.cpp @@ -225,17 +225,15 @@ operation::ProgramWithCallbacks moreh_norm_other_impl(const Tensor &input, float CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -247,12 +245,11 @@ operation::ProgramWithCallbacks moreh_norm_other_impl(const Tensor &input, float } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[2] = floored_p; runtime_args[3] = static_cast(p_is_negative); runtime_args[4] = floored_recip_p; runtime_args[5] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp index 894e11f95a9..c6c3e17ba98 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm/moreh_norm_w/moreh_norm_w.cpp @@ -220,17 +220,15 @@ operation::ProgramWithCallbacks moreh_norm_w_impl(const Tensor &input, float p, CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = *reinterpret_cast(&decimal); runtime_args[3] = *reinterpret_cast(&recip_p_decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -242,12 +240,11 @@ operation::ProgramWithCallbacks moreh_norm_w_impl(const Tensor &input, float p, } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[3] = floored_p; runtime_args[4] = static_cast(p_is_negative); runtime_args[5] = floored_recip_p; runtime_args[6] = static_cast(recip_p_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp index 8fcabd0748f..33b9ebd164f 100644 --- a/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward.cpp @@ -274,18 +274,16 @@ operation::ProgramWithCallbacks moreh_norm_backward_( CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); runtime_args[0] = input_buffer->address(); runtime_args[2] = output_buffer->address(); runtime_args[4] = output_grad_buffer->address(); runtime_args[6] = *reinterpret_cast(&decimal); - SetRuntimeArgs(program, reader_kernels_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core); runtime_args[0] = input_grad_buffer->address(); - SetRuntimeArgs(program, writer_kernels_id, core, runtime_args); } { @@ -297,12 +295,11 @@ operation::ProgramWithCallbacks moreh_norm_backward_( } else { TT_THROW("Core not in specified core ranges."); } - auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); runtime_args[5] = floored_p; runtime_args[6] = static_cast(p_is_negative); runtime_args[7] = floored_p_minus_one; runtime_args[8] = static_cast(p_minus_one_is_negative); - SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp b/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp index e6025ce4452..7567cc94100 100644 --- a/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp @@ -219,7 +219,7 @@ operation::ProgramWithCallbacks moreh_sgd_( } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_ids, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_ids, core); runtime_args[0] = param_out_address; if (has_momentum_buffer) { diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp index f3921b28a30..79e0b62a5e0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp @@ -150,15 +150,13 @@ operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp index 3f878ba0a1f..abcb6b19409 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp @@ -144,15 +144,13 @@ operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp index 4ab132eca6a..77523098d4c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp @@ -166,15 +166,13 @@ operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp index cede75c98a8..f1ae31c7dd6 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp @@ -145,15 +145,13 @@ operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp index c37530a6cdf..bf90b8d47b0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp @@ -166,15 +166,13 @@ operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp index 90682747295..5752781a893 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -157,16 +157,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp index 423f7709696..859867d17f0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -152,16 +152,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp index bdb1a941271..44df2758698 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -174,16 +174,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp index 4f316e4d921..78ce4ceecfa 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -152,16 +152,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp index cf29f8ce9d9..a834f5e4acd 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -175,16 +175,14 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &out CoreCoord core = {icore / core_h, icore % core_h}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = output_dram_buffer->address(); runtime_args[1] = output_grad_dram_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = input_grad_dram_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp b/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp index 0f4451f54ed..65d72e96387 100644 --- a/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/non_zero_indices/single_core/non_zero_indices_op_single_core.cpp @@ -106,14 +106,13 @@ operation::ProgramWithCallbacks non_zero_indices_single_core(const Tensor &input uint32_t alignment_base = 32/input.element_size(); uint32_t aligned_elements = div_up(input.get_legacy_shape()[-1] , alignment_base) * alignment_base; uint32_t actual_elements = input.get_legacy_shape()[-1]; - auto runtime_args = tt_metal::GetRuntimeArgs(program, kernel_id, core); + auto& runtime_args = tt_metal::GetRuntimeArgs(program, kernel_id, core); runtime_args[0] = input.buffer()->address(); runtime_args[1] = output_0.buffer()->address(); runtime_args[2] = output_1.buffer()->address(); runtime_args[3] = aligned_elements; runtime_args[4] = actual_elements; runtime_args[5] = input.element_size(); - tt_metal::SetRuntimeArgs(program, kernel_id, core, runtime_args); }; return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; } diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp index 78781395077..88d499a8460 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp @@ -180,15 +180,13 @@ operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = input_buffer->address(); - SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); } { - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = output_buffer->address(); - SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); } } }; diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index 7e6267c814d..cbe78eafd92 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -877,12 +877,10 @@ operation::ProgramWithCallbacks reshard_multi_core(const Tensor& input, Tensor& auto all_cores = output_shard_spec.grid; auto cores = corerange_to_cores(all_cores, std::nullopt, output_shard_spec.orientation == ShardOrientation::ROW_MAJOR); for (auto core: cores) { - auto runtime_args_0 = GetRuntimeArgs(program, kernel_id_0, core); - auto runtime_args_1 = GetRuntimeArgs(program, kernel_id_1, core); + auto &runtime_args_0 = GetRuntimeArgs(program, kernel_id_0, core); + auto &runtime_args_1 = GetRuntimeArgs(program, kernel_id_1, core); runtime_args_0[grid.x + grid.y] = input.buffer()->address(); runtime_args_1[grid.x + grid.y] = input.buffer()->address(); - tt_metal::SetRuntimeArgs(program, kernel_id_0, core, runtime_args_0); - tt_metal::SetRuntimeArgs(program, kernel_id_1, core, runtime_args_1); } UpdateDynamicCircularBufferAddress(program, cb_dst0, *output.buffer()); }; diff --git a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp index c9724c9950c..ce3d43d58c4 100644 --- a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp @@ -268,9 +268,9 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; const auto set_common_reader_args = [&]( - std::vector & reader_common_args, - uint32_t * num_unpadded_tiles_per_dim, - uint32_t * num_padded_tiles_per_dim) __attribute__((always_inline)) { + uint32_t* reader_common_args, + uint32_t* num_unpadded_tiles_per_dim, + uint32_t* num_padded_tiles_per_dim) __attribute__((always_inline)) { reader_common_args[0] = input_buffer->address(); num_unpadded_tiles_per_dim[0] = num_unpadded_Xt; num_unpadded_tiles_per_dim[1] = num_unpadded_Yt; @@ -289,7 +289,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( }; const auto set_reader_rt_args = [&]( - std::vector & reader_rt_args, + uint32_t* reader_rt_args, const uint32_t* num_unpadded_tiles_per_dim, const uint32_t* num_padded_tiles_per_dim, const uint32_t& num_tiles_per_core, @@ -311,14 +311,15 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( std::vector reader_common_args(1 + num_dims * 2); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); @@ -352,7 +353,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( if constexpr (initialize_args) { std::vector reader_kernel_args(2 + num_dims); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_tiles_per_core, @@ -362,7 +363,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( } else { auto& reader_kernel_args = reader_kernel_args_by_core[core.x][core.y]; set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_tiles_per_core, diff --git a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp index 8cceba36fa7..e7672c18b12 100644 --- a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp @@ -185,7 +185,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t num_padded_Yt = (num_total_Yt - num_unpadded_Yt) * num_total_Xt; const auto set_common_reader_args = [&]( - std::vector & reader_common_args, + uint32_t * reader_common_args, uint32_t * num_unpadded_tiles_per_dim, uint32_t * num_padded_tiles_per_dim) __attribute__((always_inline)) { reader_common_args[0] = input_buffer->address(); @@ -206,7 +206,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( }; const auto set_reader_rt_args = [&]( - std::vector & reader_rt_args, + uint32_t * reader_rt_args, const uint32_t* num_unpadded_tiles_per_dim, const uint32_t* num_padded_tiles_per_dim, const uint32_t& num_tiles_per_core, @@ -228,14 +228,15 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( std::vector reader_common_args(1 + num_dims * 2); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } auto& reader_common_args = GetCommonRuntimeArgs(program, unary_reader_kernel_id); uint32_t* num_unpadded_tiles_per_dim = reader_common_args.data() + 1; uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { - set_common_reader_args(reader_common_args, num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); @@ -243,7 +244,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( if constexpr (initialize_args) { std::vector reader_kernel_args(2 + num_dims); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_unpadded_tiles, @@ -253,7 +254,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( } else { auto& reader_kernel_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); set_reader_rt_args( - reader_kernel_args, + reader_kernel_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim, num_unpadded_tiles, diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index 0ba0cb9d70f..bf1188d5603 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -331,10 +331,24 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const */ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args); +/** + * Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel. + * Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit). + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------| + * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | + * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | + * | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes | + */ +void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args); + /** * Get the runtime args for a kernel. * - * Return value: std::vector & + * Return value: uint32_t * * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| @@ -342,46 +356,32 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | * | logical_core | The location of the Tensix core where the runtime args will be written | const CoreCoord & | Any logical Tensix core coordinate | Yes | */ -std::vector& GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core); +RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core); /** * Get the runtime args for a kernel. * - * Return value: std::vector< std::vector< std::vector> > & + * Return value: std::vector< std::vector< RuntimeArgsData > > & * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | */ -std::vector< std::vector< std::vector> > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id); +std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id); /** * Get the common runtime args for a kernel. + * Note that you must call SetCommonRuntimeArgs after updating the returned value to propagate the update. * - * Return value: std::vector & + * Return value: RuntimeArgsData & * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | */ -std::vector& GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); - -/** - * Update specific entries of the runtime args vector for a kernel using the command queue. This API must be used when Asynchronous Command Queue Mode is enabled. - * - * Return Value: void - * - * | Argument | Description | Type | Valid Range | Required | - * |--------------|------------------------------------------------------------------------|-------------------------------|--------------------------------------------------------------|----------| - * | cq | The command queue used to send the runtime args update | CommandQueue & | | Yes | - * | kernel | The kernel for which the runtime args must be updated | std::shared_ptr | | Yes | - * | core_coord | The core receiving the runtime args update | const CoreCoord & | A single core running the kernel | Yes | - * | update_idx | The runtime arg vector indices that must be updated | std::vector & | Each index in this vector must be less than num runtime args | Yes | - * | runtime_args | Updated runtime args | std::shared_ptr | 1:1 Mapping between each entry and the indices in update_idx | Yes | - */ -void UpdateRuntimeArgs(Device* device, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args); +RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); /** * Reads a buffer from the device diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 913eca885fe..a9a45955508 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -38,6 +38,7 @@ std::condition_variable finish_cv; namespace tt::tt_metal { +// TODO: Delete entries when programs are deleted to save memory thread_local std::unordered_map> EnqueueProgramCommand::runtime_args_command_sequences = {}; uint32_t get_noc_unicast_encoding(const CoreCoord &coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); } @@ -307,11 +308,11 @@ void EnqueueProgramCommand::assemble_preamble_commands() { template void generate_dispatch_write_packed( std::vector& runtime_args_command_sequences, - std::unordered_map>& cmd_mapping, const uint32_t& l1_arg_base_addr, const std::vector& sub_cmds, const std::vector>& rt_data_and_sizes, const uint32_t& max_runtime_args_len, + std::vector>& rt_args_data, const uint32_t max_prefetch_command_size, const uint32_t id) { static_assert( @@ -352,8 +353,8 @@ void generate_dispatch_write_packed( uint32_t num_packed_cmds = std::min(num_packed_cmds_in_seq, max_packed_cmds); uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast); uint32_t cmd_sequence_sizeB = align(sizeof(CQPrefetchCmd) + rt_payload_sizeB, PCIE_ALIGNMENT); - HostMemDeviceCommand command_sequence(cmd_sequence_sizeB); - command_sequence.add_dispatch_write_packed( + runtime_args_command_sequences.emplace_back(cmd_sequence_sizeB); + runtime_args_command_sequences.back().add_dispatch_write_packed( num_packed_cmds, l1_arg_base_addr, max_runtime_args_len * sizeof(uint32_t), @@ -361,12 +362,10 @@ void generate_dispatch_write_packed( sub_cmds, rt_data_and_sizes, offset_idx); - runtime_args_command_sequences.emplace_back(command_sequence); uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast); const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) { - cmd_mapping[(uint64_t)id << 32 | sub_cmds[i].noc_xy_addr] = { - runtime_args_command_sequences.size() - 1, data_offset}; + rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset); data_offset += data_inc; } num_packed_cmds_in_seq -= num_packed_cmds; @@ -419,6 +418,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { uint32_t num_dsts = unique_processor_to_l1_arg_base_addr.size(); std::vector> unique_sub_cmds(num_dsts); std::vector>> unique_rt_data_and_sizes(num_dsts); + std::vector>> unique_rt_args_data(num_dsts); std::vector unique_max_runtime_args_len(num_dsts, 0); std::vector>> common_sub_cmds(program.num_kernels()); std::vector>> common_rt_data_and_sizes(program.num_kernels()); + std::vector>> common_rt_args_data(program.num_kernels()); std::vector common_max_runtime_args_len(program.num_kernels(), 0); std::vector common_processor_to_l1_arg_base_addr(program.num_kernels()); @@ -445,12 +446,14 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { unique_processors.insert(processor_idx); unique_sub_cmds[processor_idx].reserve(kernel->cores_with_runtime_args().size()); unique_rt_data_and_sizes[processor_idx].reserve(kernel->cores_with_runtime_args().size()); + unique_rt_args_data[processor_idx].reserve(kernel->cores_with_runtime_args().size()); for (const auto& core_coord : kernel->cores_with_runtime_args()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type()); uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); const auto& runtime_args_data = kernel->runtime_args(core_coord); + unique_rt_args_data[processor_idx].emplace_back(kernel->runtime_args_data(core_coord)); // 2, 17, could be differnet len here unique_sub_cmds[processor_idx].emplace_back( @@ -475,6 +478,9 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { auto& unicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); unicast_sub_cmd.reserve(kernel->logical_cores().size()); + common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size()); + common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size()); + uint32_t i = 0; for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); @@ -483,6 +489,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } else { vector> dst_noc_multicast_info = @@ -493,11 +500,14 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { auto& multicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); + common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size()); + uint32_t i = 0; for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } common_max_runtime_args_len[kernel_id] = @@ -516,11 +526,11 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { for (const uint32_t& processor_idx : unique_processors) { generate_dispatch_write_packed( this->runtime_args_command_sequences[program.id], - program.command_indices.processor_to_cmd_mapping, unique_processor_to_l1_arg_base_addr[processor_idx], unique_sub_cmds[processor_idx], unique_rt_data_and_sizes[processor_idx], unique_max_runtime_args_len[processor_idx], + unique_rt_args_data[processor_idx], max_prefetch_command_size, processor_idx); } @@ -529,62 +539,16 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { [&](auto&& sub_cmds) { generate_dispatch_write_packed( this->runtime_args_command_sequences[program.id], - program.command_indices.kernel_to_cmd_mapping, common_processor_to_l1_arg_base_addr[kernel_id], sub_cmds, common_rt_data_and_sizes[kernel_id], common_max_runtime_args_len[kernel_id], + common_rt_args_data[kernel_id], max_prefetch_command_size, kernel_id); }, common_sub_cmds[kernel_id]); } - } else { - for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) { - auto kernel = detail::GetKernel(program, kernel_id); - - if (!kernel->cores_with_runtime_args().empty()) { - uint32_t processor_idx = - static_cast::type>(kernel->processor()); - for (const auto& core_coord : kernel->cores_with_runtime_args()) { - uint32_t noc_xy_encoding = get_noc_unicast_encoding( - this->device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type())); - const auto& data_loc = - program.command_indices - .processor_to_cmd_mapping[(uint64_t)processor_idx << 32 | noc_xy_encoding]; - const auto& runtime_args_data = kernel->runtime_args(core_coord); - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); - } - } - - const auto& common_rt_args = kernel->common_runtime_args(); - if (common_rt_args.size() > 0) { - if (kernel->get_kernel_core_type() == CoreType::ETH) { - for (const auto& core_coord : kernel->logical_cores()) { - uint32_t noc_xy_encoding = - get_noc_unicast_encoding(this->device->ethernet_core_from_logical_core(core_coord)); - const auto& data_loc = - program.command_indices.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_xy_encoding]; - const auto& runtime_args_data = kernel->runtime_args(core_coord); - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); - } - } else { - for (const auto& core_range : kernel->logical_coreranges()) { - CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start); - CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end); - - uint32_t noc_multicast_encoding = get_noc_multcast_encoding(physical_start, physical_end); - const auto& data_loc = - program.command_indices - .kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_multicast_encoding]; - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - } - } - } - } } } @@ -1867,40 +1831,6 @@ void EnqueueAddBufferToProgram(CommandQueue& cq, std::variant resolved_runtime_args = {}; - resolved_runtime_args.reserve((*runtime_args_md.runtime_args_ptr).size()); - - for (const auto& arg : *(runtime_args_md.runtime_args_ptr)) { - std::visit([&resolved_runtime_args] (auto&& a) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - resolved_runtime_args.push_back(a -> address()); - } else { - resolved_runtime_args.push_back(a); - } - }, arg); - } - auto& kernel_runtime_args = runtime_args_md.kernel->runtime_args(runtime_args_md.core_coord); - for (const auto& idx : runtime_args_md.update_idx) { - kernel_runtime_args[idx] = resolved_runtime_args[idx]; - } -} - -void EnqueueUpdateRuntimeArgs(CommandQueue& cq, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args_ptr, bool blocking) { - auto runtime_args_md = RuntimeArgsMetadata { - .core_coord = core_coord, - .runtime_args_ptr = runtime_args_ptr, - .kernel = kernel, - .update_idx = update_idx, - }; - cq.run_command(CommandInterface { - .type = EnqueueCommandType::UPDATE_RUNTIME_ARGS, - .blocking = blocking, - .runtime_args_md = runtime_args_md, - }); -} - void EnqueueSetRuntimeArgsImpl(const RuntimeArgsMetadata& runtime_args_md) { std::vector resolved_runtime_args = {}; resolved_runtime_args.reserve((*runtime_args_md.runtime_args_ptr).size()); @@ -2385,10 +2315,6 @@ void CommandQueue::run_command_impl(const CommandInterface& command) { TT_ASSERT(command.runtime_args_md.has_value(), "Must provide RuntimeArgs Metdata!"); EnqueueSetRuntimeArgsImpl(command.runtime_args_md.value()); break; - case EnqueueCommandType::UPDATE_RUNTIME_ARGS: - TT_ASSERT(command.runtime_args_md.has_value(), "Must provide RuntimeArgs Metdata!"); - EnqueueUpdateRuntimeArgsImpl(command.runtime_args_md.value()); - break; case EnqueueCommandType::ADD_BUFFER_TO_PROGRAM: TT_ASSERT(command.buffer.has_value(), "Must provide a buffer!"); TT_ASSERT(command.program.has_value(), "Must provide a program!"); diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 10a12a329c8..7a413e0f8b3 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -39,7 +39,6 @@ enum class EnqueueCommandType { GET_BUF_ADDR, ADD_BUFFER_TO_PROGRAM, SET_RUNTIME_ARGS, - UPDATE_RUNTIME_ARGS, ENQUEUE_PROGRAM, ENQUEUE_TRACE, ENQUEUE_RECORD_EVENT, @@ -634,7 +633,6 @@ void EnqueueAllocateBuffer(CommandQueue& cq, Buffer* buffer, bool bottom_up, boo void EnqueueDeallocateBuffer(CommandQueue& cq, Allocator& allocator, uint32_t device_address, BufferType buffer_type, bool blocking); void EnqueueGetBufferAddr(CommandQueue& cq, uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking); void EnqueueSetRuntimeArgs(CommandQueue& cq, const std::shared_ptr kernel, const CoreCoord &core_coord, std::shared_ptr runtime_args_ptr, bool blocking); -void EnqueueUpdateRuntimeArgs(CommandQueue& cq, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args_ptr, bool blocking); void EnqueueAddBufferToProgram(CommandQueue& cq, std::variant, std::shared_ptr> buffer, std::variant, std::shared_ptr> program, bool blocking); } // namespace tt::tt_metal diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index 17c9e0efe3a..7924839955c 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -41,6 +41,7 @@ Kernel::Kernel(const std::string &kernel_path_file_name, const CoreRangeSet &cor } } this->core_to_runtime_args_ = { max_x+1, std::vector< std::vector > (max_y+1, std::vector() ) }; + this->core_to_runtime_args_data_ = { max_x+1, std::vector< RuntimeArgsData > (max_y+1, RuntimeArgsData{} ) }; } std::string Kernel::name() const { @@ -144,27 +145,34 @@ std::string Kernel::compute_hash() const { ); } -void Kernel::update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value){ - ZoneScoped; - auto & v = this->core_to_runtime_args_[logical_core.x][logical_core.y]; - TT_ASSERT( idx < v.size(), "Runtime arg offset {} for Core {} out of bounds", idx, logical_core.str()); - v[idx] = value; -} - std::vector& Kernel::runtime_args(const CoreCoord &logical_core) { // TODO (abhullar): Should this check only be enabled in debug mode? TT_FATAL( logical_core.x < this->core_to_runtime_args_.size() && logical_core.y < this->core_to_runtime_args_[logical_core.x].size(), "Cannot get runtime args for kernel {} that is not placed on core {}", this->name(), logical_core.str()); return this->core_to_runtime_args_[logical_core.x][logical_core.y]; } +RuntimeArgsData &Kernel::runtime_args_data(const CoreCoord &logical_core) { + // TODO (abhullar): Should this check only be enabled in debug mode? + TT_FATAL( logical_core.x < this->core_to_runtime_args_.size() && logical_core.y < this->core_to_runtime_args_[logical_core.x].size(), "Cannot get runtime args for kernel {} that is not placed on core {}", this->name(), logical_core.str()); + return this->core_to_runtime_args_data_[logical_core.x][logical_core.y]; +} + std::vector< std::vector< std::vector> > & Kernel::runtime_args() { return this->core_to_runtime_args_; } +std::vector< std::vector > & Kernel::runtime_args_data() { + return this->core_to_runtime_args_data_; +} + std::vector& Kernel::common_runtime_args() { return this->common_runtime_args_; } +std::vector & Kernel::common_runtime_args_data() { + return this->common_runtime_args_data_; +} + std::pair DataMovementKernel::get_runtime_args_range() const { std::pair arg_base_to_result_base; switch (this->config_.processor) { @@ -220,16 +228,23 @@ void Kernel::set_runtime_args(const CoreCoord &logical_core, const std::vectoris_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); // Keep state for validation, to be able to check from both set_runtime_args() and set_common_runtime_args() APIs. - if (runtime_args.size() > max_runtime_args_per_core_) { - max_runtime_args_per_core_ = runtime_args.size(); - core_with_max_runtime_args_ = logical_core; - } - this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core); auto &set_rt_args = this->core_to_runtime_args_[logical_core.x][logical_core.y]; - TT_FATAL(set_rt_args.empty() or set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!"); - set_rt_args = runtime_args; - this->core_with_runtime_args_.insert( logical_core ); + // TODO: Only allow setting once + if (set_rt_args.empty()) { + if (runtime_args.size() > max_runtime_args_per_core_) { + max_runtime_args_per_core_ = runtime_args.size(); + core_with_max_runtime_args_ = logical_core; + } + this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core); + set_rt_args = runtime_args; + this->core_to_runtime_args_data_[logical_core.x][logical_core.y] = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; + this->core_with_runtime_args_.insert( logical_core ); + } else { + TT_FATAL(set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!"); + std::memcpy(this->core_to_runtime_args_data_[logical_core.x][logical_core.y].rt_args_data, runtime_args.data(), runtime_args.size() * sizeof(uint32_t)); + } + } void Kernel::set_common_runtime_args(const std::vector &common_runtime_args) { @@ -238,10 +253,26 @@ void Kernel::set_common_runtime_args(const std::vector &common_runtime // Should this check only be enabled in debug mode? // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); - this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_); auto &set_rt_args = this->common_runtime_args_; - TT_FATAL(set_rt_args.empty() or set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); + TT_FATAL(set_rt_args.empty(), "Illegal Common Runtime Args: Can only set common runtime args once. Get and modify args in place instead."); + this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_); set_rt_args = common_runtime_args; + this->common_runtime_args_data_ = std::vector(this->core_range_set_.ranges().size(), RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}); +} + +void Kernel::set_common_runtime_args(const RuntimeArgsData& common_runtime_args) { + + // TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on. + // Should this check only be enabled in debug mode? + // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); + + auto &set_rt_args = this->common_runtime_args_; + TT_FATAL(!set_rt_args.empty() and set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); + for (auto& rt_args_data : this->common_runtime_args_data_) { + if (common_runtime_args.data() != rt_args_data.data()) { + memcpy(rt_args_data.data(), common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t)); + } + } } void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const { diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index 87e3a7e6d4c..ea5cb9d4510 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -23,6 +23,37 @@ namespace tt_metal { using Config = std::variant; +struct RuntimeArgsData { + uint32_t * rt_args_data; + size_t rt_args_size; + + inline uint32_t & operator[](size_t index) { + TT_ASSERT(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline const uint32_t& operator[](size_t index) const { + TT_ASSERT(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline uint32_t & at(size_t index) { + TT_FATAL(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline const uint32_t& at(size_t index) const { + TT_FATAL(index < rt_args_size, "Index specified is larger than runtime args size"); + return this->rt_args_data[index]; + } + inline uint32_t * data() noexcept { + return rt_args_data; + } + inline const uint32_t * data() const noexcept { + return rt_args_data; + } + inline size_t size() const noexcept{ + return rt_args_size; + } +}; + class Kernel : public JitBuildSettings { public: Kernel(const std::string &kernel_path_file_name, const CoreRangeSet &core_range_set, const std::vector &compile_args, const std::map&defines); @@ -47,11 +78,12 @@ class Kernel : public JitBuildSettings { const std::set& cores_with_runtime_args() const { return core_with_runtime_args_; } - void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value); - std::vector & runtime_args(const CoreCoord &logical_core); + RuntimeArgsData & runtime_args_data(const CoreCoord &logical_core); std::vector< std::vector< std::vector> > & runtime_args(); + std::vector< std::vector< RuntimeArgsData > > & runtime_args_data(); std::vector & common_runtime_args(); + std::vector & common_runtime_args_data(); std::map defines() const { return defines_; } @@ -74,6 +106,7 @@ class Kernel : public JitBuildSettings { void validate_runtime_args_size(size_t num_unique_rt_args, size_t num_common_rt_args, const CoreCoord& logical_core); void set_runtime_args(const CoreCoord &logical_core, const std::vector &runtime_args); void set_common_runtime_args(const std::vector &runtime_args); + void set_common_runtime_args(const RuntimeArgsData& common_runtime_args); int get_watcher_kernel_id() { return watcher_kernel_id_; } @@ -96,7 +129,9 @@ class Kernel : public JitBuildSettings { uint16_t binary_size16_; std::vector compile_time_args_; std::vector< std::vector< std::vector> > core_to_runtime_args_; + std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_; std::vector common_runtime_args_; + std::vector common_runtime_args_data_; std::set core_with_runtime_args_; std::size_t max_runtime_args_per_core_; // For validation CoreCoord core_with_max_runtime_args_; // For validation diff --git a/tt_metal/impl/program/program_device_map.hpp b/tt_metal/impl/program/program_device_map.hpp index ae9fb53e504..288717f9aa9 100644 --- a/tt_metal/impl/program/program_device_map.hpp +++ b/tt_metal/impl/program/program_device_map.hpp @@ -42,8 +42,4 @@ struct ProgramTransferInfo { struct ProgramCommandIndices { std::uint32_t cb_configs_payload_start; // device_commands - // pair of cmd idx, rt arg offset - // Currently we only really need the base cmd idx since they are sequential, and the rt arg len is currently the same for all splits - std::unordered_map> processor_to_cmd_mapping; - std::unordered_map> kernel_to_cmd_mapping; }; diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index b3c4092598b..18752e7da23 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -949,25 +949,28 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const } } - -std::vector & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { - TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->runtime_args(logical_core); +void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) { + ZoneScoped; + TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); + if (runtime_args.size() != 0) { + detail::GetKernel(program, kernel_id)->set_common_runtime_args(runtime_args); + } } -std::vector< std::vector< std::vector> > & GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { + +RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->runtime_args(); + return detail::GetKernel(program, kernel_id)->runtime_args_data(logical_core); } -std::vector & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { +std::vector< std::vector< RuntimeArgsData> >& GetRuntimeArgs(const Program &program, KernelHandle kernel_id) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->common_runtime_args(); + return detail::GetKernel(program, kernel_id)->runtime_args_data(); } -void UpdateRuntimeArgs(Device* device, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args) { - detail::DispatchStateCheck(not device->using_slow_dispatch()); - EnqueueUpdateRuntimeArgs(device->command_queue(), kernel, core_coord, update_idx, runtime_args, false); +RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { + TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); + return detail::GetKernel(program, kernel_id)->common_runtime_args_data().at(0); } } // namespace tt_metal