Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have runtime args update directly into device cmd for FD #8504

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/aspell-dictionary.pws
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ UnaryOpType
UpdateCircularBufferPageSize
UpdateCircularBufferTotalSize
UpdateDynamicCircularBufferAddress
UpdateRuntimeArgs
VC
VCs
WH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ Runtime Arguments

.. doxygenfunction:: SetRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const std::vector< CoreCoord > & core_spec, const std::vector<std::shared_ptr<RuntimeArgs>> runtime_args)

.. doxygenfunction:: UpdateRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args)

.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core)

.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id)

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector<uint32_t> &runtime_args)

.. doxygenfunction:: SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const RuntimeArgsData &runtime_args)

.. doxygenfunction:: GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id)
Original file line number Diff line number Diff line change
Expand Up @@ -276,75 +276,6 @@ TEST_F(CommandQueueFixture, TestAsyncBufferRW) {
command_queue.set_mode(current_mode);
}

TEST_F(CommandQueueFixture, TestAsyncSetAndUpdateRuntimeArgs) {
// Test Asynchronous buffer allocation and SetRuntimeArgs API
auto& command_queue = this->device_->command_queue();
auto current_mode = CommandQueue::default_mode();
command_queue.set_mode(CommandQueue::CommandQueueMode::ASYNC);

uint32_t buf_size = 4096;
uint32_t page_size = 4096;
CoreCoord core = {0, 0};
// Initialize kernels in program
Program program;
auto reader = CreateKernel(
program,
"tt_metal/kernels/dataflow/reader_binary_diff_lengths.cpp",
core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
auto writer = CreateKernel(
program,
"tt_metal/kernels/dataflow/writer_unary.cpp",
core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default});

// Asynchronously allocate buffers on device
auto src0 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto src1 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto dst = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
// Asynchronously set the runtime args based on (potentially unallocated) buffer addrs
std::shared_ptr<RuntimeArgs> writer_runtime_args = std::make_shared<RuntimeArgs>();
*writer_runtime_args = {src0.get(), src1.get()};
std::shared_ptr<RuntimeArgs> reader_runtime_args= std::make_shared<RuntimeArgs>();
*reader_runtime_args = {dst.get()};
SetRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_runtime_args);
SetRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_runtime_args);
Finish(this->device_->command_queue());

auto resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core);
auto resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core);

EXPECT_EQ(resolved_writer_args.size(), 2);
EXPECT_EQ(resolved_reader_args.size(), 1);
EXPECT_EQ(resolved_writer_args[0], src0->address());
EXPECT_EQ(resolved_writer_args[1], src1->address());
EXPECT_EQ(resolved_reader_args[0], dst->address());

// Create new buffers and update the runtime args based on their address
auto src2 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto src3 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto dst1 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);

*writer_runtime_args = {src2.get(), src3.get()};
*reader_runtime_args = {dst1.get()};
std::vector<uint32_t> writer_update_idx = {0, 1};
std::vector<uint32_t> reader_update_idx = {0};
// Asynchronously update the runtime args based on (potentially unallocated) buffer addrs
UpdateRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_update_idx, writer_runtime_args);
UpdateRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_update_idx, reader_runtime_args);
Finish(this->device_->command_queue());

resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core);
resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core);

EXPECT_EQ(resolved_writer_args.size(), 2);
EXPECT_EQ(resolved_reader_args.size(), 1);
EXPECT_EQ(resolved_writer_args[0], src2->address());
EXPECT_EQ(resolved_writer_args[1], src3->address());
EXPECT_EQ(resolved_reader_args[0], dst1->address());
command_queue.set_mode(current_mode);
}

TEST_F(CommandQueueFixture, TestAsyncCBAllocation) {
// Test asynchronous allocation of buffers and their assignment to CBs
auto& command_queue = this->device_->command_queue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1157,42 +1157,42 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor&
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<Tensor>& output_tensors
) {
bool is_sharded = input_tensors.at(0).is_sharded();
const auto& input = input_tensors.at(0);
const auto& output = output_tensors.at(0);
bool is_sharded = input_tensors[0].is_sharded();
const auto& input = input_tensors[0];
const auto& output = output_tensors[0];
for (uint32_t i = 0; i < total_worker_core_pairs_used; ++i) {
if (is_sharded) {
auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_reader_sender_runtime_args.at(7) = input.buffer()->address();
worker_reader_sender_runtime_args[7] = input.buffer()->address();
uint32_t num_dest_cores = worker_reader_sender_runtime_args.at(12);
worker_reader_sender_runtime_args.at(12 + num_dest_cores + 4) = output.buffer()->address();
worker_reader_sender_runtime_args[12 + num_dest_cores + 4] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_reader_sender_runtime_args:");
for (uint32_t j = 0; j < worker_reader_sender_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_reader_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}

auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_writer_sender_runtime_args.at(12) = output.buffer()->address();
worker_writer_sender_runtime_args[12] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_writer_sender_runtime_args:");
for (uint32_t j = 0; j < worker_writer_sender_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_writer_sender_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}

auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i));
worker_writer_receiver_runtime_args.at(10) = output.buffer()->address();
worker_writer_receiver_runtime_args[10] = output.buffer()->address();
log_trace(tt::LogOp, "override worker_writer_receiver_runtime_args:");
for (uint32_t j = 0; j < worker_writer_receiver_runtime_args.size(); ++j) {
log_trace(tt::LogOp, "\tworker_writer_receiver_runtime_args[{}]: {}", j, worker_reader_sender_runtime_args.at(j));
}
} else {
auto &worker_reader_sender_runtime_args = GetRuntimeArgs(program, worker_reader_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_reader_sender_runtime_args.at(0) = input.buffer()->address();
worker_reader_sender_runtime_args.at(1) = output.buffer()->address();
worker_reader_sender_runtime_args[0] = input.buffer()->address();
worker_reader_sender_runtime_args[1] = output.buffer()->address();
auto &worker_writer_sender_runtime_args = GetRuntimeArgs(program, worker_writer_sender_kernels.at(i), all_worker_sender_cores.at(i));
worker_writer_sender_runtime_args.at(0) = output.buffer()->address();
worker_writer_sender_runtime_args[0] = output.buffer()->address();

auto &worker_writer_receiver_runtime_args = GetRuntimeArgs(program, worker_writer_receiver_kernels.at(i), all_worker_receiver_cores.at(i));
worker_writer_receiver_runtime_args.at(0) = output.buffer()->address();
worker_writer_receiver_runtime_args[0] = output.buffer()->address();
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ operation::ProgramWithCallbacks concat_multi_core(
for (const auto &core : cores) {
{
auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ operation::ProgramWithCallbacks concat_single_core(

{
auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 3);
std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.data() + 3);
}

{
Expand Down
6 changes: 2 additions & 4 deletions tt_eager/tt_dnn/op_library/moreh_adam/moreh_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,25 @@ operation::ProgramWithCallbacks moreh_adam_(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
6 changes: 2 additions & 4 deletions tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,25 @@ operation::ProgramWithCallbacks moreh_adamw_(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = param_buffer->address();
runtime_args[1] = grad_buffer->address();
runtime_args[2] = exp_avg_buffer->address();
runtime_args[3] = exp_avg_sq_buffer->address();
if (max_exp_avg_sq_buffer != nullptr) {
runtime_args[4] = max_exp_avg_sq_buffer->address();
}
tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
3 changes: 1 addition & 2 deletions tt_eager/tt_dnn/op_library/moreh_arange/moreh_arange_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ operation::ProgramWithCallbacks moreh_arange_(
CoreCoord core = {icore / core_h, icore % core_h};

{
auto runtime_args = GetRuntimeArgs(program, kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, kernel_id, core);
runtime_args[0] = src_dram_buffer->address();
SetRuntimeArgs(program, kernel_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,23 +196,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
runtime_args[0] = input_tensors.at(i).buffer()->address();
runtime_args[3] = *reinterpret_cast<uint32_t*>(&decimal);
SetRuntimeArgs(program, reader_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
runtime_args[0] = output_address;
SetRuntimeArgs(program, writer_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, core);
runtime_args[1] = p;
runtime_args[2] = static_cast<uint32_t>(p_is_negative);
SetRuntimeArgs(program, compute_kernels_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,20 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
const auto output_address = input_tensors.at(1).buffer()->address();

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, single_core);
runtime_args[0] = input_address;
runtime_args[3] = *reinterpret_cast<uint32_t*>(&decimal);
SetRuntimeArgs(program, reader_kernels_id, single_core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, single_core);
runtime_args[0] = output_address;
SetRuntimeArgs(program, writer_kernels_id, single_core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core);
auto &runtime_args = GetRuntimeArgs(program, compute_kernels_id, single_core);
runtime_args[1] = p;
runtime_args[2] = static_cast<uint32_t>(p_is_negative);
SetRuntimeArgs(program, compute_kernels_id, single_core, runtime_args);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
CoreCoord core = {i / num_cores_y, i % num_cores_y};

{
auto runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
runtime_args[0] = input_buffers.at(i)->address();
runtime_args[2] = clip_coef_clamped_address;
SetRuntimeArgs(program, reader_kernels_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernels_id, core);
runtime_args[0] = input_buffers.at(i)->address();
SetRuntimeArgs(program, writer_kernels_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,13 @@ operation::ProgramWithCallbacks moreh_cumsum_nc(
for (uint32_t i = 0; i < num_cores_to_be_used; ++i) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = input_buffer->address();
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = output_buffer->address();
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,21 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten
uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW;

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer_a->address();
runtime_args[1] = src_buffer_b->address();
runtime_args[2] = num_tiles;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = num_tiles;
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
runtime_args[1] = 1;
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
Expand Down
Loading
Loading