From aa949613cd7b59307617a4b635eec5e3b88d16c6 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 4 Jun 2024 04:30:01 +0000 Subject: [PATCH] #7724: Add prototype for autonomous streams for use in tunneller Streams are autonomous data movement hardware engines present on every tensix and erisc core. Typically, streams are only capable of moving data in a static pattern from a predetermined ordering of senders/receivers. Luckily, for the tunneling use case, the producers and consumers are always the same and we only need to make sure we can forward messages indefinitely. This prototype is the first step to enable streams in the dispatch datapath so that we may recover erisc cores for use by kernels. Since the stream can run autonomously with this setup, we can initialize it such that it implements tunnelling behaviour without erisc overhead. With the exception of some bandwidth sharing (L1 and ethernet) on the margins, a user kernel would never know the stream is busy working as the tunneler. Indefinite message forwarding can be accomplished by creating two phases in the autonomous stream's blob and making the second phase point its next phase to the start of the first phase. This way, with the stream configured to auto-configure and auto-advance, it will end up looping forever. The remaining challenge is to ensure that we can safely reset/teardown the stream so that the next time a program runs on the hardware, the remote sender dispatch core is able to establish a handshake with the relay stream. If it kept running in the background, the dispatch code path would have no idea how to intercept it and establish communication with it. Therefore we reset any time we need to teardown the dispatch datapath. Streams are opaque and brittle, and this is not an originally intended use-case for them. However, ironically, it seems to map best with all of the other limitations provided with streams. === Phase Selection === Streams are very finicky and have an undesirable trait where even if they are reset, they expect the next phase they handshake on to be different. So if in a prior run, the sender finished on phase 1 and the relay finished on phase 1, then for the next run, neither stream should start on phase 1 on the next run. For this reason, on stream startup, the FW inspects the streams current phase and based on that, chooses a valid next starting phase. It sends this starting phase information to its sender stream, if it has one. The same is done for the downstream direction so receivers know which `remote_src_phase` to handshake on. === Resets === After every run, we must teardown and reset the streams so they are ready to use and able to handshake properly the next time a program uses the AI accelerator. To reset properly, we need to ensure a few things: 1) The relay stream is *not* processing any data at the time of reset - in other words, the full datapath should be flushed before reset 2) There should be no acks pending to be sent upstream. The receiver/relay kernels do this be checking for stream active and a special debug register In a fully fleshed out design, this reset should ideally be done before stream construction. Additionally, it must also be done in the event of program failure (e.g. ctrl^C, sigkill, etc.). === Limitations === There are some limitations that will always be true: - max_message_size == min(stream_buffer_size, sender_buffer_size) - streams expect a header present for every message - streams expect the entire message to be resident when send is started There are currently some known limitations (may be lifted in future): - min # messages per phase = 128 - fewer leads to deterministic handshake hang - this hang deterministically happens after min(num_phase_ranges,24) runs. 24 also happens to be the number of dest ready table entries for WH although it's unclear if this is a pure coincidence - disabling the dest_ready_table leads to immediate handshake hang and so wasn't pursued further - max # messages per phase = 2048 - This is due to how the phase range selection is implemented --- .../streams/stream_io_kernel_helpers.hpp | 135 +++ .../dataflow/streams/stream_relay.cpp | 325 ++++++ .../streams/stream_relay_remote_receiver.cpp | 272 +++++ .../stream_relay_remote_receiver_writer.cpp | 37 + .../streams/stream_relay_remote_sender.cpp | 364 +++++++ .../stream_relay_remote_sender_reader.cpp | 66 ++ .../unit_tests_fast_dispatch/CMakeLists.txt | 1 + .../streams/test_autonomous_relay_streams.cpp | 973 ++++++++++++++++++ .../inc/wormhole/noc/noc_overlay_parameters.h | 1 + 9 files changed, 2174 insertions(+) create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp create mode 100644 tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp new file mode 100644 index 00000000000..0df88172c89 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +struct stream_state_t { + const uint32_t local_data_buffer_base_address; + const uint32_t local_msg_info_ptr_base_address; + + uint32_t local_phase_id; + uint32_t messages_per_phase; + uint32_t msg_info_wrptr_addr; + + uint32_t num_tiles_sent; + uint32_t tile_header_num_msgs; + + uint32_t local_buffer_base_addr; + uint32_t local_buffer_size; + uint32_t local_msg_info_ptr; + uint32_t local_buffer_read_offset; + + uint32_t remote_buffer_base_addr; + uint32_t remote_buffer_size; + uint32_t remote_msg_info_ptr; + uint32_t remote_buffer_write_offset; + + uint32_t remote_phase_id; + + uint32_t get_current_local_buffer_address() const { + return local_data_buffer_base_address + local_buffer_read_offset; + } +}; + +struct phase_iterator_t { + phase_iterator_t(uint32_t start_phase, uint32_t max_phase) : + phase_id(start_phase), max_phase(max_phase), start_phase(start_phase) {} + uint32_t phase_id; + uint32_t max_phase; + uint32_t start_phase; + + FORCE_INLINE uint32_t get() const { return phase_id; } + + FORCE_INLINE void increment() { phase_id = phase_id == max_phase ? start_phase : phase_id + 1; } +}; + +struct noc_endpoint_info_t { + uint32_t data_noc_id; + uint32_t update_noc_id; + uint32_t noc_x; + uint32_t noc_y; +}; + +#define STREAM_CFG(field, val) ((val) << (field)) + +#define AUTO_CFG_HEADER(next_phase_num_cfg_reg_writes, curr_phase_num_msgs, phase_num_incr) \ + ((uint32_t)(((next_phase_num_cfg_reg_writes) << 24) | ((curr_phase_num_msgs) << 12) | (phase_num_incr))) + +#define STREAM_REMOTE_DEST(dest_x, dest_y, dest_stream_id) \ + (((dest_x) << STREAM_REMOTE_DEST_X) | ((dest_y) << STREAM_REMOTE_DEST_Y) | \ + ((dest_stream_id) << STREAM_REMOTE_DEST_STREAM_ID)) + +#define STREAM_REMOTE_SRC(src_x, src_y, src_stream_id) \ + (((src_x) << STREAM_REMOTE_SRC_X) | ((src_y) << STREAM_REMOTE_SRC_Y) | ((src_stream_id) << REMOTE_SRC_STREAM_ID)) + +FORCE_INLINE uint32_t +blob_header_dw(uint32_t next_phase_num_cfg_reg_writes, uint32_t curr_phase_num_msgs, uint32_t phase_num_incr) { + return (next_phase_num_cfg_reg_writes << 24) | (curr_phase_num_msgs << 12) | phase_num_incr; +} + +FORCE_INLINE void stream_phase_blob_run( + uint32_t stream_id, volatile uint32_t *blob_start_addr, uint32_t start_phase_num_cfg_regs) { + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, reinterpret_cast(blob_start_addr)); + NOC_STREAM_WRITE_REG( + stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, start_phase_num_cfg_regs << NEXT_PHASE_NUM_CFG_REG_WRITES); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + (0x1 << PHASE_AUTO_CONFIG) | (1 << NEXT_PHASE_SRC_CHANGE) | (1 << NEXT_PHASE_DEST_CHANGE)); +} +FORCE_INLINE void stream_phase_blob_run( + uint32_t stream_id, + volatile uint32_t *blob_start_addr, + uint32_t num_messages_per_phase, + uint32_t start_phase_num_cfg_regs) { + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, reinterpret_cast(blob_start_addr)); + + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, + blob_header_dw(start_phase_num_cfg_regs, num_messages_per_phase, 1)); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + (0x1 << PHASE_AUTO_ADVANCE) | (0x1 << PHASE_AUTO_CONFIG) | (1 << NEXT_PHASE_SRC_CHANGE) | + (1 << NEXT_PHASE_DEST_CHANGE)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 1); +} + +FORCE_INLINE uint32_t blob_cfg_dw(uint32_t reg_index, uint32_t reg_val) { return (reg_val << 8) | reg_index; } + +FORCE_INLINE uint32_t set_blob_reg_field(uint32_t blob_dw, uint32_t field_width, uint32_t field_offset, uint32_t val) { + uint32_t mask = ((1 << field_width) - 1) << field_offset; + return (blob_dw & ~mask) | ((val << field_offset) & mask); +} + +FORCE_INLINE uint32_t get_first_available_phase_out_of_reset(uint32_t stream_id) { + uint32_t stream_phase_coming_out_of_reset = stream_get_curr_phase(stream_id); + return ( + stream_phase_coming_out_of_reset < 4096 ? 4096 : 1); +} + +FORCE_INLINE uint32_t notify_remote_receiver_of_starting_phase( + uint32_t stream_id, uint32_t local_buffer_addr, uint64_t remote_receiver_noc_addr) { + uint32_t starting_phase = get_first_available_phase_out_of_reset(stream_id); + ASSERT(starting_phase > 0); + *reinterpret_cast(local_buffer_addr) = starting_phase; + noc_async_write(local_buffer_addr, remote_receiver_noc_addr, sizeof(uint32_t)); + // noc_semaphore_set_remote(local_buffer_addr, remote_receiver_noc_addr); + noc_async_writes_flushed(); + return starting_phase; +} + +FORCE_INLINE uint32_t wait_for_remote_source_starting_phase(volatile uint32_t *addr) { + while (*addr == 0) { + asm volatile("nop"); + } + return *addr; +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp new file mode 100644 index 00000000000..e6e23c33fa7 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp @@ -0,0 +1,325 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + // Work to do before productizable: + // - Test phase advance + // - test > 2k messages (and > 4k messages) + // - Test variable sized messages + // - Test rerun after test completion (without reset) + // - Currently a bug where the phase ID persists from prior run + // + + uint32_t arg_idx = 0; + + uint32_t relay_stream_overlay_blob_addr = get_arg_val(arg_idx++); + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_src_noc_x = get_arg_val(arg_idx++); + uint32_t remote_src_noc_y = get_arg_val(arg_idx++); + uint32_t remote_src_stream_id = get_arg_val(arg_idx++); + uint32_t remote_src_noc_id = get_arg_val(arg_idx++); + + uint32_t remote_dest_noc_x = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_y = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_id = get_arg_val(arg_idx++); + uint32_t remote_dest_buf_addr = get_arg_val(arg_idx++); + uint32_t remote_dest_buf_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_dest_tile_header_buffer_addr = get_arg_val(arg_idx++); + volatile uint32_t* tx_rx_done_semaphore_addr = + reinterpret_cast(get_arg_val(arg_idx++)); + bool is_first_relay_stream_in_chain = get_arg_val(arg_idx++) == 1; + + uint32_t remote_src_start_phase_addr = get_arg_val(arg_idx++); + uint32_t dest_remote_src_start_phase_addr = get_arg_val(arg_idx++); + + *tx_rx_done_semaphore_addr = 0; // should already be set to 0, but why not... + // use stream_buffer_addr as temporary storage just for this initial setup + + const uint32_t local_first_phase = notify_remote_receiver_of_starting_phase( + stream_id, + stream_buffer_addr + 16, // local storage to hold the phase while async send in progress, 16B for noc alignment + get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, dest_remote_src_start_phase_addr)); + const uint32_t local_second_phase = local_first_phase + 1; + + // If first relay, we'd expect this to be stream_tile_header_max_num_messages + STARTING_PHASE because the + // remote_sender (FW managed) is programmed as one phase per message and there are + // `stream_tile_header_max_num_messages` messages in this stream's phase. If second relay, we'd expect this to be + // SECOND_PHASE + const uint32_t first_phase_remote_src_phase = + wait_for_remote_source_starting_phase(reinterpret_cast(remote_src_start_phase_addr)); + const uint32_t second_phase_remote_src_phase = + is_first_relay_stream_in_chain ? stream_tile_header_max_num_messages + first_phase_remote_src_phase + : first_phase_remote_src_phase + 1; + + // Setup the stream phases + volatile uint32_t* stream_phases_start = reinterpret_cast(relay_stream_overlay_blob_addr); + + // + // phase 1 + // + + const uint32_t stream_phase_1_start = reinterpret_cast(stream_phases_start); + volatile uint32_t* stream_phase_1_reg_addr = reinterpret_cast(stream_phase_1_start) + 1; + + // Local stream buffer address register + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_BUF_START_REG_INDEX, stream_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_BUF_SIZE_REG_INDEX, stream_buffer_size >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // msg info rdptr + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // msg info wrptr + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_WR_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_1_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, remote_dest_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // STREAM_MISC_CFG_REG_INDEX + const uint32_t remote_src_update_noc_id = 1 - remote_src_noc_id; + uint32_t stream_msc_cfg_reg = 0; + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, INCOMING_DATA_NOC_WIDTH, INCOMING_DATA_NOC, remote_src_noc_id); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, OUTGOING_DATA_NOC_WIDTH, OUTGOING_DATA_NOC, remote_dest_noc_id); + stream_msc_cfg_reg = set_blob_reg_field( + stream_msc_cfg_reg, REMOTE_SRC_UPDATE_NOC_WIDTH, REMOTE_SRC_UPDATE_NOC, remote_src_update_noc_id); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_SOURCE_WIDTH, REMOTE_SOURCE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_RECEIVER_WIDTH, REMOTE_RECEIVER, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, PHASE_AUTO_CONFIG_WIDTH, PHASE_AUTO_CONFIG, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, PHASE_AUTO_ADVANCE_WIDTH, PHASE_AUTO_ADVANCE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, DATA_AUTO_SEND_WIDTH, DATA_AUTO_SEND, 1); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, NEXT_PHASE_DEST_CHANGE_WIDTH, NEXT_PHASE_DEST_CHANGE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, NEXT_PHASE_SRC_CHANGE_WIDTH, NEXT_PHASE_SRC_CHANGE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, UNICAST_VC_REG_WIDTH, UNICAST_VC_REG, 0); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REG_UPDATE_VC_REG_WIDTH, REG_UPDATE_VC_REG, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, DATA_BUF_NO_FLOW_CTRL_WIDTH, DATA_BUF_NO_FLOW_CTRL, 0); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, DEST_DATA_BUF_NO_FLOW_CTRL_WIDTH, DEST_DATA_BUF_NO_FLOW_CTRL, 0); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_SRC_IS_MCAST_WIDTH, REMOTE_SRC_IS_MCAST, 0); + stream_msc_cfg_reg = set_blob_reg_field( + stream_msc_cfg_reg, NO_PREV_PHASE_OUTGOING_DATA_FLUSH_WIDTH, NO_PREV_PHASE_OUTGOING_DATA_FLUSH, 0); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MISC_CFG_REG_INDEX, stream_msc_cfg_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote src + // Remote src noc x/y is based on the update noc (because it sends updates, NOT data, to src, so it needs update + // noc) + uint32_t stream_remote_src_reg = 0; + uint32_t data_noc_in_src_noc_x = + remote_src_update_noc_id == 0 ? remote_src_noc_x : noc_size_x - 1 - remote_src_noc_x; + uint32_t data_noc_in_src_noc_y = + remote_src_update_noc_id == 0 ? remote_src_noc_y : noc_size_y - 1 - remote_src_noc_y; + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, STREAM_REMOTE_SRC_X_WIDTH, STREAM_REMOTE_SRC_X, data_noc_in_src_noc_x); + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, STREAM_REMOTE_SRC_Y_WIDTH, STREAM_REMOTE_SRC_Y, data_noc_in_src_noc_y); + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, REMOTE_SRC_STREAM_ID_WIDTH, REMOTE_SRC_STREAM_ID, remote_src_stream_id); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_REG_INDEX, stream_remote_src_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote dest + // Remote dest noc x/y is NOT based on the update noc (because it is sending data to the dest, so it needs data noc) + uint32_t stream_remote_dest_reg = 0; + uint32_t data_noc_out_dest_noc_x = remote_dest_noc_id == 0 ? remote_dest_noc_x : noc_size_x - 1 - remote_dest_noc_x; + uint32_t data_noc_out_dest_noc_y = remote_dest_noc_id == 0 ? remote_dest_noc_y : noc_size_y - 1 - remote_dest_noc_y; + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, STREAM_REMOTE_DEST_X_WIDTH, STREAM_REMOTE_DEST_X, data_noc_out_dest_noc_x); + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, STREAM_REMOTE_DEST_Y_WIDTH, STREAM_REMOTE_DEST_Y, data_noc_out_dest_noc_y); + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, + STREAM_REMOTE_DEST_STREAM_ID_WIDTH, + STREAM_REMOTE_DEST_STREAM_ID, + remote_dest_noc_stream_id); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_REG_INDEX, stream_remote_dest_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote_dest buf start + uint32_t stream_remote_dest_buf_start_reg_val = 0; + stream_remote_dest_buf_start_reg_val = set_blob_reg_field( + stream_remote_dest_buf_start_reg_val, + DRAM_WRITES__SCRATCH_1_PTR_LO_WIDTH, + DRAM_WRITES__SCRATCH_1_PTR_LO, + remote_dest_buf_addr >> 4); + *stream_phase_1_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_remote_dest_buf_start_reg_val); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote_dest buf size + uint32_t stream_remote_dest_buf_size_reg = 0; + stream_remote_dest_buf_size_reg = set_blob_reg_field( + stream_remote_dest_buf_size_reg, + REMOTE_DEST_BUF_SIZE_WORDS_WIDTH, + REMOTE_DEST_BUF_SIZE_WORDS, + remote_dest_buf_size_4B_words >> 4); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, stream_remote_dest_buf_size_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_PHASE_REG_INDEX, first_phase_remote_src_phase); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_REG_INDEX, local_first_phase); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // + // phase 2 - we're unrolling one iteration of the first phase, so the second phase is mostly identical + // + volatile uint32_t* const stream_phase_2_start = stream_phase_1_reg_addr; + volatile uint32_t* stream_phase_2_stream_reg_addr = stream_phase_2_start + 1; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_BUF_START_REG_INDEX, stream_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_BUF_SIZE_REG_INDEX, stream_buffer_size >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // msg info rdptr + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // msg info wrptr + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_MSG_INFO_WR_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, remote_dest_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MISC_CFG_REG_INDEX, stream_msc_cfg_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_REG_INDEX, stream_remote_src_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_REG_INDEX, stream_remote_dest_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_remote_dest_buf_start_reg_val); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, stream_remote_dest_buf_size_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_REG_INDEX, local_second_phase); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_PHASE_AUTO_CFG_PTR_BASE_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_PHASE_REG_INDEX, second_phase_remote_src_phase); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, stream_phase_1_start); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + const uint32_t phase_1_num_cfg_regs = + ((reinterpret_cast(stream_phase_1_reg_addr) >> 2) - (stream_phase_1_start >> 2)) - 1; + uint32_t phase_2_num_cfg_regs = ((reinterpret_cast(stream_phase_2_stream_reg_addr) >> 2) - + (reinterpret_cast(stream_phase_2_start) >> 2)) - + 1; + + // We're supposed to put the **next** phase num config registers in the **current** phase's blob header. This means + // we need to flip the register counts between the two phases for their headers So in a sequence of 3 phases, the + // header blob on phase 1 would need the #cfg regs for phase 2. Phase 2's cfg header blob would need the #cfg regs + // for phase 3 and for phase 3, the #cfg regs in the header blob would be 0 (since no phase follows it) In our case, + // we just need to point to the opposite phase's #cfg regs + *reinterpret_cast(stream_phase_1_start) = + blob_header_dw(phase_2_num_cfg_regs, stream_tile_header_max_num_messages, 1); + *stream_phase_2_start = blob_header_dw(phase_1_num_cfg_regs, stream_tile_header_max_num_messages, 1); + + // Now kick off the stream + stream_phase_blob_run( + stream_id, + reinterpret_cast(stream_phase_1_start), + stream_tile_header_max_num_messages, + phase_1_num_cfg_regs); + + // Wait for sender and receiver to signal completion + while (*tx_rx_done_semaphore_addr != 2) { + asm volatile("nop"); + } + + // Now teardown the stream + // Unknown if it's safe to reset the stream while it's in a state before active + while ((NOC_STREAM_READ_REG(stream_id, STREAM_DEBUG_STATUS_REG_INDEX + 9) >> MEM_WORD_ADDR_WIDTH) != 0 || + !stream_phase_is_active(stream_id)) { + asm volatile("nop"); + } + + stream_reset(stream_id); + ASSERT(!assert_check(stream_id, false)); + for (auto ptr = reinterpret_cast(stream_phase_1_start); ptr != stream_phase_2_stream_reg_addr; + ptr++) { + *ptr = 0; + } +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp new file mode 100644 index 00000000000..a21474a048a --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp @@ -0,0 +1,272 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +// THESE TWO FUNCTIONS WERE ONLY VALID FOR WORMHOLE_B0 AND MAY NOT WORK WITH BLACKHOLE!!! +// STREAM_RECEIVER_ENDPOINT_MULTI_TILE_CLEAR_REG_INDEX is aliased to STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX for +// whb0 +inline bool is_stream_receiver_endpoint_tile_clearing_finished(uint32_t stream_id) { + return (NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX) == 0); +} +inline void stream_receiver_endpoint_tiles_clear_b0(uint32_t stream_id, uint32_t num_tiles) { + uint32_t clr_val = num_tiles; + clr_val *= 2; + clr_val = (~clr_val) + 1; + NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, clr_val); +} +////////////////////////////////////////////////////////////////////////////////////////// + +uint32_t get_receiver_stream_config_reg(uint32_t data_noc_id, uint32_t update_noc, bool drain_after_phase_send) { + uint32_t stream_cfg_reg = 0; + bool next_phase_src_dest_change = drain_after_phase_send ? 1 : 0; + stream_cfg_reg |= STREAM_CFG(INCOMING_DATA_NOC, data_noc_id) | STREAM_CFG(REMOTE_SRC_UPDATE_NOC, update_noc) | + STREAM_CFG(RECEIVER_ENDPOINT, 1) | STREAM_CFG(REMOTE_SOURCE, 1) | + STREAM_CFG(NEXT_PHASE_SRC_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(NEXT_PHASE_DEST_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(PHASE_AUTO_ADVANCE, 0) | STREAM_CFG(DATA_AUTO_SEND, 0) | + STREAM_CFG(REG_UPDATE_VC_REG, 1); + + return stream_cfg_reg; +} + +FORCE_INLINE bool messages_are_available(uint32_t stream_id, stream_state_t &stream_state) { + uint32_t wrptr = NOC_STREAM_READ_REG(stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX); + uint32_t rdptr = NOC_STREAM_READ_REG(stream_id, STREAM_MSG_INFO_PTR_REG_INDEX); + uint32_t internal_rdptr = stream_state.local_msg_info_ptr >> 4; + bool messages_available = internal_rdptr < wrptr; + return messages_available; +} + +FORCE_INLINE void flush_message_from_stream_buffer( + uint32_t stream_id, stream_state_t &stream_state, uint32_t msg_size_bytes) { + stream_receiver_endpoint_tiles_clear_b0(stream_id, 1); + while (!is_stream_receiver_endpoint_tile_clearing_finished(stream_id)) { + asm volatile(""); + } +} + +FORCE_INLINE uint32_t +get_next_available_stream_message_size_in_bytes(stream_state_t &stream_state, uint32_t stream_id) { + uint32_t msg_info_byte_ptr = stream_state.local_msg_info_ptr; + uint32_t msg_size_bytes = *reinterpret_cast(msg_info_byte_ptr) << 4; + ASSERT(msg_size_bytes > 0); + return msg_size_bytes; +} + +FORCE_INLINE std::tuple get_next_message_info(uint32_t stream_id, stream_state_t &stream_state) { + uint32_t rdptr_offset = NOC_STREAM_READ_REG(stream_id, STREAM_RD_PTR_REG_INDEX) << 4; + uint32_t addr = rdptr_offset + stream_state.local_data_buffer_base_address; + ASSERT((rdptr_offset & 0xF) == 0); + ASSERT((addr & 0xF) == 0); + return {addr, get_next_available_stream_message_size_in_bytes(stream_state, stream_id)}; +} + +FORCE_INLINE void advance_stream_state_struct( + uint32_t stream_id, stream_state_t &stream_state, uint32_t msg_size_bytes) { + uint32_t next_offset = stream_state.local_buffer_read_offset + msg_size_bytes; + if (next_offset >= stream_state.local_buffer_size) { + next_offset -= stream_state.local_buffer_size; + } + stream_state.local_buffer_read_offset = next_offset; + stream_state.local_msg_info_ptr += (1 << 4); +} + +FORCE_INLINE void advance_phase( + noc_endpoint_info_t const &remote_endpoint_info, stream_state_t &state, uint32_t stream_id) { + // This is remote receiver, so it sends messages (updates) to remote source, NOT data, so it uses + // the update noc to communicate to remote src instead of the data noc. Therefore, we need to set remote + // src x/y based on the update noc. + uint32_t translated_remote_noc_x = remote_endpoint_info.update_noc_id == 0 + ? remote_endpoint_info.noc_x + : noc_size_x - 1 - remote_endpoint_info.noc_x; + uint32_t translated_remote_noc_y = remote_endpoint_info.update_noc_id == 0 + ? remote_endpoint_info.noc_y + : noc_size_y - 1 - remote_endpoint_info.noc_y; + + NOC_STREAM_WRITE_REG(stream_id, STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(stream_id, STREAM_CURR_PHASE_REG_INDEX, ((uint32_t)state.local_phase_id)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_BUF_START_REG_INDEX, ((uint32_t)state.local_buffer_base_addr) >> 4); + NOC_STREAM_WRITE_REG(stream_id, STREAM_BUF_SIZE_REG_INDEX, state.local_buffer_size >> 4); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_REMOTE_SRC_REG_INDEX, + STREAM_REMOTE_SRC(translated_remote_noc_x, translated_remote_noc_y, stream_id)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_SRC_PHASE_REG_INDEX, ((uint32_t)state.remote_phase_id)); + + NOC_STREAM_WRITE_REG(stream_id, STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(stream_id, STREAM_MSG_INFO_PTR_REG_INDEX, ((uint32_t)state.local_msg_info_ptr) >> 4); + NOC_STREAM_WRITE_REG(stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX, ((uint32_t)state.local_msg_info_ptr) >> 4); + + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_receiver_stream_config_reg(remote_endpoint_info.data_noc_id, remote_endpoint_info.update_noc_id, true)); + + NOC_STREAM_WRITE_REG( + stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, AUTO_CFG_HEADER(0, state.messages_per_phase, 0)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 0x1); +} + +FORCE_INLINE void advance_stream_to_next_message( + noc_endpoint_info_t const &remote_endpoint_info, + stream_state_t &state, + uint32_t stream_id, + uint32_t msg_size_bytes, + phase_iterator_t &local_phase_iterator, + phase_iterator_t &remote_phase_iterator) { + advance_stream_state_struct(stream_id, state, msg_size_bytes); + flush_message_from_stream_buffer(stream_id, state, msg_size_bytes); + + if (state.num_tiles_sent == state.tile_header_num_msgs - 1) { + remote_phase_iterator.increment(); + state.remote_phase_id = remote_phase_iterator.get(); + local_phase_iterator.increment(); + state.local_phase_id = local_phase_iterator.get(); + state.num_tiles_sent = 0; + state.local_msg_info_ptr = state.local_msg_info_ptr_base_address; + + advance_phase(remote_endpoint_info, state, stream_id); + state.local_buffer_read_offset = 0; + } else { + state.num_tiles_sent++; + } +} + +FORCE_INLINE void copy_message_to_cb_blocking( + uint32_t cb, uint32_t msg_addr, uint32_t msg_size_bytes, stream_state_t &stream_state) { + uint32_t cb_write_addr = get_write_ptr(cb); + uint64_t dest_noc_addr = get_noc_addr(cb_write_addr); + ASSERT((dest_noc_addr & 0xF) == 0); + ASSERT((msg_addr & 0xF) == 0); + uint32_t distance_until_end = + stream_state.local_buffer_size - (msg_addr - stream_state.local_data_buffer_base_address); + uint32_t bytes_to_copy = std::min(distance_until_end, msg_size_bytes); + + noc_async_write(msg_addr, dest_noc_addr, bytes_to_copy); + if (bytes_to_copy < msg_size_bytes) { + uint32_t bytes_to_copy_second = msg_size_bytes - bytes_to_copy; + noc_async_write( + stream_state.local_data_buffer_base_address, dest_noc_addr + bytes_to_copy, bytes_to_copy_second); + uint32_t num_words = bytes_to_copy_second >> 2; + } + noc_async_write_barrier(); +} + +void kernel_main() { + uint32_t arg_idx = 0; + + uint32_t num_messages_to_forward = get_arg_val(arg_idx++); + + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_src_noc_x = get_arg_val(arg_idx++); + uint32_t remote_src_noc_y = get_arg_val(arg_idx++); + uint32_t remote_src_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_src_data_noc_id = get_arg_val(arg_idx++); + uint32_t remote_src_buffer_addr = get_arg_val(arg_idx++); + uint32_t remote_src_buffer_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_src_tile_header_buffer_addr = get_arg_val(arg_idx++); + + uint32_t relay_done_semaphore_addr = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_x = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_y = get_arg_val(arg_idx++); + uint32_t other_relay_done_semaphore = get_arg_val(arg_idx++); + + uint32_t sender_noc_x = get_arg_val(arg_idx++); + uint32_t sender_noc_y = get_arg_val(arg_idx++); + uint32_t sender_wait_finish_semaphore = get_arg_val(arg_idx++); + uint32_t remote_src_start_phase_addr = get_arg_val(arg_idx++); + + const uint32_t first_phase_remote_src_phase = + wait_for_remote_source_starting_phase(reinterpret_cast(remote_src_start_phase_addr)); + const uint32_t second_phase_remote_src_phase = first_phase_remote_src_phase + 1; + const uint32_t local_first_phase = get_first_available_phase_out_of_reset(stream_id); + const uint32_t local_second_phase = local_first_phase; + + auto local_phase_iterator = phase_iterator_t(local_first_phase, local_second_phase); + auto remote_phase_iterator = phase_iterator_t(first_phase_remote_src_phase, second_phase_remote_src_phase); + + stream_state_t stream_state{ + stream_buffer_addr, + stream_tile_header_buffer_addr, + + local_phase_iterator.get(), // phase_id + stream_tile_header_max_num_messages, + + stream_tile_header_buffer_addr, // msg_info_wrptr_addr; + + 0, // num_tiles_sent; + stream_tile_header_max_num_messages, // tile_header_num_msgs; + + stream_buffer_addr, // dest_buffer_base_addr; + stream_buffer_size, // dest_buffer_size; + stream_tile_header_buffer_addr, // dest_msg_info_ptr; + + 0, // src_buffer_read_offset; + + remote_src_buffer_addr, // src_buffer_base_addr; + remote_src_buffer_size_4B_words, // src_buffer_size; + remote_src_tile_header_buffer_addr, // src_msg_info_ptr; + + 0, // dest_buffer_write_offset; + remote_phase_iterator.get(), // receiver start phase + }; + + ASSERT((stream_state.local_data_buffer_base_address & 0xf) == 0); + + auto remote_noc_info_desc = + noc_endpoint_info_t{remote_src_data_noc_id, 1 - remote_src_data_noc_id, remote_src_noc_x, remote_src_noc_y}; + + advance_phase(remote_noc_info_desc, stream_state, stream_id); + + auto cb = tt::CB::c_in0; + stream_state.local_buffer_base_addr = stream_buffer_addr; + + for (uint32_t i = 0; i < num_messages_to_forward; i++) { + cb_reserve_back(cb, 1); + + while (!messages_are_available(stream_id, stream_state)) { + asm volatile("nop"); + } + + auto const &[msg_addr, msg_size_bytes] = get_next_message_info(stream_id, stream_state); + ASSERT(msg_size_bytes > 0); + ASSERT(msg_size_bytes <= stream_state.local_buffer_size); + + copy_message_to_cb_blocking(cb, msg_addr, msg_size_bytes, stream_state); + + cb_push_back(cb, 1); + + stream_relay_tiles(stream_id, 1, msg_size_bytes >> 4); + advance_stream_to_next_message( + remote_noc_info_desc, stream_state, stream_id, msg_size_bytes, local_phase_iterator, remote_phase_iterator); + } + + noc_semaphore_inc(get_noc_addr(sender_noc_x, sender_noc_y, sender_wait_finish_semaphore), 1); + + while ((NOC_STREAM_READ_REG(stream_id, STREAM_DEBUG_STATUS_REG_INDEX + 9) >> MEM_WORD_ADDR_WIDTH) != 0) { + asm volatile("nop"); + } + + stream_reset(stream_id); + + noc_semaphore_inc( + get_noc_addr(remote_noc_info_desc.noc_x, remote_noc_info_desc.noc_y, relay_done_semaphore_addr), 1); + noc_semaphore_inc( + get_noc_addr(other_relay_core_to_signal_x, other_relay_core_to_signal_y, other_relay_done_semaphore), 1); + + ASSERT(!assert_check(stream_id, false)); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp new file mode 100644 index 00000000000..470ef6a4264 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + uint32_t arg_idx = 0; + + constexpr uint32_t msg_hdr_size = get_compile_time_arg_val(0); + + uint32_t output_buffer_addr = get_arg_val(arg_idx++); + uint32_t cb_page_size = get_arg_val(arg_idx++); + uint32_t num_pages = get_arg_val(arg_idx++); + + uint32_t write_page_size = cb_page_size - msg_hdr_size; + const InterleavedAddrGen dest_addr_gen = { + .bank_base_address = output_buffer_addr, .page_size = write_page_size}; + + auto cb = tt::CB::c_in0; + for (uint32_t i = 0; i < num_pages; i++) { + cb_wait_front(cb, 1); + // NOTE THAT msg_hdr_size is doubled on host side to maintain alignment for DRAM reads/writes in THIS TEST ONLY + uint32_t src_start = get_read_ptr(cb) + msg_hdr_size; + + uint64_t dst_noc_addr = get_noc_addr(i, dest_addr_gen); + noc_async_write(src_start, dst_noc_addr, write_page_size); + + noc_async_write_barrier(); + cb_pop_front(cb, 1); + } + +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp new file mode 100644 index 00000000000..606930d73ff --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +////////// +/// FUTURE OPTIMIZATIONS +/////////// +// 1) Don't update message info rd/wrptrs. Instead, just write message size into the next corresponding message info +// buffer entry 2) Use stream registers to track # messages sent 3) For contiguous messages, use a single stream phase +// to send them back to back then only do one wait for flush at the end + +////////// +// Q/A W/ Djordje + Extra Notes +// +// 1) DON'T set any of the STREAM_REMOTE_DEST_* registers if NEXT_PHASE_SRC_CHANGE is false +// 2) stream_phase_advance_wait can be used to wait for the current phase to complete +// -> in the scheme for this producer, it'll end up waiting until the message is sent out of L1 +// 3) How does initial stream handshake happen? +// -> Stream has hidden registers: curr_phase_src/dest_change. When comming out of reset, these are set true +// This value is sticky and the next_phase_src/dest_change will override it for the next phase +/////// + +uint32_t get_sender_stream_config_reg(uint32_t tx_noc_id, uint32_t rx_src_update_noc, bool drain_after_phase_send) { + uint32_t stream_cfg_reg = 0; + bool next_phase_src_dest_change = drain_after_phase_send ? 1 : 0; + stream_cfg_reg |= STREAM_CFG(OUTGOING_DATA_NOC, tx_noc_id) | STREAM_CFG(REMOTE_SRC_UPDATE_NOC, rx_src_update_noc) | + STREAM_CFG(SOURCE_ENDPOINT, 1) | STREAM_CFG(REMOTE_RECEIVER, 1) | + STREAM_CFG(NEXT_PHASE_SRC_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(NEXT_PHASE_DEST_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(PHASE_AUTO_ADVANCE, 0) | STREAM_CFG(DATA_AUTO_SEND, 0) | + STREAM_CFG(REG_UPDATE_VC_REG, 1); + + return stream_cfg_reg; +} + +FORCE_INLINE void write_message_size_to_message_info_buffer( + stream_state_t const &stream_state, uint32_t message_size_noc_words) { + ASSERT((message_size_noc_words << 4) <= stream_state.local_buffer_size); + if (!((message_size_noc_words << 4) <= stream_state.local_buffer_size)) { + DPRINT << "YIKES\n"; + } + *reinterpret_cast(stream_state.local_msg_info_ptr) = message_size_noc_words; +} + +FORCE_INLINE void reset_stream_message_info_buffer_rdptr(stream_state_t &stream_state, uint32_t stream_id) { + stream_state.local_msg_info_ptr = stream_state.local_msg_info_ptr_base_address; + NOC_STREAM_WRITE_REG( + stream_id, STREAM_MSG_INFO_PTR_REG_INDEX, ((uint32_t)(stream_state.local_msg_info_ptr_base_address >> 4))); + NOC_STREAM_WRITE_REG( + stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX, (((uint32_t)stream_state.local_msg_info_ptr_base_address >> 4))); +} +FORCE_INLINE void advance_stream_message_info_buffer_wrptr( + stream_state_t &stream_state, uint32_t stream_id, uint32_t message_size) { + stream_state.local_msg_info_ptr += (1 << 4); + stream_state.local_buffer_read_offset += message_size; + if (stream_state.local_buffer_read_offset >= stream_state.local_buffer_size) { + stream_state.local_buffer_read_offset -= stream_state.local_buffer_size; + } +} + +FORCE_INLINE void wait_for_stream_write_complete(uint32_t sender_stream_id) { + while (!stream_phase_advance_wait(sender_stream_id)) { + asm volatile("nop"); + } +} + +FORCE_INLINE void copy_from_cb_to_stream_buffer( + stream_state_t &stream_state, uint32_t message_base, uint32_t message_size_noc_words) { + ASSERT((message_size_noc_words << 4) <= stream_state.local_buffer_size); + if (!((message_size_noc_words << 4) <= stream_state.local_buffer_size)) { + DPRINT << "YIKES2\n"; + } + uint32_t message_size_size_in_bytes = message_size_noc_words << 4; + uint32_t bytes_to_copy = + std::min(stream_state.local_buffer_size - stream_state.local_buffer_read_offset, message_size_size_in_bytes); + noc_async_write(message_base, get_noc_addr(stream_state.get_current_local_buffer_address()), bytes_to_copy); + ASSERT(stream_state.local_buffer_size + stream_state.local_buffer_read_offset >= bytes_to_copy); + if (!(stream_state.local_buffer_size + stream_state.local_buffer_read_offset >= bytes_to_copy)) { + DPRINT << "YIKES3\n"; + } + + if (bytes_to_copy < message_size_size_in_bytes) { + uint32_t second_bytes_to_copy = message_size_size_in_bytes - bytes_to_copy; + noc_async_write( + message_base + bytes_to_copy, get_noc_addr(stream_state.local_buffer_base_addr), second_bytes_to_copy); + } + noc_async_write_barrier(); +} + +FORCE_INLINE void hang_toggle(volatile uint32_t *hang_toggle_semaphore) { + return; + while (*hang_toggle_semaphore == 0) { + asm volatile(""); + } + *hang_toggle_semaphore = 0; +} + +FORCE_INLINE void stream_noc_write( + stream_state_t &stream_state, + uint32_t message_base, + uint32_t sender_stream_id, + uint32_t dest_addr, + uint32_t remote_noc_x, + uint32_t remote_noc_y, + uint32_t dest_noc_id, + uint32_t dest_tile_header_buffer_addr, + uint32_t local_start_phase, + bool very_first_message, + volatile uint32_t *hang_toggle_semaphore, + uint32_t message_id) { + const uint32_t tiles_per_phase = stream_state.messages_per_phase; + + uint32_t message_size_noc_words = *reinterpret_cast(message_base); + + uint32_t dest_noc_reg = 0; + uint32_t num_tiles = stream_state.num_tiles_sent; + const bool send_last_message_and_drain = num_tiles == (stream_state.tile_header_num_msgs - 1); + + bool first_message = num_tiles == 0; + + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_CURR_PHASE_REG_INDEX, stream_state.local_phase_id); + + if (first_message) { + reset_stream_message_info_buffer_rdptr(stream_state, sender_stream_id); + stream_state.local_buffer_read_offset = 0; + } + copy_from_cb_to_stream_buffer(stream_state, message_base, message_size_noc_words); + + if (message_id < 10) { + hang_toggle(hang_toggle_semaphore); + } + + uint32_t rx_src_update_noc = 1 - dest_noc_id; + if (send_last_message_and_drain) { + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_sender_stream_config_reg(dest_noc_id, rx_src_update_noc, true)); + + } else if (first_message) { + // ASSERT(stream_state.remote_buffer_base_addr + stream_state.local_buffer_size <= + // stream_state.remote_buffer_size || + // stream_state.remote_buffer_size + (stream_state.tile_header_num_msgs << 4) <= + // stream_state.remote_buffer_base_addr); + + uint32_t rx_src_update_noc = 1 - dest_noc_id; + uint32_t translated_remote_noc_x = dest_noc_id == 0 ? remote_noc_x : noc_size_x - 1 - remote_noc_x; + uint32_t translated_remote_noc_y = dest_noc_id == 0 ? remote_noc_y : noc_size_y - 1 - remote_noc_y; + uint32_t dest_stream_id = sender_stream_id; + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_BUF_START_REG_INDEX, + ((uint32_t)stream_state.get_current_local_buffer_address()) >> 4); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_BUF_SIZE_REG_INDEX, stream_state.local_buffer_size >> 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_REMOTE_DEST_REG_INDEX, + STREAM_REMOTE_DEST(translated_remote_noc_x, translated_remote_noc_y, dest_stream_id)); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_HI_REG_INDEX, 0); + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, stream_state.remote_msg_info_ptr >> 4); + + // DPRINT << "STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX: " << (uint32_t)(stream_state.remote_msg_info_ptr >> + // 4) << "\n"; + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_state.remote_buffer_base_addr >> 4); + // Inserting an assert here causes test to pass + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_REMOTE_DEST_BUF_START_HI_REG_INDEX, + (stream_state.remote_buffer_base_addr / MEM_WORD_WIDTH) >> MEM_WORD_ADDR_WIDTH); + NOC_STREAM_WRITE_REG_FIELD( + sender_stream_id, + STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, + REMOTE_DEST_BUF_SIZE_WORDS, + stream_state.remote_buffer_size >> 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_sender_stream_config_reg(dest_noc_id, rx_src_update_noc, false)); + } + + if (first_message) { + // DPRINT << "Msg info ptr: " << (uint32_t)stream_state.local_msg_info_ptr << "\n"; + } + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + write_message_size_to_message_info_buffer(stream_state, message_size_noc_words); + advance_stream_message_info_buffer_wrptr(stream_state, sender_stream_id, message_size_noc_words << 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, AUTO_CFG_HEADER(0, 1 /*tiles_per_phase*/, 1)); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 0x1); + + if (first_message) { + // wait for handshake to complete + while (!stream_phase_is_active(sender_stream_id)) { + asm volatile(""); + } + } + + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + if (send_last_message_and_drain) { + // We only wrap around to 0 when the remote receiver relay stream has finished its second phase. We need to do + // this to avoid any handshake bugs we might hit if the second phase of relay must sync with phase 1 of the + // producer (this) since the relay will handshake with phase 1 of the producer (this) stream for relay stream's + // first phase too + num_tiles = 0; + stream_state.remote_phase_id = 3 - stream_state.remote_phase_id; // will alternate between 1 and 2 + // Remote phase was already updated so the condition is inverted + stream_state.local_phase_id = + (stream_state.remote_phase_id == 1) ? local_start_phase : stream_state.local_phase_id + 1; + } else { + num_tiles++; + stream_state.local_phase_id++; + } + + stream_relay_tiles(sender_stream_id, 1, message_size_noc_words); + wait_for_stream_write_complete(sender_stream_id); + + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + stream_state.num_tiles_sent = num_tiles; +} + +void kernel_main() { + uint32_t arg_idx = 0; + + uint32_t num_messages_to_forward = get_arg_val(arg_idx++); + + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_dest_noc_x = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_y = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_id = get_arg_val(arg_idx++); + uint32_t remote_dest_buffer_addr = get_arg_val(arg_idx++); + uint32_t remote_dest_buffer_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_dest_tile_header_buffer_addr = get_arg_val(arg_idx++); + + uint32_t relay_done_semaphore_addr = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_x = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_y = get_arg_val(arg_idx++); + uint32_t other_relay_done_semaphore = get_arg_val(arg_idx++); + + uint32_t wait_receiver_semaphore = get_arg_val(arg_idx++); + *reinterpret_cast(wait_receiver_semaphore) = 0; + + uint32_t first_relay_remote_src_start_phase_addr = get_arg_val(arg_idx++); + volatile uint32_t *hang_toggle_semaphore = reinterpret_cast(get_arg_val(arg_idx++)); + + uint32_t local_starting_phase = + notify_remote_receiver_of_starting_phase( + stream_id, + stream_buffer_addr, + get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, first_relay_remote_src_start_phase_addr)) - + 1; + + // clear the buffers + for (uint32_t i = 0; i < stream_buffer_size / sizeof(uint32_t); i++) { + reinterpret_cast(stream_buffer_addr)[i] = 0; + } + for (uint32_t i = 0; i < stream_tile_header_max_num_messages * 4; i++) { + reinterpret_cast(stream_tile_header_buffer_addr)[i] = 0; + } + + stream_state_t stream_state{ + stream_buffer_addr, + stream_tile_header_buffer_addr, + + local_starting_phase, // phase_id + stream_tile_header_max_num_messages, // messages_per_phase + + stream_tile_header_buffer_addr, // msg_info_wrptr_addr; + + 0, // num_tiles_sent; + stream_tile_header_max_num_messages, // tile_header_num_msgs; + + stream_buffer_addr, // src_buffer_base_addr; + stream_buffer_size, // src_buffer_size; + stream_tile_header_buffer_addr, // src_msg_info_ptr; + 0, // src_buffer_read_offset; + + remote_dest_buffer_addr, // dest_buffer_base_addr; + remote_dest_buffer_size_4B_words, // dest_buffer_size; + remote_dest_tile_header_buffer_addr, // dest_msg_info_ptr; + 0, // dest_buffer_write_offset; + + 1, // receiver_phase; // receiver start phase // don't need the true value + }; + + DPRINT << "hang_toggle_semaphore: " << (uint32_t)hang_toggle_semaphore << "\n"; + + hang_toggle(hang_toggle_semaphore); + + auto cb = tt::CB::c_in0; + bool very_first_message = true; + + uint32_t message_id = 0; + uint32_t count = 0; + for (uint32_t i = 0; i < num_messages_to_forward; i++) { + cb_wait_front(cb, 1); + uint32_t src_addr = get_read_ptr(cb); + stream_noc_write( + stream_state, + src_addr, + stream_id, + stream_state.remote_buffer_base_addr, + remote_dest_noc_x, + remote_dest_noc_y, + remote_dest_noc_id, + remote_dest_tile_header_buffer_addr, + local_starting_phase, + very_first_message, + hang_toggle_semaphore, + message_id); + + cb_pop_front(cb, 1); + // if (count == 1000) { + // DPRINT << "Sent " << i << " messages\n"; + // count = 0; + // } else { + // count++; + // } + very_first_message = false; + message_id++; + } + + // Reset sequence is that both the remote sender and remote receiver streams of the relay + // should reset first so that no data is in flight. Sender and receiver must ensure that no + // payloads are in flight to the relay stream(s) before sending the reset signal to the relay + // core + noc_semaphore_wait(reinterpret_cast(wait_receiver_semaphore), 1); + + stream_reset(stream_id); + + noc_semaphore_inc( + get_noc_addr(other_relay_core_to_signal_x, other_relay_core_to_signal_y, other_relay_done_semaphore), 1); + noc_semaphore_inc(get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, relay_done_semaphore_addr), 1); + + ASSERT(!assert_check(stream_id, false)); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp new file mode 100644 index 00000000000..2127013baac --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + uint32_t arg_idx = 0; + + constexpr uint32_t msg_hdr_size = get_compile_time_arg_val(0); + constexpr bool enable_page_size_variations = get_compile_time_arg_val(1) == 1; + + const uint32_t input_buffer_addr = get_arg_val(arg_idx++); + const uint32_t cb_page_size = get_arg_val(arg_idx++); + const uint32_t num_pages = get_arg_val(arg_idx++); + + constexpr uint32_t num_sizes = 8; + std::array sub_sizes = {}; + for (uint32_t i = 0; i < num_sizes; i++) { + sub_sizes[i] = get_arg_val(arg_idx++); + } + + const uint32_t read_page_size = cb_page_size - msg_hdr_size; + const InterleavedAddrGen src_addr_gen = {.bank_base_address = input_buffer_addr, .page_size = read_page_size}; + + auto cb = tt::CB::c_in0; + + uint32_t sub_index = 0; + + for (uint32_t i = 0; i < num_pages; i++) { + cb_reserve_back(cb, 1); + volatile uint32_t *page_header_addr = reinterpret_cast(get_write_ptr(cb)); + // NOTE THAT msg_hdr_size is doubled on host side to maintain alignment for the DRAM reads in THIS TEST ONLY + uint32_t data_out_start = reinterpret_cast(page_header_addr) + msg_hdr_size; + uint64_t src_noc_addr = get_noc_addr(i, src_addr_gen); + uint32_t message_header_size = + (read_page_size >> 4) + 2; // one for header one for padding to maintain noc word alignment + if (enable_page_size_variations) { + if (message_header_size < sub_sizes[sub_index] || sub_index >= 8) { + DPRINT << "REMOTE SENDER READER ERROR!\n"; + } + message_header_size -= sub_sizes[sub_index]; + sub_index = sub_index == num_sizes - 1 ? 0 : sub_index + 1; + } + page_header_addr[0] = message_header_size; + page_header_addr[1] = 0; + page_header_addr[2] = 0; + page_header_addr[3] = 0; + page_header_addr[4] = 0; + page_header_addr[5] = 0; + page_header_addr[6] = 0; + page_header_addr[7] = 0; + + noc_async_read(src_noc_addr, data_out_start, read_page_size); + + // TODO: upgrade to look at the writes acked counter instead + noc_async_read_barrier(); + cb_push_back(cb, 1); + } +} diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt index 1dd900cc7d4..dd25ee844ad 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt @@ -10,6 +10,7 @@ set(UNIT_TESTS_FD_SRC ${CMAKE_CURRENT_SOURCE_DIR}/multichip/test_eth_EnqueueProgram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/multichip/test_eth_ring_gather_EnqueueProgram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/pipelining/basic_pipeline.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/streams/test_autonomous_relay_streams.cpp ) add_executable(unit_tests_fast_dispatch ${UNIT_TESTS_FD_SRC} $) diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp new file mode 100644 index 00000000000..2c963a0796d --- /dev/null +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp @@ -0,0 +1,973 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include + +#include "device/tt_arch_types.h" +#include "gtest/gtest.h" +#include "tests/tt_metal/tt_metal/unit_tests_fast_dispatch/common/command_queue_fixture.hpp" +#include "tt_metal/common/logger.hpp" +// #include "impl/device/device.hpp" +#include "impl/kernels/data_types.hpp" +#include "impl/kernels/kernel_types.hpp" +#include "tt_metal/common/core_coord.h" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +// #include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/detail/persistent_kernel_cache.hpp" +#include "tt_metal/test_utils/stimulus.hpp" + +using tt::tt_metal::Device; + +constexpr uint32_t num_sizes = 8; +namespace tt { + +namespace tt_metal { + +struct hop_eth_sockets { + chip_id_t receiver_device_id; + CoreCoord receiver_core; + chip_id_t sender_device_id; + CoreCoord sender_core; +}; + +struct stream_config_t { + uint32_t buffer_addr; + uint32_t buffer_size; // in bytes + uint32_t tile_header_buffer_addr; + uint32_t tile_header_num_msgs; + uint32_t tile_header_buffer_size; // in bytes +}; + +struct stream_builder_spec_t { + uint32_t buffer_size_bytes; + uint32_t tile_header_buffer_size_bytes; +}; + +constexpr uint32_t relay_stream_id = 32; +constexpr uint32_t tile_header_size = 32; // needs to provide noc word alignment +// constexpr uint32_t tile_header_size = 16; +constexpr uint32_t noc_word_size = 16; + +// Reads data from input +std::vector get_sender_reader_rt_args( + Device* device, + uint32_t input_buffer_addr, + uint32_t page_size_plus_header, + uint32_t num_messages_to_read, + std::array const& sub_sizes) { + auto args = std::vector{input_buffer_addr, page_size_plus_header, num_messages_to_read}; + for (auto const& sub_size : sub_sizes) { + args.push_back(sub_size); + } + return args; +} +// sender stream data mover kernel +std::vector get_sender_writer_rt_args( + Device* device, + uint32_t num_messages, + uint32_t relay_done_semaphore, + CoreCoord const& relay_core, + uint32_t sender_noc_id, + stream_config_t const& sender_stream_config, + stream_config_t const& relay_stream_config, + CoreCoord const& other_relay_to_notify_when_done, + uint32_t other_relay_done_semaphore, + uint32_t sender_wait_for_receiver_semaphore, + uint32_t first_relay_remote_src_start_phase_addr, + uint32_t hang_toggle_addr) { + return std::vector{ + num_messages, + + relay_stream_id, + sender_stream_config.buffer_addr, + sender_stream_config.buffer_size, + sender_stream_config.tile_header_buffer_addr, + relay_stream_config.tile_header_num_msgs, + + static_cast(device->worker_core_from_logical_core(relay_core).x), + static_cast(device->worker_core_from_logical_core(relay_core).y), + relay_stream_id, + sender_noc_id, + + relay_stream_config.buffer_addr, + relay_stream_config.buffer_size, + relay_stream_config.tile_header_buffer_addr, + + relay_done_semaphore, + static_cast(device->worker_core_from_logical_core(other_relay_to_notify_when_done).x), + static_cast(device->worker_core_from_logical_core(other_relay_to_notify_when_done).y), + other_relay_done_semaphore, + + static_cast(sender_wait_for_receiver_semaphore), + first_relay_remote_src_start_phase_addr, + hang_toggle_addr}; +} + +std::vector get_relay_rt_args( + Device* device, + uint32_t relay_stream_overlay_blob_addr, + uint32_t relay_done_semaphore, + CoreCoord const& sender_core, + CoreCoord const& receiver_core, + uint32_t sender_noc_id, + uint32_t receiver_noc_id, + // stream_config_t const& sender_stream_config, + stream_config_t const& relay_stream_config, + stream_config_t const& receiver_stream_config, + uint32_t remote_src_start_phase_addr, + uint32_t dest_remote_src_start_phase_addr, + bool is_first_relay_in_chain) { + return std::vector{ + static_cast(relay_stream_overlay_blob_addr), + static_cast(relay_stream_id), + static_cast(relay_stream_config.buffer_addr), + static_cast(relay_stream_config.buffer_size), + static_cast(relay_stream_config.tile_header_buffer_addr), + static_cast(relay_stream_config.tile_header_num_msgs), + + // noc0 address + static_cast(device->worker_core_from_logical_core(sender_core).x), + static_cast(device->worker_core_from_logical_core(sender_core).y), + static_cast(relay_stream_id), + static_cast(sender_noc_id), + + static_cast(device->worker_core_from_logical_core(receiver_core).x), + static_cast(device->worker_core_from_logical_core(receiver_core).y), + static_cast(relay_stream_id), + static_cast(receiver_noc_id), + static_cast(receiver_stream_config.buffer_addr), + static_cast(receiver_stream_config.buffer_size), + static_cast(receiver_stream_config.tile_header_buffer_addr), + + static_cast(relay_done_semaphore), + static_cast(is_first_relay_in_chain ? 1 : 0), + + remote_src_start_phase_addr, + dest_remote_src_start_phase_addr}; +} + +// Receiver stream data mover kernel +std::vector get_receiver_reader_rt_args( + Device* device, + uint32_t num_messages, + uint32_t relay_done_semaphore, + CoreCoord const& relay_core, + uint32_t receiver_noc_id, + stream_config_t const& relay_stream_config, + stream_config_t const& receiver_stream_config, + CoreCoord const& other_relay_core_to_notify_when_done, + uint32_t other_relay_done_semaphore, + CoreCoord const& sender_core, + uint32_t sender_receiver_semaphore, + uint32_t remote_src_start_phase_addr) { + return std::vector{ + static_cast(num_messages), + static_cast(relay_stream_id), + static_cast(receiver_stream_config.buffer_addr), + static_cast(receiver_stream_config.buffer_size), + static_cast(receiver_stream_config.tile_header_buffer_addr), + static_cast(receiver_stream_config.tile_header_num_msgs), + static_cast(device->worker_core_from_logical_core(relay_core).x), + static_cast(device->worker_core_from_logical_core(relay_core).y), + static_cast(relay_stream_id), + static_cast(receiver_noc_id), + static_cast(relay_stream_config.buffer_addr), + static_cast(relay_stream_config.buffer_size), + static_cast(relay_stream_config.tile_header_buffer_addr), + + static_cast(relay_done_semaphore), + static_cast(device->worker_core_from_logical_core(other_relay_core_to_notify_when_done).x), + static_cast(device->worker_core_from_logical_core(other_relay_core_to_notify_when_done).y), + other_relay_done_semaphore, + + static_cast(device->worker_core_from_logical_core(sender_core).x), + static_cast(device->worker_core_from_logical_core(sender_core).y), + sender_receiver_semaphore, + remote_src_start_phase_addr}; +} +std::vector get_receiver_writer_rt_args( + Device* device, uint32_t output_buffer_addr, uint32_t page_size, uint32_t num_messages_to_read) { + return std::vector{output_buffer_addr, page_size, num_messages_to_read}; +} + +// TODO: randomize each noc for testing purposes +void build_and_run_autonomous_stream_test( + std::vector& programs, + std::vector const& devices, + std::size_t num_messages, + std::size_t page_size, + uint32_t tile_header_buffer_num_messages, + stream_builder_spec_t const& sender_stream_spec, + stream_builder_spec_t const& relay_stream_spec, + stream_builder_spec_t const& receiver_stream_spec, + bool enable_page_size_variations, + std::array const& sub_sizes, + std::size_t num_loop_iterations) { + TT_ASSERT(programs.size() == 0); + // Make configurable + const uint32_t read_write_cb_num_pages = 8; + const uint32_t page_size_plus_header = page_size + tile_header_size; + + const uint32_t sender_stream_buffer_num_pages = sender_stream_spec.buffer_size_bytes / page_size; + const uint32_t relay_stream_buffer_num_pages = relay_stream_spec.buffer_size_bytes / page_size; + const uint32_t receiver_stream_buffer_num_pages = receiver_stream_spec.buffer_size_bytes / page_size; + + const uint32_t sender_stream_buffer_size_bytes = sender_stream_buffer_num_pages * page_size_plus_header; + const uint32_t relay_stream_buffer_size_bytes = relay_stream_buffer_num_pages * page_size_plus_header; + const uint32_t receiver_stream_buffer_size_bytes = receiver_stream_buffer_num_pages * page_size_plus_header; + uint32_t stream_tile_header_buffer_size_bytes = tile_header_buffer_num_messages * tile_header_size; + uint32_t relay_stream_overlay_blob_size_bytes = 256; + + programs.emplace_back(); + Device* device = devices.at(0); + Program& program = programs.at(0); + log_trace(tt::LogTest, "Device ID: {}", device->id()); + + CoreCoord sender_core = CoreCoord(0, 0); + CoreCoord first_relay_core = CoreCoord(1, 0); + CoreCoord second_relay_core = CoreCoord(2, 0); + CoreCoord receiver_core = CoreCoord(3, 0); + + log_trace( + tt::LogTest, + "sender_core: x={}, y={}", + device->physical_core_from_logical_core(sender_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(sender_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "first_relay_core: x={}, y={}", + device->physical_core_from_logical_core(first_relay_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(first_relay_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "second_relay_core: x={}, y={}", + device->physical_core_from_logical_core(second_relay_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(second_relay_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "receiver_core: x={}, y={}", + device->physical_core_from_logical_core(receiver_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(receiver_core, CoreType::WORKER).y); + + // Input DRAM buffer creation + uint32_t buffer_size_bytes = num_messages * page_size; + auto inputs = test_utils::generate_uniform_random_vector(0, 100, buffer_size_bytes / sizeof(uint32_t)); + std::iota(inputs.begin(), inputs.end(), 1); + // for (auto i = 0; i < inputs.size(); i += page_size) { + // for (auto ii = 0; ii < std::min(page_size, inputs.size() - i); ii++) { + // inputs.at(i + ii) = i + 1; + // } + // } + + auto zeroes_buffer = std::vector(buffer_size_bytes / sizeof(uint32_t), 0); + std::vector outputs(buffer_size_bytes / sizeof(uint32_t), 0); + log_trace(tt::LogTest, "outputs.size(): {}", outputs.size()); + log_trace(tt::LogTest, "inputs.size(): {}", inputs.size()); + auto input_buffer = CreateBuffer( + InterleavedBufferConfig{device, static_cast(num_messages * page_size), page_size, BufferType::DRAM}); + auto output_buffer = CreateBuffer( + InterleavedBufferConfig{device, static_cast(num_messages * page_size), page_size, BufferType::DRAM}); + + tt_metal::EnqueueWriteBuffer(device->command_queue(), input_buffer, inputs, false); + // Explicitly overwrite to 0 in case of left over state from prior run(s) + tt_metal::EnqueueWriteBuffer(device->command_queue(), output_buffer, zeroes_buffer, true); + const uint32_t dram_input_buf_base_addr = input_buffer->address(); + + // For overlay blob on relay core + constexpr uint32_t dummy_cb_index3 = CB::c_in3; + auto const& relay_stream_overlay_blob_buffer_cb_config = + tt_metal::CircularBufferConfig( + relay_stream_overlay_blob_size_bytes, {{dummy_cb_index3, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index3, relay_stream_overlay_blob_size_bytes); + auto first_relay_stream_overlay_blob_cb = + CreateCircularBuffer(program, first_relay_core, relay_stream_overlay_blob_buffer_cb_config); + auto second_relay_stream_overlay_blob_cb = + CreateCircularBuffer(program, second_relay_core, relay_stream_overlay_blob_buffer_cb_config); + + // Sender/Receiver CBs for pulling in/pushing out stimulus data taht we can output compare + constexpr uint32_t cb_index = CB::c_in0; + const uint32_t cb_size = page_size_plus_header * read_write_cb_num_pages; + auto const& cb_config = tt_metal::CircularBufferConfig(cb_size, {{cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(cb_index, page_size_plus_header); + auto sender_cb = CreateCircularBuffer(program, sender_core, cb_config); + auto receiver_cb = CreateCircularBuffer(program, receiver_core, cb_config); + + // Stream Tile Header Buffers + constexpr uint32_t dummy_cb_index2 = CB::c_in2; + auto const& stream_tile_header_buffer_cb_config = + tt_metal::CircularBufferConfig( + stream_tile_header_buffer_size_bytes, {{dummy_cb_index2, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index2, stream_tile_header_buffer_size_bytes); + auto sender_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, sender_core, stream_tile_header_buffer_cb_config); + auto first_relay_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, first_relay_core, stream_tile_header_buffer_cb_config); + auto second_relay_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, second_relay_core, stream_tile_header_buffer_cb_config); + auto receiver_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, receiver_core, stream_tile_header_buffer_cb_config); + + constexpr uint32_t dummy_cb_index = CB::c_in1; + auto const& sender_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(sender_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, sender_stream_buffer_size_bytes); + auto const& relay_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(relay_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, relay_stream_buffer_size_bytes); + auto const& receiver_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(receiver_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, receiver_stream_buffer_size_bytes); + auto sender_stream_buffer_cb = CreateCircularBuffer(program, sender_core, sender_stream_buffer_cb_config); + auto first_relay_stream_buffer_cb = CreateCircularBuffer(program, first_relay_core, relay_stream_buffer_cb_config); + auto second_relay_stream_buffer_cb = + CreateCircularBuffer(program, second_relay_core, relay_stream_buffer_cb_config); + auto receiver_stream_buffer_cb = CreateCircularBuffer(program, receiver_core, receiver_stream_buffer_cb_config); + + program.allocate_circular_buffers(); + + uint32_t sender_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, sender_stream_buffer_cb)->address(); + uint32_t first_relay_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_buffer_cb)->address(); + uint32_t second_relay_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_buffer_cb)->address(); + uint32_t receiver_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, receiver_stream_buffer_cb)->address(); + uint32_t sender_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, sender_stream_tile_header_buffer_cb)->address(); + uint32_t first_relay_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_tile_header_buffer_cb)->address(); + uint32_t second_relay_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_tile_header_buffer_cb)->address(); + uint32_t receiver_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, receiver_stream_tile_header_buffer_cb)->address(); + uint32_t first_relay_stream_overlay_blob_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_overlay_blob_cb)->address(); + uint32_t second_relay_stream_overlay_blob_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_overlay_blob_cb)->address(); + + uint32_t receiver_cb_address = tt_metal::detail::GetCircularBuffer(program, receiver_cb)->address(); + log_trace(tt::LogTest, "receiver_cb_address: {}", receiver_cb_address); + + TT_ASSERT(sender_stream_buffer_size_bytes % page_size_plus_header == 0); + TT_ASSERT(relay_stream_buffer_size_bytes % page_size_plus_header == 0); + TT_ASSERT(receiver_stream_buffer_size_bytes % page_size_plus_header == 0); + log_trace( + tt::LogTest, "first_relay_stream_tile_header_buffer_addr: {}", first_relay_stream_tile_header_buffer_addr); + log_trace( + tt::LogTest, "second_relay_stream_tile_header_buffer_addr: {}", second_relay_stream_tile_header_buffer_addr); + stream_config_t sender_stream_config = stream_config_t{ + sender_stream_buffer_addr, + sender_stream_buffer_size_bytes, + sender_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t first_relay_stream_config = stream_config_t{ + first_relay_stream_buffer_addr, + relay_stream_buffer_size_bytes, + first_relay_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t second_relay_stream_config = stream_config_t{ + second_relay_stream_buffer_addr, + relay_stream_buffer_size_bytes, + second_relay_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t receiver_stream_config = stream_config_t{ + receiver_stream_buffer_addr, + receiver_stream_buffer_size_bytes, + receiver_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + + uint32_t sender_receiver_semaphore_sender = CreateSemaphore(program, sender_core, 0, CoreType::WORKER); + uint32_t remote_sender_hang_toggle_addr = CreateSemaphore(program, sender_core, 0, CoreType::WORKER); + uint32_t first_relay_done_semaphore = CreateSemaphore(program, first_relay_core, 0, CoreType::WORKER); + uint32_t second_relay_done_semaphore = CreateSemaphore(program, second_relay_core, 0, CoreType::WORKER); + + uint32_t first_relay_remote_src_start_phase_addr = CreateSemaphore(program, first_relay_core, 0, CoreType::WORKER); + uint32_t second_relay_remote_src_start_phase_addr = + CreateSemaphore(program, second_relay_core, 0, CoreType::WORKER); + uint32_t receiver_remote_src_start_phase_addr = CreateSemaphore(program, receiver_core, 0, CoreType::WORKER); + + auto sender_noc_id = tt_metal::NOC::NOC_0; + auto relay_to_relay_data_noc_id = tt_metal::NOC::NOC_0; + // remote deceiver doesn't handshake properly with noc_1 + auto receiver_noc_id = tt_metal::NOC::NOC_0; + std::vector const& sender_reader_rt_args = + get_sender_reader_rt_args(device, input_buffer->address(), page_size_plus_header, num_messages, sub_sizes); + std::vector const& sender_writer_rt_args = get_sender_writer_rt_args( + device, + num_messages, + first_relay_done_semaphore, + first_relay_core, + sender_noc_id, + sender_stream_config, + first_relay_stream_config, + second_relay_core, + second_relay_done_semaphore, + sender_receiver_semaphore_sender, + first_relay_remote_src_start_phase_addr, + remote_sender_hang_toggle_addr); + + log_trace(tt::LogTest, "first_relay_stream_config"); + log_trace(tt::LogTest, "\tfirst_relay_stream_config.buffer_addr: {}", first_relay_stream_config.buffer_addr); + log_trace(tt::LogTest, "\tfirst_relay_stream_config.buffer_size: {}", first_relay_stream_config.buffer_size); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_buffer_addr: {}", + first_relay_stream_config.tile_header_buffer_addr); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_num_msgs: {}", + first_relay_stream_config.tile_header_num_msgs); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_buffer_size: {}", + first_relay_stream_config.tile_header_buffer_size); + log_trace(tt::LogTest, "second_relay_stream_config"); + log_trace(tt::LogTest, "\tsecond_relay_stream_config.buffer_addr: {}", second_relay_stream_config.buffer_addr); + log_trace(tt::LogTest, "\tsecond_relay_stream_config.buffer_size: {}", second_relay_stream_config.buffer_size); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_buffer_addr: {}", + second_relay_stream_config.tile_header_buffer_addr); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_num_msgs: {}", + second_relay_stream_config.tile_header_num_msgs); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_buffer_size: {}", + second_relay_stream_config.tile_header_buffer_size); + + // Need to figure out the noc IDs between the first and second relay. Also double check the + std::vector const first_relay_rt_args = get_relay_rt_args( + device, + first_relay_stream_overlay_blob_addr, + first_relay_done_semaphore, + sender_core, + second_relay_core, + sender_noc_id, + relay_to_relay_data_noc_id, + /*sender_stream_config,*/ first_relay_stream_config, + second_relay_stream_config, + first_relay_remote_src_start_phase_addr, + second_relay_remote_src_start_phase_addr, + true); + std::vector const second_relay_rt_args = get_relay_rt_args( + device, + second_relay_stream_overlay_blob_addr, + second_relay_done_semaphore, + first_relay_core, + receiver_core, + relay_to_relay_data_noc_id, + receiver_noc_id, + /*first_relay_stream_config,*/ second_relay_stream_config, + receiver_stream_config, + second_relay_remote_src_start_phase_addr, + receiver_remote_src_start_phase_addr, + false); + + std::vector const& receiver_reader_rt_args = get_receiver_reader_rt_args( + device, + num_messages, + second_relay_done_semaphore, + second_relay_core, + receiver_noc_id, + second_relay_stream_config, + receiver_stream_config, + first_relay_core, + first_relay_done_semaphore, + sender_core, + sender_receiver_semaphore_sender, + receiver_remote_src_start_phase_addr); + std::vector const& receiver_writer_rt_args = + get_receiver_writer_rt_args(device, output_buffer->address(), page_size_plus_header, num_messages); + + auto sender_reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp", + sender_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::NOC_0, + .compile_args = {tile_header_size, static_cast(enable_page_size_variations ? 1 : 0)}}); + auto sender_writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp", + sender_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::NOC_1, // to keep noc coords simple (no calculating noc1 coords) + .compile_args = {}}); + + auto first_relay_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp", + first_relay_core, + tt_metal::DataMovementConfig{.noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + + auto second_relay_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp", + second_relay_core, + tt_metal::DataMovementConfig{.noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + + auto receiver_reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp", + receiver_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + auto receiver_writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp", + receiver_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::NOC_1, // to keep noc coords simple (no calculating noc1 coords) + .compile_args = {tile_header_size}}); + + log_trace(tt::LogTest, "sender_reader_rt_args: "); + for (auto const& arg : sender_reader_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, sender_reader_kernel, sender_core, sender_reader_rt_args); + + log_trace(tt::LogTest, "sender_writer_rt_args: "); + for (auto const& arg : sender_writer_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, sender_writer_kernel, sender_core, sender_writer_rt_args); + + log_trace(tt::LogTest, "first_relay_rt_args: "); + for (auto const& arg : first_relay_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, first_relay_kernel, first_relay_core, first_relay_rt_args); + + log_trace(tt::LogTest, "second_relay_rt_args: "); + for (auto const& arg : second_relay_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, second_relay_kernel, second_relay_core, second_relay_rt_args); + + log_trace(tt::LogTest, "receiver_reader_rt_args: "); + for (auto const& arg : receiver_reader_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, receiver_reader_kernel, receiver_core, receiver_reader_rt_args); + + log_trace(tt::LogTest, "receiver_writer_rt_args: "); + for (auto const& arg : receiver_writer_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, receiver_writer_kernel, receiver_core, receiver_writer_rt_args); + + tt::tt_metal::detail::CompileProgram(device, program); + for (std::size_t i = 0; i < num_loop_iterations; i++) { + log_debug(tt::LogTest, "Enqueing Program"); + tt_metal::EnqueueProgram(device->command_queue(), program, true); + log_debug(tt::LogTest, "Calling Finish"); + tt_metal::Finish(device->command_queue()); + if (i == 0) { + log_debug(tt::LogTest, "Reading Output Buffer"); + tt_metal::EnqueueReadBuffer(device->command_queue(), output_buffer, outputs, true); + } + } + + log_debug(tt::LogTest, "outputs.size(): {}", outputs.size()); + log_debug(tt::LogTest, "inputs.size(): {}", inputs.size()); + log_debug(tt::LogTest, "Comparing Outputs"); + TT_ASSERT(inputs.size() == outputs.size()); + if (enable_page_size_variations) { + uint32_t page_size_words = page_size / sizeof(uint32_t); + bool matches = true; + std::size_t size = outputs.size(); + uint32_t sub_size_i = 0; + uint32_t page_idx = 0; + for (auto i = 0; i < size; i += page_size_words) { + std::size_t n_elems = page_size_words - (sub_sizes.at(sub_size_i) * noc_word_size / sizeof(uint32_t)); + sub_size_i = (sub_size_i + 1) % num_sizes; + bool printed_page_info = false; + for (auto ii = 0; ii < n_elems; ii++) { + bool match = outputs.at(i + ii) == inputs.at(i + ii); + if (!match) { + if (!printed_page_info) { + printed_page_info = true; + log_error(tt::LogTest, "Output Mismatch"); + } + log_trace( + tt::LogTest, + "Mismatch at index {}: {} (expected) != {} (actual)", + i + ii, + inputs.at(i + ii), + outputs.at(i + ii)); + matches = false; + } + } + page_idx++; + } + TT_ASSERT(matches); + } else { + bool matches = true; + bool printed = false; + TT_ASSERT(inputs.size() == outputs.size()); + for (std::size_t i = 0; i < inputs.size(); i++) { + if (inputs.at(i) != outputs.at(i)) { + if (!printed) { + log_error(tt::LogTest, "Output Mismatch"); + printed = true; + } + matches = false; + log_trace( + tt::LogTest, "Mismatch at index {}: {} (expected) != {} (actual)", i, inputs.at(i), outputs.at(i)); + } + } + TT_ASSERT(matches); + } +} + +} // namespace tt_metal + +} // namespace tt + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreams) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 10; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsSmallPackets) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 10; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 128; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 5, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsLoopingShort) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 50; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingRandomShort) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + // if (num_devices != 8) { + // log_info(tt::LogTest, "Need at least 2 devices to run this test"); + // return; + // } + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 500; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + for (std::size_t i = 0; i < num_loop_iterations; i++) { + std::array sub_sizes = {}; + for (auto i = 0; i < num_sizes; i++) { + sub_sizes.at(i) = std::rand() % (page_size / noc_word_size); + EXPECT_TRUE(sub_sizes.at(i) < (page_size / noc_word_size)); + } + std::vector programs; + log_info(tt::LogTest, "Iteration: {}", i); + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + 1); + } + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingLong) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + // if (num_devices != 8) { + // log_info(tt::LogTest, "Need at least 2 devices to run this test"); + // return; + // } + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 1'000; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsSweep) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + + // Create array of size `num_sizes` of random integers using c++ random + std::array sub_sizes_global = {}; + std::srand(0); + for (auto i = 0; i < num_sizes; i++) { + sub_sizes_global.at(i) = std::rand(); + } + + uint32_t num_loop_iterations = 10; + std::vector message_counts = {1'000'000}; + std::vector fw_stream_buffer_sizes = {2 * 1024, 8 * 1024, 16 * 1024, 32 * 1024}; + std::vector relay_stream_buffer_sizes = {8 * 1024, 16 * 1024, 24 * 1024}; + std::vector phase_message_counts = { + // 32, // Hangs on handshake on phase range wrap, or 25th run, whichever comes first + // 64, // Hangs on handshake on phase range wrap, or 25th run, whichever comes first + 128, // works with 16KB buffer + 256, // works with 16KB buffer + 1024 // works with 16KB buffer + }; + // std::vector page_size = {2048, 4096}; + std::vector page_size = {4096}; + for (auto num_messages : message_counts) { + for (auto fw_stream_buffer_size : fw_stream_buffer_sizes) { + for (auto relay_stream_buffer_size : relay_stream_buffer_sizes) { + // auto fw_stream_buffer_size = relay_stream_buffer_size; + for (auto tile_header_buffer_num_messages : phase_message_counts) { + for (auto page_size : page_size) { + if (page_size > fw_stream_buffer_size) { + continue; + } + if (page_size > relay_stream_buffer_size) { + continue; + } + uint32_t enable_variable_sized_messages = 1; + + log_info( + tt::LogTest, + "num_messages: {}, fw_stream_buffer_size: {}, relay_stream_buffer_size: {}, " + "tile_header_buffer_num_messages: {}, page_size: {}, enable_variable_sized_messages: {}", + num_messages, + fw_stream_buffer_size, + relay_stream_buffer_size, + tile_header_buffer_num_messages, + page_size, + enable_variable_sized_messages); + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{fw_stream_buffer_size, tile_header_buffer_num_messages}; + auto relay_stream_spec = tt::tt_metal::stream_builder_spec_t{ + relay_stream_buffer_size, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{fw_stream_buffer_size, tile_header_buffer_num_messages}; + + std::array sub_sizes = {}; + for (auto i = 0; i < num_sizes; i++) { + sub_sizes.at(i) = sub_sizes_global.at(i) % (page_size / noc_word_size); + EXPECT_TRUE(sub_sizes.at(i) < (page_size / noc_word_size)); + } + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + } + } + } + } + } + + return; +} diff --git a/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h index 3a70066d9af..4ba2e33be1c 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h @@ -414,6 +414,7 @@ // Set when stream is in data forwarding state. #define MSG_FWD_ONGOING (WAIT_PREV_PHASE_DATA_FLUSH+WAIT_PREV_PHASE_DATA_FLUSH_WIDTH) #define MSG_FWD_ONGOING_WIDTH 1 +// 0 is idle. 1/2 is auto cfg. 3 is waiting for phase advance. 4 is waiting for data send. 5 is phase active #define STREAM_CURR_STATE (MSG_FWD_ONGOING+MSG_FWD_ONGOING_WIDTH) #define STREAM_CURR_STATE_WIDTH 4