Skip to content

Commit

Permalink
#7530: Uplift unique rt args to update directly into the cmd queue fo…
Browse files Browse the repository at this point in the history
…r fast dispatch. TODO: Fix/uplift common runtime args update to remove looping in assemble_runtime_args
  • Loading branch information
tt-aho committed May 15, 2024
1 parent 3716890 commit 8d99471
Show file tree
Hide file tree
Showing 49 changed files with 197 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1157,42 +1157,42 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor&
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<Tensor>& output_tensors
) {
bool is_sharded = input_tensors.at(0).is_sharded();
const auto& input = input_tensors.at(0);
const auto& output = output_tensors.at(0);
bool is_sharded = input_tensors[0].is_sharded();
const auto& input = input_tensors[0];
const auto& output = output_tensors[0];
for (uint32_t i = 0; i < total_worker_core_pairs_used; ++i) {
if (is_sharded) {
auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_reader_sender_runtime_args.at(7) = input.buffer()->address();
worker_reader_sender_runtime_args[7] = input.buffer()->address();
uint32_t num_dest_cores = worker_reader_sender_runtime_args.at(12);
worker_reader_sender_runtime_args.at(12 + num_dest_cores + 4) = output.buffer()->address();
worker_reader_sender_runtime_args[12 + num_dest_cores + 4] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_reader_sender_runtime_args:");
for (uint32_t j = 0; j < worker_reader_sender_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_reader_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}

auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_writer_sender_runtime_args.at(12) = output.buffer()->address();
worker_writer_sender_runtime_args[12] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_writer_sender_runtime_args:");
for (uint32_t j = 0; j < worker_writer_sender_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_writer_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}

auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i));
worker_writer_receiver_runtime_args.at(10) = output.buffer()->address();
worker_writer_receiver_runtime_args[10] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_writer_receiver_runtime_args:");
for (uint32_t j = 0; j < worker_writer_receiver_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_writer_receiver_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}
} else {
auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_reader_sender_runtime_args.at(0) = input.buffer()->address();
worker_reader_sender_runtime_args.at(1) = output.buffer()->address();
worker_reader_sender_runtime_args[0] = input.buffer()->address();
worker_reader_sender_runtime_args[1] = output.buffer()->address();
auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_writer_sender_runtime_args.at(0) = output.buffer()->address();
worker_writer_sender_runtime_args[0] = output.buffer()->address();

auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i));
worker_writer_receiver_runtime_args.at(0) = output.buffer()->address();
worker_writer_receiver_runtime_args[0] = output.buffer()->address();
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ operation::ProgramWithCallbacks concat_multi_core(
for (const auto &core : cores) {
{
auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ operation::ProgramWithCallbacks concat_single_core(

{
auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3);
}

{
Expand Down
6 changes: 2 additions & 4 deletions tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,25 @@ operation::ProgramWithCallbacks moreh_adam_(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
6 changes: 2 additions & 4 deletions tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,25 @@ operation::ProgramWithCallbacks moreh_adamw_(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
3 changes: 1 addition & 2 deletions tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ operation::ProgramWithCallbacks moreh_arange_(
CoreCoord core = {icore / core_h, icore % core_h};

{
auto runtime_args = GetRuntimeArgs(program, kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, kernel_id, core);
runtime_args[0] = src_dram_buffer->address();
SetRuntimeArgs(program, kernel_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,23 +196,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
runtime_args[0] = input_tensors.at(i).buffer()->address();
runtime_args[3] = *reinterpret_cast<uint32_t*>(&decimal);
SetRuntimeArgs(program, reader_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
runtime_args[0] = output_address;
SetRuntimeArgs(program, writer_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, core);
runtime_args[1] = p;
runtime_args[2] = static_cast<uint32_t>(p_is_negative);
SetRuntimeArgs(program, compute_kernels_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
const auto output_address = input_tensors.at(1).buffer()->address();

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core);
runtime_args[0] = input_address;
runtime_args[3] = *reinterpret_cast<uint32_t*>(&decimal);
SetRuntimeArgs(program, reader_kernels_id, single_core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core);
runtime_args[0] = output_address;
SetRuntimeArgs(program, writer_kernels_id, single_core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core);
runtime_args[1] = p;
runtime_args[2] = static_cast<uint32_t>(p_is_negative);
SetRuntimeArgs(program, compute_kernels_id, single_core, runtime_args);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
runtime_args[0] = input_buffers.at(i)->address();
runtime_args[2] = clip_coef_clamped_address;
SetRuntimeArgs(program, reader_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
runtime_args[0] = input_buffers.at(i)->address();
SetRuntimeArgs(program, writer_kernels_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,13 @@ operation::ProgramWithCallbacks moreh_cumsum_nc(
for (uint32_t i = 0; i < num_cores_to_be_used; ++i) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = input_buffer->address();
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = output_buffer->address();
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,21 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten
uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW;

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer_a->address();
runtime_args[1] = src_buffer_b->address();
runtime_args[2] = num_tiles;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = num_tiles;
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
runtime_args[1] = 1;
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,26 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core(

CoreCoord core = {0, 0};
{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = (std::uint32_t)has_input_grad;
runtime_args[1] = (std::uint32_t)has_input_grad;
runtime_args[2] = src0_buffer->address();
runtime_args[3] = src1_buffer->address();
runtime_args[4] = src2_buffer->address();
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = (std::uint32_t)has_input_grad;
runtime_args[1] = (std::uint32_t)has_input_grad;
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = (std::uint32_t)has_input_grad;
runtime_args[1] = (std::uint32_t)has_input_grad;
runtime_args[2] = dst0_address;
runtime_args[3] = dst1_address;
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,17 @@ operation::ProgramWithCallbacks moreh_getitem_rm(
CoreCoord core = {icore / core_h, icore % core_h};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer->address();
runtime_args[1] = index_info[0].address;
runtime_args[2] = index_info[1].address;
runtime_args[3] = index_info[2].address;
runtime_args[4] = index_info[3].address;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized(
CoreCoord core = {icore / core_h, icore % core_h};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer->address();
runtime_args[1] = index_info[0].address;
runtime_args[2] = index_info[1].address;
runtime_args[3] = index_info[2].address;
runtime_args[4] = index_info[3].address;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down Expand Up @@ -545,19 +543,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized(
CoreCoord core = {icore / core_h, icore % core_h};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer->address();
runtime_args[1] = index_info[0].address;
runtime_args[2] = index_info[1].address;
runtime_args[3] = index_info[2].address;
runtime_args[4] = index_info[3].address;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
Loading

0 comments on commit 8d99471

Please sign in to comment.