Skip to content

Commit

Permalink
#4941: Convert command header to struct for easier maintainability
Browse files Browse the repository at this point in the history
  • Loading branch information
DrJessop committed Jan 25, 2024
1 parent 66d2a9b commit 4f2c631
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 133 deletions.
16 changes: 7 additions & 9 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void EnqueueRestartCommand::process() {
const DeviceCommand cmd = this->assemble_device_command(0);
uint32_t cmd_size = DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND;
this->manager.issue_queue_reserve_back(cmd_size, this->command_queue_channel);
this->manager.cq_write(cmd.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.issue_queue_push_back(cmd_size, false, this->command_queue_channel);
}

Expand Down Expand Up @@ -448,7 +448,7 @@ void EnqueueReadBufferCommand::process() {
const DeviceCommand cmd = this->assemble_device_command(this->read_buffer_addr);

this->manager.issue_queue_reserve_back(DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, this->command_queue_id);
this->manager.cq_write(cmd.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.issue_queue_push_back(DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, LAZY_COMMAND_QUEUE_MODE, this->command_queue_id);
}

Expand Down Expand Up @@ -578,7 +578,7 @@ void EnqueueWriteBufferCommand::process() {
uint32_t cmd_size = DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND + data_size_in_bytes;
this->manager.issue_queue_reserve_back(cmd_size, this->command_queue_id);

this->manager.cq_write(cmd.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
uint32_t unpadded_src_offset = this->dst_page_index * this->buffer.page_size();

if (this->buffer.page_size() % 32 != 0 and this->buffer.page_size() != this->buffer.size()) {
Expand All @@ -594,8 +594,6 @@ void EnqueueWriteBufferCommand::process() {
}

this->manager.issue_queue_push_back(cmd_size, LAZY_COMMAND_QUEUE_MODE, this->command_queue_id);

auto cmd_desc = cmd.get_desc();
}

EnqueueProgramCommand::EnqueueProgramCommand(
Expand Down Expand Up @@ -733,7 +731,7 @@ void EnqueueProgramCommand::process() {
uint32_t data_size_in_bytes = cmd.get_data_size();
const uint32_t cmd_size = DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND + data_size_in_bytes;
this->manager.issue_queue_reserve_back(cmd_size, this->command_queue_id);
this->manager.cq_write(cmd.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);

bool tracing = this->trace.has_value();
vector<uint32_t> trace_host_data;
Expand Down Expand Up @@ -789,7 +787,7 @@ void FinishCommand::process() {
const DeviceCommand cmd = this->assemble_device_command(0);
uint32_t cmd_size = DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND;
this->manager.issue_queue_reserve_back(cmd_size, this->command_queue_id);
this->manager.cq_write(cmd.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);
this->manager.issue_queue_push_back(cmd_size, false, this->command_queue_id);
}

Expand All @@ -814,7 +812,7 @@ void EnqueueWrapCommand::process() {

const DeviceCommand cmd = this->assemble_device_command(0);
this->manager.issue_queue_reserve_back(wrap_packet_size_bytes, this->command_queue_id);
this->manager.cq_write(cmd.get_desc().data(), wrap_packet_size_bytes, write_ptr);
this->manager.cq_write(cmd.data(), wrap_packet_size_bytes, write_ptr);
if (this->wrap_region == DeviceCommand::WrapRegion::COMPLETION) {
// Wrap the read pointers for completion queue because device will start writing data at head of completion queue and there are no more reads to be done at current completion queue write pointer
// If we don't wrap the read then the subsequent read buffer command may attempt to read past the total command queue size
Expand Down Expand Up @@ -1140,7 +1138,7 @@ void Trace::create_replay() {
for (auto& [device_command, data, command_type, num_data_bytes]: this->history) {
uint32_t issue_write_ptr = manager.get_issue_queue_write_ptr(command_queue_id);
device_command.update_buffer_transfer_src(0, issue_write_ptr + DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND);
manager.cq_write(device_command.get_desc().data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, issue_write_ptr);
manager.cq_write(device_command.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, issue_write_ptr);
manager.issue_queue_push_back(DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, lazy_push, command_queue_id);

uint32_t host_data_size = align(data.size() * sizeof(uint32_t), 16);
Expand Down
102 changes: 49 additions & 53 deletions tt_metal/impl/dispatch/device_command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,75 @@

#include "tt_metal/common/logger.hpp"
#include "tt_metal/common/assert.hpp"
#include <atomic>

DeviceCommand::DeviceCommand() {
for (uint32_t idx = 0; idx < DeviceCommand::NUM_ENTRIES_IN_COMMAND_HEADER; idx++) {
this->desc[idx] = 0;
}
this->buffer_transfer_idx = DeviceCommand::NUM_ENTRIES_IN_COMMAND_HEADER;
this->buffer_transfer_idx = 0;
this->program_transfer_idx = this->buffer_transfer_idx + DeviceCommand::NUM_POSSIBLE_BUFFER_TRANSFERS *
DeviceCommand::NUM_ENTRIES_PER_BUFFER_TRANSFER_INSTRUCTION;

this->desc[this->sharded_buffer_num_cores_idx] = 1;
// Not sure why this the default, but not changing behaviour
this->set_sharded_buffer_num_cores(1);
}

void DeviceCommand::set_restart() { this->desc[this->restart_idx] = 1; }
void DeviceCommand::set_restart() { this->packet.header.restart = 1; }

void DeviceCommand::set_issue_queue_size(uint32_t new_issue_queue_size) { this->desc[this->new_issue_queue_size_idx] = new_issue_queue_size; }
void DeviceCommand::set_issue_queue_size(uint32_t new_issue_queue_size) { this->packet.header.new_issue_queue_size = new_issue_queue_size; }

void DeviceCommand::set_completion_queue_size(uint32_t new_completion_queue_size) { this->desc[this->new_completion_queue_size_idx] = new_completion_queue_size; }
void DeviceCommand::set_completion_queue_size(uint32_t new_completion_queue_size) { this->packet.header.new_completion_queue_size = new_completion_queue_size; }

void DeviceCommand::set_wrap(WrapRegion wrap_region) { this->desc[this->wrap_idx] = (uint32_t)wrap_region; }
void DeviceCommand::set_wrap(WrapRegion wrap_region) { this->packet.header.wrap = (uint32_t)wrap_region; }

void DeviceCommand::set_finish() { this->desc[this->finish_idx] = 1; }
void DeviceCommand::set_finish() { this->packet.header.finish = 1; }

void DeviceCommand::set_num_workers(const uint32_t num_workers) { this->desc.at(this->num_workers_idx) = num_workers; }
void DeviceCommand::set_num_workers(const uint32_t num_workers) { this->packet.header.num_workers = num_workers; }

void DeviceCommand::set_is_program() { this->desc[this->is_program_buffer_idx] = 1; }
void DeviceCommand::set_is_program() { this->packet.header.is_program_buffer = 1; }

void DeviceCommand::set_stall() { this->desc[this->stall_idx] = 1; }
void DeviceCommand::set_stall() { this->packet.header.stall = 1; }

void DeviceCommand::set_page_size(const uint32_t page_size) { this->desc[this->page_size_idx] = page_size; }
void DeviceCommand::set_page_size(const uint32_t page_size) { this->packet.header.page_size = page_size; }

void DeviceCommand::set_producer_cb_size(const uint32_t cb_size) { this->desc[this->producer_cb_size_idx] = cb_size; }
void DeviceCommand::set_producer_cb_size(const uint32_t cb_size) { this->packet.header.producer_cb_size = cb_size; }

void DeviceCommand::set_consumer_cb_size(const uint32_t cb_size) { this->desc[this->consumer_cb_size_idx] = cb_size; }
void DeviceCommand::set_consumer_cb_size(const uint32_t cb_size) { this->packet.header.consumer_cb_size = cb_size; }

void DeviceCommand::set_producer_cb_num_pages(const uint32_t cb_num_pages) { this->desc[this->producer_cb_num_pages_idx] = cb_num_pages; }
void DeviceCommand::set_producer_cb_num_pages(const uint32_t cb_num_pages) { this->packet.header.producer_cb_num_pages = cb_num_pages; }

void DeviceCommand::set_consumer_cb_num_pages(const uint32_t cb_num_pages) { this->desc[this->consumer_cb_num_pages_idx] = cb_num_pages; }
void DeviceCommand::set_consumer_cb_num_pages(const uint32_t cb_num_pages) { this->packet.header.consumer_cb_num_pages = cb_num_pages; }

void DeviceCommand::set_num_pages(uint32_t num_pages) { this->desc[this->num_pages_idx] = num_pages; }
void DeviceCommand::set_num_pages(uint32_t num_pages) { this->packet.header.num_pages = num_pages; }

void DeviceCommand::set_sharded_buffer_num_cores(uint32_t num_cores) { this->desc[this->sharded_buffer_num_cores_idx] = num_cores; }
void DeviceCommand::set_sharded_buffer_num_cores(uint32_t num_cores) { this->packet.header.sharded_buffer_num_cores = num_cores; }

void DeviceCommand::set_num_pages(const DeviceCommand::TransferType transfer_type, const uint32_t num_pages) {
switch (transfer_type) {
case DeviceCommand::TransferType::RUNTIME_ARGS:
this->desc[this->num_runtime_arg_pages_idx] = num_pages;
this->packet.header.num_runtime_arg_pages = num_pages;
break;
case DeviceCommand::TransferType::CB_CONFIGS:
this->desc[this->num_cb_config_pages_idx] = num_pages;
this->packet.header.num_cb_config_pages = num_pages;
break;
case DeviceCommand::TransferType::PROGRAM_PAGES:
this->desc[this->num_program_pages_idx] = num_pages;
this->packet.header.num_program_pages = num_pages;
break;
case DeviceCommand::TransferType::GO_SIGNALS:
this->desc[this->num_go_signal_pages_idx] = num_pages;
this->packet.header.num_go_signal_pages = num_pages;
break;
default:
TT_ASSERT(false, "Invalid transfer type.");
}
}

void DeviceCommand::set_data_size(const uint32_t data_size) { this->desc[this->data_size_idx] = data_size; }
void DeviceCommand::set_data_size(const uint32_t data_size) { this->packet.header.data_size = data_size; }

uint32_t DeviceCommand::get_data_size() const { return this->desc[this->data_size_idx]; }
uint32_t DeviceCommand::get_data_size() const { return this->packet.header.data_size; }

void DeviceCommand::set_producer_consumer_transfer_num_pages(const uint32_t producer_consumer_transfer_num_pages) {
this->desc[this->producer_consumer_transfer_num_pages_idx] = producer_consumer_transfer_num_pages;
this->packet.header.producer_consumer_transfer_num_pages = producer_consumer_transfer_num_pages;
}

void DeviceCommand::update_buffer_transfer_src(const uint8_t buffer_transfer_idx, const uint32_t new_src) {
this->desc[DeviceCommand::NUM_ENTRIES_IN_COMMAND_HEADER + buffer_transfer_idx * DeviceCommand::NUM_ENTRIES_PER_BUFFER_TRANSFER_INSTRUCTION] = new_src;
this->packet.data[DeviceCommand::NUM_ENTRIES_IN_COMMAND_HEADER + buffer_transfer_idx * DeviceCommand::NUM_ENTRIES_PER_BUFFER_TRANSFER_INSTRUCTION] = new_src;
}


Expand All @@ -91,23 +89,23 @@ void DeviceCommand::add_buffer_transfer_instruction_preamble(
const uint32_t dst_page_index
)
{
this->desc[this->buffer_transfer_idx] = src;
this->desc[this->buffer_transfer_idx + 1] = dst;
this->desc[this->buffer_transfer_idx + 2] = num_pages;
this->desc[this->buffer_transfer_idx + 3] = padded_page_size;
this->desc[this->buffer_transfer_idx + 4] = src_buf_type;
this->desc[this->buffer_transfer_idx + 5] = dst_buf_type;
this->desc[this->buffer_transfer_idx + 6] = src_page_index;
this->desc[this->buffer_transfer_idx + 7] = dst_page_index;
this->packet.data[this->buffer_transfer_idx] = src;
this->packet.data[this->buffer_transfer_idx + 1] = dst;
this->packet.data[this->buffer_transfer_idx + 2] = num_pages;
this->packet.data[this->buffer_transfer_idx + 3] = padded_page_size;
this->packet.data[this->buffer_transfer_idx + 4] = src_buf_type;
this->packet.data[this->buffer_transfer_idx + 5] = dst_buf_type;
this->packet.data[this->buffer_transfer_idx + 6] = src_page_index;
this->packet.data[this->buffer_transfer_idx + 7] = dst_page_index;

}

void DeviceCommand::add_buffer_transfer_instruction_postamble(){
this->buffer_transfer_idx += DeviceCommand::NUM_ENTRIES_PER_BUFFER_TRANSFER_INSTRUCTION;

this->desc[this->num_buffer_transfers_idx]++;
this->packet.header.num_buffer_transfers++;
TT_ASSERT(
this->desc[this->num_buffer_transfers_idx] <= DeviceCommand::NUM_POSSIBLE_BUFFER_TRANSFERS,
this->packet.header.num_buffer_transfers <= DeviceCommand::NUM_POSSIBLE_BUFFER_TRANSFERS,
"Surpassing the limit of {} on possible buffer transfers in a single command",
DeviceCommand::NUM_POSSIBLE_BUFFER_TRANSFERS);
}
Expand All @@ -127,7 +125,6 @@ void DeviceCommand::add_buffer_transfer_interleaved_instruction(
src_page_index, dst_page_index
);
this->add_buffer_transfer_instruction_postamble();

}


Expand Down Expand Up @@ -158,35 +155,34 @@ void DeviceCommand::add_buffer_transfer_sharded_instruction(
uint32_t num_shards = core_id_x.size();
uint32_t idx_offset = COMMAND_PTR_SHARD_IDX;
for (auto shard_id = 0; shard_id < num_shards; shard_id++) {
this->desc[this->buffer_transfer_idx + idx_offset++] = num_pages_in_shard[shard_id];
this->desc[this->buffer_transfer_idx + idx_offset++] = core_id_x[shard_id];
this->desc[this->buffer_transfer_idx + idx_offset++] = core_id_y[shard_id];
this->packet.data[this->buffer_transfer_idx + idx_offset++] = num_pages_in_shard[shard_id];
this->packet.data[this->buffer_transfer_idx + idx_offset++] = core_id_x[shard_id];
this->packet.data[this->buffer_transfer_idx + idx_offset++] = core_id_y[shard_id];
}


this->add_buffer_transfer_instruction_postamble();
}

void DeviceCommand::write_program_entry(const uint32_t value) {
this->desc.at(this->program_transfer_idx) = value;
this->packet.data.at(this->program_transfer_idx) = value;
this->program_transfer_idx++;
}

void DeviceCommand::add_write_page_partial_instruction(
const uint32_t num_bytes, const uint32_t dst, const uint32_t dst_noc, const uint32_t num_receivers, const bool advance, const bool linked) {

// This 'at' does size checking
this->desc.at(this->program_transfer_idx + 5) = linked;
this->packet.data.at(this->program_transfer_idx + 5) = linked;

this->desc[this->program_transfer_idx] = num_bytes;
this->desc[this->program_transfer_idx + 1] = dst;
this->desc[this->program_transfer_idx + 2] = dst_noc;
this->desc[this->program_transfer_idx + 3] = num_receivers;
this->desc[this->program_transfer_idx + 4] = advance;
this->packet.data[this->program_transfer_idx] = num_bytes;
this->packet.data[this->program_transfer_idx + 1] = dst;
this->packet.data[this->program_transfer_idx + 2] = dst_noc;
this->packet.data[this->program_transfer_idx + 3] = num_receivers;
this->packet.data[this->program_transfer_idx + 4] = advance;

this->program_transfer_idx += 6;
}

const std::array<uint32_t, DeviceCommand::NUM_ENTRIES_IN_DEVICE_COMMAND>& DeviceCommand::get_desc() const {
return this->desc;
void* DeviceCommand::data() const {
return (void*)&this->packet;
}
62 changes: 34 additions & 28 deletions tt_metal/impl/dispatch/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@
#include "dev_mem_map.h"
#include "tt_metal/hostdevcommon/common_runtime_address_map.h"

struct CommandHeader {
uint32_t wrap = 0;
uint32_t finish = 0;
uint32_t num_workers = 0;
uint32_t num_buffer_transfers = 0;
uint32_t is_program_buffer = 0;
uint32_t stall = 0;
uint32_t page_size = 0;
uint32_t producer_cb_size = 0;
uint32_t consumer_cb_size = 0;
uint32_t producer_cb_num_pages = 0;
uint32_t consumer_cb_num_pages = 0;
uint32_t num_pages = 0;
uint32_t num_runtime_arg_pages = 0;
uint32_t num_cb_config_pages = 0;
uint32_t num_program_pages = 0;
uint32_t num_go_signal_pages = 0;
uint32_t data_size = 0;
uint32_t producer_consumer_transfer_num_pages = 0;
uint32_t sharded_buffer_num_cores = 0;
uint32_t restart = 0;
uint32_t new_issue_queue_size = 0;
uint32_t new_completion_queue_size = 0;
};

class DeviceCommand {
public:
DeviceCommand();
Expand All @@ -18,41 +43,16 @@ class DeviceCommand {
//TODO: investigate other num_cores
static constexpr uint32_t MAX_HUGEPAGE_SIZE = 1 << 30; // 1GB;
static constexpr uint32_t NUM_MAX_CORES = 108; //12 x 9
static constexpr uint32_t NUM_ENTRIES_IN_COMMAND_HEADER = 22;
static constexpr uint32_t NUM_ENTRIES_IN_COMMAND_HEADER = sizeof(CommandHeader) / sizeof(uint32_t);
static constexpr uint32_t NUM_ENTRIES_IN_DEVICE_COMMAND = 5632;
static constexpr uint32_t NUM_BYTES_IN_DEVICE_COMMAND = NUM_ENTRIES_IN_DEVICE_COMMAND * sizeof(uint32_t);
static constexpr uint32_t PROGRAM_PAGE_SIZE = 2048;
static constexpr uint32_t NUM_ENTRIES_PER_BUFFER_TRANSFER_INSTRUCTION = COMMAND_PTR_SHARD_IDX + NUM_MAX_CORES*NUM_ENTRIES_PER_SHARD;
static constexpr uint32_t NUM_POSSIBLE_BUFFER_TRANSFERS = 2;

// Ensure any changes to this device command have asserts modified/extended
static_assert(NUM_ENTRIES_IN_COMMAND_HEADER == 22);
static_assert((NUM_BYTES_IN_DEVICE_COMMAND % 32) == 0);

// Command header
static constexpr uint32_t wrap_idx = 0;
static constexpr uint32_t finish_idx = 1;
static constexpr uint32_t num_workers_idx = 2;
static constexpr uint32_t num_buffer_transfers_idx = 3;
static constexpr uint32_t is_program_buffer_idx = 4;
static constexpr uint32_t stall_idx = 5;
static constexpr uint32_t page_size_idx = 6;
static constexpr uint32_t producer_cb_size_idx = 7;
static constexpr uint32_t consumer_cb_size_idx = 8;
static constexpr uint32_t producer_cb_num_pages_idx = 9;
static constexpr uint32_t consumer_cb_num_pages_idx = 10;
static constexpr uint32_t num_pages_idx = 11;
static constexpr uint32_t num_runtime_arg_pages_idx = 12;
static constexpr uint32_t num_cb_config_pages_idx = 13;
static constexpr uint32_t num_program_pages_idx = 14;
static constexpr uint32_t num_go_signal_pages_idx = 15;
static constexpr uint32_t data_size_idx = 16;
static constexpr uint32_t producer_consumer_transfer_num_pages_idx = 17;
static constexpr uint32_t sharded_buffer_num_cores_idx = 18;
static constexpr uint32_t restart_idx = 19;
static constexpr uint32_t new_issue_queue_size_idx = 20;
static constexpr uint32_t new_completion_queue_size_idx = 21;

// Denotes which portion of the command queue needs to be wrapped
enum class WrapRegion : uint8_t {
NONE = 0,
Expand Down Expand Up @@ -135,10 +135,9 @@ class DeviceCommand {
const bool advance,
const bool linked);

const std::array<uint32_t, NUM_ENTRIES_IN_DEVICE_COMMAND>& get_desc() const;
void* data() const;

private:
std::array<uint32_t, DeviceCommand::NUM_ENTRIES_IN_DEVICE_COMMAND> desc;
uint32_t buffer_transfer_idx;
uint32_t program_transfer_idx;
void add_buffer_transfer_instruction_preamble(
Expand All @@ -152,4 +151,11 @@ class DeviceCommand {
const uint32_t dst_page_index
);
void add_buffer_transfer_instruction_postamble();

struct packet_ {
CommandHeader header;
std::array<uint32_t, DeviceCommand::NUM_ENTRIES_IN_DEVICE_COMMAND - DeviceCommand::NUM_ENTRIES_IN_COMMAND_HEADER> data;
};

packet_ packet;
};
Loading

0 comments on commit 4f2c631

Please sign in to comment.