Skip to content

Commit

Permalink
#0: Remove deprecated UpdateRuntimeArgs api
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 15, 2024
1 parent f612388 commit b759289
Show file tree
Hide file tree
Showing 9 changed files with 0 additions and 141 deletions.
1 change: 0 additions & 1 deletion docs/aspell-dictionary.pws
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ UnaryOpType
UpdateCircularBufferPageSize
UpdateCircularBufferTotalSize
UpdateDynamicCircularBufferAddress
UpdateRuntimeArgs
VC
VCs
WH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ Runtime Arguments

.. doxygenfunction:: SetRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const std::vector< CoreCoord > & core_spec, const std::vector<std::shared_ptr<RuntimeArgs>> runtime_args)

.. doxygenfunction:: UpdateRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args)

.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id, const CoreCoord &logical_core)

.. doxygenfunction:: GetRuntimeArgs(const Program &program, KernelHandle kernel_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,75 +276,6 @@ TEST_F(CommandQueueFixture, TestAsyncBufferRW) {
command_queue.set_mode(current_mode);
}

TEST_F(CommandQueueFixture, TestAsyncSetAndUpdateRuntimeArgs) {
// Test Asynchronous buffer allocation and SetRuntimeArgs API
auto& command_queue = this->device_->command_queue();
auto current_mode = CommandQueue::default_mode();
command_queue.set_mode(CommandQueue::CommandQueueMode::ASYNC);

uint32_t buf_size = 4096;
uint32_t page_size = 4096;
CoreCoord core = {0, 0};
// Initialize kernels in program
Program program;
auto reader = CreateKernel(
program,
"tt_metal/kernels/dataflow/reader_binary_diff_lengths.cpp",
core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default});
auto writer = CreateKernel(
program,
"tt_metal/kernels/dataflow/writer_unary.cpp",
core,
DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default});

// Asynchronously allocate buffers on device
auto src0 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto src1 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto dst = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
// Asynchronously set the runtime args based on (potentially unallocated) buffer addrs
std::shared_ptr<RuntimeArgs> writer_runtime_args = std::make_shared<RuntimeArgs>();
*writer_runtime_args = {src0.get(), src1.get()};
std::shared_ptr<RuntimeArgs> reader_runtime_args= std::make_shared<RuntimeArgs>();
*reader_runtime_args = {dst.get()};
SetRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_runtime_args);
SetRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_runtime_args);
Finish(this->device_->command_queue());

auto resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core);
auto resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core);

EXPECT_EQ(resolved_writer_args.size(), 2);
EXPECT_EQ(resolved_reader_args.size(), 1);
EXPECT_EQ(resolved_writer_args[0], src0->address());
EXPECT_EQ(resolved_writer_args[1], src1->address());
EXPECT_EQ(resolved_reader_args[0], dst->address());

// Create new buffers and update the runtime args based on their address
auto src2 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto src3 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);
auto dst1 = std::make_shared<Buffer>(this->device_, buf_size, page_size, BufferType::DRAM);

*writer_runtime_args = {src2.get(), src3.get()};
*reader_runtime_args = {dst1.get()};
std::vector<uint32_t> writer_update_idx = {0, 1};
std::vector<uint32_t> reader_update_idx = {0};
// Asynchronously update the runtime args based on (potentially unallocated) buffer addrs
UpdateRuntimeArgs(this->device_, detail::GetKernel(program, writer), core, writer_update_idx, writer_runtime_args);
UpdateRuntimeArgs(this->device_, detail::GetKernel(program, reader), core, reader_update_idx, reader_runtime_args);
Finish(this->device_->command_queue());

resolved_writer_args = detail::GetKernel(program, writer)->runtime_args(core);
resolved_reader_args = detail::GetKernel(program, reader)->runtime_args(core);

EXPECT_EQ(resolved_writer_args.size(), 2);
EXPECT_EQ(resolved_reader_args.size(), 1);
EXPECT_EQ(resolved_writer_args[0], src2->address());
EXPECT_EQ(resolved_writer_args[1], src3->address());
EXPECT_EQ(resolved_reader_args[0], dst1->address());
command_queue.set_mode(current_mode);
}

TEST_F(CommandQueueFixture, TestAsyncCBAllocation) {
// Test asynchronous allocation of buffers and their assignment to CBs
auto& command_queue = this->device_->command_queue();
Expand Down
15 changes: 0 additions & 15 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,21 +383,6 @@ std::vector< std::vector< RuntimeArgsData > > & GetRuntimeArgs(const Program &pr
*/
RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id);

/**
* Update specific entries of the runtime args vector for a kernel using the command queue. This API must be used when Asynchronous Command Queue Mode is enabled.
*
* Return Value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|-------------------------------|--------------------------------------------------------------|----------|
* | cq | The command queue used to send the runtime args update | CommandQueue & | | Yes |
* | kernel | The kernel for which the runtime args must be updated | std::shared_ptr<Kernel> | | Yes |
* | core_coord | The core receiving the runtime args update | const CoreCoord & | A single core running the kernel | Yes |
* | update_idx | The runtime arg vector indices that must be updated | std::vector<uint32_t> & | Each index in this vector must be less than num runtime args | Yes |
* | runtime_args | Updated runtime args | std::shared_ptr<RuntimeArgs> | 1:1 Mapping between each entry and the indices in update_idx | Yes |
*/
void UpdateRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args);

/**
* Reads a buffer from the device
*
Expand Down
38 changes: 0 additions & 38 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1831,40 +1831,6 @@ void EnqueueAddBufferToProgram(CommandQueue& cq, std::variant<std::reference_wra
// });
}

void EnqueueUpdateRuntimeArgsImpl (const RuntimeArgsMetadata& runtime_args_md) {
std::vector<uint32_t> resolved_runtime_args = {};
resolved_runtime_args.reserve((*runtime_args_md.runtime_args_ptr).size());

for (const auto& arg : *(runtime_args_md.runtime_args_ptr)) {
std::visit([&resolved_runtime_args] (auto&& a) {
using T = std::decay_t<decltype(a)>;
if constexpr (std::is_same_v<T, Buffer*>) {
resolved_runtime_args.push_back(a -> address());
} else {
resolved_runtime_args.push_back(a);
}
}, arg);
}
auto& kernel_runtime_args = runtime_args_md.kernel->runtime_args(runtime_args_md.core_coord);
for (const auto& idx : runtime_args_md.update_idx) {
kernel_runtime_args[idx] = resolved_runtime_args[idx];
}
}

void EnqueueUpdateRuntimeArgs(CommandQueue& cq, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args_ptr, bool blocking) {
auto runtime_args_md = RuntimeArgsMetadata {
.core_coord = core_coord,
.runtime_args_ptr = runtime_args_ptr,
.kernel = kernel,
.update_idx = update_idx,
};
cq.run_command(CommandInterface {
.type = EnqueueCommandType::UPDATE_RUNTIME_ARGS,
.blocking = blocking,
.runtime_args_md = runtime_args_md,
});
}

void EnqueueSetRuntimeArgsImpl(const RuntimeArgsMetadata& runtime_args_md) {
std::vector<uint32_t> resolved_runtime_args = {};
resolved_runtime_args.reserve((*runtime_args_md.runtime_args_ptr).size());
Expand Down Expand Up @@ -2349,10 +2315,6 @@ void CommandQueue::run_command_impl(const CommandInterface& command) {
TT_ASSERT(command.runtime_args_md.has_value(), "Must provide RuntimeArgs Metdata!");
EnqueueSetRuntimeArgsImpl(command.runtime_args_md.value());
break;
case EnqueueCommandType::UPDATE_RUNTIME_ARGS:
TT_ASSERT(command.runtime_args_md.has_value(), "Must provide RuntimeArgs Metdata!");
EnqueueUpdateRuntimeArgsImpl(command.runtime_args_md.value());
break;
case EnqueueCommandType::ADD_BUFFER_TO_PROGRAM:
TT_ASSERT(command.buffer.has_value(), "Must provide a buffer!");
TT_ASSERT(command.program.has_value(), "Must provide a program!");
Expand Down
2 changes: 0 additions & 2 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ enum class EnqueueCommandType {
GET_BUF_ADDR,
ADD_BUFFER_TO_PROGRAM,
SET_RUNTIME_ARGS,
UPDATE_RUNTIME_ARGS,
ENQUEUE_PROGRAM,
ENQUEUE_TRACE,
ENQUEUE_RECORD_EVENT,
Expand Down Expand Up @@ -634,7 +633,6 @@ void EnqueueAllocateBuffer(CommandQueue& cq, Buffer* buffer, bool bottom_up, boo
void EnqueueDeallocateBuffer(CommandQueue& cq, Allocator& allocator, uint32_t device_address, BufferType buffer_type, bool blocking);
void EnqueueGetBufferAddr(CommandQueue& cq, uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking);
void EnqueueSetRuntimeArgs(CommandQueue& cq, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::shared_ptr<RuntimeArgs> runtime_args_ptr, bool blocking);
void EnqueueUpdateRuntimeArgs(CommandQueue& cq, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args_ptr, bool blocking);
void EnqueueAddBufferToProgram(CommandQueue& cq, std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer>> buffer, std::variant<std::reference_wrapper<Program>, std::shared_ptr<Program>> program, bool blocking);

} // namespace tt::tt_metal
Expand Down
7 changes: 0 additions & 7 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,6 @@ std::string Kernel::compute_hash() const {
);
}

void Kernel::update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value){
ZoneScoped;
auto & v = this->core_to_runtime_args_[logical_core.x][logical_core.y];
TT_ASSERT( idx < v.size(), "Runtime arg offset {} for Core {} out of bounds", idx, logical_core.str());
v[idx] = value;
}

std::vector<uint32_t>& Kernel::runtime_args(const CoreCoord &logical_core) {
// TODO (abhullar): Should this check only be enabled in debug mode?
TT_FATAL( logical_core.x < this->core_to_runtime_args_.size() && logical_core.y < this->core_to_runtime_args_[logical_core.x].size(), "Cannot get runtime args for kernel {} that is not placed on core {}", this->name(), logical_core.str());
Expand Down
2 changes: 0 additions & 2 deletions tt_metal/impl/kernels/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class Kernel : public JitBuildSettings {

const std::set<CoreCoord>& cores_with_runtime_args() const { return core_with_runtime_args_; }

void update_runtime_arg( const CoreCoord &logical_core, size_t idx, uint32_t value);

std::vector<uint32_t> & runtime_args(const CoreCoord &logical_core);
RuntimeArgsData & runtime_args_data(const CoreCoord &logical_core);
std::vector< std::vector< std::vector<uint32_t>> > & runtime_args();
Expand Down
5 changes: 0 additions & 5 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,11 +973,6 @@ RuntimeArgsData & GetCommonRuntimeArgs(const Program &program, KernelHandle kern
return detail::GetKernel(program, kernel_id)->common_runtime_args_data().at(0);
}

void UpdateRuntimeArgs(Device* device, const std::shared_ptr<Kernel> kernel, const CoreCoord &core_coord, std::vector<uint32_t> &update_idx, std::shared_ptr<RuntimeArgs> runtime_args) {
detail::DispatchStateCheck(not device->using_slow_dispatch());
EnqueueUpdateRuntimeArgs(device->command_queue(), kernel, core_coord, update_idx, runtime_args, false);
}

} // namespace tt_metal

} // namespace tt

0 comments on commit b759289

Please sign in to comment.