Skip to content

Commit

Permalink
#3718: Link multicasts that use the same path to avoid multiple path …
Browse files Browse the repository at this point in the history
…reservations in a row
  • Loading branch information
DrJessop committed Nov 20, 2023
1 parent 2bf8ee9 commit afb882d
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 49 deletions.
5 changes: 3 additions & 2 deletions tt_metal/hw/inc/dataflow_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,8 @@ void noc_async_write_multicast(
std::uint32_t src_local_l1_addr,
std::uint64_t dst_noc_addr_multicast,
std::uint32_t size,
std::uint32_t num_dests) {
std::uint32_t num_dests,
bool linked = false) {
DEBUG_STATUS('N', 'M', 'W', 'W');
DEBUG_SANITIZE_NOC_MULTI_ADDR(dst_noc_addr_multicast, size);
DEBUG_SANITIZE_WORKER_ADDR(src_local_l1_addr, size);
Expand All @@ -1219,7 +1220,7 @@ void noc_async_write_multicast(
size,
NOC_MULTICAST_WRITE_VC,
true,
false,
linked,
num_dests);
DEBUG_STATUS('N', 'M', 'W', 'D');
}
Expand Down
99 changes: 64 additions & 35 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ ProgramMap ConstructProgramMap(const Device* device, Program& program) {
uint32_t dst,
vector<transfer_info>& transfers,
vector<uint32_t>& num_transfers_per_page,
const vector<pair<uint32_t, uint32_t>>& dst_noc_transfer_info) -> uint32_t {
const vector<pair<uint32_t, uint32_t>>& dst_noc_transfer_info,
bool linked = false) -> uint32_t {
while (num_bytes) {
uint32_t num_bytes_left_in_page = DeviceCommand::PROGRAM_PAGE_SIZE - (src % DeviceCommand::PROGRAM_PAGE_SIZE);
uint32_t num_bytes_in_transfer = std::min(num_bytes_left_in_page, num_bytes);
Expand All @@ -60,7 +61,7 @@ ProgramMap ConstructProgramMap(const Device* device, Program& program) {
uint32_t transfer_instruction_idx = 1;
for (const auto& [dst_noc_encoding, num_receivers] : dst_noc_transfer_info) {
bool last = transfer_instruction_idx == dst_noc_transfer_info.size();
transfer_info transfer_instruction = {.size_in_bytes = num_bytes_in_transfer, .dst = dst, .dst_noc_encoding = dst_noc_encoding, .num_receivers = num_receivers, .last_transfer_in_group = last};
transfer_info transfer_instruction = {.size_in_bytes = num_bytes_in_transfer, .dst = dst, .dst_noc_encoding = dst_noc_encoding, .num_receivers = num_receivers, .last_transfer_in_group = last, .linked = linked};
transfers.push_back(transfer_instruction);
num_transfers_within_page++;
transfer_instruction_idx++;
Expand Down Expand Up @@ -154,32 +155,54 @@ ProgramMap ConstructProgramMap(const Device* device, Program& program) {

// Step 3: Determine the transfer information for each program binary
src = 0; // Restart src since it begins in a new page
for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) {
const Kernel* kernel = detail::GetKernel(program, kernel_id);
vector<pair<uint32_t, uint32_t>> dst_noc_multicast_info =
extract_dst_noc_multicast_info(kernel->core_range_set().ranges());
for (const KernelGroup &kg: program.get_kernel_groups()) {

vector<RISCV> sub_kernels;
if (kernel->processor() == RISCV::COMPUTE) {
sub_kernels = {RISCV::TRISC0, RISCV::TRISC1, RISCV::TRISC2};
} else {
sub_kernels = {kernel->processor()};
}
vector<pair<uint32_t, uint32_t>> dst_noc_multicast_info =
extract_dst_noc_multicast_info(kg.core_ranges.ranges());

uint32_t sub_kernel_index = 0;
for (const ll_api::memory& kernel_bin : kernel->binaries(device->id())) {
kernel_bin.process_spans([&](vector<uint32_t>::const_iterator mem_ptr, uint64_t dst, uint32_t len) {
uint32_t num_bytes = len * sizeof(uint32_t);
if ((dst & MEM_LOCAL_BASE) == MEM_LOCAL_BASE) {
dst = (dst & ~MEM_LOCAL_BASE) + processor_to_local_mem_addr.at(sub_kernels[sub_kernel_index]);
} else if ((dst & MEM_NCRISC_IRAM_BASE) == MEM_NCRISC_IRAM_BASE) {
dst = (dst & ~MEM_NCRISC_IRAM_BASE) + MEM_NCRISC_INIT_IRAM_L1_BASE;
}
// So far, we don't support linking optimizations for kernel groups
// which use multiple core ranges
bool linked = dst_noc_multicast_info.size() == 1;

vector<KernelID> kernel_ids;
if (kg.riscv0_id) kernel_ids.push_back(kg.riscv0_id.value());
if (kg.riscv1_id) kernel_ids.push_back(kg.riscv1_id.value());
if (kg.compute_id) kernel_ids.push_back(kg.compute_id.value());

uint32_t src_copy = src;
for (size_t i = 0; i < kernel_ids.size(); i++) {
KernelID kernel_id = kernel_ids[i];
vector<RISCV> sub_kernels;
const Kernel* kernel = detail::GetKernel(program, kernel_id);
if (kernel->processor() == RISCV::COMPUTE) {
sub_kernels = {RISCV::TRISC0, RISCV::TRISC1, RISCV::TRISC2};
} else {
sub_kernels = {kernel->processor()};
}

src = update_program_page_transfers(
src, num_bytes, dst, program_page_transfers, num_transfers_in_program_pages, dst_noc_multicast_info);
});
sub_kernel_index++;
uint32_t sub_kernel_index = 0;
const auto& binaries = kernel->binaries(device->id());
for (size_t j = 0; j < binaries.size(); j++) {
const ll_api::memory& kernel_bin = binaries[j];

uint32_t k = 0;
uint32_t num_spans = kernel_bin.num_spans();
kernel_bin.process_spans([&](vector<uint32_t>::const_iterator mem_ptr, uint64_t dst, uint32_t len) {
linked &= (i != kernel_ids.size() - 1) or (j != binaries.size() - 1) or (k != num_spans - 1);

uint32_t num_bytes = len * sizeof(uint32_t);
if ((dst & MEM_LOCAL_BASE) == MEM_LOCAL_BASE) {
dst = (dst & ~MEM_LOCAL_BASE) + processor_to_local_mem_addr.at(sub_kernels[sub_kernel_index]);
} else if ((dst & MEM_NCRISC_IRAM_BASE) == MEM_NCRISC_IRAM_BASE) {
dst = (dst & ~MEM_NCRISC_IRAM_BASE) + MEM_NCRISC_INIT_IRAM_L1_BASE;
}

src = update_program_page_transfers(
src, num_bytes, dst, program_page_transfers, num_transfers_in_program_pages, dst_noc_multicast_info, linked);
k++;
});
sub_kernel_index++;
}
}
}

Expand Down Expand Up @@ -230,14 +253,20 @@ ProgramMap ConstructProgramMap(const Device* device, Program& program) {

// Create a vector of all program binaries/cbs/semaphores
uint32_t program_page_idx = 0;
for (size_t kernel_id = 0; kernel_id < program.num_kernels(); kernel_id++) {
const Kernel* kernel = detail::GetKernel(program, kernel_id);

for (const ll_api::memory& kernel_bin : kernel->binaries(device->id())) {
kernel_bin.process_spans([&](vector<uint32_t>::const_iterator mem_ptr, uint64_t dst, uint32_t len) {
std::copy(mem_ptr, mem_ptr + len, program_pages.begin() + program_page_idx);
program_page_idx = align(program_page_idx + len, noc_transfer_alignment_in_bytes / sizeof(uint32_t));
});
for (const KernelGroup &kg: program.get_kernel_groups()) {
vector<KernelID> kernel_ids;
if (kg.riscv0_id) kernel_ids.push_back(kg.riscv0_id.value());
if (kg.riscv1_id) kernel_ids.push_back(kg.riscv1_id.value());
if (kg.compute_id) kernel_ids.push_back(kg.compute_id.value());
for (KernelID kernel_id: kernel_ids) {
const Kernel* kernel = detail::GetKernel(program, kernel_id);

for (const ll_api::memory& kernel_bin : kernel->binaries(device->id())) {
kernel_bin.process_spans([&](vector<uint32_t>::const_iterator mem_ptr, uint64_t dst, uint32_t len) {
std::copy(mem_ptr, mem_ptr + len, program_pages.begin() + program_page_idx);
program_page_idx = align(program_page_idx + len, noc_transfer_alignment_in_bytes / sizeof(uint32_t));
});
}
}
}

Expand Down Expand Up @@ -440,8 +469,8 @@ const DeviceCommand EnqueueProgramCommand::assemble_device_command(uint32_t host
uint32_t num_transfers_in_page = num_transfers_per_page[j];
command.write_program_entry(num_transfers_in_page);
for (uint32_t k = 0; k < num_transfers_in_page; k++) {
const auto [num_bytes, dst, dst_noc, num_receivers, last_multicast_in_group] = transfers_in_pages[i];
command.add_write_page_partial_instruction(num_bytes, dst, dst_noc, num_receivers, last_multicast_in_group);
const auto [num_bytes, dst, dst_noc, num_receivers, last_multicast_in_group, linked] = transfers_in_pages[i];
command.add_write_page_partial_instruction(num_bytes, dst, dst_noc, num_receivers, last_multicast_in_group, linked);
i++;
}
}
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct transfer_info {
uint32_t dst_noc_encoding;
uint32_t num_receivers;
bool last_transfer_in_group;
bool linked;
};

struct ProgramMap {
Expand Down
9 changes: 4 additions & 5 deletions tt_metal/impl/dispatch/device_command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,18 @@ void DeviceCommand::write_program_entry(const uint32_t value) {
}

void DeviceCommand::add_write_page_partial_instruction(
const uint32_t num_bytes, const uint32_t dst, const uint32_t dst_noc, const uint32_t num_receivers, const bool advance) {
const uint32_t num_bytes, const uint32_t dst, const uint32_t dst_noc, const uint32_t num_receivers, const bool advance, const bool linked) {

// This 'at' does size checking
this->desc.at(this->program_transfer_idx + 4) = advance;
this->desc.at(this->program_transfer_idx + 5) = linked;

this->desc[this->program_transfer_idx] = num_bytes;
this->desc[this->program_transfer_idx + 1] = dst;
this->desc[this->program_transfer_idx + 2] = dst_noc;
this->desc[this->program_transfer_idx + 3] = num_receivers;
this->desc[this->program_transfer_idx + 4] = advance;

// std::cout << "WRITE PAGE PARTIAL AT " << this->program_transfer_idx << ": " << num_bytes << ", " << dst << ", " << dst_noc << std::endl;

this->program_transfer_idx += 5;
this->program_transfer_idx += 6;
}

const std::array<uint32_t, DeviceCommand::NUM_ENTRIES_IN_DEVICE_COMMAND>& DeviceCommand::get_desc() const {
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/dispatch/device_command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class DeviceCommand {
const uint32_t dst,
const uint32_t dst_noc,
const uint32_t num_receivers,
const bool advance);
const bool advance,
const bool linked);

const std::array<uint32_t, NUM_ENTRIES_IN_DEVICE_COMMAND>& get_desc() const;

Expand Down
12 changes: 6 additions & 6 deletions tt_metal/impl/dispatch/kernels/command_queue_consumer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ FORCE_INLINE void write_buffers(
}

template <bool multicast>
FORCE_INLINE void write_program_page(uint32_t page_addr, volatile tt_l1_ptr uint32_t*& command_ptr) {
FORCE_INLINE void write_program_page(uint32_t page_addr, volatile tt_l1_ptr uint32_t*& command_ptr, bool last_page) {
uint32_t num_transfers = command_ptr[0];
command_ptr++;
uint32_t src = page_addr;
Expand All @@ -98,18 +98,18 @@ FORCE_INLINE void write_program_page(uint32_t page_addr, volatile tt_l1_ptr uint
uint32_t dst = command_ptr[1];
uint32_t dst_noc = command_ptr[2];
uint32_t num_recv = command_ptr[3];

// advance is false if we are sending the same data to different rectangles of workers
bool last_transfer_in_group = command_ptr[4];
bool linked = (not (last_page & last_transfer_in_group)) & command_ptr[5];

uint64_t dst_noc_addr = (uint64_t(dst_noc) << 32) | dst;

if constexpr (multicast) {
noc_async_write_multicast(src, dst_noc_addr, num_bytes, num_recv);
noc_async_write_multicast(src, dst_noc_addr, num_bytes, num_recv, linked);
} else {
noc_async_write_one_packet(src, dst_noc_addr, num_bytes);
}

command_ptr += 5;
command_ptr += 6;
if (last_transfer_in_group) {
src = align(src + num_bytes, 16);
}
Expand All @@ -132,7 +132,7 @@ FORCE_INLINE void program_page_transfer(
multicore_cb_wait_front(db_buf_switch, num_to_write);
uint32_t src_addr = get_read_ptr(db_buf_switch);
for (uint32_t i = 0; i < num_to_write; i++) {
write_program_page<multicast>(src_addr, command_ptr);
write_program_page<multicast>(src_addr, command_ptr, i == num_to_write - 1);
src_addr += DeviceCommand::PROGRAM_PAGE_SIZE;
}
page_idx += num_to_write;
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/llrt/tt_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class memory {

size_t size() const { return data_.size(); }

size_t num_spans() const { return link_spans_.size(); }

// Read from file
void fill_from_discontiguous_hex(std::istream& is);

Expand Down

0 comments on commit afb882d

Please sign in to comment.