Skip to content

Commit

Permalink
#7530: First pass for common rt args. See if we can make it more user…
Browse files Browse the repository at this point in the history
… friendly
  • Loading branch information
tt-aho committed May 15, 2024
1 parent 8d99471 commit c541511
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile(
uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims;
if constexpr (!initialize_args) {
set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim);
SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args);
}

uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start);
Expand Down
14 changes: 14 additions & 0 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,20 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args);

/**
* Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel.
* Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit).
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------|
* | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes |
* | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes |
* | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes |
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args);

/**
* Get the runtime args for a kernel.
*
Expand Down
42 changes: 5 additions & 37 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ std::condition_variable finish_cv;

namespace tt::tt_metal {

// TODO: Delete entries when programs are deleted to save memory
thread_local std::unordered_map<uint64_t, std::vector<HostMemDeviceCommand>> EnqueueProgramCommand::runtime_args_command_sequences = {};

uint32_t get_noc_unicast_encoding(const CoreCoord &coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); }
Expand Down Expand Up @@ -307,7 +308,6 @@ void EnqueueProgramCommand::assemble_preamble_commands() {
template <typename PackedSubCmd>
void generate_dispatch_write_packed(
std::vector<HostMemDeviceCommand>& runtime_args_command_sequences,
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>>& cmd_mapping,
const uint32_t& l1_arg_base_addr,
const std::vector<PackedSubCmd>& sub_cmds,
const std::vector<std::pair<const void*, uint32_t>>& rt_data_and_sizes,
Expand Down Expand Up @@ -365,8 +365,6 @@ void generate_dispatch_write_packed(
uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast);
const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) {
cmd_mapping[(uint64_t)id << 32 | sub_cmds[i].noc_xy_addr] = {
runtime_args_command_sequences.size() - 1, data_offset};
rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset);
data_offset += data_inc;
}
Expand Down Expand Up @@ -482,6 +480,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
unicast_sub_cmd.reserve(kernel->logical_cores().size());
common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size());
common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size());
uint32_t i = 0;
for (auto& core_coord : kernel->logical_cores()) {
// can make a vector of unicast encodings here
CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord);
Expand All @@ -490,7 +489,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data());
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
} else {
vector<pair<uint32_t, uint32_t>> dst_noc_multicast_info =
Expand All @@ -502,12 +501,13 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
std::get<std::vector<CQDispatchWritePackedMulticastSubCmd>>(common_sub_cmds[kernel_id]);
multicast_sub_cmd.reserve(dst_noc_multicast_info.size());
common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size());
uint32_t i = 0;
for (const auto& mcast_dests : dst_noc_multicast_info) {
multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{
.noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data());
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
}
common_max_runtime_args_len[kernel_id] =
Expand All @@ -526,7 +526,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
for (const uint32_t& processor_idx : unique_processors) {
generate_dispatch_write_packed(
this->runtime_args_command_sequences[program.id],
program.command_indices.processor_to_cmd_mapping,
unique_processor_to_l1_arg_base_addr[processor_idx],
unique_sub_cmds[processor_idx],
unique_rt_data_and_sizes[processor_idx],
Expand All @@ -540,7 +539,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
[&](auto&& sub_cmds) {
generate_dispatch_write_packed(
this->runtime_args_command_sequences[program.id],
program.command_indices.kernel_to_cmd_mapping,
common_processor_to_l1_arg_base_addr[kernel_id],
sub_cmds,
common_rt_data_and_sizes[kernel_id],
Expand All @@ -551,36 +549,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
},
common_sub_cmds[kernel_id]);
}
} else {
for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) {
auto kernel = detail::GetKernel(program, kernel_id);
const auto& common_rt_args = kernel->common_runtime_args();
if (common_rt_args.size() > 0) {
if (kernel->get_kernel_core_type() == CoreType::ETH) {
for (const auto& core_coord : kernel->logical_cores()) {
uint32_t noc_xy_encoding =
get_noc_unicast_encoding(this->device->ethernet_core_from_logical_core(core_coord));
const auto& data_loc =
program.command_indices.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_xy_encoding];
const auto& runtime_args_data = kernel->runtime_args(core_coord);
this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence(
data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t));
}
} else {
for (const auto& core_range : kernel->logical_coreranges()) {
CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start);
CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end);

uint32_t noc_multicast_encoding = get_noc_multcast_encoding(physical_start, physical_end);
const auto& data_loc =
program.command_indices
.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_multicast_encoding];
this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence(
data_loc.second, common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
}
}
}
}
}
}

Expand Down
27 changes: 24 additions & 3 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ std::vector<uint32_t>& Kernel::common_runtime_args() {
return this->common_runtime_args_;
}

RuntimeArgsData & Kernel::common_runtime_args_data() {
std::vector<RuntimeArgsData> & Kernel::common_runtime_args_data() {
return this->common_runtime_args_data_;
}

Expand Down Expand Up @@ -263,8 +263,29 @@ void Kernel::set_common_runtime_args(const std::vector<uint32_t> &common_runtime
this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_);
auto &set_rt_args = this->common_runtime_args_;
TT_FATAL(set_rt_args.empty() or set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!");
set_rt_args = common_runtime_args;
this->common_runtime_args_data_ = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()};
if (set_rt_args.empty()) {
set_rt_args = common_runtime_args;
this->common_runtime_args_data_ = std::vector<RuntimeArgsData>(this->core_range_set_.ranges().size(), RuntimeArgsData{set_rt_args.data(), set_rt_args.size()});
} else {
for (auto& rt_args_data : this->common_runtime_args_data_) {
memcpy(rt_args_data.data(), common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t));
}
}
}

void Kernel::set_common_runtime_args(const RuntimeArgsData& common_runtime_args) {

// TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on.
// Should this check only be enabled in debug mode?
// TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name());

auto &set_rt_args = this->common_runtime_args_;
TT_FATAL(set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!");
for (auto& rt_args_data : this->common_runtime_args_data_) {
if (common_runtime_args.data() != rt_args_data.data()) {
memcpy(rt_args_data.data(), common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t));
}
}
}

void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const {
Expand Down
7 changes: 4 additions & 3 deletions tt_metal/impl/kernels/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class Kernel : public JitBuildSettings {
void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value);

std::vector<uint32_t> & runtime_args(const CoreCoord &logical_core);
RuntimeArgsData& runtime_args_data(const CoreCoord &logical_core);
RuntimeArgsData & runtime_args_data(const CoreCoord &logical_core);
std::vector< std::vector< std::vector<uint32_t>> > & runtime_args();
std::vector< std::vector< RuntimeArgsData > > & runtime_args_data();
std::vector<uint32_t> & common_runtime_args();
RuntimeArgsData& common_runtime_args_data();
std::vector<RuntimeArgsData> & common_runtime_args_data();

std::map<std::string, std::string> defines() const { return defines_; }

Expand All @@ -108,6 +108,7 @@ class Kernel : public JitBuildSettings {
void validate_runtime_args_size(size_t num_unique_rt_args, size_t num_common_rt_args, const CoreCoord& logical_core);
void set_runtime_args(const CoreCoord &logical_core, const std::vector<uint32_t> &runtime_args);
void set_common_runtime_args(const std::vector<uint32_t> &runtime_args);
void set_common_runtime_args(const RuntimeArgsData& common_runtime_args);

int get_watcher_kernel_id() { return watcher_kernel_id_; }

Expand All @@ -132,7 +133,7 @@ class Kernel : public JitBuildSettings {
std::vector< std::vector< std::vector<uint32_t>> > core_to_runtime_args_;
std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_;
std::vector<uint32_t> common_runtime_args_;
RuntimeArgsData common_runtime_args_data_;
std::vector<RuntimeArgsData> common_runtime_args_data_;
std::set<CoreCoord> core_with_runtime_args_;
std::size_t max_runtime_args_per_core_; // For validation
CoreCoord core_with_max_runtime_args_; // For validation
Expand Down
4 changes: 0 additions & 4 deletions tt_metal/impl/program/program_device_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,4 @@ struct ProgramTransferInfo {

struct ProgramCommandIndices {
std::uint32_t cb_configs_payload_start; // device_commands
// pair of cmd idx, rt arg offset
// Currently we only really need the base cmd idx since they are sequential, and the rt arg len is currently the same for all splits
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> processor_to_cmd_mapping;
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> kernel_to_cmd_mapping;
};
8 changes: 8 additions & 0 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,14 @@ void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const
}
}

void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args) {
ZoneScoped;
TT_FATAL( not CommandQueue::async_mode_set(), "This variant of SetCommonRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch.");
if (runtime_args.size() != 0) {
detail::GetKernel(program, kernel_id)->set_common_runtime_args(runtime_args);
}
}


RuntimeArgsData & GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core) {
TT_FATAL( not CommandQueue::async_mode_set(), "GetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast Dispatch.");
Expand Down

0 comments on commit c541511

Please sign in to comment.