Skip to content

Commit

Permalink
#7530: Add no stride flag to dispatch packed write
Browse files Browse the repository at this point in the history
  • Loading branch information
pgkeller committed May 21, 2024
1 parent 59fd6ce commit 6e6b6aa
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 20 deletions.
50 changes: 39 additions & 11 deletions tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,21 +588,30 @@ inline void generate_random_paged_payload(Device *device,
inline void generate_random_packed_payload(vector<uint32_t>& cmds,
vector<CoreCoord>& worker_cores,
DeviceData& data,
uint32_t size_words) {
uint32_t size_words,
bool repeat = false) {

static uint32_t coherent_count = 0;
const uint32_t bank_id = 0; // No interleaved pages here.

bool first_core = true;
vector<uint32_t>results;
CoreCoord first_worker = worker_cores[0];
for (uint32_t i = 0; i < size_words; i++) {
uint32_t datum = (use_coherent_data_g) ? ((first_worker.x << 16) | (first_worker.y << 24) | coherent_count++) : std::rand();
results.push_back(datum);
}
for (CoreCoord core : worker_cores) {
for (uint32_t i = 0; i < size_words; i++) {
uint32_t datum = (use_coherent_data_g) ? ((core.x << 16) | (core.y << 24) | coherent_count++) : std::rand();

cmds.push_back(datum);
data.push_one(core, bank_id, datum);
data.push_one(core, bank_id, results[i]);
if (!repeat || first_core) {
cmds.push_back(results[i]);
}
}

cmds.resize(padded_size(cmds.size(), 4)); // XXXXX L1_ALIGNMENT16/sizeof(uint)
data.pad(core, bank_id, 16); // L1_ALIGNMENT16
first_core = false;
}
}

Expand Down Expand Up @@ -709,7 +718,8 @@ inline void add_dispatcher_packed_cmd(Device *device,
vector<CoreCoord>& worker_cores,
DeviceData& device_data,
CQDispatchCmd cmd,
uint32_t size_words) {
uint32_t size_words,
bool repeat = false) {

size_t prior_end = debug_prologue(cmds);

Expand All @@ -720,7 +730,7 @@ inline void add_dispatcher_packed_cmd(Device *device,
}
cmds.resize(padded_size(cmds.size(), 4)); // XXXXX L1_ALIGNMENT16/sizeof(uint)

generate_random_packed_payload(cmds, worker_cores, device_data, size_words);
generate_random_packed_payload(cmds, worker_cores, device_data, size_words, repeat);

debug_epilogue(cmds, prior_end);
}
Expand Down Expand Up @@ -843,7 +853,8 @@ inline void gen_dispatcher_packed_write_cmd(Device *device,
vector<uint32_t>& cmds,
vector<CoreCoord>& worker_cores,
DeviceData& device_data,
uint32_t size_words) {
uint32_t size_words,
bool repeat = false) {

// Pad w/ blank data until all workers are at the same address
device_data.relevel(CoreType::WORKER);
Expand All @@ -852,12 +863,15 @@ inline void gen_dispatcher_packed_write_cmd(Device *device,
memset(&cmd, 0, sizeof(CQDispatchCmd));

cmd.base.cmd_id = CQ_DISPATCH_CMD_WRITE_PACKED;
cmd.write_packed.is_multicast = 0;
cmd.write_packed.flags = repeat ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE;
cmd.write_packed.count = worker_cores.size();
cmd.write_packed.addr = device_data.get_result_data_addr(worker_cores[0]);
cmd.write_packed.size = size_words * sizeof(uint32_t);

add_dispatcher_packed_cmd(device, cmds, worker_cores, device_data, cmd, size_words);
uint32_t sub_cmds_size = padded_size(worker_cores.size() * sizeof(uint32_t), sizeof(CQDispatchCmd));
TT_FATAL(repeat == false || size_words * sizeof(uint32_t) + sizeof(CQDispatchCmd) + sub_cmds_size <= dispatch_buffer_page_size_g);

add_dispatcher_packed_cmd(device, cmds, worker_cores, device_data, cmd, size_words, repeat);
}

inline void gen_rnd_dispatcher_packed_write_cmd(Device *device,
Expand Down Expand Up @@ -887,8 +901,22 @@ inline void gen_rnd_dispatcher_packed_write_cmd(Device *device,
}
}

bool repeat = std::rand() % 2;
if (repeat) {
// TODO fix this if/when we add mcast
uint32_t sub_cmds_size = padded_size(gets_data.size() * sizeof(uint32_t), 16); // L1_ALIGNMENT16
if (xfer_size_bytes + sizeof (CQDispatchCmd) + sub_cmds_size > dispatch_buffer_page_size_g) {
static bool warned = false;
if (!warned) {
log_warning(tt::LogTest, "Clamping packed_write cmd w/ stride=0 size to fit a dispatch page. Adjust max/min xfer sizes for reliable perf data");
warned = true;
}
xfer_size_bytes = dispatch_buffer_page_size_g - sizeof (CQDispatchCmd) - sub_cmds_size;
}
}

gen_dispatcher_packed_write_cmd(device, cmds, gets_data, device_data,
xfer_size_bytes / sizeof(uint32_t));
xfer_size_bytes / sizeof(uint32_t), repeat);
}

inline void gen_dispatcher_host_write_cmd(vector<uint32_t>& cmds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,10 @@ void gen_smoke_test(Device *device,
gen_dispatcher_packed_write_cmd(device, dispatch_cmds, worker_cores, device_data, 12);
add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds);

dispatch_cmds.resize(0);
gen_dispatcher_packed_write_cmd(device, dispatch_cmds, worker_cores, device_data, 12, true);
add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds);

dispatch_cmds.resize(0);
worker_cores.resize(0);
worker_cores.push_back(first_worker_g);
Expand Down
6 changes: 5 additions & 1 deletion tt_metal/impl/dispatch/cq_commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,12 @@ struct CQDispatchWritePagedCmd {
uint32_t pages;
} __attribute__((packed));


constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE = 0x00;
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST = 0x01;
constexpr uint32_t CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE = 0x02;
struct CQDispatchWritePackedCmd {
uint8_t is_multicast;
uint8_t flags; // see above
uint16_t count; // number of sub-cmds (max 1020 unicast, 510 mcast). Max num sub-cmds = (dispatch_constants::TRANSFER_PAGE_SIZE - sizeof(CQDispatchCmd)) / sizeof(CQDispatchWritePacked*castSubCmd)
uint32_t addr; // common memory address across all packed SubCmds
uint16_t size; // size of each packet, stride is padded to L1 alignment and less than dispatch_cb_page_size
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/dispatch/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ class DeviceCommand {

auto initialize_write_packed_cmd = [&](CQDispatchCmd *write_packed_cmd) {
write_packed_cmd->base.cmd_id = CQ_DISPATCH_CMD_WRITE_PACKED;
write_packed_cmd->write_packed.is_multicast = multicast;
write_packed_cmd->write_packed.flags =
multicast ? CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST : CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NONE;
write_packed_cmd->write_packed.count = num_sub_cmds;
write_packed_cmd->write_packed.addr = common_addr;
write_packed_cmd->write_packed.size = packed_data_sizeB;
Expand Down
19 changes: 12 additions & 7 deletions tt_metal/impl/dispatch/kernels/cq_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ void process_write_paged() {
// Since all subcmds all appear in the first page and given the size restrictions
// this command can't be too many pages. All pages are released at the end
template<bool mcast, typename WritePackedSubCmd>
void process_write_packed() {
void process_write_packed(uint32_t flags) {
volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr;

uint32_t count = cmd->write_packed.count;
Expand All @@ -560,7 +560,9 @@ void process_write_packed() {

uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd) + count * sizeof(WritePackedSubCmd);
data_ptr = round_up_pow2(data_ptr, L1_NOC_ALIGNMENT);
uint32_t stride = round_up_pow2(xfer_size, L1_NOC_ALIGNMENT);
uint32_t stride = (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE) ? 0 : round_up_pow2(xfer_size, L1_NOC_ALIGNMENT);
DPRINT << data_ptr << " " << cmd_ptr << " " << xfer_size << " " << dispatch_cb_page_size << ENDL();
ASSERT(stride != 0 || data_ptr - cmd_ptr + xfer_size <= dispatch_cb_page_size);

volatile uint32_t tt_l1_ptr *l1_addr = (uint32_t *)(cmd_ptr + sizeof(CQDispatchCmd));
cq_noc_async_write_init_state<CQ_NOC_snDL, mcast>(0, dst_addr, xfer_size);
Expand Down Expand Up @@ -766,11 +768,14 @@ static inline bool process_cmd_d(uint32_t& cmd_ptr) {
break;

case CQ_DISPATCH_CMD_WRITE_PACKED:
DPRINT << "cmd_write_packed" << ENDL();
if (cmd->write_packed.is_multicast) {
process_write_packed<true, CQDispatchWritePackedMulticastSubCmd>();
} else {
process_write_packed<false, CQDispatchWritePackedUnicastSubCmd>();
{
DPRINT << "cmd_write_packed" << ENDL();
uint32_t flags = cmd->write_packed.flags;
if (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST) {
process_write_packed<true, CQDispatchWritePackedMulticastSubCmd>(flags);
} else {
process_write_packed<false, CQDispatchWritePackedUnicastSubCmd>(flags);
}
}
break;

Expand Down

0 comments on commit 6e6b6aa

Please sign in to comment.