Skip to content

Commit

Permalink
#7530: Add support for directly updating common runtime args into the…
Browse files Browse the repository at this point in the history
… device command

Currently due to common args having multiple locations in device cmd due to one copy being added per mcast/core group,
users must call Get, then Set for common args to update all locations.
TODO is to have device cmd only retain one copy of the data, so that it is one to one and we can remove the Set requirement.
  • Loading branch information
tt-aho committed May 15, 2024
1 parent e11eded commit 44f1791
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ Runtime Arguments

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args)

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args)

.. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id)
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile(
uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims;
if constexpr (!initialize_args) {
set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim);
SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args);
}

uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ inline __attribute__((always_inline)) void set_unpad_runtime_args_tile(
uint32_t* num_padded_tiles_per_dim = num_unpadded_tiles_per_dim + num_dims;
if constexpr (!initialize_args) {
set_common_reader_args(reader_common_args.data(), num_unpadded_tiles_per_dim, num_padded_tiles_per_dim);
SetCommonRuntimeArgs(program, unary_reader_kernel_id, reader_common_args);
}

uint32_t start_offset = get_tiled_start_offset(input_tensor, output_tensor_start);
Expand Down
19 changes: 17 additions & 2 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,20 @@ void SetRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args);

/**
* Set common (shared by all cores) runtime args for a kernel that are sent to all cores during runtime. This API needs to be called to update the common runtime args for the kernel.
* Maximum of 255 allowed runtime args per core (unique and common runtime args count toward same limit).
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|--------------------------------------------------------|---------------------------------------------------------------------|----------|
* | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes |
* | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes |
* | runtime_args | The runtime args to be written | const RuntimeArgsData & | | Yes |
*/
void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args);

/**
* Get the runtime args for a kernel.
*
Expand Down Expand Up @@ -358,15 +372,16 @@ std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &pr

/**
* Get the common runtime args for a kernel.
* Note that you must call SetCommonRuntimeArgs after updating the returned value to propagate the update.
*
* Return value: std::vector<uint32_t> &
* Return value: RuntimeArgsData &
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------|
* | program | The program containing kernels, circular buffers, semaphores | const Program & | | Yes |
* | kernel_id | ID of the kernel that will receive the runtime args | KernelHandle (uint64_t) | | Yes |
*/
std::vector<uint32_t> & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id);
RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id);

/**
* Update specific entries of the runtime args vector for a kernel using the command queue. This API must be used when Asynchronous Command Queue Mode is enabled.
Expand Down
42 changes: 5 additions & 37 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ std::condition_variable finish_cv;

namespace tt::tt_metal {

// TODO: Delete entries when programs are deleted to save memory
thread_local std::unordered_map<uint64_t, std::vector<HostMemDeviceCommand>> EnqueueProgramCommand::runtime_args_command_sequences = {};

uint32_t get_noc_unicast_encoding(const CoreCoord &coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); }
Expand Down Expand Up @@ -307,7 +308,6 @@ void EnqueueProgramCommand::assemble_preamble_commands() {
template <typename PackedSubCmd>
void generate_dispatch_write_packed(
std::vector<HostMemDeviceCommand>& runtime_args_command_sequences,
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>>& cmd_mapping,
const uint32_t& l1_arg_base_addr,
const std::vector<PackedSubCmd>& sub_cmds,
const std::vector<std::pair<const void*, uint32_t>>& rt_data_and_sizes,
Expand Down Expand Up @@ -365,8 +365,6 @@ void generate_dispatch_write_packed(
uint32_t data_offset = (uint32_t)get_runtime_args_data_offset(num_packed_cmds, max_runtime_args_len, unicast);
const uint32_t data_inc = align(max_runtime_args_len * sizeof(uint32_t), L1_ALIGNMENT);
for (uint32_t i = offset_idx; i < offset_idx + num_packed_cmds; ++i) {
cmd_mapping[(uint64_t)id << 32 | sub_cmds[i].noc_xy_addr] = {
runtime_args_command_sequences.size() - 1, data_offset};
rt_args_data[i].get().rt_args_data = (uint32_t *)((char *)runtime_args_command_sequences.back().data() + data_offset);
data_offset += data_inc;
}
Expand Down Expand Up @@ -482,6 +480,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
unicast_sub_cmd.reserve(kernel->logical_cores().size());
common_rt_data_and_sizes[kernel_id].reserve(kernel->logical_cores().size());
common_rt_args_data[kernel_id].reserve(kernel->logical_cores().size());
uint32_t i = 0;
for (auto& core_coord : kernel->logical_cores()) {
// can make a vector of unicast encodings here
CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord);
Expand All @@ -490,7 +489,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data());
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
} else {
vector<pair<uint32_t, uint32_t>> dst_noc_multicast_info =
Expand All @@ -502,12 +501,13 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
std::get<std::vector<CQDispatchWritePackedMulticastSubCmd>>(common_sub_cmds[kernel_id]);
multicast_sub_cmd.reserve(dst_noc_multicast_info.size());
common_rt_data_and_sizes[kernel_id].reserve(dst_noc_multicast_info.size());
uint32_t i = 0;
for (const auto& mcast_dests : dst_noc_multicast_info) {
multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{
.noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second});
common_rt_data_and_sizes[kernel_id].emplace_back(
common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data());
common_rt_args_data[kernel_id].emplace_back(kernel->common_runtime_args_data()[i++]);
}
}
common_max_runtime_args_len[kernel_id] =
Expand All @@ -526,7 +526,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
for (const uint32_t& processor_idx : unique_processors) {
generate_dispatch_write_packed(
this->runtime_args_command_sequences[program.id],
program.command_indices.processor_to_cmd_mapping,
unique_processor_to_l1_arg_base_addr[processor_idx],
unique_sub_cmds[processor_idx],
unique_rt_data_and_sizes[processor_idx],
Expand All @@ -540,7 +539,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
[&](auto&& sub_cmds) {
generate_dispatch_write_packed(
this->runtime_args_command_sequences[program.id],
program.command_indices.kernel_to_cmd_mapping,
common_processor_to_l1_arg_base_addr[kernel_id],
sub_cmds,
common_rt_data_and_sizes[kernel_id],
Expand All @@ -551,36 +549,6 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() {
},
common_sub_cmds[kernel_id]);
}
} else {
for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) {
auto kernel = detail::GetKernel(program, kernel_id);
const auto& common_rt_args = kernel->common_runtime_args();
if (common_rt_args.size() > 0) {
if (kernel->get_kernel_core_type() == CoreType::ETH) {
for (const auto& core_coord : kernel->logical_cores()) {
uint32_t noc_xy_encoding =
get_noc_unicast_encoding(this->device->ethernet_core_from_logical_core(core_coord));
const auto& data_loc =
program.command_indices.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_xy_encoding];
const auto& runtime_args_data = kernel->runtime_args(core_coord);
this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence(
data_loc.second, runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t));
}
} else {
for (const auto& core_range : kernel->logical_coreranges()) {
CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start);
CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end);

uint32_t noc_multicast_encoding = get_noc_multcast_encoding(physical_start, physical_end);
const auto& data_loc =
program.command_indices
.kernel_to_cmd_mapping[(uint64_t)kernel_id << 32 | noc_multicast_encoding];
this->runtime_args_command_sequences[program.id][data_loc.first].update_cmd_sequence(
data_loc.second, common_rt_args.data(), common_rt_args.size() * sizeof(uint32_t));
}
}
}
}
}
}

Expand Down
35 changes: 25 additions & 10 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ std::vector<uint32_t>& Kernel::common_runtime_args() {
return this->common_runtime_args_;
}

RuntimeArgsData & Kernel::common_runtime_args_data() {
std::vector<RuntimeArgsData> & Kernel::common_runtime_args_data() {
return this->common_runtime_args_data_;
}

Expand Down Expand Up @@ -235,20 +235,20 @@ void Kernel::set_runtime_args(const CoreCoord &logical_core, const std::vector<u
// TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name());

// Keep state for validation, to be able to check from both set_runtime_args() and set_common_runtime_args() APIs.
if (runtime_args.size() > max_runtime_args_per_core_) {
max_runtime_args_per_core_ = runtime_args.size();
core_with_max_runtime_args_ = logical_core;
}

this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core);
auto &set_rt_args = this->core_to_runtime_args_[logical_core.x][logical_core.y];
// TODO: Only allow setting once
TT_FATAL(set_rt_args.empty() or set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!");
if (set_rt_args.empty()) {
if (runtime_args.size() > max_runtime_args_per_core_) {
max_runtime_args_per_core_ = runtime_args.size();
core_with_max_runtime_args_ = logical_core;
}
this->validate_runtime_args_size(runtime_args.size(), this->common_runtime_args_.size(), logical_core);
set_rt_args = runtime_args;
this->core_to_runtime_args_data_[logical_core.x][logical_core.y] = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()};
this->core_with_runtime_args_.insert( logical_core );
} else {
TT_FATAL(set_rt_args.size() == runtime_args.size(), "Illegal Runtime Args: Number of runtime args cannot be modified!");
std::memcpy(this->core_to_runtime_args_data_[logical_core.x][logical_core.y].rt_args_data, runtime_args.data(), runtime_args.size() * sizeof(uint32_t));
}

Expand All @@ -260,11 +260,26 @@ void Kernel::set_common_runtime_args(const std::vector<uint32_t> &common_runtime
// Should this check only be enabled in debug mode?
// TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name());

this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_);
auto &set_rt_args = this->common_runtime_args_;
TT_FATAL(set_rt_args.empty() or set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!");
TT_FATAL(set_rt_args.empty(), "Illegal Common Runtime Args: Can only set common runtime args once. Get and modify args in place instead.");
this->validate_runtime_args_size(max_runtime_args_per_core_, common_runtime_args.size(), core_with_max_runtime_args_);
set_rt_args = common_runtime_args;
this->common_runtime_args_data_ = RuntimeArgsData{set_rt_args.data(), set_rt_args.size()};
this->common_runtime_args_data_ = std::vector<RuntimeArgsData>(this->core_range_set_.ranges().size(), RuntimeArgsData{set_rt_args.data(), set_rt_args.size()});
}

void Kernel::set_common_runtime_args(const RuntimeArgsData& common_runtime_args) {

// TODO (abhullar): If we don't include this check then user can write runtime args to a core that the kernel is not placed on.
// Should this check only be enabled in debug mode?
// TT_FATAL(this->is_on_logical_core(logical_core), "Cannot set runtime args for core {} since kernel {} is not placed on it!", logical_core.str(), this->name());

auto &set_rt_args = this->common_runtime_args_;
TT_FATAL(!set_rt_args.empty() and set_rt_args.size() == common_runtime_args.size(), "Illegal Common Runtime Args: Number of common runtime args cannot be modified!");
for (auto& rt_args_data : this->common_runtime_args_data_) {
if (common_runtime_args.data() != rt_args_data.data()) {
memcpy(rt_args_data.data(), common_runtime_args.data(), common_runtime_args.size() * sizeof(uint32_t));
}
}
}

void DataMovementKernel::set_build_options(JitBuildOptions& build_options) const {
Expand Down
7 changes: 4 additions & 3 deletions tt_metal/impl/kernels/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class Kernel : public JitBuildSettings {
void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value);

std::vector<uint32_t> & runtime_args(const CoreCoord &logical_core);
RuntimeArgsData& runtime_args_data(const CoreCoord &logical_core);
RuntimeArgsData & runtime_args_data(const CoreCoord &logical_core);
std::vector< std::vector< std::vector<uint32_t>> > & runtime_args();
std::vector< std::vector< RuntimeArgsData > > & runtime_args_data();
std::vector<uint32_t> & common_runtime_args();
RuntimeArgsData& common_runtime_args_data();
std::vector<RuntimeArgsData> & common_runtime_args_data();

std::map<std::string, std::string> defines() const { return defines_; }

Expand All @@ -108,6 +108,7 @@ class Kernel : public JitBuildSettings {
void validate_runtime_args_size(size_t num_unique_rt_args, size_t num_common_rt_args, const CoreCoord& logical_core);
void set_runtime_args(const CoreCoord &logical_core, const std::vector<uint32_t> &runtime_args);
void set_common_runtime_args(const std::vector<uint32_t> &runtime_args);
void set_common_runtime_args(const RuntimeArgsData& common_runtime_args);

int get_watcher_kernel_id() { return watcher_kernel_id_; }

Expand All @@ -132,7 +133,7 @@ class Kernel : public JitBuildSettings {
std::vector< std::vector< std::vector<uint32_t>> > core_to_runtime_args_;
std::vector< std::vector< RuntimeArgsData> > core_to_runtime_args_data_;
std::vector<uint32_t> common_runtime_args_;
RuntimeArgsData common_runtime_args_data_;
std::vector<RuntimeArgsData> common_runtime_args_data_;
std::set<CoreCoord> core_with_runtime_args_;
std::size_t max_runtime_args_per_core_; // For validation
CoreCoord core_with_max_runtime_args_; // For validation
Expand Down
4 changes: 0 additions & 4 deletions tt_metal/impl/program/program_device_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,4 @@ struct ProgramTransferInfo {

struct ProgramCommandIndices {
std::uint32_t cb_configs_payload_start; // device_commands
// pair of cmd idx, rt arg offset
// Currently we only really need the base cmd idx since they are sequential, and the rt arg len is currently the same for all splits
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> processor_to_cmd_mapping;
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> kernel_to_cmd_mapping;
};
Loading

0 comments on commit 44f1791

Please sign in to comment.