diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst index 661262e22dc..1594750e10e 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/runtime_args/runtime_args.rst @@ -15,6 +15,4 @@ Runtime Arguments .. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args) -.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) - .. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) diff --git a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp index ce3d43d58c4..1387791d805 100644 --- a/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/multi_core/unpad_op_multi_core.cpp @@ -319,7 +319,6 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile( uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims; if constexpr (!initialize_args) { set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim); - SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args); } uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start); diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index 5f91cbbd885..fd5d63d95e8 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -331,19 +331,6 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr kernel, const */ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args); -/** - * Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel. - * Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit). - * - * Return value: void - * - * | Argument | Description | Type | Valid Range | Required | - * |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------| - * | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes | - * | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes | - * | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes | - */ -void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args); /** * Get the runtime args for a kernel. @@ -372,7 +359,6 @@ std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &pr /** * Get the common runtime args for a kernel. - * Note that you must call SetCommonRuntimeArgs after updating the returned value to propagate the update. * * Return value: RuntimeArgsData & * diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 590d6a17daa..344078791cc 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -307,25 +307,28 @@ void generate_dispatch_write_packed( const uint32_t& max_runtime_args_len, std::vector>& rt_args_data, const uint32_t max_prefetch_command_size, - const uint32_t id) { + const uint32_t id, + bool no_stride=false) { static_assert( std::is_same::value or std::is_same::value); - thread_local static auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) { + thread_local static auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast, bool no_stride) { uint32_t sub_cmd_sizeB = is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd); uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT); uint32_t aligned_runtime_data_sizeB = - num_packed_cmds * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); + (no_stride ? 1 : num_packed_cmds) * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); return dispatch_cmd_sizeB + aligned_runtime_data_sizeB; }; - thread_local static auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast) { + thread_local static auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast, bool no_stride) { uint32_t sub_cmd_sizeB = is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd); // Approximate calculation due to alignment max_size = max_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT; uint32_t max_num_packed_cmds = + no_stride ? + (max_size - align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT)) / sub_cmd_sizeB : max_size / (align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB); return max_num_packed_cmds; }; @@ -340,11 +343,14 @@ void generate_dispatch_write_packed( constexpr bool unicast = std::is_same::value; uint32_t num_packed_cmds_in_seq = sub_cmds.size(); - uint32_t max_packed_cmds = get_max_num_packed_cmds(max_runtime_args_len, max_prefetch_command_size, unicast); + uint32_t max_packed_cmds = get_max_num_packed_cmds(max_runtime_args_len, max_prefetch_command_size, unicast, no_stride); uint32_t offset_idx = 0; + if (no_stride) { + TT_FATAL(max_packed_cmds >= num_packed_cmds_in_seq); + } while (num_packed_cmds_in_seq != 0) { uint32_t num_packed_cmds = std::min(num_packed_cmds_in_seq, max_packed_cmds); - uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast); + uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast, no_stride); uint32_t cmd_sequence_sizeB = align(sizeof(CQPrefetchCmd) + rt_payload_sizeB, PCIE_ALIGNMENT); runtime_args_command_sequences.emplace_back(cmd_sequence_sizeB); runtime_args_command_sequences.back().add_dispatch_write_packed( @@ -354,10 +360,12 @@ void generate_dispatch_write_packed( rt_payload_sizeB, sub_cmds, rt_data_and_sizes, - offset_idx); + offset_idx, + no_stride); uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast); const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); - for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) { + uint32_t num_data_copies = no_stride ? 1 : num_packed_cmds; + for (uint32_t i = offset_idx; i < offset_idx + num_data_copies; ++i) { rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset); data_offset += data_inc; } @@ -383,44 +391,21 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { const uint32_t max_prefetch_command_size = dispatch_constants::get(dispatch_core_type).max_prefetch_command_size(); - auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) { - uint32_t sub_cmd_sizeB = - is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd); - uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT); - uint32_t aligned_runtime_data_sizeB = - num_packed_cmds * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT); - return dispatch_cmd_sizeB + aligned_runtime_data_sizeB; - }; - auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast) { - uint32_t sub_cmd_sizeB = - is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd); - // Approximate calculation due to alignment - max_size = max_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT; - uint32_t max_num_packed_cmds = - max_size / (align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB); - return max_num_packed_cmds; - }; - auto get_runtime_args_data_offset = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) { - uint32_t sub_cmd_sizeB = - is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd); - uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT); - return sizeof(CQPrefetchCmd) + dispatch_cmd_sizeB; - }; - - uint32_t num_dsts = unique_processor_to_l1_arg_base_addr.size(); - std::vector> unique_sub_cmds(num_dsts); - std::vector>> unique_rt_data_and_sizes(num_dsts); - std::vector>> unique_rt_args_data(num_dsts); - std::vector unique_max_runtime_args_len(num_dsts, 0); + uint32_t num_processors = unique_processor_to_l1_arg_base_addr.size(); + std::vector> unique_sub_cmds(num_processors); + std::vector>> unique_rt_data_and_sizes(num_processors); + std::vector>> unique_rt_args_data(num_processors); + std::vector unique_max_runtime_args_len(num_processors, 0); + uint32_t num_kernels = program.num_kernels(); std::vector, std::vector>> - common_sub_cmds(program.num_kernels()); - std::vector>> common_rt_data_and_sizes(program.num_kernels()); - std::vector>> common_rt_args_data(program.num_kernels()); - std::vector common_max_runtime_args_len(program.num_kernels(), 0); - std::vector common_processor_to_l1_arg_base_addr(program.num_kernels()); + common_sub_cmds(num_kernels); + std::vector>> common_rt_data_and_sizes(num_kernels); + std::vector>> common_rt_args_data(num_kernels); + std::vector common_max_runtime_args_len(num_kernels, 0); + std::vector common_processor_to_l1_arg_base_addr(num_kernels); std::set unique_processors; std::set common_kernels; @@ -464,24 +449,22 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { uint32_t common_args_addr = unique_processor_to_l1_arg_base_addr[processor_idx] + kernel->get_common_runtime_args_offset(); common_processor_to_l1_arg_base_addr[kernel_id] = common_args_addr; + common_rt_data_and_sizes[kernel_id].emplace_back( + common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); + common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()); + common_max_runtime_args_len[kernel_id] = (uint32_t)common_rt_args.size(); if (kernel->get_kernel_core_type() == CoreType::ETH) { common_sub_cmds[kernel_id].emplace>( std::vector()); auto& unicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); unicast_sub_cmd.reserve(kernel->logical_cores().size()); - common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size()); - common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size()); - uint32_t i = 0; for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); unicast_sub_cmd.emplace_back( CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); - common_rt_data_and_sizes[kernel_id].emplace_back( - common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } else { vector> dst_noc_multicast_info = @@ -492,24 +475,18 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { auto& multicast_sub_cmd = std::get>(common_sub_cmds[kernel_id]); multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); - common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size()); - uint32_t i = 0; for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); - common_rt_data_and_sizes[kernel_id].emplace_back( - common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t)); - common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]); } } - common_max_runtime_args_len[kernel_id] = - std::max(common_max_runtime_args_len[kernel_id], (uint32_t)common_rt_args.size()); } } - // Reserve 2x as we pontentially split the cmds due to not fitting in one prefetch cmd + // Reserve 2x for unique rtas as we pontentially split the cmds due to not fitting in one prefetch cmd + // Common rtas are always expected to fit in one prefetch cmd this->cached_program_command_sequences[program.id].runtime_args_command_sequences = {}; this->cached_program_command_sequences[program.id].runtime_args_command_sequences.reserve( - 2 * (unique_processors.size() + common_kernels.size())); + 2 * unique_processors.size() + common_kernels.size()); std::vector> runtime_args_data_index; runtime_args_data_index.reserve(2 * (unique_processors.size() + common_kernels.size())); // Array of cmd idx, # sub cmds, rt arg offset, rt arg len @@ -524,7 +501,8 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { unique_max_runtime_args_len[processor_idx], unique_rt_args_data[processor_idx], max_prefetch_command_size, - processor_idx); + processor_idx, + false); } for (const uint32_t& kernel_id : common_kernels) { std::visit( @@ -537,7 +515,8 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { common_max_runtime_args_len[kernel_id], common_rt_args_data[kernel_id], max_prefetch_command_size, - kernel_id); + kernel_id, + true); }, common_sub_cmds[kernel_id]); } diff --git a/tt_metal/impl/dispatch/device_command.hpp b/tt_metal/impl/dispatch/device_command.hpp index 08916379c41..d527e439511 100644 --- a/tt_metal/impl/dispatch/device_command.hpp +++ b/tt_metal/impl/dispatch/device_command.hpp @@ -312,7 +312,8 @@ class DeviceCommand { uint32_t payload_sizeB, const std::vector &sub_cmds, const std::vector> &data_collection, - const uint32_t offset_idx = 0) { + const uint32_t offset_idx = 0, + const bool no_stride = false) { static_assert(std::is_same::value or std::is_same::value); bool multicast = std::is_same::value; @@ -325,7 +326,7 @@ class DeviceCommand { auto initialize_write_packed_cmd = [&](CQDispatchCmd *write_packed_cmd) { write_packed_cmd->base.cmd_id = CQ_DISPATCH_CMD_WRITE_PACKED; write_packed_cmd->write_packed.flags = - multicast ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE; + (multicast ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE) | (no_stride ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE); write_packed_cmd->write_packed.count = num_sub_cmds; write_packed_cmd->write_packed.addr = common_addr; write_packed_cmd->write_packed.size = packed_data_sizeB; @@ -350,7 +351,8 @@ class DeviceCommand { // copy the actual data increment_sizeB = align(packed_data_sizeB, L1_ALIGNMENT); - for (uint32_t i = offset_idx; i < offset_idx + num_sub_cmds; ++i) { + uint32_t num_data_copies = no_stride ? 1 : num_sub_cmds; + for (uint32_t i = offset_idx; i < offset_idx + num_data_copies; ++i) { this->memcpy((char*)this->cmd_region + this->cmd_write_offsetB, data_collection[i].first, data_collection[i].second); this->cmd_write_offsetB += increment_sizeB; } diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index 7ebf1d3bed5..c945e1a0387 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -169,7 +169,7 @@ std::vector& Kernel::common_runtime_args() { return this->common_runtime_args_; } -std::vector & Kernel::common_runtime_args_data() { +RuntimeArgsData & Kernel::common_runtime_args_data() { return this->common_runtime_args_data_; } @@ -248,31 +248,11 @@ void Kernel::set_runtime_args(const CoreCoord &logical_core, const std::vector &common_runtime_args) { - - // TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on. - // Should this check only be enabled in debug mode? - // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); - auto &set_rt_args = this->common_runtime_args_; TT_FATAL(set_rt_args.empty(), "Illegal Common Runtime Args: Can only set common runtime args once. Get and modify args in place instead."); this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_); set_rt_args = common_runtime_args; - this->common_runtime_args_data_ = std::vector(this->core_range_set_.ranges().size(), RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}); -} - -void Kernel::set_common_runtime_args(const RuntimeArgsData& common_runtime_args) { - - // TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on. - // Should this check only be enabled in debug mode? - // TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name()); - - auto &set_rt_args = this->common_runtime_args_; - TT_FATAL(!set_rt_args.empty() and set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!"); - for (auto& rt_args_data : this->common_runtime_args_data_) { - if (common_runtime_args.data() != rt_args_data.data()) { - std::memcpy((void *)rt_args_data.data(),(void *) common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t)); - } - } + this->common_runtime_args_data_ = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()}; } void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const { diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index ea5cb9d4510..a9c577aa271 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -83,7 +83,7 @@ class Kernel : public JitBuildSettings { std::vector< std::vector< std::vector> > & runtime_args(); std::vector< std::vector< RuntimeArgsData > > & runtime_args_data(); std::vector & common_runtime_args(); - std::vector & common_runtime_args_data(); + RuntimeArgsData & common_runtime_args_data(); std::map defines() const { return defines_; } @@ -106,7 +106,6 @@ class Kernel : public JitBuildSettings { void validate_runtime_args_size(size_t num_unique_rt_args, size_t num_common_rt_args, const CoreCoord& logical_core); void set_runtime_args(const CoreCoord &logical_core, const std::vector &runtime_args); void set_common_runtime_args(const std::vector &runtime_args); - void set_common_runtime_args(const RuntimeArgsData& common_runtime_args); int get_watcher_kernel_id() { return watcher_kernel_id_; } @@ -131,7 +130,7 @@ class Kernel : public JitBuildSettings { std::vector< std::vector< std::vector> > core_to_runtime_args_; std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_; std::vector common_runtime_args_; - std::vector common_runtime_args_data_; + RuntimeArgsData common_runtime_args_data_; std::set core_with_runtime_args_; std::size_t max_runtime_args_per_core_; // For validation CoreCoord core_with_max_runtime_args_; // For validation diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index ce31d1557f0..115b29412f6 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -927,15 +927,6 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const } } -void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) { - ZoneScoped; - TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - if (runtime_args.size() != 0) { - detail::GetKernel(program, kernel_id)->set_common_runtime_args(runtime_args); - } -} - - RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); return detail::GetKernel(program, kernel_id)->runtime_args_data(logical_core); @@ -948,7 +939,7 @@ std::vector< std::vector< RuntimeArgsData> >& GetRuntimeArgs(const Program &prog RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id) { TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch."); - return detail::GetKernel(program, kernel_id)->common_runtime_args_data().at(0); + return detail::GetKernel(program, kernel_id)->common_runtime_args_data(); } uint32_t BeginTraceCapture(Device *device, const uint8_t cq_id, const uint32_t trace_buff_size) {