diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst index 678e9ffc51d..031807b082f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst @@ -17,4 +17,6 @@ Runtime Arguments .. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args) +.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) + .. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) diff --git a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp index 1387791d805..ce3d43d58c4 100644 --- a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp @@ -319,6 +319,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); diff --git a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp index c4a32a76518..e7672c18b12 100644 --- a/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/single_core/unpad_op_single_core.cpp @@ -236,6 +236,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); + SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index def697295cb..683b358a388 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -331,6 +331,20 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const */ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args); +/** + * Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel. + * Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit). + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------| + * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | + * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | + * | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes | + */ +void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args); + /** * Get the runtime args for a kernel. * @@ -358,15 +372,16 @@ std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &pr /** * Get the common runtime args for a kernel. + * Note that you must call SetCommonRuntimeArgs after updating the returned value to propagate the update. * - * Return value: std::vector & + * Return value: RuntimeArgsData & * * | Argument | Description | Type | Valid Range | Required | * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | */ -std::vector & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); +RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id); /** * Update specific entries of the runtime args vector for a kernel using the command queue. This API must be used when Asynchronous Command Queue Mode is enabled. diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index dc074b035ec..ecd91dc8927 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -38,6 +38,7 @@ std::condition_variable finish_cv; namespace tt::tt_metal { +// TODO: Delete entries when programs are deleted to save memory thread_local std::unordered_map> EnqueueProgramCommand::runtime_args_command_sequences = {}; uint32_t get_noc_unicast_encoding(const CoreCoord &coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); } @@ -307,7 +308,6 @@ void EnqueueProgramCommand::assemble_preamble_commands() { template void generate_dispatch_write_packed( std::vector& runtime_args_command_sequences, - std::unordered_map>& cmd_mapping, const uint32_t& l1_arg_base_addr, const std::vector& sub_cmds, const std::vector>& rt_data_and_sizes, @@ -365,8 +365,6 @@ void generate_dispatch_write_packed( uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast); const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) { - cmd_mapping[(uint64_t)id << 32 | sub_cmds[i].noc_xy_addr] = { - runtime_args_command_sequences.size() - 1, data_offset}; rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset); data_offset += data_inc; } @@ -482,6 +480,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { unicast_sub_cmd.reserve(kernel->logical_cores().size()); common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size()); common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size()); + uint32_t i = 0; for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); @@ -490,7 +489,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } else { vector> dst_noc_multicast_info = @@ -502,12 +501,13 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { std::get>(common_sub_cmds[kernel_id]); multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size()); + uint32_t i = 0; for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); common_rt_data_and_sizes[kernel_id].emplace_back( common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } common_max_runtime_args_len[kernel_id] = @@ -526,7 +526,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { for (const uint32_t& processor_idx : unique_processors) { generate_dispatch_write_packed( this->runtime_args_command_sequences[program.id], - program.command_indices.processor_to_cmd_mapping, unique_processor_to_l1_arg_base_addr[processor_idx], unique_sub_cmds[processor_idx], unique_rt_data_and_sizes[processor_idx], @@ -540,7 +539,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { [&](auto&& sub_cmds) { generate_dispatch_write_packed( this->runtime_args_command_sequences[program.id], - program.command_indices.kernel_to_cmd_mapping, common_processor_to_l1_arg_base_addr[kernel_id], sub_cmds, common_rt_data_and_sizes[kernel_id], @@ -551,36 +549,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { }, common_sub_cmds[kernel_id]); } - } else { - for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) { - auto kernel = detail::GetKernel(program, kernel_id); - const auto& common_rt_args = kernel->common_runtime_args(); - if (common_rt_args.size() > 0) { - if (kernel->get_kernel_core_type() == CoreType::ETH) { - for (const auto& core_coord : kernel->logical_cores()) { - uint32_t noc_xy_encoding = - get_noc_unicast_encoding(this->device->ethernet_core_from_logical_core(core_coord)); - const auto& data_loc = - program.command_indices.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_xy_encoding]; - const auto& runtime_args_data = kernel->runtime_args(core_coord); - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); - } - } else { - for (const auto& core_range : kernel->logical_coreranges()) { - CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start); - CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end); - - uint32_t noc_multicast_encoding = get_noc_multcast_encoding(physical_start, physical_end); - const auto& data_loc = - program.command_indices - .kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_multicast_encoding]; - this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence( - data_loc.second, common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - } - } - } - } } } diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index d0943e8722f..2b1985c4e53 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -176,7 +176,7 @@ std::vector& Kernel::common_runtime_args() { return this->common_runtime_args_; } -RuntimeArgsData & Kernel::common_runtime_args_data() { +std::vector & Kernel::common_runtime_args_data() { return this->common_runtime_args_data_; } @@ -235,20 +235,20 @@ void Kernel::set_runtime_args(const CoreCoord &logical_core, const std::vectoris_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); // Keep state for validation, to be able to check from both set_runtime_args() and set_common_runtime_args() APIs. - if (runtime_args.size() > max_runtime_args_per_core_) { - max_runtime_args_per_core_ = runtime_args.size(); - core_with_max_runtime_args_ = logical_core; - } - this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core); auto &set_rt_args = this->core_to_runtime_args_[logical_core.x][logical_core.y]; // TODO: Only allow setting once - TT_FATAL(set_rt_args.empty() or set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!"); if (set_rt_args.empty()) { + if (runtime_args.size() > max_runtime_args_per_core_) { + max_runtime_args_per_core_ = runtime_args.size(); + core_with_max_runtime_args_ = logical_core; + } + this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core); set_rt_args = runtime_args; this->core_to_runtime_args_data_[logical_core.x][logical_core.y] = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; this->core_with_runtime_args_.insert( logical_core ); } else { + TT_FATAL(set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!"); std::memcpy(this->core_to_runtime_args_data_[logical_core.x][logical_core.y].rt_args_data, runtime_args.data(), runtime_args.size() * sizeof(uint32_t)); } @@ -260,11 +260,26 @@ void Kernel::set_common_runtime_args(const std::vector &common_runtime // Should this check only be enabled in debug mode? // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); - this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_); auto &set_rt_args = this->common_runtime_args_; - TT_FATAL(set_rt_args.empty() or set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); + TT_FATAL(set_rt_args.empty(), "Illegal Common Runtime Args: Can only set common runtime args once. Get and modify args in place instead."); + this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_); set_rt_args = common_runtime_args; - this->common_runtime_args_data_ = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; + this->common_runtime_args_data_ = std::vector(this->core_range_set_.ranges().size(), RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}); +} + +void Kernel::set_common_runtime_args(const RuntimeArgsData& common_runtime_args) { + + // TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on. + // Should this check only be enabled in debug mode? + // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); + + auto &set_rt_args = this->common_runtime_args_; + TT_FATAL(!set_rt_args.empty() and set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); + for (auto& rt_args_data : this->common_runtime_args_data_) { + if (common_runtime_args.data() != rt_args_data.data()) { + memcpy(rt_args_data.data(), common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t)); + } + } } void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const { diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index 1b3bcc1b05b..16919ce5310 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -81,11 +81,11 @@ class Kernel : public JitBuildSettings { void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value); std::vector & runtime_args(const CoreCoord &logical_core); - RuntimeArgsData& runtime_args_data(const CoreCoord &logical_core); + RuntimeArgsData & runtime_args_data(const CoreCoord &logical_core); std::vector< std::vector< std::vector> > & runtime_args(); std::vector< std::vector< RuntimeArgsData > > & runtime_args_data(); std::vector & common_runtime_args(); - RuntimeArgsData& common_runtime_args_data(); + std::vector & common_runtime_args_data(); std::map defines() const { return defines_; } @@ -108,6 +108,7 @@ class Kernel : public JitBuildSettings { void validate_runtime_args_size(size_t num_unique_rt_args, size_t num_common_rt_args, const CoreCoord& logical_core); void set_runtime_args(const CoreCoord &logical_core, const std::vector &runtime_args); void set_common_runtime_args(const std::vector &runtime_args); + void set_common_runtime_args(const RuntimeArgsData& common_runtime_args); int get_watcher_kernel_id() { return watcher_kernel_id_; } @@ -132,7 +133,7 @@ class Kernel : public JitBuildSettings { std::vector< std::vector< std::vector> > core_to_runtime_args_; std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_; std::vector common_runtime_args_; - RuntimeArgsData common_runtime_args_data_; + std::vector common_runtime_args_data_; std::set core_with_runtime_args_; std::size_t max_runtime_args_per_core_; // For validation CoreCoord core_with_max_runtime_args_; // For validation diff --git a/tt_metal/impl/program/program_device_map.hpp b/tt_metal/impl/program/program_device_map.hpp index ae9fb53e504..288717f9aa9 100644 --- a/tt_metal/impl/program/program_device_map.hpp +++ b/tt_metal/impl/program/program_device_map.hpp @@ -42,8 +42,4 @@ struct ProgramTransferInfo { struct ProgramCommandIndices { std::uint32_t cb_configs_payload_start; // device_commands - // pair of cmd idx, rt arg offset - // Currently we only really need the base cmd idx since they are sequential, and the rt arg len is currently the same for all splits - std::unordered_map> processor_to_cmd_mapping; - std::unordered_map> kernel_to_cmd_mapping; }; diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index fb33ef6112b..900b103f8a9 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -949,8 +949,16 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const } } +void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) { + ZoneScoped; + TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); + if (runtime_args.size() != 0) { + detail::GetKernel(program, kernel_id)->set_common_runtime_args(runtime_args); + } +} + -RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { +RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); return detail::GetKernel(program, kernel_id)->runtime_args_data(logical_core); } @@ -960,9 +968,9 @@ std::vector< std::vector< RuntimeArgsData> >& GetRuntimeArgs(const Program &prog return detail::GetKernel(program, kernel_id)->runtime_args_data(); } -std::vector & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { +RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->common_runtime_args(); + return detail::GetKernel(program, kernel_id)->common_runtime_args_data().at(0); } void UpdateRuntimeArgs(Device* device, const std::shared_ptr kernel, const CoreCoord &core_coord, std::vector &update_idx, std::shared_ptr runtime_args) {