From 1b0640f917409a712d2600e6d3406ba3608174af Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Tue, 22 Oct 2024 17:33:55 +0000 Subject: [PATCH] #13655: Initial FD refactor to support sub devices Support multiple dispatch entries for worker->dispatch sync Update dispatch d/s to have a semaphore per dispatch entry to enable syncing on specific worker counts Update LaunchMessageRingBufferState and WorkerConfigBufferMgr to be tracked per sub_device Update various FD commands to support syncing on multiple sub devices: - ERB, EWB, ERE will be updated to take in a list of sub devices for blocking on in the future. Currently will sync all sub_devices - Trace will currently track all sub devices. Potential to track specific sub devices (could be automatic) in the future - EP is currently hardcoded to sub device 0. This will be updated to determine the used sub devices in the future --- .../tools/profiler/test_device_profiler.py | 4 +- .../dispatch/test_dispatcher.cpp | 1 + .../dispatch/test_prefetcher.cpp | 6 + .../test_kernels/dataflow/dram_copy.cpp | 2 +- .../test_kernels/misc/watcher_asserts.cpp | 2 +- tt_metal/hw/firmware/src/brisc.cc | 4 +- tt_metal/hw/firmware/src/erisc.cc | 6 +- tt_metal/hw/firmware/src/idle_erisc.cc | 2 +- tt_metal/hw/inc/dev_msgs.h | 2 +- tt_metal/impl/device/device.cpp | 159 +++++--- tt_metal/impl/device/device.hpp | 16 +- tt_metal/impl/dispatch/command_queue.cpp | 353 +++++++++++++----- tt_metal/impl/dispatch/command_queue.hpp | 34 +- .../impl/dispatch/command_queue_interface.hpp | 43 ++- tt_metal/impl/dispatch/cq_commands.hpp | 10 +- tt_metal/impl/dispatch/debug_tools.cpp | 9 +- tt_metal/impl/dispatch/device_command.hpp | 20 +- .../impl/dispatch/kernels/cq_dispatch.cpp | 51 ++- .../dispatch/kernels/cq_dispatch_slave.cpp | 66 +++- .../impl/dispatch/kernels/cq_prefetch.cpp | 2 +- tt_metal/impl/program/program.cpp | 10 +- tt_metal/impl/trace/trace_buffer.hpp | 10 +- 22 files changed, 573 insertions(+), 239 deletions(-) diff --git a/tests/tt_metal/tools/profiler/test_device_profiler.py b/tests/tt_metal/tools/profiler/test_device_profiler.py index 4132d4d90bc..736e944cbe6 100644 --- a/tests/tt_metal/tools/profiler/test_device_profiler.py +++ b/tests/tt_metal/tools/profiler/test_device_profiler.py @@ -167,11 +167,11 @@ def test_dispatch_cores(): REF_COUNT_DICT = { "grayskull": { "Tensix CQ Dispatch": 16, - "Tensix CQ Prefetch": 24, + "Tensix CQ Prefetch": 25, }, "wormhole_b0": { "Tensix CQ Dispatch": 16, - "Tensix CQ Prefetch": 24, + "Tensix CQ Prefetch": 25, }, } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_dispatcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_dispatcher.cpp index ae6c2cf33a3..bddd81b5b91 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_dispatcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_dispatcher.cpp @@ -477,6 +477,7 @@ int main(int argc, char **argv) { 0, // prefetch_downstream_buffer_pages num_compute_cores, // max_write_packed_cores 0, + dispatch_constants::DISPATCH_MESSAGE_ENTRIES, 0, 0, 0, diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp index ba38f8ac8db..6f2539d1651 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp @@ -1912,6 +1912,7 @@ void configure_for_single_chip(Device *device, prefetch_downstream_buffer_pages, num_compute_cores, // max_write_packed_cores 0, + dispatch_constants::DISPATCH_MESSAGE_ENTRIES, 0, 0, 0, @@ -1932,6 +1933,7 @@ void configure_for_single_chip(Device *device, dispatch_compile_args[12] = dispatch_downstream_cb_sem; dispatch_compile_args[13] = dispatch_h_cb_sem; dispatch_compile_args[14] = dispatch_d_preamble_size; + dispatch_compile_args[21] = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; CoreCoord phys_dispatch_d_downstream_core = packetized_path_en_g ? phys_dispatch_relay_mux_core : phys_dispatch_h_core; configure_kernel_variant(program, @@ -1952,6 +1954,7 @@ void configure_for_single_chip(Device *device, dispatch_compile_args[12] = dispatch_h_cb_sem; dispatch_compile_args[13] = dispatch_downstream_cb_sem; dispatch_compile_args[14] = 0; // preamble size + dispatch_compile_args[21] = 1; // unused: dispatch_d only. max_num_worker_sems is used for array sizing, set to 1 CoreCoord phys_dispatch_h_upstream_core = packetized_path_en_g ? phys_dispatch_relay_demux_core : phys_dispatch_core; configure_kernel_variant(program, @@ -2655,6 +2658,7 @@ void configure_for_multi_chip(Device *device, prefetch_downstream_buffer_pages, num_compute_cores, 0, + dispatch_constants::DISPATCH_MESSAGE_ENTRIES, 0, 0, 0, @@ -2675,6 +2679,7 @@ void configure_for_multi_chip(Device *device, dispatch_compile_args[12] = dispatch_downstream_cb_sem; dispatch_compile_args[13] = dispatch_h_cb_sem; dispatch_compile_args[14] = dispatch_d_preamble_size; + dispatch_compile_args[21] = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; CoreCoord phys_dispatch_d_downstream_core = packetized_path_en_g ? phys_dispatch_relay_mux_core : phys_dispatch_h_core; configure_kernel_variant(program_r, @@ -2694,6 +2699,7 @@ void configure_for_multi_chip(Device *device, dispatch_compile_args[12] = dispatch_h_cb_sem; dispatch_compile_args[13] = dispatch_downstream_cb_sem; dispatch_compile_args[14] = 0; // preamble size + dispatch_compile_args[21] = 1; // unused: dispatch_d only. max_num_worker_sems is used for array sizing, set to 1 CoreCoord phys_dispatch_h_upstream_core = packetized_path_en_g ? phys_dispatch_relay_demux_core : phys_dispatch_core; configure_kernel_variant(program, diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp index 78a989fdab7..13c5c4c40c5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp @@ -34,7 +34,7 @@ void kernel_main() { tt_l1_ptr mailboxes_t* const mailboxes = (tt_l1_ptr mailboxes_t*)(MEM_MAILBOX_BASE); #endif uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); noc_fast_atomic_increment(noc_index, NCRISC_AT_CMD_BUF, dispatch_addr, NOC_UNICAST_WRITE_VC, 1, 31, false); #endif diff --git a/tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp b/tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp index 6c623db7eb3..13406c2423b 100644 --- a/tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp @@ -41,7 +41,7 @@ void MAIN { #endif uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); noc_fast_atomic_increment(noc_index, NCRISC_AT_CMD_BUF, dispatch_addr, NOC_UNICAST_WRITE_VC, 1, 31 /*wrap*/, false /*linked*/); } #else diff --git a/tt_metal/hw/firmware/src/brisc.cc b/tt_metal/hw/firmware/src/brisc.cc index a3b22ccfdf1..c69ba4ad7dd 100644 --- a/tt_metal/hw/firmware/src/brisc.cc +++ b/tt_metal/hw/firmware/src/brisc.cc @@ -374,7 +374,7 @@ int main() { // For future proofing, the noc_index value is initialized to 0, to ensure an invalid NOC txn is not issued. uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); mailboxes->go_message.signal = RUN_MSG_DONE; // Notify dispatcher that this has been done DEBUG_SANITIZE_NOC_ADDR(noc_index, dispatch_addr, 4); @@ -453,7 +453,7 @@ int main() { launch_msg_address->kernel_config.enables = 0; uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); DEBUG_SANITIZE_NOC_ADDR(noc_index, dispatch_addr, 4); noc_fast_atomic_increment( noc_index, diff --git a/tt_metal/hw/firmware/src/erisc.cc b/tt_metal/hw/firmware/src/erisc.cc index 2c1f978b994..b67876cda47 100644 --- a/tt_metal/hw/firmware/src/erisc.cc +++ b/tt_metal/hw/firmware/src/erisc.cc @@ -83,7 +83,7 @@ void __attribute__((noinline)) Application(void) { launch_msg_address->kernel_config.enables = 0; uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); internal_::notify_dispatch_core_done(dispatch_addr); mailboxes->launch_msg_rd_ptr = (launch_msg_rd_ptr + 1) & (launch_msg_buffer_num_entries - 1); // Only executed if watcher is enabled. Ensures that we don't report stale data due to invalid launch messages in the ring buffer @@ -94,9 +94,9 @@ void __attribute__((noinline)) Application(void) { } else if (go_message_signal == RUN_MSG_RESET_READ_PTR) { // Reset the launch message buffer read ptr mailboxes->launch_msg_rd_ptr = 0; - int64_t dispatch_addr = + uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_y), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); mailboxes->go_message.signal = RUN_MSG_DONE; internal_::notify_dispatch_core_done(dispatch_addr); } else { diff --git a/tt_metal/hw/firmware/src/idle_erisc.cc b/tt_metal/hw/firmware/src/idle_erisc.cc index 518b33f544c..43164366df0 100644 --- a/tt_metal/hw/firmware/src/idle_erisc.cc +++ b/tt_metal/hw/firmware/src/idle_erisc.cc @@ -145,7 +145,7 @@ int main() { launch_msg_address->kernel_config.enables = 0; uint64_t dispatch_addr = NOC_XY_ADDR(NOC_X(mailboxes->go_message.master_x), - NOC_Y(mailboxes->go_message.master_x), DISPATCH_MESSAGE_ADDR); + NOC_Y(mailboxes->go_message.master_x), DISPATCH_MESSAGE_ADDR + mailboxes->go_message.dispatch_message_offset); DEBUG_SANITIZE_NOC_ADDR(noc_index, dispatch_addr, 4); noc_fast_atomic_increment(noc_index, NCRISC_AT_CMD_BUF, dispatch_addr, NOC_UNICAST_WRITE_VC, 1, 31 /*wrap*/, false /*linked*/); mailboxes->launch_msg_rd_ptr = (launch_msg_rd_ptr + 1) & (launch_msg_buffer_num_entries - 1); diff --git a/tt_metal/hw/inc/dev_msgs.h b/tt_metal/hw/inc/dev_msgs.h index 0b027259c6a..69a57f90d3a 100644 --- a/tt_metal/hw/inc/dev_msgs.h +++ b/tt_metal/hw/inc/dev_msgs.h @@ -109,7 +109,7 @@ struct kernel_config_msg_t { } __attribute__((packed)); struct go_msg_t { - volatile uint8_t pad; + volatile uint8_t dispatch_message_offset; volatile uint8_t master_x; volatile uint8_t master_y; volatile uint8_t signal; // INIT, GO, DONE, RESET_RD_PTR diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index e536c9e940a..22270d6497a 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -53,12 +53,12 @@ bool Device::is_inactive_ethernet_core(CoreCoord logical_core) const { return inactive_ethernet_cores.find(logical_core) != inactive_ethernet_cores.end(); } -uint32_t Device::num_eth_worker_cores() const { - return this->num_eth_worker_cores_; +uint32_t Device::num_eth_worker_cores(uint32_t sub_device_index) const { + return this->num_eth_worker_cores_[sub_device_index]; } -uint32_t Device::num_worker_cores() const { - return this->num_worker_cores_; +uint32_t Device::num_worker_cores(uint32_t sub_device_index) const { + return this->num_worker_cores_[sub_device_index]; } std::vector Device::get_noc_encoding_for_active_eth_cores(NOC noc_index) { @@ -199,8 +199,11 @@ void Device::initialize_cluster() { this->clear_l1_state(); } int ai_clk = tt::Cluster::instance().get_device_aiclk(this->id_); - this->num_worker_cores_ = this->compute_with_storage_grid_size().x * this->compute_with_storage_grid_size().y; - this->num_eth_worker_cores_ = this->get_active_ethernet_cores(true).size(); + // TODO: This will be changed to be updated when setting up sub devices + this->num_worker_cores_.fill(0); + this->num_eth_worker_cores_.fill(0); + this->num_worker_cores_[0] = this->compute_with_storage_grid_size().x * this->compute_with_storage_grid_size().y; + this->num_eth_worker_cores_[0] = this->get_active_ethernet_cores(true).size(); log_info(tt::LogMetal, "AI CLK for device {} is: {} MHz", this->id_, ai_clk); } @@ -1204,7 +1207,7 @@ void Device::update_workers_build_settings(std::vector(device_worker_variants[DispatchWorkerType::PREFETCH_D][dispatch_d_idx]); // 1 to 1 mapping bw prefetch_d and dispatch_d auto dispatch_s_settings = std::get<1>(device_worker_variants[DispatchWorkerType::DISPATCH_S][dispatch_d_idx]); // 1 to 1 mapping bw dispatch_s and dispatch_d @@ -1578,10 +1584,11 @@ void Device::update_workers_build_settings(std::vectorget_noc_multicast_encoding(dispatch_s_noc_index, tensix_worker_physical_grid); tt_cxy_pair dispatch_s_location = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id); @@ -2150,6 +2161,7 @@ void Device::compile_command_queue_programs() { uint32_t dev_completion_queue_wr_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR); uint32_t dev_completion_queue_rd_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD); uint32_t dispatch_message_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t max_dispatch_message_entries = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; const uint32_t prefetch_sync_sem = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, prefetch_core, 0, dispatch_core_type); const uint32_t prefetch_sem = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, prefetch_core, dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages(), dispatch_core_type); @@ -2161,7 +2173,7 @@ void Device::compile_command_queue_programs() { CoreCoord dispatch_s_physical_core = {0xff, 0xff}; uint32_t dispatch_s_buffer_base = 0xff; uint32_t dispatch_s_sem = 0xff; // used by dispatch_s to sync with prefetch - uint32_t dispatch_s_sync_sem_id = 0xff; // used by dispatch_d to signal that dispatch_s can send go signal + uint32_t dispatch_s_sync_sem_base_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM);; // used by dispatch_d to signal that dispatch_s can send go signal if (this->dispatch_s_enabled()) { // Skip allocating dispatch_s for multi-CQ configurations with ethernet dispatch dispatch_s_core = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id); @@ -2176,7 +2188,6 @@ void Device::compile_command_queue_programs() { dispatch_s_buffer_base = dispatch_buffer_base; } dispatch_s_sem = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_s to sync with prefetch - dispatch_s_sync_sem_id = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_d to signal that dispatch_s can send go signal } log_debug(LogDevice, "Dispatching out of {} cores", magic_enum::enum_name(dispatch_core_type)); @@ -2262,7 +2273,8 @@ void Device::compile_command_queue_programs() { 0, // unused prefetch_local_downstream_sem_addr 0, // unused prefetch_downstream_buffer_pages num_compute_cores, // max_write_packed_cores - dispatch_s_sync_sem_id, // used to notify dispatch_s that its safe to send a go signal + dispatch_s_sync_sem_base_addr, // used to notify dispatch_s that its safe to send a go signal + max_dispatch_message_entries, this->get_noc_multicast_encoding(my_noc_index, tensix_worker_physical_grid), // used by dispatch_d to mcast go signals when dispatch_s is not enabled tensix_worker_go_signal_addr, // used by dispatch_d to mcast go signals when dispatch_s is not enabled eth_worker_go_signal_addr, // used by dispatch_d to mcast go signals when dispatch_s is not enabled @@ -2296,13 +2308,14 @@ void Device::compile_command_queue_programs() { dispatch_constants::get(dispatch_core_type).dispatch_s_buffer_size(), dispatch_s_sem, prefetch_dispatch_s_sync_sem, - dispatch_s_sync_sem_id, + dispatch_s_sync_sem_base_addr, this->get_noc_multicast_encoding(NOC::NOC_1, tensix_worker_physical_grid), tensix_num_worker_cores, tensix_worker_go_signal_addr, eth_worker_go_signal_addr, dispatch_core_type == CoreType::ETH, - dispatch_message_addr + dispatch_message_addr, + max_dispatch_message_entries, }; configure_kernel_variant( *command_queue_program_ptr, @@ -2698,6 +2711,7 @@ void Device::configure_command_queue_programs() { } uint32_t prefetch_q_base = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED); + uint32_t dispatch_message_entries = dispatch_constants::DISPATCH_MESSAGE_ENTRIES; for (uint8_t cq_id = 0; cq_id < num_hw_cqs; cq_id++) { tt_cxy_pair prefetch_location = dispatch_core_manager::instance().prefetcher_core(device_id, channel, cq_id); tt_cxy_pair completion_q_writer_location = dispatch_core_manager::instance().completion_queue_writer_core(device_id, channel, cq_id); @@ -2719,7 +2733,8 @@ void Device::configure_command_queue_programs() { uint32_t prefetch_q_pcie_rd_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD); uint32_t completion_q_wr_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR); uint32_t completion_q_rd_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD); - uint32_t dispatch_message_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t dispatch_s_sync_sem_base_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM); + uint32_t dispatch_message_base_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); uint32_t completion_q0_last_event_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); uint32_t completion_q1_last_event_ptr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); std::vector prefetch_q_pcie_rd_ptr_addr_data = {get_absolute_cq_offset(channel, cq_id, cq_size) + cq_start}; @@ -2742,16 +2757,24 @@ void Device::configure_command_queue_programs() { detail::WriteToDeviceL1(mmio_device, completion_q_writer_location, completion_q1_last_event_ptr, zero, dispatch_core_type); // Initialize address where workers signal completion to dispatch core(s). - if (this->distributed_dispatcher()) { - // Ethernet dispatch with a single CQ. dispatch_s and dispatch_d are on different cores. Initialize counter for both to zero. - tt_cxy_pair dispatch_s_location = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id); - detail::WriteToDeviceL1(this, dispatch_s_location, dispatch_message_addr, zero, dispatch_core_type); - } - detail::WriteToDeviceL1(mmio_device, dispatch_location, dispatch_message_addr, zero, dispatch_core_type); - if (device_id != mmio_device_id) { - tt_cxy_pair dispatch_d_location = dispatch_core_manager::instance().dispatcher_d_core(device_id, channel, cq_id); - dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device_id); - detail::WriteToDeviceL1(this, dispatch_d_location, dispatch_message_addr, zero, dispatch_core_type); + // TODO: Should only initialize dispatch_s_sync_sem if this->dispatch_s_enabled()? + for (uint32_t i = 0; i < dispatch_message_entries; i++) { + uint32_t dispatch_s_sync_sem_addr = dispatch_s_sync_sem_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + if (this->distributed_dispatcher()) { + // Ethernet dispatch with a single CQ. dispatch_s and dispatch_d are on different cores. Initialize counter for both to zero. + tt_cxy_pair dispatch_s_location = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id); + detail::WriteToDeviceL1(this, dispatch_s_location, dispatch_s_sync_sem_addr, zero, dispatch_core_type); + detail::WriteToDeviceL1(this, dispatch_s_location, dispatch_message_addr, zero, dispatch_core_type); + } + detail::WriteToDeviceL1(mmio_device, dispatch_location, dispatch_s_sync_sem_addr, zero, dispatch_core_type); + detail::WriteToDeviceL1(mmio_device, dispatch_location, dispatch_message_addr, zero, dispatch_core_type); + if (device_id != mmio_device_id) { + tt_cxy_pair dispatch_d_location = dispatch_core_manager::instance().dispatcher_d_core(device_id, channel, cq_id); + CoreType remote_dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device_id); + detail::WriteToDeviceL1(this, dispatch_d_location, dispatch_s_sync_sem_addr, zero, remote_dispatch_core_type); + detail::WriteToDeviceL1(this, dispatch_d_location, dispatch_message_addr, zero, remote_dispatch_core_type); + } } } @@ -2845,6 +2868,7 @@ void Device::init_command_queue_device() { // TODO: Move this inside the command queue for (auto& hw_cq : this->hw_command_queues_) { hw_cq->set_unicast_only_cores_on_dispatch(this->get_noc_encoding_for_active_eth_cores(this->dispatch_s_enabled() ? NOC::NOC_1 : NOC::NOC_0)); + hw_cq->set_num_worker_sems_on_dispatch(this->num_sub_devices()); } // Added this for safety while debugging hangs with FD v1.3 tunnel to R, should experiment with removing it // tt::Cluster::instance().l1_barrier(this->id()); @@ -2873,7 +2897,7 @@ bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t t this->initialize_allocator(l1_small_size, trace_region_size, l1_bank_remap); this->initialize_build(); // Reset the launch_message ring buffer state seen on host, since its reset on device, each time FW is initialized - this->worker_launch_message_buffer_state.reset(); + std::for_each(this->worker_launch_message_buffer_state.begin(), this->worker_launch_message_buffer_state.end(), std::mem_fn(&LaunchMessageRingBufferState::reset)); // For minimal setup, don't initialize FW, watcher, dprint. They won't work if we're attaching to a hung chip. if (minimal) return true; @@ -3102,6 +3126,29 @@ void Device::check_allocator_is_initialized() const { } } +void Device::reset_num_sub_devices(uint32_t num_sub_devices) { + TT_FATAL((num_sub_devices >=1 && num_sub_devices <= dispatch_constants::DISPATCH_MESSAGE_ENTRIES), "Illegal number of sub devices specified"); + // Finish all running programs + Synchronize(this); + + // Set new number of worker sems on dispatch_s + for (auto& hw_cq : this->hw_command_queues_) { + // Only need to reset launch messages once, so reset on cq 0 + TT_FATAL(!hw_cq->manager.get_bypass_mode(), "Cannot reset worker state during trace capture"); + hw_cq->reset_worker_state(hw_cq->id == 0); + hw_cq->set_num_worker_sems_on_dispatch(num_sub_devices); + } + // Reset the config buffer mgr (is this needed?) + this->sysmem_manager_->reset_config_buffer_mgr(num_sub_devices); + // Reset the launch_message ring buffer state seen on host + std::for_each(this->worker_launch_message_buffer_state.begin(), this->worker_launch_message_buffer_state.begin() + num_sub_devices, std::mem_fn(&LaunchMessageRingBufferState::reset)); + num_sub_devices_ = num_sub_devices; +} + +uint32_t Device::num_sub_devices() const { + return num_sub_devices_; +} + uint32_t Device::num_banks(const BufferType &buffer_type) const { this->check_allocator_is_initialized(); return allocator::num_banks(*this->allocator_, buffer_type); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index dce53a1eae8..3a7515d4535 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -138,9 +138,9 @@ class Device { bool is_inactive_ethernet_core(CoreCoord logical_core) const; - uint32_t num_eth_worker_cores() const; + uint32_t num_eth_worker_cores(uint32_t sub_device_index) const; - uint32_t num_worker_cores() const; + uint32_t num_worker_cores(uint32_t sub_device_index) const; std::tuple get_connected_ethernet_core(CoreCoord eth_core) const { return tt::Cluster::instance().get_connected_ethernet_core(std::make_tuple(this->id_, eth_core)); @@ -158,6 +158,10 @@ class Device { void update_workers_build_settings(std::vector>> &device_worker_variants); + void reset_num_sub_devices(uint32_t num_sub_devices); + + uint32_t num_sub_devices() const; + uint32_t num_banks(const BufferType &buffer_type) const; uint32_t bank_size(const BufferType &buffer_type) const; @@ -301,14 +305,14 @@ class Device { uint32_t worker_thread_core; uint32_t completion_queue_reader_core; std::unique_ptr sysmem_manager_; - LaunchMessageRingBufferState worker_launch_message_buffer_state; + std::array worker_launch_message_buffer_state; uint8_t num_hw_cqs_; std::vector> command_queue_programs; bool using_fast_dispatch; program_cache::detail::ProgramCache program_cache; - uint32_t num_worker_cores_; - uint32_t num_eth_worker_cores_; + std::array num_worker_cores_; + std::array num_eth_worker_cores_; // Program cache interface. Syncrhonize with worker worker threads before querying or // modifying this structure, since worker threads use this for compiling ops void enable_program_cache() { @@ -348,6 +352,8 @@ class Device { void MarkAllocationsUnsafe(); void MarkAllocationsSafe(); std::unordered_map> trace_buffer_pool_; + // Temporary until actual sub_device implementation is added + uint32_t num_sub_devices_ = 1; }; } // namespace v0 diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index e4ee5405f07..efc146e9fcd 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -39,9 +39,6 @@ using std::set; using std::shared_ptr; using std::unique_ptr; -std::mutex finish_mutex; -std::condition_variable finish_cv; - namespace tt::tt_metal { namespace detail { @@ -73,7 +70,7 @@ EnqueueReadBufferCommand::EnqueueReadBufferCommand( Buffer& buffer, void* dst, SystemMemoryManager& manager, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t src_page_index, std::optional pages_to_read) : command_queue_id(command_queue_id), @@ -109,7 +106,7 @@ void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& void EnqueueReadBufferCommand::process() { // accounts for padding uint32_t cmd_sequence_sizeB = - CQ_PREFETCH_CMD_BARE_MIN_SIZE + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + CQ_PREFETCH_CMD_BARE_MIN_SIZE * this->expected_num_workers_completed.size() + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT CQ_PREFETCH_CMD_BARE_MIN_SIZE + // CQ_PREFETCH_CMD_STALL CQ_PREFETCH_CMD_BARE_MIN_SIZE + // CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH + CQ_DISPATCH_CMD_WRITE_LINEAR_HOST CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_RELAY_LINEAR or CQ_PREFETCH_CMD_RELAY_PAGED @@ -118,10 +115,20 @@ void EnqueueReadBufferCommand::process() { HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); - uint32_t dispatch_message_addr = dispatch_constants::get( - this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t dispatch_message_base_addr = dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t last_index = this->expected_num_workers_completed.size() - 1; + // We only need the write barrier + prefetch stall for the last wait cmd + for (uint32_t i = 0; i < last_index; ++i) { + auto [offset_index, workers_completed] = this->expected_num_workers_completed[i]; + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + command_sequence.add_dispatch_wait( + false, dispatch_message_addr, workers_completed); + + } + auto [offset_index, workers_completed] = this->expected_num_workers_completed[last_index]; + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); command_sequence.add_dispatch_wait_with_prefetch_stall( - true, dispatch_message_addr, this->expected_num_workers_completed); + true, dispatch_message_addr, workers_completed); uint32_t padded_page_size = this->buffer.aligned_page_size(); bool flush_prefetch = false; @@ -146,7 +153,7 @@ EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( const void* src, SystemMemoryManager& manager, bool issue_wait, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t bank_base_address, uint32_t padded_page_size, uint32_t dst_page_index, @@ -277,7 +284,7 @@ void EnqueueWriteBufferCommand::process() { // CQ_DISPATCH_CMD_WRITE_LINEAR) data_size_bytes; if (this->issue_wait) { - cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE * this->expected_num_workers_completed.size(); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT } void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); @@ -285,9 +292,13 @@ void EnqueueWriteBufferCommand::process() { HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); if (this->issue_wait) { - uint32_t dispatch_message_addr = dispatch_constants::get( + uint32_t dispatch_message_base_addr = dispatch_constants::get( this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - command_sequence.add_dispatch_wait(false, dispatch_message_addr, this->expected_num_workers_completed); + for (uint32_t i = 0; i < this->expected_num_workers_completed.size(); ++i) { + auto [offset_index, workers_completed] = this->expected_num_workers_completed[i]; + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + command_sequence.add_dispatch_wait(false, dispatch_message_addr, workers_completed); + } } this->add_dispatch_write(command_sequence); @@ -305,7 +316,7 @@ void EnqueueWriteBufferCommand::process() { } inline uint32_t get_packed_write_max_unicast_sub_cmds(Device* device) { - return device->num_worker_cores(); + return device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; } // EnqueueProgramCommand Section @@ -319,20 +330,23 @@ EnqueueProgramCommand::EnqueueProgramCommand( SystemMemoryManager& manager, uint32_t expected_num_workers_completed, uint32_t multicast_cores_launch_message_wptr, - uint32_t unicast_cores_launch_message_wptr) : + uint32_t unicast_cores_launch_message_wptr, + uint32_t sub_device_id) : command_queue_id(command_queue_id), noc_index(noc_index), manager(manager), expected_num_workers_completed(expected_num_workers_completed), program(program), - dispatch_core(dispatch_core) { + dispatch_core(dispatch_core), + multicast_cores_launch_message_wptr(multicast_cores_launch_message_wptr), + unicast_cores_launch_message_wptr(unicast_cores_launch_message_wptr), + sub_device_id(sub_device_id) { this->device = device; this->dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); this->packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(this->device); this->dispatch_message_addr = dispatch_constants::get( - this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - this->multicast_cores_launch_message_wptr = multicast_cores_launch_message_wptr; - this->unicast_cores_launch_message_wptr = unicast_cores_launch_message_wptr; + this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) + + dispatch_constants::get(this->dispatch_core_type).get_dispatch_message_offset(this->sub_device_id); } void EnqueueProgramCommand::assemble_preamble_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs) { @@ -1253,7 +1267,9 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; if (this->device->dispatch_s_enabled()) { // dispatch_d signals dispatch_s to send the go signal, use a barrier if there are cores active - device_command_sequence.add_notify_dispatch_s_go_signal_cmd(program_transfer_info.num_active_cores > 0); + uint16_t index_bitmask = 0; + index_bitmask |= 1 << this->sub_device_id; + device_command_sequence.add_notify_dispatch_s_go_signal_cmd(program_transfer_info.num_active_cores > 0, index_bitmask); dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; } else { // Wait Noc Write Barrier, wait for binaries/configs and launch_msg to be written to worker cores @@ -1265,6 +1281,7 @@ void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& pro run_program_go_signal.signal = RUN_MSG_GO; run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; + run_program_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(this->dispatch_core_type).get_dispatch_message_offset(this->sub_device_id); uint32_t write_offset_bytes = device_command_sequence.write_offset_bytes(); device_command_sequence.add_dispatch_go_signal_mcast(this->expected_num_workers_completed, go_signal_mcast_flag, *reinterpret_cast(&run_program_go_signal), this->dispatch_message_addr, dispatcher_for_go_signal); program_command_sequence.mcast_go_signal_cmd_ptr = &((CQDispatchCmd*) ((uint32_t*)device_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->mcast; @@ -1313,6 +1330,7 @@ void EnqueueProgramCommand::update_device_commands(ProgramCommandSequence& cache run_program_go_signal.signal = RUN_MSG_GO; run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; + run_program_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(this->dispatch_core_type).get_dispatch_message_offset(this->sub_device_id); cached_program_command_sequence.mcast_go_signal_cmd_ptr->go_signal = *reinterpret_cast(&run_program_go_signal); cached_program_command_sequence.mcast_go_signal_cmd_ptr->wait_count = this->expected_num_workers_completed; } @@ -1336,8 +1354,7 @@ void EnqueueProgramCommand::write_program_command_sequence(const ProgramCommandS uint32_t total_fetch_size_bytes = stall_fetch_size_bytes + preamble_fetch_size_bytes + runtime_args_fetch_size_bytes + program_fetch_size_bytes; - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); - if (total_fetch_size_bytes <= dispatch_constants::get(dispatch_core_type).max_prefetch_command_size()) { + if (total_fetch_size_bytes <= dispatch_constants::get(this->dispatch_core_type).max_prefetch_command_size()) { this->manager.issue_queue_reserve(total_fetch_size_bytes, this->command_queue_id); uint32_t write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); @@ -1459,24 +1476,26 @@ void EnqueueProgramCommand::write_program_command_sequence(const ProgramCommandS void EnqueueProgramCommand::process() { + // TODO: Finalize needs to be by mesh manager bool is_finalized = program.is_finalized(); if (not is_finalized) { program.finalize(device); } + auto& config_buffer_mgr = this->manager.get_config_buffer_mgr(this->sub_device_id); const std::pair&> reservation = - this->manager.get_config_buffer_mgr().reserve(program.get_program_config_sizes()); + config_buffer_mgr.reserve(program.get_program_config_sizes()); bool stall_first = reservation.first.need_sync; // Note: since present implementation always stalls, we always free up to "now" - this->manager.get_config_buffer_mgr().free(reservation.first.sync_count); + config_buffer_mgr.free(reservation.first.sync_count); uint32_t num_workers = 0; if (program.runs_on_noc_multicast_only_cores()) { - num_workers += device->num_worker_cores(); + num_workers += device->num_worker_cores(this->sub_device_id); } if (program.runs_on_noc_unicast_only_cores()) { - num_workers += device->num_eth_worker_cores(); + num_workers += device->num_eth_worker_cores(this->sub_device_id); } - this->manager.get_config_buffer_mgr().alloc( + config_buffer_mgr.alloc( this->expected_num_workers_completed + num_workers); std::vector& kernel_config_addrs = reservation.second; @@ -1543,7 +1562,7 @@ EnqueueRecordEventCommand::EnqueueRecordEventCommand( NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, bool clear_count, bool write_barrier) : command_queue_id(command_queue_id), @@ -1569,7 +1588,7 @@ void EnqueueRecordEventCommand::process() { uint32_t packed_write_sizeB = align(sizeof(CQPrefetchCmd) + packed_event_payload_sizeB, pcie_alignment); uint32_t cmd_sequence_sizeB = - CQ_PREFETCH_CMD_BARE_MIN_SIZE + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + CQ_PREFETCH_CMD_BARE_MIN_SIZE * this->expected_num_workers_completed.size() + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT packed_write_sizeB + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_PACKED + unicast subcmds + event // payload align( @@ -1581,11 +1600,22 @@ void EnqueueRecordEventCommand::process() { HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); - uint32_t dispatch_message_addr = dispatch_constants::get( + uint32_t dispatch_message_base_addr = dispatch_constants::get( dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t last_index = this->expected_num_workers_completed.size() - 1; + // We only need the write barrier for the last wait cmd + for (uint32_t i = 0; i < last_index; ++i) { + auto [offset_index, workers_completed] = this->expected_num_workers_completed[i]; + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + command_sequence.add_dispatch_wait( + false, dispatch_message_addr, workers_completed, this->clear_count); + + } + auto [offset_index, workers_completed] = this->expected_num_workers_completed[last_index]; + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); command_sequence.add_dispatch_wait( - this->write_barrier, dispatch_message_addr, this->expected_num_workers_completed, this->clear_count); + this->write_barrier, dispatch_message_addr, workers_completed, this->clear_count); CoreType core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); @@ -1674,7 +1704,7 @@ EnqueueTraceCommand::EnqueueTraceCommand( SystemMemoryManager& manager, std::shared_ptr& desc, Buffer& buffer, - uint32_t& expected_num_workers_completed, + std::array & expected_num_workers_completed, NOC noc_index, CoreCoord dispatch_core) : command_queue_id(command_queue_id), @@ -1688,59 +1718,71 @@ EnqueueTraceCommand::EnqueueTraceCommand( dispatch_core(dispatch_core) {} void EnqueueTraceCommand::process() { + uint32_t num_sub_devices = device->num_sub_devices(); uint32_t cmd_sequence_sizeB = this->device->dispatch_s_enabled() * CQ_PREFETCH_CMD_BARE_MIN_SIZE + // dispatch_d -> dispatch_s sem update (send only if dispatch_s is running) - CQ_PREFETCH_CMD_BARE_MIN_SIZE + // go signal cmd + (CQ_PREFETCH_CMD_BARE_MIN_SIZE + // go signal cmd CQ_PREFETCH_CMD_BARE_MIN_SIZE + // wait to ensure that reset go signal was processed (dispatch_d) // when dispatch_s and dispatch_d are running on 2 cores, workers update dispatch_s. dispatch_s is responsible for resetting worker count // and giving dispatch_d the latest worker state. This is encapsulated in the dispatch_s wait command (only to be sent when dispatch is distributed // on 2 cores) - (this->device->distributed_dispatcher()) * CQ_PREFETCH_CMD_BARE_MIN_SIZE + + (this->device->distributed_dispatcher()) * CQ_PREFETCH_CMD_BARE_MIN_SIZE) * num_sub_devices + CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_EXEC_BUF - uint8_t go_signal_mcast_flag = 0; - if (desc->num_traced_programs_needing_go_signal_multicast) { - go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_MCAST; - } - if (desc->num_traced_programs_needing_go_signal_unicast) { - go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; - } void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; if (this->device->dispatch_s_enabled()) { - command_sequence.add_notify_dispatch_s_go_signal_cmd(false); + uint16_t index_bitmask = 0; + for (uint32_t i = 0; i < num_sub_devices; ++i) { + index_bitmask |= 1 << i; + } + command_sequence.add_notify_dispatch_s_go_signal_cmd(false, index_bitmask); dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; } + CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); + uint32_t dispatch_message_base_addr = dispatch_constants::get( + dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); go_msg_t reset_launch_message_read_ptr_go_signal; reset_launch_message_read_ptr_go_signal.signal = RUN_MSG_RESET_READ_PTR; reset_launch_message_read_ptr_go_signal.master_x = (uint8_t)this->dispatch_core.x; reset_launch_message_read_ptr_go_signal.master_y = (uint8_t)this->dispatch_core.y; - CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); - uint32_t dispatch_message_addr = dispatch_constants::get( - dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - // Wait to ensure that all kernels have completed. Then send the reset_rd_ptr go_signal. - command_sequence.add_dispatch_go_signal_mcast(this->expected_num_workers_completed, go_signal_mcast_flag, *reinterpret_cast(&reset_launch_message_read_ptr_go_signal), dispatch_message_addr, dispatcher_for_go_signal); - if (desc->num_traced_programs_needing_go_signal_multicast) { - this->expected_num_workers_completed += device->num_worker_cores(); - } - if (desc->num_traced_programs_needing_go_signal_unicast) { - this->expected_num_workers_completed += device->num_eth_worker_cores(); + for (uint32_t i = 0; i < num_sub_devices; ++i) { + uint8_t go_signal_mcast_flag = 0; + if (desc->descriptors[i].num_traced_programs_needing_go_signal_multicast) { + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_MCAST; + } + if (desc->descriptors[i].num_traced_programs_needing_go_signal_unicast) { + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; + } + reset_launch_message_read_ptr_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + // Wait to ensure that all kernels have completed. Then send the reset_rd_ptr go_signal. + command_sequence.add_dispatch_go_signal_mcast(this->expected_num_workers_completed[i], go_signal_mcast_flag, *reinterpret_cast(&reset_launch_message_read_ptr_go_signal), dispatch_message_addr, dispatcher_for_go_signal); + if (desc->descriptors[i].num_traced_programs_needing_go_signal_multicast) { + this->expected_num_workers_completed[i] += device->num_worker_cores(i); + } + if (desc->descriptors[i].num_traced_programs_needing_go_signal_unicast) { + this->expected_num_workers_completed[i] += device->num_eth_worker_cores(i); + } } // Wait to ensure that all workers have reset their read_ptr. dispatch_d will stall until all workers have completed this step, before sending kernel config data to workers // or notifying dispatch_s that its safe to send the go_signal. // Clear the dispatch <--> worker semaphore, since trace starts at 0. - if (this->device->distributed_dispatcher()) { + for (uint32_t i = 0; i < num_sub_devices; ++i) { + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + if (this->device->distributed_dispatcher()) { + command_sequence.add_dispatch_wait( + false, dispatch_message_addr, this->expected_num_workers_completed[i], this->clear_count, false, true, 1); + } command_sequence.add_dispatch_wait( - false, dispatch_message_addr, this->expected_num_workers_completed, this->clear_count, false, true, 1); + false, dispatch_message_addr, this->expected_num_workers_completed[i], this->clear_count); } - command_sequence.add_dispatch_wait( - false, dispatch_message_addr, this->expected_num_workers_completed, this->clear_count); if (this->clear_count) { - this->expected_num_workers_completed = 0; + std::fill(this->expected_num_workers_completed.begin(), this->expected_num_workers_completed.begin() + num_sub_devices, 0); } uint32_t page_size = buffer.page_size(); @@ -1834,7 +1876,7 @@ HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : this->completion_queue_thread = std::move(completion_queue_thread); // Set the affinity of the completion queue reader. set_device_thread_affinity(this->completion_queue_thread, device->completion_queue_reader_core); - this->expected_num_workers_completed = 0; + this->expected_num_workers_completed.fill(0); } void HWCommandQueue::set_unicast_only_cores_on_dispatch(const std::vector& unicast_only_noc_encodings) { @@ -1848,6 +1890,84 @@ void HWCommandQueue::set_unicast_only_cores_on_dispatch(const std::vectormanager.fetch_queue_write(cmd_sequence_sizeB, this->id); } +void HWCommandQueue::set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) { + // Not needed for regular dispatch kernel + if (!this->device->dispatch_s_enabled()) { + return; + } + uint32_t cmd_sequence_sizeB = align(CQ_PREFETCH_CMD_BARE_MIN_SIZE, PCIE_ALIGNMENT); + void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->id); + HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); + command_sequence.add_dispatch_set_num_worker_sems(num_worker_sems, DispatcherSelect::DISPATCH_SLAVE); + this->manager.issue_queue_push_back(cmd_sequence_sizeB, this->id); + this->manager.fetch_queue_reserve_back(this->id); + this->manager.fetch_queue_write(cmd_sequence_sizeB, this->id); +} + +void HWCommandQueue::reset_worker_state(bool reset_launch_msg_state) { + uint32_t num_sub_devices = device->num_sub_devices(); + uint32_t cmd_sequence_sizeB = + reset_launch_msg_state * this->device->dispatch_s_enabled() * CQ_PREFETCH_CMD_BARE_MIN_SIZE + // dispatch_d -> dispatch_s sem update (send only if dispatch_s is running) + (reset_launch_msg_state * CQ_PREFETCH_CMD_BARE_MIN_SIZE + // go signal cmd + CQ_PREFETCH_CMD_BARE_MIN_SIZE + // wait to ensure that reset go signal was processed (dispatch_d) + // when dispatch_s and dispatch_d are running on 2 cores, workers update dispatch_s. dispatch_s is responsible for resetting worker count + // and giving dispatch_d the latest worker state. This is encapsulated in the dispatch_s wait command (only to be sent when dispatch is distributed + // on 2 cores) + this->device->distributed_dispatcher() * CQ_PREFETCH_CMD_BARE_MIN_SIZE) * num_sub_devices; + void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->id); + HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); + bool clear_count = true; + DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; + CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); + uint32_t dispatch_message_base_addr = dispatch_constants::get( + dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + if (reset_launch_msg_state) { + if (device->dispatch_s_enabled()) { + uint16_t index_bitmask = 0; + for (uint32_t i = 0; i < num_sub_devices; ++i) { + index_bitmask |= 1 << i; + } + command_sequence.add_notify_dispatch_s_go_signal_cmd(false, index_bitmask); + dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; + } + go_msg_t reset_launch_message_read_ptr_go_signal; + reset_launch_message_read_ptr_go_signal.signal = RUN_MSG_RESET_READ_PTR; + reset_launch_message_read_ptr_go_signal.master_x = (uint8_t)this->physical_enqueue_program_dispatch_core.x; + reset_launch_message_read_ptr_go_signal.master_y = (uint8_t)this->physical_enqueue_program_dispatch_core.y; + uint8_t go_signal_mcast_flag = 0; + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_MCAST; + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; + for (uint32_t i = 0; i < num_sub_devices; ++i) { + + reset_launch_message_read_ptr_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + // Wait to ensure that all kernels have completed. Then send the reset_rd_ptr go_signal. + command_sequence.add_dispatch_go_signal_mcast(expected_num_workers_completed[i], go_signal_mcast_flag, *reinterpret_cast(&reset_launch_message_read_ptr_go_signal), dispatch_message_addr, dispatcher_for_go_signal); + expected_num_workers_completed[i] += device->num_worker_cores(i); + expected_num_workers_completed[i] += device->num_eth_worker_cores(i); + } + } + // Wait to ensure that all workers have reset their read_ptr. dispatch_d will stall until all workers have completed this step, before sending kernel config data to workers + // or notifying dispatch_s that its safe to send the go_signal. + // Clear the dispatch <--> worker semaphore, since trace starts at 0. + for (uint32_t i = 0; i < num_sub_devices; ++i) { + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + if (device->distributed_dispatcher()) { + command_sequence.add_dispatch_wait( + false, dispatch_message_addr, expected_num_workers_completed[i], clear_count, false, true, 1); + } + command_sequence.add_dispatch_wait( + false, dispatch_message_addr, expected_num_workers_completed[i], clear_count); + } + this->manager.issue_queue_push_back(cmd_sequence_sizeB, this->id); + this->manager.fetch_queue_reserve_back(this->id); + this->manager.fetch_queue_write(cmd_sequence_sizeB, this->id); + + if (clear_count) { + std::fill(expected_num_workers_completed.begin(), expected_num_workers_completed.begin() + num_sub_devices, 0); + } +} + HWCommandQueue::~HWCommandQueue() { ZoneScopedN("HWCommandQueue_destructor"); if (this->exit_condition) { @@ -1911,6 +2031,13 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin uint32_t unpadded_dst_offset = 0; uint32_t src_page_index = 0; + // TODO: We can take in the meshes to wait on + std::vector> expected_workers_completed; + expected_workers_completed.reserve(this->device->num_sub_devices()); + for (uint32_t i = 0; i < this->device->num_sub_devices(); ++i) { + expected_workers_completed.emplace_back(i, this->expected_num_workers_completed[i]); + } + if (is_sharded(buffer.buffer_layout())) { const bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; const auto& buffer_page_mapping = width_split ? buffer.get_buffer_page_mapping() : nullptr; @@ -1954,7 +2081,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin buffer, dst, this->manager, - this->expected_num_workers_completed, + expected_workers_completed, cores[core_id], bank_base_address, src_page_index, @@ -1988,7 +2115,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin buffer, dst, this->manager, - this->expected_num_workers_completed, + expected_workers_completed, src_page_index, pages_to_read); @@ -2053,6 +2180,13 @@ void HWCommandQueue::enqueue_write_buffer(Buffer& buffer, const void* src, bool uint32_t dst_page_index = 0; + // TODO: We can take in the meshes to wait on + std::vector> expected_workers_completed; + expected_workers_completed.reserve(this->device->num_sub_devices()); + for (uint32_t i = 0; i < this->device->num_sub_devices(); ++i) { + expected_workers_completed.emplace_back(i, this->expected_num_workers_completed[i]); + } + if (is_sharded(buffer.buffer_layout())) { const bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; const auto& buffer_page_mapping = width_split ? buffer.get_buffer_page_mapping() : nullptr; @@ -2120,7 +2254,7 @@ void HWCommandQueue::enqueue_write_buffer(Buffer& buffer, const void* src, bool src, this->manager, issue_wait, - this->expected_num_workers_completed, + expected_workers_completed, address, buffer_page_mapping, cores[core_id], @@ -2211,7 +2345,7 @@ void HWCommandQueue::enqueue_write_buffer(Buffer& buffer, const void* src, bool src, this->manager, issue_wait, - this->expected_num_workers_completed, + expected_workers_completed, bank_base_address, page_size_to_write, dst_page_index, @@ -2251,24 +2385,27 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { } #endif + // TODO: We should determine the meshes used by the program. Hardcoded here for now + uint32_t sub_device_id = 0; + // Snapshot of expected workers from previous programs, used for dispatch_wait cmd generation. - uint32_t expected_workers_completed = this->manager.get_bypass_mode() ? this->trace_ctx->num_completion_worker_cores - : this->expected_num_workers_completed; + uint32_t expected_workers_completed = this->manager.get_bypass_mode() ? this->trace_ctx->descriptors[sub_device_id].num_completion_worker_cores + : this->expected_num_workers_completed[sub_device_id]; if (this->manager.get_bypass_mode()) { if (program.runs_on_noc_multicast_only_cores()) { - this->trace_ctx->num_traced_programs_needing_go_signal_multicast++; - this->trace_ctx->num_completion_worker_cores += device->num_worker_cores(); + this->trace_ctx->descriptors[sub_device_id].num_traced_programs_needing_go_signal_multicast++; + this->trace_ctx->descriptors[sub_device_id].num_completion_worker_cores += device->num_worker_cores(sub_device_id); } if (program.runs_on_noc_unicast_only_cores()) { - this->trace_ctx->num_traced_programs_needing_go_signal_unicast++; - this->trace_ctx->num_completion_worker_cores += device->num_eth_worker_cores(); + this->trace_ctx->descriptors[sub_device_id].num_traced_programs_needing_go_signal_unicast++; + this->trace_ctx->descriptors[sub_device_id].num_completion_worker_cores += device->num_eth_worker_cores(sub_device_id); } } else { if (program.runs_on_noc_multicast_only_cores()) { - this->expected_num_workers_completed += device->num_worker_cores(); + this->expected_num_workers_completed[sub_device_id] += device->num_worker_cores(sub_device_id); } if (program.runs_on_noc_unicast_only_cores()) { - this->expected_num_workers_completed += device->num_eth_worker_cores(); + this->expected_num_workers_completed[sub_device_id] += device->num_eth_worker_cores(sub_device_id); } } @@ -2281,14 +2418,15 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { this->manager, expected_workers_completed, // The assembled program command will encode the location of the launch messages in the ring buffer - this->device->worker_launch_message_buffer_state.get_mcast_wptr(), - this->device->worker_launch_message_buffer_state.get_unicast_wptr()); + this->device->worker_launch_message_buffer_state[sub_device_id].get_mcast_wptr(), + this->device->worker_launch_message_buffer_state[sub_device_id].get_unicast_wptr(), + sub_device_id); // Update wptrs for tensix and eth launch message in the device class if (program.runs_on_noc_multicast_only_cores()) { - this->device->worker_launch_message_buffer_state.inc_mcast_wptr(1); + this->device->worker_launch_message_buffer_state[sub_device_id].inc_mcast_wptr(1); } if (program.runs_on_noc_unicast_only_cores()) { - this->device->worker_launch_message_buffer_state.inc_unicast_wptr(1); + this->device->worker_launch_message_buffer_state[sub_device_id].inc_unicast_wptr(1); } this->enqueue_command(command, blocking); @@ -2326,19 +2464,30 @@ void HWCommandQueue::enqueue_record_event(const std::shared_ptr& event, b event->device = this->device; event->ready = true; // what does this mean??? + // TODO: This should take in the meshes to wait on + uint32_t num_sub_devices = this->device->num_sub_devices(); + + std::vector> expected_workers_completed; + expected_workers_completed.reserve(num_sub_devices); + for (uint32_t i = 0; i < num_sub_devices; ++i) { + expected_workers_completed.emplace_back(i, this->expected_num_workers_completed[i]); + } + auto command = EnqueueRecordEventCommand( this->id, this->device, this->noc_index, this->manager, event->event_id, - this->expected_num_workers_completed, + expected_workers_completed, clear_count, true); this->enqueue_command(command, false); if (clear_count) { - this->expected_num_workers_completed = 0; + for (uint32_t i = 0; i < num_sub_devices; ++i) { + this->expected_num_workers_completed[i] = 0; + } } this->issued_completion_q_reads.push( std::make_shared(std::in_place_type, event->event_id)); @@ -2366,11 +2515,14 @@ void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { this->enqueue_command(command, false); // Increment the expected worker cores counter due to trace programs completion - this->expected_num_workers_completed += trace_inst->desc->num_completion_worker_cores; - // After trace runs, the rdptr on each worker will be incremented by the number of programs in the trace - // Update the wptr on host to match state - this->device->worker_launch_message_buffer_state.set_mcast_wptr(trace_inst->desc->num_traced_programs_needing_go_signal_multicast); - this->device->worker_launch_message_buffer_state.set_unicast_wptr(trace_inst->desc->num_traced_programs_needing_go_signal_unicast); + for (const auto& [index, desc]: trace_inst->desc->descriptors) { + this->expected_num_workers_completed[index] += desc.num_completion_worker_cores; + // After trace runs, the rdptr on each worker will be incremented by the number of programs in the trace + // Update the wptr on host to match state + this->device->worker_launch_message_buffer_state[index].set_mcast_wptr(desc.num_traced_programs_needing_go_signal_multicast); + this->device->worker_launch_message_buffer_state[index].set_unicast_wptr(desc.num_traced_programs_needing_go_signal_unicast); + } + if (blocking) { this->finish(); @@ -2675,39 +2827,52 @@ void HWCommandQueue::record_begin(const uint32_t tid, std::shared_ptrdevice->id()); - uint32_t dispatch_message_addr = dispatch_constants::get( + uint32_t dispatch_message_base_addr = dispatch_constants::get( dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - if (this->device->distributed_dispatcher()) { - // wait on dispatch_s before issuing counter reset - command_sequence.add_dispatch_wait(false, dispatch_message_addr, this->expected_num_workers_completed, true, false, true, 1); + + // Currently Trace will track all sub_devices + // Potentially support tracking only used sub_devices in the future + uint32_t num_sub_devices = this->device->num_sub_devices(); + for (uint32_t i = 0; i < num_sub_devices; ++i) { + uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); + if (this->device->distributed_dispatcher()) { + // wait on dispatch_s before issuing counter reset + command_sequence.add_dispatch_wait(false, dispatch_message_addr, this->expected_num_workers_completed[i], true, false, true, 1); + } + // dispatch_d waits for latest non-zero counter from dispatch_s and then clears its local counter + command_sequence.add_dispatch_wait(false, dispatch_message_addr, this->expected_num_workers_completed[i], true); } - // dispatch_d waits for latest non-zero counter from dispatch_s and then clears its local counter - command_sequence.add_dispatch_wait(false, dispatch_message_addr, this->expected_num_workers_completed, true); this->manager.issue_queue_push_back(cmd_sequence_sizeB, this->id); this->manager.fetch_queue_reserve_back(this->id); this->manager.fetch_queue_write(cmd_sequence_sizeB, this->id); - this->expected_num_workers_completed = 0; + std::fill(this->expected_num_workers_completed.begin(), this->expected_num_workers_completed.begin() + num_sub_devices, 0); // Record commands using bypass mode this->tid = tid; this->trace_ctx = ctx; // Record original value of launch msg wptr - this->multicast_cores_launch_message_wptr_reset = this->device->worker_launch_message_buffer_state.get_mcast_wptr(); - this->unicast_cores_launch_message_wptr_reset = this->device->worker_launch_message_buffer_state.get_unicast_wptr(); - // Set launch msg wptr to 0. Every time trace runs on device, it will ensure that the workers - // reset their rptr to be in sync with device. - this->device->worker_launch_message_buffer_state.reset(); + for (uint32_t i = 0; i < num_sub_devices; ++i) { + this->multicast_cores_launch_message_wptr_reset = this->device->worker_launch_message_buffer_state[i].get_mcast_wptr(); + this->unicast_cores_launch_message_wptr_reset = this->device->worker_launch_message_buffer_state[i].get_unicast_wptr(); + // Set launch msg wptr to 0. Every time trace runs on device, it will ensure that the workers + // reset their rptr to be in sync with device. + this->device->worker_launch_message_buffer_state[i].reset(); + } this->manager.set_bypass_mode(true, true); // start } void HWCommandQueue::record_end() { this->tid = std::nullopt; this->trace_ctx = nullptr; + // Currently Trace will track all sub_devices + uint32_t num_sub_devices = this->device->num_sub_devices(); // Reset the launch msg wptrs to their original value, so device can run programs after a trace // was captured. This is needed since trace capture modifies the wptr state on host, even though device // doesn't run any programs. - this->device->worker_launch_message_buffer_state.set_mcast_wptr(this->multicast_cores_launch_message_wptr_reset); - this->device->worker_launch_message_buffer_state.set_unicast_wptr(this->unicast_cores_launch_message_wptr_reset); + for (uint32_t i = 0; i < num_sub_devices; ++i) { + this->device->worker_launch_message_buffer_state[i].set_mcast_wptr(this->multicast_cores_launch_message_wptr_reset); + this->device->worker_launch_message_buffer_state[i].set_unicast_wptr(this->unicast_cores_launch_message_wptr_reset); + } this->manager.set_bypass_mode(false, false); // stop } diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index ba0316502fc..c5e55297506 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -73,7 +73,7 @@ class EnqueueReadBufferCommand : public Command { Device* device; uint32_t command_queue_id; NOC noc_index; - uint32_t expected_num_workers_completed; + const std::vector>& expected_num_workers_completed; uint32_t src_page_index; uint32_t pages_to_read; @@ -86,7 +86,7 @@ class EnqueueReadBufferCommand : public Command { Buffer& buffer, void* dst, SystemMemoryManager& manager, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t src_page_index = 0, std::optional pages_to_read = std::nullopt); @@ -109,7 +109,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { Buffer& buffer, void* dst, SystemMemoryManager& manager, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t src_page_index = 0, std::optional pages_to_read = std::nullopt) : EnqueueReadBufferCommand( @@ -138,7 +138,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { Buffer& buffer, void* dst, SystemMemoryManager& manager, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, const CoreCoord& core, uint32_t bank_base_address, uint32_t src_page_index = 0, @@ -173,7 +173,7 @@ class EnqueueWriteBufferCommand : public Command { NOC noc_index; const void* src; const Buffer& buffer; - uint32_t expected_num_workers_completed; + const std::vector>& expected_num_workers_completed; uint32_t bank_base_address; uint32_t padded_page_size; uint32_t dst_page_index; @@ -189,7 +189,7 @@ class EnqueueWriteBufferCommand : public Command { const void* src, SystemMemoryManager& manager, bool issue_wait, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t bank_base_address, uint32_t padded_page_size, uint32_t dst_page_index = 0, @@ -216,7 +216,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { const void* src, SystemMemoryManager& manager, bool issue_wait, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t bank_base_address, uint32_t padded_page_size, uint32_t dst_page_index = 0, @@ -255,7 +255,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { const void* src, SystemMemoryManager& manager, bool issue_wait, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, uint32_t bank_base_address, const std::shared_ptr& buffer_page_mapping, const CoreCoord& core, @@ -295,6 +295,8 @@ class EnqueueProgramCommand : public Command { uint32_t dispatch_message_addr; uint32_t multicast_cores_launch_message_wptr = 0; uint32_t unicast_cores_launch_message_wptr = 0; + // TODO: There will be multiple ids once programs support spanning multiple sub_devices + uint32_t sub_device_id = 0; public: @@ -307,7 +309,8 @@ class EnqueueProgramCommand : public Command { SystemMemoryManager& manager, uint32_t expected_num_workers_completed, uint32_t multicast_cores_launch_message_wptr, - uint32_t unicast_cores_launch_message_wptr); + uint32_t unicast_cores_launch_message_wptr, + uint32_t sub_device_id); void assemble_preamble_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs); void assemble_stall_commands(ProgramCommandSequence& program_command_sequence, bool prefetch_stall); @@ -331,7 +334,7 @@ class EnqueueRecordEventCommand : public Command { NOC noc_index; SystemMemoryManager& manager; uint32_t event_id; - uint32_t expected_num_workers_completed; + const std::vector>& expected_num_workers_completed; bool clear_count; bool write_barrier; @@ -342,7 +345,7 @@ class EnqueueRecordEventCommand : public Command { NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, - uint32_t expected_num_workers_completed, + const std::vector>& expected_num_workers_completed, bool clear_count = false, bool write_barrier = true); @@ -384,7 +387,7 @@ class EnqueueTraceCommand : public Command { Device* device; SystemMemoryManager& manager; std::shared_ptr& desc; - uint32_t& expected_num_workers_completed; + std::array& expected_num_workers_completed; bool clear_count; NOC noc_index; CoreCoord dispatch_core; @@ -395,7 +398,7 @@ class EnqueueTraceCommand : public Command { SystemMemoryManager& manager, std::shared_ptr& desc, Buffer& buffer, - uint32_t& expected_num_workers_completed, + std::array& expected_num_workers_completed, NOC noc_index, CoreCoord dispatch_core); @@ -496,6 +499,9 @@ class HWCommandQueue { void record_begin(const uint32_t tid, std::shared_ptr ctx); void record_end(); void set_unicast_only_cores_on_dispatch(const std::vector& unicast_only_noc_encodings); + void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems); + void reset_worker_state(bool reset_launch_msg_state); + private: uint32_t id; uint32_t size_B; @@ -506,7 +512,7 @@ class HWCommandQueue { // Expected value of DISPATCH_MESSAGE_ADDR in dispatch core L1 // Value in L1 incremented by worker to signal completion to dispatch. Value on host is set on each enqueue program // call - uint32_t expected_num_workers_completed; + std::array expected_num_workers_completed; volatile bool exit_condition; volatile bool dprint_server_hang = false; diff --git a/tt_metal/impl/dispatch/command_queue_interface.hpp b/tt_metal/impl/dispatch/command_queue_interface.hpp index bf8ac017030..872453f3fef 100644 --- a/tt_metal/impl/dispatch/command_queue_interface.hpp +++ b/tt_metal/impl/dispatch/command_queue_interface.hpp @@ -34,8 +34,9 @@ enum class CommandQueueDeviceAddrType : uint8_t { // Max of 2 CQs. COMPLETION_Q*_LAST_EVENT_PTR track the last completed event in the respective CQs COMPLETION_Q0_LAST_EVENT = 4, COMPLETION_Q1_LAST_EVENT = 5, - DISPATCH_MESSAGE = 6, - UNRESERVED = 7 + DISPATCH_S_SYNC_SEM = 6, + DISPATCH_MESSAGE = 7, + UNRESERVED = 8 }; enum class CommandQueueHostAddrType : uint8_t { @@ -64,8 +65,12 @@ struct dispatch_constants { return *inst; } + using prefetch_q_entry_type = uint16_t; + static constexpr uint8_t MAX_NUM_HW_CQS = 2; - typedef uint16_t prefetch_q_entry_type; + static constexpr uint32_t DISPATCH_MESSAGE_ENTRIES = 16; + static constexpr uint32_t DISPATCH_MESSAGES_MAX_OFFSET = std::numeric_limits::max(); + static constexpr uint32_t PREFETCH_Q_LOG_MINSIZE = 4; static constexpr uint32_t LOG_TRANSFER_PAGE_SIZE = 12; @@ -128,6 +133,12 @@ struct dispatch_constants { return tt::utils::underlying_type(host_addr) * hal.get_alignment(HalMemType::HOST); } + uint32_t get_dispatch_message_offset(uint32_t index) const { + TT_ASSERT(index < DISPATCH_MESSAGE_ENTRIES); + uint32_t offset = index * hal.get_alignment(HalMemType::L1); + return offset; + } + private: dispatch_constants(const CoreType &core_type, const uint32_t num_hw_cqs) { TT_ASSERT(core_type == CoreType::WORKER or core_type == CoreType::ETH); @@ -160,6 +171,7 @@ struct dispatch_constants { TT_ASSERT(cmddat_q_size_ >= 2 * max_prefetch_command_size_); TT_ASSERT(scratch_db_size_ % 2 == 0); TT_ASSERT((dispatch_buffer_block_size & (dispatch_buffer_block_size - 1)) == 0); + TT_ASSERT(DISPATCH_MESSAGE_ENTRIES <= DISPATCH_MESSAGES_MAX_OFFSET / L1_ALIGNMENT + 1, "Number of dispatch message entries exceeds max representable offset"); uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); @@ -171,8 +183,10 @@ struct dispatch_constants { device_cq_addr_sizes_[dev_addr_idx] = sizeof(uint32_t); } else if (dev_addr_type == CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD) { device_cq_addr_sizes_[dev_addr_idx] = L1_ALIGNMENT - sizeof(uint32_t); + } else if (dev_addr_type == CommandQueueDeviceAddrType::DISPATCH_S_SYNC_SEM) { + device_cq_addr_sizes_[dev_addr_idx] = DISPATCH_MESSAGE_ENTRIES * L1_ALIGNMENT; } else if (dev_addr_type == CommandQueueDeviceAddrType::DISPATCH_MESSAGE) { - device_cq_addr_sizes_[dev_addr_idx] = 32; // Should this be 2x L1_ALIGNMENT? + device_cq_addr_sizes_[dev_addr_idx] = DISPATCH_MESSAGE_ENTRIES * L1_ALIGNMENT; } else { device_cq_addr_sizes_[dev_addr_idx] = L1_ALIGNMENT; } @@ -449,7 +463,7 @@ class SystemMemoryManager { std::vector bypass_buffer; uint32_t bypass_buffer_write_offset; - WorkerConfigBufferMgr config_buffer_mgr; + std::array config_buffer_mgr; public: SystemMemoryManager(chip_id_t device_id, uint8_t num_hw_cqs) : @@ -531,11 +545,7 @@ class SystemMemoryManager { std::vector temp_mutexes(num_hw_cqs); cq_to_event_locks.swap(temp_mutexes); - for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { - this->config_buffer_mgr.init_add_core( - hal.get_dev_addr(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG), - hal.get_dev_size(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG)); - } + this->reset_config_buffer_mgr(config_buffer_mgr.size()); } uint32_t get_next_event(const uint8_t cq_id) { @@ -845,7 +855,18 @@ class SystemMemoryManager { this->prefetch_q_dev_ptrs[cq_id] += sizeof(dispatch_constants::prefetch_q_entry_type); } - WorkerConfigBufferMgr& get_config_buffer_mgr() { return config_buffer_mgr; } + WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr[index]; } + + void reset_config_buffer_mgr(const uint32_t max_index) { + for (uint32_t cfg_index = 0; cfg_index < max_index; cfg_index++) { + this->config_buffer_mgr[cfg_index] = WorkerConfigBufferMgr(); + for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { + this->config_buffer_mgr[cfg_index].init_add_core( + hal.get_dev_addr(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG), + hal.get_dev_size(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG)); + } + } + } }; diff --git a/tt_metal/impl/dispatch/cq_commands.hpp b/tt_metal/impl/dispatch/cq_commands.hpp index a2a0399fb5e..f7848a83c86 100644 --- a/tt_metal/impl/dispatch/cq_commands.hpp +++ b/tt_metal/impl/dispatch/cq_commands.hpp @@ -51,6 +51,7 @@ enum CQDispatchCmdId : uint8_t { CQ_DISPATCH_CMD_SEND_GO_SIGNAL = 15, CQ_DISPATCH_NOTIFY_SLAVE_GO_SIGNAL = 16, CQ_DISPATCH_SET_UNICAST_ONLY_CORES = 17, + CQ_DISPATCH_SET_NUM_WORKER_SEMS = 18, CQ_DISPATCH_CMD_MAX_COUNT, // for checking legal IDs }; @@ -268,10 +269,16 @@ struct CQDispatchGoSignalMcastCmd { struct CQDispatchNotifySlaveGoSignalCmd { // sends a counter update to dispatch_s when it sees this cmd uint8_t wait; // if true, issue a write barrier before sending signal to dispatch_s - uint16_t pad2; + uint16_t index_bitmask; uint32_t pad3; } __attribute__((packed)); +struct CQDispatchSetNumWorkerSemsCmd { + uint8_t pad1; + uint16_t pad2; + uint32_t num_worker_sems; +} __attribute__ ((packed)); + struct CQDispatchCmd { CQDispatchBaseCmd base; @@ -288,6 +295,7 @@ struct CQDispatchCmd { CQDispatchGoSignalMcastCmd mcast; CQDispatchSetUnicastOnlyCoresCmd set_unicast_only_cores; CQDispatchNotifySlaveGoSignalCmd notify_dispatch_s_go_signal; + CQDispatchSetNumWorkerSemsCmd set_num_worker_sems; } __attribute__((packed)); }; diff --git a/tt_metal/impl/dispatch/debug_tools.cpp b/tt_metal/impl/dispatch/debug_tools.cpp index ea5141443b6..04003fd41ab 100644 --- a/tt_metal/impl/dispatch/debug_tools.cpp +++ b/tt_metal/impl/dispatch/debug_tools.cpp @@ -178,6 +178,14 @@ uint32_t dump_dispatch_cmd(CQDispatchCmd *cmd, uint32_t cmd_addr, std::ofstream val(cmd->debug.stride)); break; case CQ_DISPATCH_CMD_DELAY: cq_file << fmt::format(" (delay={})", val(cmd->delay.delay)); break; + case CQ_DISPATCH_SET_UNICAST_ONLY_CORES: + cq_file << fmt::format( + " (num_unicast_only_cores={})", val(cmd->set_unicast_only_cores.num_unicast_only_cores)); + break; + case CQ_DISPATCH_SET_NUM_WORKER_SEMS: + cq_file << fmt::format( + " (num_worker_sems={})", val(cmd->set_num_worker_sems.num_worker_sems)); + break; // These commands don't have any additional data to dump. case CQ_DISPATCH_CMD_ILLEGAL: break; case CQ_DISPATCH_CMD_GO: break; @@ -185,7 +193,6 @@ uint32_t dump_dispatch_cmd(CQDispatchCmd *cmd, uint32_t cmd_addr, std::ofstream case CQ_DISPATCH_CMD_EXEC_BUF_END: break; case CQ_DISPATCH_CMD_SEND_GO_SIGNAL: break; case CQ_DISPATCH_NOTIFY_SLAVE_GO_SIGNAL: break; - case CQ_DISPATCH_SET_UNICAST_ONLY_CORES: break; case CQ_DISPATCH_CMD_TERMINATE: break; case CQ_DISPATCH_CMD_SET_WRITE_OFFSET: break; default: TT_THROW("Unrecognized dispatch command: {}", cmd_id); break; diff --git a/tt_metal/impl/dispatch/device_command.hpp b/tt_metal/impl/dispatch/device_command.hpp index c665b63f99c..c7b39abeb7a 100644 --- a/tt_metal/impl/dispatch/device_command.hpp +++ b/tt_metal/impl/dispatch/device_command.hpp @@ -273,13 +273,14 @@ class DeviceCommand { this->cmd_write_offsetB = align(this->cmd_write_offsetB, PCIE_ALIGNMENT); } - void add_notify_dispatch_s_go_signal_cmd(uint8_t wait) { + void add_notify_dispatch_s_go_signal_cmd(uint8_t wait, uint16_t index_bitmask) { // Command to have dispatch_master send a notification to dispatch_slave this->add_prefetch_relay_inline(true, sizeof(CQDispatchCmd), DispatcherSelect::DISPATCH_MASTER); auto initialize_sem_update_cmd = [&](CQDispatchCmd *sem_update_cmd) { *sem_update_cmd = {}; sem_update_cmd->base.cmd_id = CQ_DISPATCH_NOTIFY_SLAVE_GO_SIGNAL; sem_update_cmd->notify_dispatch_s_go_signal.wait = wait; + sem_update_cmd->notify_dispatch_s_go_signal.index_bitmask = index_bitmask; }; CQDispatchCmd *dispatch_s_sem_update_dst = this->reserve_space(sizeof(CQDispatchCmd)); if constexpr (hugepage_write) { @@ -398,6 +399,23 @@ class DeviceCommand { this->add_data(noc_encodings.data(), data_sizeB, increment_sizeB); } + void add_dispatch_set_num_worker_sems(const uint32_t num_worker_sems, DispatcherSelect dispatcher_type) { + this->add_prefetch_relay_inline(true, sizeof(CQDispatchCmd), dispatcher_type); + auto initialize_set_num_worker_sems_cmd = [&] (CQDispatchCmd *set_num_worker_sems_cmd) { + set_num_worker_sems_cmd->base.cmd_id = CQ_DISPATCH_SET_NUM_WORKER_SEMS; + set_num_worker_sems_cmd->set_num_worker_sems.num_worker_sems = num_worker_sems; + }; + CQDispatchCmd *set_num_worker_sems_cmd_dst = this->reserve_space(sizeof(CQDispatchCmd)); + if constexpr (hugepage_write) { + alignas(MEMCPY_ALIGNMENT) CQDispatchCmd set_num_worker_sems_cmd; + initialize_set_num_worker_sems_cmd(&set_num_worker_sems_cmd); + this->memcpy(set_num_worker_sems_cmd_dst, &set_num_worker_sems_cmd, sizeof(CQDispatchCmd)); + } else { + initialize_set_num_worker_sems_cmd(set_num_worker_sems_cmd_dst); + } + this->cmd_write_offsetB = align(this->cmd_write_offsetB, this->pcie_alignment); + } + void add_dispatch_set_write_offsets(uint32_t write_offset0, uint32_t write_offset1, uint32_t write_offset2) { this->add_prefetch_relay_inline(true, sizeof(CQDispatchCmd)); auto initialize_write_offset_cmd = [&](CQDispatchCmd *write_offset_cmd) { diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index 384a1793a7d..3e8ce0d6b48 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -41,16 +41,17 @@ constexpr uint32_t prefetch_h_noc_xy = get_compile_time_arg_val(16); constexpr uint32_t prefetch_h_local_downstream_sem_addr = get_compile_time_arg_val(17); constexpr uint32_t prefetch_h_max_credits = get_compile_time_arg_val(18); constexpr uint32_t packed_write_max_unicast_sub_cmds = get_compile_time_arg_val(19); // Number of cores in compute grid -constexpr uint32_t dispatch_s_sem_id = get_compile_time_arg_val(20); -constexpr uint32_t worker_mcast_grid = get_compile_time_arg_val(21); -constexpr uint32_t mcast_go_signal_addr = get_compile_time_arg_val(22); -constexpr uint32_t unicast_go_signal_addr = get_compile_time_arg_val(23); -constexpr uint32_t distributed_dispatcher = get_compile_time_arg_val(24); -constexpr uint32_t host_completion_q_wr_ptr = get_compile_time_arg_val(25); -constexpr uint32_t dev_completion_q_wr_ptr = get_compile_time_arg_val(26); -constexpr uint32_t dev_completion_q_rd_ptr = get_compile_time_arg_val(27); -constexpr uint32_t is_d_variant = get_compile_time_arg_val(28); -constexpr uint32_t is_h_variant = get_compile_time_arg_val(29); +constexpr uint32_t dispatch_s_sync_sem_base_addr = get_compile_time_arg_val(20); +constexpr uint32_t max_num_worker_sems = get_compile_time_arg_val(21); // maximum number of worker semaphores +constexpr uint32_t worker_mcast_grid = get_compile_time_arg_val(22); +constexpr uint32_t mcast_go_signal_addr = get_compile_time_arg_val(23); +constexpr uint32_t unicast_go_signal_addr = get_compile_time_arg_val(24); +constexpr uint32_t distributed_dispatcher = get_compile_time_arg_val(25); +constexpr uint32_t host_completion_q_wr_ptr = get_compile_time_arg_val(26); +constexpr uint32_t dev_completion_q_wr_ptr = get_compile_time_arg_val(27); +constexpr uint32_t dev_completion_q_rd_ptr = get_compile_time_arg_val(28); +constexpr uint32_t is_d_variant = get_compile_time_arg_val(29); +constexpr uint32_t is_h_variant = get_compile_time_arg_val(30); constexpr uint8_t upstream_noc_index = UPSTREAM_NOC_INDEX; constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); @@ -858,14 +859,22 @@ void process_notify_dispatch_s_go_signal_cmd() { DPRINT << " DISPATCH_S_NOTIFY BARRIER\n"; noc_async_write_barrier(); } - if constexpr (distributed_dispatcher) { - uint64_t dispatch_s_notify_addr = get_noc_addr_helper(dispatch_s_noc_xy, get_semaphore(dispatch_s_sem_id)); - static uint32_t num_go_signals_safe_to_send = 1; - noc_inline_dw_write(dispatch_s_notify_addr, num_go_signals_safe_to_send); - num_go_signals_safe_to_send++; - } else { - tt_l1_ptr uint32_t* notify_ptr = (uint32_t tt_l1_ptr*)(get_semaphore(dispatch_s_sem_id)); - *notify_ptr = (*notify_ptr) + 1; + uint16_t index_bitmask = cmd->notify_dispatch_s_go_signal.index_bitmask; + + while(index_bitmask != 0) { + uint32_t set_index = __builtin_ctz(index_bitmask); + uint32_t dispatch_s_sync_sem_addr = dispatch_s_sync_sem_base_addr + set_index * L1_ALIGNMENT; + if constexpr (distributed_dispatcher) { + static uint32_t num_go_signals_safe_to_send[max_num_worker_sems] = {0}; + uint64_t dispatch_s_notify_addr = get_noc_addr_helper(dispatch_s_noc_xy, dispatch_s_sync_sem_addr); + num_go_signals_safe_to_send[set_index]++; + noc_inline_dw_write(dispatch_s_notify_addr, num_go_signals_safe_to_send[set_index]); + } else { + tt_l1_ptr uint32_t* notify_ptr = (uint32_t tt_l1_ptr*)(dispatch_s_sync_sem_addr); + *notify_ptr = (*notify_ptr) + 1; + } + // Unset the bit + index_bitmask &= index_bitmask - 1; } cmd_ptr += sizeof(CQDispatchCmd); } @@ -974,6 +983,12 @@ static inline bool process_cmd_d(uint32_t &cmd_ptr, uint32_t* l1_cache, uint32_t process_set_unicast_only_cores(); break; + case CQ_DISPATCH_SET_NUM_WORKER_SEMS: + DPRINT << "cmd_set_num_worker_sems" << ENDL(); + // This command is only used by dispatch_s + cmd_ptr += sizeof(CQDispatchCmd); + break; + case CQ_DISPATCH_CMD_SET_WRITE_OFFSET: DPRINT << "write offset: " << cmd->set_write_offset.offset0 << " " << diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch_slave.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch_slave.cpp index 3ba5a9454fd..198fd833a25 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch_slave.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch_slave.cpp @@ -30,13 +30,14 @@ constexpr uint32_t cb_log_page_size = get_compile_time_arg_val(1); constexpr uint32_t cb_size = get_compile_time_arg_val(2); constexpr uint32_t my_dispatch_cb_sem_id = get_compile_time_arg_val(3); constexpr uint32_t upstream_dispatch_cb_sem_id = get_compile_time_arg_val(4); -constexpr uint32_t dispatch_s_sync_sem_id = get_compile_time_arg_val(5); +constexpr uint32_t dispatch_s_sync_sem_base_addr = get_compile_time_arg_val(5); constexpr uint32_t worker_mcast_grid = get_compile_time_arg_val(6); constexpr uint32_t num_worker_cores_to_mcast = get_compile_time_arg_val(7); constexpr uint32_t mcast_go_signal_addr = get_compile_time_arg_val(8); constexpr uint32_t unicast_go_signal_addr = get_compile_time_arg_val(9); constexpr uint32_t distributed_dispatcher = get_compile_time_arg_val(10); // dispatch_s and dispatch_d running on different cores -constexpr uint32_t worker_sem_addr = get_compile_time_arg_val(11); // workers update the semaphore at this location to signal completion +constexpr uint32_t worker_sem_base_addr = get_compile_time_arg_val(11); // workers update the semaphore at this location to signal completion +constexpr uint32_t max_num_worker_sems = get_compile_time_arg_val(12); // maximum number of worker semaphores constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t dispatch_d_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); @@ -56,7 +57,9 @@ static int num_unicast_cores = -1; // When dispatch_d and dispatch_s run on separate cores, dispatch_s gets the go signal update from workers. // dispatch_s is responsible for sending the latest worker completion count to dispatch_d. // To minimize the number of writes from dispatch_s to dispatch_d, locally track dispatch_d's copy. -static uint32_t worker_count_update_for_dispatch_d = 0; +static uint32_t worker_count_update_for_dispatch_d[max_num_worker_sems] = {0}; + +static uint32_t num_worker_sems = 1; FORCE_INLINE void dispatch_s_wr_reg_cmd_buf_init() { @@ -102,7 +105,8 @@ void dispatch_s_noc_inline_dw_write(uint64_t addr, uint32_t val, uint8_t noc_id, FORCE_INLINE void wait_for_workers(volatile CQDispatchCmd tt_l1_ptr *cmd) { - volatile tt_l1_ptr uint32_t* worker_sem = reinterpret_cast(worker_sem_addr); + uint8_t dispatch_message_offset = *((uint8_t *)&cmd->mcast.go_signal + offsetof(go_msg_t, dispatch_message_offset)); + volatile tt_l1_ptr uint32_t* worker_sem = reinterpret_cast(worker_sem_base_addr + dispatch_message_offset); while (wrap_gt(cmd->mcast.wait_count, *worker_sem)); } @@ -110,12 +114,18 @@ template FORCE_INLINE void update_worker_completion_count_on_dispatch_d() { if constexpr(distributed_dispatcher) { - uint32_t num_workers_signalling_completion = *reinterpret_cast(worker_sem_addr); - if (num_workers_signalling_completion != worker_count_update_for_dispatch_d) { - worker_count_update_for_dispatch_d = num_workers_signalling_completion; - uint64_t dispatch_d_dst = get_noc_addr_helper(dispatch_d_noc_xy, worker_sem_addr); - dispatch_s_noc_inline_dw_write(dispatch_d_dst, num_workers_signalling_completion, my_noc_index); - if constexpr (flush_write) { + bool write = false; + for (uint32_t i = 0, worker_sem_addr = worker_sem_base_addr; i < num_worker_sems; ++i, worker_sem_addr += L1_ALIGNMENT) { + uint32_t num_workers_signalling_completion = *reinterpret_cast(worker_sem_addr); + if (num_workers_signalling_completion != worker_count_update_for_dispatch_d[i]) { + worker_count_update_for_dispatch_d[i] = num_workers_signalling_completion; + uint64_t dispatch_d_dst = get_noc_addr_helper(dispatch_d_noc_xy, worker_sem_addr); + dispatch_s_noc_inline_dw_write(dispatch_d_dst, num_workers_signalling_completion, my_noc_index); + write = true; + } + } + if constexpr (flush_write) { + if (write) { noc_async_writes_flushed(); } } @@ -151,13 +161,7 @@ void process_go_signal_mcast_cmd() { volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; // Get semaphore that will be update by dispatch_d, signalling that it's safe to send a go signal volatile tt_l1_ptr uint32_t* sync_sem_addr = - reinterpret_cast(get_semaphore(dispatch_s_sync_sem_id)); - // The location of the go signal embedded in the command does not meet NOC alignment requirements. - // cmd_ptr is guaranteed to meet the alignment requirements, since it is written to by prefetcher over NOC. - // Copy the go signal from an unaligned location to an aligned (cmd_ptr) location. This is safe as long as we - // can guarantee that copying the go signal does not corrupt any other command fields, which is true (see CQDispatchGoSignalMcastCmd). - volatile uint32_t tt_l1_ptr* aligned_go_signal_storage = (volatile uint32_t tt_l1_ptr*)cmd_ptr; - *aligned_go_signal_storage = cmd->mcast.go_signal; + reinterpret_cast(dispatch_s_sync_sem_base_addr + (cmd->mcast.wait_addr - worker_sem_base_addr)); // Wait for notification from dispatch_d, signalling that it's safe to send the go signal while (wrap_ge(num_mcasts_sent, *sync_sem_addr)) { @@ -167,6 +171,14 @@ void process_go_signal_mcast_cmd() { num_mcasts_sent++; // Go signal sent -> update counter // Wait until workers have completed before sending go signal wait_for_workers(cmd); + + // The location of the go signal embedded in the command does not meet NOC alignment requirements. + // cmd_ptr is guaranteed to meet the alignment requirements, since it is written to by prefetcher over NOC. + // Copy the go signal from an unaligned location to an aligned (cmd_ptr) location. This is safe as long as we + // can guarantee that copying the go signal does not corrupt any other command fields, which is true (see CQDispatchGoSignalMcastCmd). + volatile uint32_t tt_l1_ptr* aligned_go_signal_storage = (volatile uint32_t tt_l1_ptr*)cmd_ptr; + *aligned_go_signal_storage = cmd->mcast.go_signal; + // send go signal update here if (cmd->mcast.mcast_flag & GoSignalMcastSettings::SEND_MCAST) { uint64_t dst = get_noc_addr_helper(worker_mcast_grid, mcast_go_signal_addr); @@ -200,10 +212,15 @@ void set_go_signal_unicast_only_cores() { FORCE_INLINE void process_dispatch_s_wait_cmd() { + static constexpr uint32_t worker_sem_max_addr = worker_sem_base_addr + (max_num_worker_sems - 1) * L1_ALIGNMENT; + volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; // Limited Usage of Wait CMD: dispatch_s should get a wait command only if it's not on the // same core as dispatch_d and is used to clear the worker count - ASSERT(cmd->wait.clear_count && (cmd->wait.addr == worker_sem_addr) && distributed_dispatcher); + ASSERT(cmd->wait.clear_count && distributed_dispatcher); + uint32_t worker_sem_addr = cmd->wait.addr; + ASSERT(worker_sem_addr >= worker_sem_base_addr && worker_sem_addr <= worker_sem_max_addr); + uint32_t index = (worker_sem_addr - worker_sem_base_addr) / L1_ALIGNMENT; volatile tt_l1_ptr uint32_t* worker_sem = reinterpret_cast(worker_sem_addr); // Wait for workers to complete while (wrap_gt(cmd->wait.count, *worker_sem)); @@ -211,7 +228,15 @@ void process_dispatch_s_wait_cmd() { // dispatch_d will clear it's own counter update_worker_completion_count_on_dispatch_d(); *worker_sem = 0; - worker_count_update_for_dispatch_d = 0; // Local worker count update for dispatch_d should reflect state of worker semaphore on dispatch_s + worker_count_update_for_dispatch_d[index] = 0; // Local worker count update for dispatch_d should reflect state of worker semaphore on dispatch_s + cmd_ptr += sizeof(CQDispatchCmd); +} + +FORCE_INLINE +void set_num_worker_sems() { + volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; + num_worker_sems = cmd->set_num_worker_sems.num_worker_sems; + ASSERT(num_worker_sems <= max_num_worker_sems); cmd_ptr += sizeof(CQDispatchCmd); } @@ -234,6 +259,9 @@ void kernel_main() { case CQ_DISPATCH_SET_UNICAST_ONLY_CORES: set_go_signal_unicast_only_cores(); break; + case CQ_DISPATCH_SET_NUM_WORKER_SEMS: + set_num_worker_sems(); + break; case CQ_DISPATCH_CMD_WAIT: process_dispatch_s_wait_cmd(); break; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index acf40f655ea..cf5dada487d 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -13,7 +13,7 @@ #include "tt_metal/impl/dispatch/kernels/cq_common.hpp" #include "debug/dprint.h" -typedef uint16_t prefetch_q_entry_type; +using prefetch_q_entry_type = uint16_t; constexpr uint32_t downstream_cb_base = get_compile_time_arg_val(0); constexpr uint32_t downstream_cb_log_page_size = get_compile_time_arg_val(1); diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index a35923aa65e..5d93965b72f 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -1432,9 +1432,10 @@ uint32_t detail::Program_::get_sem_base_addr(Device *device, CoreCoord logical_c CoreCoord phys_core = device->physical_core_from_logical_core(logical_core, core_type); HalProgrammableCoreType programmable_core_type = device->get_programmable_core_type(phys_core); uint32_t index = hal.get_programmable_core_type_index(programmable_core_type); - + // TODO: We should determine the meshes used by the program. Hardcoded here for now + uint32_t sub_device_id = 0; uint32_t base_addr = device->using_fast_dispatch ? - device->sysmem_manager().get_config_buffer_mgr().get_last_slot_addr(programmable_core_type) : + device->sysmem_manager().get_config_buffer_mgr(sub_device_id).get_last_slot_addr(programmable_core_type) : hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG); return base_addr + this->program_configs_[index].sem_offset; @@ -1449,9 +1450,10 @@ uint32_t detail::Program_::get_cb_base_addr(Device *device, CoreCoord logical_co CoreCoord phys_core = device->physical_core_from_logical_core(logical_core, core_type); HalProgrammableCoreType programmable_core_type = device->get_programmable_core_type(phys_core); uint32_t index = hal.get_programmable_core_type_index(programmable_core_type); - + // TODO: We should determine the meshes used by the program. Hardcoded here for now + uint32_t sub_device_id = 0; uint32_t base_addr = device->using_fast_dispatch ? - device->sysmem_manager().get_config_buffer_mgr().get_last_slot_addr(programmable_core_type) : + device->sysmem_manager().get_config_buffer_mgr(sub_device_id).get_last_slot_addr(programmable_core_type) : hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG); return base_addr + this->program_configs_[index].cb_offset; diff --git a/tt_metal/impl/trace/trace_buffer.hpp b/tt_metal/impl/trace/trace_buffer.hpp index fce464a3b8c..4338a7d3c78 100644 --- a/tt_metal/impl/trace/trace_buffer.hpp +++ b/tt_metal/impl/trace/trace_buffer.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -16,9 +17,12 @@ namespace tt::tt_metal { namespace detail { struct TraceDescriptor { - uint32_t num_completion_worker_cores = 0; - uint32_t num_traced_programs_needing_go_signal_multicast = 0; - uint32_t num_traced_programs_needing_go_signal_unicast = 0; + struct Descriptor { + uint32_t num_completion_worker_cores = 0; + uint32_t num_traced_programs_needing_go_signal_multicast = 0; + uint32_t num_traced_programs_needing_go_signal_unicast = 0; + }; + std::unordered_map descriptors; std::vector data; }; } // namespace detail