Skip to content

Commit

Permalink
#0: Update common RT args to use no stride flag for packed cmd. Remov…
Browse files Browse the repository at this point in the history
…e SetCommonRuntimeArgs requirement when hitting cache since we now have a direct 1 to 1 mapping into device cmd
  • Loading branch information
tt-aho committed May 23, 2024
1 parent 2632a5c commit 4926418
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@ Runtime Arguments

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args)

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args)

.. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id)
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile(
uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims;
if constexpr (!initialize_args) {
set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim);
SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args);
}

uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start);
Expand Down
14 changes: 0 additions & 14 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,6 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args);

/**
* Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel.
* Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit).
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------|
* | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes |
* | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes |
* | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes |
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args);

/**
* Get the runtime args for a kernel.
Expand Down Expand Up @@ -372,7 +359,6 @@ std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &pr

/**
* Get the common runtime args for a kernel.
* Note that you must call SetCommonRuntimeArgs after updating the returned value to propagate the update.
*
* Return value: RuntimeArgsData &
*
Expand Down
97 changes: 38 additions & 59 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,25 +307,28 @@ void generate_dispatch_write_packed(
const uint32_t& max_runtime_args_len,
std::vector<std::reference_wrapper<RuntimeArgsData>>& rt_args_data,
const uint32_t max_prefetch_command_size,
const uint32_t id) {
const uint32_t id,
bool no_stride=false) {
static_assert(
std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value or
std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value);

thread_local static auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) {
thread_local static auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast, bool no_stride) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT);
uint32_t aligned_runtime_data_sizeB =
num_packed_cmds * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
(no_stride ? 1 : num_packed_cmds) * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
return dispatch_cmd_sizeB + aligned_runtime_data_sizeB;
};
thread_local static auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast) {
thread_local static auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast, bool no_stride) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
// Approximate calculation due to alignment
max_size = max_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT;
uint32_t max_num_packed_cmds =
no_stride ?
(max_size - align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT)) / sub_cmd_sizeB :
max_size / (align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB);
return max_num_packed_cmds;
};
Expand All @@ -340,11 +343,14 @@ void generate_dispatch_write_packed(
constexpr bool unicast = std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value;

uint32_t num_packed_cmds_in_seq = sub_cmds.size();
uint32_t max_packed_cmds = get_max_num_packed_cmds(max_runtime_args_len, max_prefetch_command_size, unicast);
uint32_t max_packed_cmds = get_max_num_packed_cmds(max_runtime_args_len, max_prefetch_command_size, unicast, no_stride);
uint32_t offset_idx = 0;
if (no_stride) {
TT_FATAL(max_packed_cmds >= num_packed_cmds_in_seq);
}
while (num_packed_cmds_in_seq != 0) {
uint32_t num_packed_cmds = std::min(num_packed_cmds_in_seq, max_packed_cmds);
uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast);
uint32_t rt_payload_sizeB = get_runtime_payload_sizeB(num_packed_cmds, max_runtime_args_len, unicast, no_stride);
uint32_t cmd_sequence_sizeB = align(sizeof(CQPrefetchCmd) + rt_payload_sizeB, PCIE_ALIGNMENT);
runtime_args_command_sequences.emplace_back(cmd_sequence_sizeB);
runtime_args_command_sequences.back().add_dispatch_write_packed<PackedSubCmd>(
Expand All @@ -354,10 +360,12 @@ void generate_dispatch_write_packed(
rt_payload_sizeB,
sub_cmds,
rt_data_and_sizes,
offset_idx);
offset_idx,
no_stride);
uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast);
const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) {
uint32_t num_data_copies = no_stride ? 1 : num_packed_cmds;
for (uint32_t i = offset_idx; i < offset_idx + num_data_copies; ++i) {
rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset);
data_offset += data_inc;
}
Expand All @@ -383,44 +391,21 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
const uint32_t max_prefetch_command_size =
dispatch_constants::get(dispatch_core_type).max_prefetch_command_size();

auto get_runtime_payload_sizeB = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT);
uint32_t aligned_runtime_data_sizeB =
num_packed_cmds * align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
return dispatch_cmd_sizeB + aligned_runtime_data_sizeB;
};
auto get_max_num_packed_cmds = [](uint32_t runtime_args_len, uint32_t max_size, bool is_unicast) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
// Approximate calculation due to alignment
max_size = max_size - sizeof(CQPrefetchCmd) - PCIE_ALIGNMENT - sizeof(CQDispatchCmd) - L1_ALIGNMENT;
uint32_t max_num_packed_cmds =
max_size / (align(runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT) + sub_cmd_sizeB);
return max_num_packed_cmds;
};
auto get_runtime_args_data_offset = [](uint32_t num_packed_cmds, uint32_t runtime_args_len, bool is_unicast) {
uint32_t sub_cmd_sizeB =
is_unicast ? sizeof(CQDispatchWritePackedUnicastSubCmd) : sizeof(CQDispatchWritePackedMulticastSubCmd);
uint32_t dispatch_cmd_sizeB = sizeof(CQDispatchCmd) + align(num_packed_cmds * sub_cmd_sizeB, L1_ALIGNMENT);
return sizeof(CQPrefetchCmd) + dispatch_cmd_sizeB;
};

uint32_t num_dsts = unique_processor_to_l1_arg_base_addr.size();
std::vector<std::vector<CQDispatchWritePackedUnicastSubCmd>> unique_sub_cmds(num_dsts);
std::vector<std::vector<std::pair<const void*, uint32_t>>> unique_rt_data_and_sizes(num_dsts);
std::vector<std::vector<std::reference_wrapper<RuntimeArgsData>>> unique_rt_args_data(num_dsts);
std::vector<uint32_t> unique_max_runtime_args_len(num_dsts, 0);
uint32_t num_processors = unique_processor_to_l1_arg_base_addr.size();
std::vector<std::vector<CQDispatchWritePackedUnicastSubCmd>> unique_sub_cmds(num_processors);
std::vector<std::vector<std::pair<const void*, uint32_t>>> unique_rt_data_and_sizes(num_processors);
std::vector<std::vector<std::reference_wrapper<RuntimeArgsData>>> unique_rt_args_data(num_processors);
std::vector<uint32_t> unique_max_runtime_args_len(num_processors, 0);

uint32_t num_kernels = program.num_kernels();
std::vector<std::variant<
std::vector<CQDispatchWritePackedMulticastSubCmd>,
std::vector<CQDispatchWritePackedUnicastSubCmd>>>
common_sub_cmds(program.num_kernels());
std::vector<std::vector<std::pair<const void*, uint32_t>>> common_rt_data_and_sizes(program.num_kernels());
std::vector<std::vector<std::reference_wrapper<RuntimeArgsData>>> common_rt_args_data(program.num_kernels());
std::vector<uint32_t> common_max_runtime_args_len(program.num_kernels(), 0);
std::vector<uint32_t> common_processor_to_l1_arg_base_addr(program.num_kernels());
common_sub_cmds(num_kernels);
std::vector<std::vector<std::pair<const void*, uint32_t>>> common_rt_data_and_sizes(num_kernels);
std::vector<std::vector<std::reference_wrapper<RuntimeArgsData>>> common_rt_args_data(num_kernels);
std::vector<uint32_t> common_max_runtime_args_len(num_kernels, 0);
std::vector<uint32_t> common_processor_to_l1_arg_base_addr(num_kernels);

std::set<uint32_t> unique_processors;
std::set<uint32_t> common_kernels;
Expand Down Expand Up @@ -464,24 +449,22 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
uint32_t common_args_addr =
unique_processor_to_l1_arg_base_addr[processor_idx] + kernel->get_common_runtime_args_offset();
common_processor_to_l1_arg_base_addr[kernel_id] = common_args_addr;
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data());
common_max_runtime_args_len[kernel_id] = (uint32_t)common_rt_args.size();
if (kernel->get_kernel_core_type() == CoreType::ETH) {
common_sub_cmds[kernel_id].emplace<std::vector<CQDispatchWritePackedUnicastSubCmd>>(
std::vector<CQDispatchWritePackedUnicastSubCmd>());
auto& unicast_sub_cmd =
std::get<std::vector<CQDispatchWritePackedUnicastSubCmd>>(common_sub_cmds[kernel_id]);
unicast_sub_cmd.reserve(kernel->logical_cores().size());
common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size());
common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size());
uint32_t i = 0;
for (auto& core_coord : kernel->logical_cores()) {
// can make a vector of unicast encodings here
CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord);
uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core);
unicast_sub_cmd.emplace_back(
CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
} else {
vector<pair<uint32_t, uint32_t>> dst_noc_multicast_info =
Expand All @@ -492,24 +475,18 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
auto& multicast_sub_cmd =
std::get<std::vector<CQDispatchWritePackedMulticastSubCmd>>(common_sub_cmds[kernel_id]);
multicast_sub_cmd.reserve(dst_noc_multicast_info.size());
common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size());
uint32_t i = 0;
for (const auto& mcast_dests : dst_noc_multicast_info) {
multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{
.noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
}
common_max_runtime_args_len[kernel_id] =
std::max(common_max_runtime_args_len[kernel_id], (uint32_t)common_rt_args.size());
}
}
// Reserve 2x as we pontentially split the cmds due to not fitting in one prefetch cmd
// Reserve 2x for unique rtas as we pontentially split the cmds due to not fitting in one prefetch cmd
// Common rtas are always expected to fit in one prefetch cmd
this->cached_program_command_sequences[program.id].runtime_args_command_sequences = {};
this->cached_program_command_sequences[program.id].runtime_args_command_sequences.reserve(
2 * (unique_processors.size() + common_kernels.size()));
2 * unique_processors.size() + common_kernels.size());
std::vector<std::pair<uint32_t, uint32_t>> runtime_args_data_index;
runtime_args_data_index.reserve(2 * (unique_processors.size() + common_kernels.size()));
// Array of cmd idx, # sub cmds, rt arg offset, rt arg len
Expand All @@ -524,7 +501,8 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
unique_max_runtime_args_len[processor_idx],
unique_rt_args_data[processor_idx],
max_prefetch_command_size,
processor_idx);
processor_idx,
false);
}
for (const uint32_t& kernel_id : common_kernels) {
std::visit(
Expand All @@ -537,7 +515,8 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
common_max_runtime_args_len[kernel_id],
common_rt_args_data[kernel_id],
max_prefetch_command_size,
kernel_id);
kernel_id,
true);
},
common_sub_cmds[kernel_id]);
}
Expand Down
8 changes: 5 additions & 3 deletions tt_metal/impl/dispatch/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ class DeviceCommand {
uint32_t payload_sizeB,
const std::vector<PackedSubCmd> &sub_cmds,
const std::vector<std::pair<const void *, uint32_t>> &data_collection,
const uint32_t offset_idx = 0) {
const uint32_t offset_idx = 0,
const bool no_stride = false) {
static_assert(std::is_same<PackedSubCmd, CQDispatchWritePackedUnicastSubCmd>::value or std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value);
bool multicast = std::is_same<PackedSubCmd, CQDispatchWritePackedMulticastSubCmd>::value;

Expand All @@ -325,7 +326,7 @@ class DeviceCommand {
auto initialize_write_packed_cmd = [&](CQDispatchCmd *write_packed_cmd) {
write_packed_cmd->base.cmd_id = CQ_DISPATCH_CMD_WRITE_PACKED;
write_packed_cmd->write_packed.flags =
multicast ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE;
(multicast ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE) | (no_stride ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE);
write_packed_cmd->write_packed.count = num_sub_cmds;
write_packed_cmd->write_packed.addr = common_addr;
write_packed_cmd->write_packed.size = packed_data_sizeB;
Expand All @@ -350,7 +351,8 @@ class DeviceCommand {

// copy the actual data
increment_sizeB = align(packed_data_sizeB, L1_ALIGNMENT);
for (uint32_t i = offset_idx; i < offset_idx + num_sub_cmds; ++i) {
uint32_t num_data_copies = no_stride ? 1 : num_sub_cmds;
for (uint32_t i = offset_idx; i < offset_idx + num_data_copies; ++i) {
this->memcpy((char*)this->cmd_region + this->cmd_write_offsetB, data_collection[i].first, data_collection[i].second);
this->cmd_write_offsetB += increment_sizeB;
}
Expand Down
Loading

0 comments on commit 4926418

Please sign in to comment.