diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp new file mode 100644 index 00000000000..0df88172c89 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +struct stream_state_t { + const uint32_t local_data_buffer_base_address; + const uint32_t local_msg_info_ptr_base_address; + + uint32_t local_phase_id; + uint32_t messages_per_phase; + uint32_t msg_info_wrptr_addr; + + uint32_t num_tiles_sent; + uint32_t tile_header_num_msgs; + + uint32_t local_buffer_base_addr; + uint32_t local_buffer_size; + uint32_t local_msg_info_ptr; + uint32_t local_buffer_read_offset; + + uint32_t remote_buffer_base_addr; + uint32_t remote_buffer_size; + uint32_t remote_msg_info_ptr; + uint32_t remote_buffer_write_offset; + + uint32_t remote_phase_id; + + uint32_t get_current_local_buffer_address() const { + return local_data_buffer_base_address + local_buffer_read_offset; + } +}; + +struct phase_iterator_t { + phase_iterator_t(uint32_t start_phase, uint32_t max_phase) : + phase_id(start_phase), max_phase(max_phase), start_phase(start_phase) {} + uint32_t phase_id; + uint32_t max_phase; + uint32_t start_phase; + + FORCE_INLINE uint32_t get() const { return phase_id; } + + FORCE_INLINE void increment() { phase_id = phase_id == max_phase ? start_phase : phase_id + 1; } +}; + +struct noc_endpoint_info_t { + uint32_t data_noc_id; + uint32_t update_noc_id; + uint32_t noc_x; + uint32_t noc_y; +}; + +#define STREAM_CFG(field, val) ((val) << (field)) + +#define AUTO_CFG_HEADER(next_phase_num_cfg_reg_writes, curr_phase_num_msgs, phase_num_incr) \ + ((uint32_t)(((next_phase_num_cfg_reg_writes) << 24) | ((curr_phase_num_msgs) << 12) | (phase_num_incr))) + +#define STREAM_REMOTE_DEST(dest_x, dest_y, dest_stream_id) \ + (((dest_x) << STREAM_REMOTE_DEST_X) | ((dest_y) << STREAM_REMOTE_DEST_Y) | \ + ((dest_stream_id) << STREAM_REMOTE_DEST_STREAM_ID)) + +#define STREAM_REMOTE_SRC(src_x, src_y, src_stream_id) \ + (((src_x) << STREAM_REMOTE_SRC_X) | ((src_y) << STREAM_REMOTE_SRC_Y) | ((src_stream_id) << REMOTE_SRC_STREAM_ID)) + +FORCE_INLINE uint32_t +blob_header_dw(uint32_t next_phase_num_cfg_reg_writes, uint32_t curr_phase_num_msgs, uint32_t phase_num_incr) { + return (next_phase_num_cfg_reg_writes << 24) | (curr_phase_num_msgs << 12) | phase_num_incr; +} + +FORCE_INLINE void stream_phase_blob_run( + uint32_t stream_id, volatile uint32_t *blob_start_addr, uint32_t start_phase_num_cfg_regs) { + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, reinterpret_cast(blob_start_addr)); + NOC_STREAM_WRITE_REG( + stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, start_phase_num_cfg_regs << NEXT_PHASE_NUM_CFG_REG_WRITES); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + (0x1 << PHASE_AUTO_CONFIG) | (1 << NEXT_PHASE_SRC_CHANGE) | (1 << NEXT_PHASE_DEST_CHANGE)); +} +FORCE_INLINE void stream_phase_blob_run( + uint32_t stream_id, + volatile uint32_t *blob_start_addr, + uint32_t num_messages_per_phase, + uint32_t start_phase_num_cfg_regs) { + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, reinterpret_cast(blob_start_addr)); + + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, + blob_header_dw(start_phase_num_cfg_regs, num_messages_per_phase, 1)); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + (0x1 << PHASE_AUTO_ADVANCE) | (0x1 << PHASE_AUTO_CONFIG) | (1 << NEXT_PHASE_SRC_CHANGE) | + (1 << NEXT_PHASE_DEST_CHANGE)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 1); +} + +FORCE_INLINE uint32_t blob_cfg_dw(uint32_t reg_index, uint32_t reg_val) { return (reg_val << 8) | reg_index; } + +FORCE_INLINE uint32_t set_blob_reg_field(uint32_t blob_dw, uint32_t field_width, uint32_t field_offset, uint32_t val) { + uint32_t mask = ((1 << field_width) - 1) << field_offset; + return (blob_dw & ~mask) | ((val << field_offset) & mask); +} + +FORCE_INLINE uint32_t get_first_available_phase_out_of_reset(uint32_t stream_id) { + uint32_t stream_phase_coming_out_of_reset = stream_get_curr_phase(stream_id); + return ( + stream_phase_coming_out_of_reset < 4096 ? 4096 : 1); +} + +FORCE_INLINE uint32_t notify_remote_receiver_of_starting_phase( + uint32_t stream_id, uint32_t local_buffer_addr, uint64_t remote_receiver_noc_addr) { + uint32_t starting_phase = get_first_available_phase_out_of_reset(stream_id); + ASSERT(starting_phase > 0); + *reinterpret_cast(local_buffer_addr) = starting_phase; + noc_async_write(local_buffer_addr, remote_receiver_noc_addr, sizeof(uint32_t)); + // noc_semaphore_set_remote(local_buffer_addr, remote_receiver_noc_addr); + noc_async_writes_flushed(); + return starting_phase; +} + +FORCE_INLINE uint32_t wait_for_remote_source_starting_phase(volatile uint32_t *addr) { + while (*addr == 0) { + asm volatile("nop"); + } + return *addr; +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp new file mode 100644 index 00000000000..e6e23c33fa7 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp @@ -0,0 +1,325 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + // Work to do before productizable: + // - Test phase advance + // - test > 2k messages (and > 4k messages) + // - Test variable sized messages + // - Test rerun after test completion (without reset) + // - Currently a bug where the phase ID persists from prior run + // + + uint32_t arg_idx = 0; + + uint32_t relay_stream_overlay_blob_addr = get_arg_val(arg_idx++); + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_src_noc_x = get_arg_val(arg_idx++); + uint32_t remote_src_noc_y = get_arg_val(arg_idx++); + uint32_t remote_src_stream_id = get_arg_val(arg_idx++); + uint32_t remote_src_noc_id = get_arg_val(arg_idx++); + + uint32_t remote_dest_noc_x = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_y = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_id = get_arg_val(arg_idx++); + uint32_t remote_dest_buf_addr = get_arg_val(arg_idx++); + uint32_t remote_dest_buf_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_dest_tile_header_buffer_addr = get_arg_val(arg_idx++); + volatile uint32_t* tx_rx_done_semaphore_addr = + reinterpret_cast(get_arg_val(arg_idx++)); + bool is_first_relay_stream_in_chain = get_arg_val(arg_idx++) == 1; + + uint32_t remote_src_start_phase_addr = get_arg_val(arg_idx++); + uint32_t dest_remote_src_start_phase_addr = get_arg_val(arg_idx++); + + *tx_rx_done_semaphore_addr = 0; // should already be set to 0, but why not... + // use stream_buffer_addr as temporary storage just for this initial setup + + const uint32_t local_first_phase = notify_remote_receiver_of_starting_phase( + stream_id, + stream_buffer_addr + 16, // local storage to hold the phase while async send in progress, 16B for noc alignment + get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, dest_remote_src_start_phase_addr)); + const uint32_t local_second_phase = local_first_phase + 1; + + // If first relay, we'd expect this to be stream_tile_header_max_num_messages + STARTING_PHASE because the + // remote_sender (FW managed) is programmed as one phase per message and there are + // `stream_tile_header_max_num_messages` messages in this stream's phase. If second relay, we'd expect this to be + // SECOND_PHASE + const uint32_t first_phase_remote_src_phase = + wait_for_remote_source_starting_phase(reinterpret_cast(remote_src_start_phase_addr)); + const uint32_t second_phase_remote_src_phase = + is_first_relay_stream_in_chain ? stream_tile_header_max_num_messages + first_phase_remote_src_phase + : first_phase_remote_src_phase + 1; + + // Setup the stream phases + volatile uint32_t* stream_phases_start = reinterpret_cast(relay_stream_overlay_blob_addr); + + // + // phase 1 + // + + const uint32_t stream_phase_1_start = reinterpret_cast(stream_phases_start); + volatile uint32_t* stream_phase_1_reg_addr = reinterpret_cast(stream_phase_1_start) + 1; + + // Local stream buffer address register + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_BUF_START_REG_INDEX, stream_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_BUF_SIZE_REG_INDEX, stream_buffer_size >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // msg info rdptr + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // msg info wrptr + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_WR_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_1_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, remote_dest_tile_header_buffer_addr >> 4); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // STREAM_MISC_CFG_REG_INDEX + const uint32_t remote_src_update_noc_id = 1 - remote_src_noc_id; + uint32_t stream_msc_cfg_reg = 0; + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, INCOMING_DATA_NOC_WIDTH, INCOMING_DATA_NOC, remote_src_noc_id); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, OUTGOING_DATA_NOC_WIDTH, OUTGOING_DATA_NOC, remote_dest_noc_id); + stream_msc_cfg_reg = set_blob_reg_field( + stream_msc_cfg_reg, REMOTE_SRC_UPDATE_NOC_WIDTH, REMOTE_SRC_UPDATE_NOC, remote_src_update_noc_id); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_SOURCE_WIDTH, REMOTE_SOURCE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_RECEIVER_WIDTH, REMOTE_RECEIVER, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, PHASE_AUTO_CONFIG_WIDTH, PHASE_AUTO_CONFIG, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, PHASE_AUTO_ADVANCE_WIDTH, PHASE_AUTO_ADVANCE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, DATA_AUTO_SEND_WIDTH, DATA_AUTO_SEND, 1); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, NEXT_PHASE_DEST_CHANGE_WIDTH, NEXT_PHASE_DEST_CHANGE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, NEXT_PHASE_SRC_CHANGE_WIDTH, NEXT_PHASE_SRC_CHANGE, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, UNICAST_VC_REG_WIDTH, UNICAST_VC_REG, 0); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REG_UPDATE_VC_REG_WIDTH, REG_UPDATE_VC_REG, 1); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, DATA_BUF_NO_FLOW_CTRL_WIDTH, DATA_BUF_NO_FLOW_CTRL, 0); + stream_msc_cfg_reg = + set_blob_reg_field(stream_msc_cfg_reg, DEST_DATA_BUF_NO_FLOW_CTRL_WIDTH, DEST_DATA_BUF_NO_FLOW_CTRL, 0); + stream_msc_cfg_reg = set_blob_reg_field(stream_msc_cfg_reg, REMOTE_SRC_IS_MCAST_WIDTH, REMOTE_SRC_IS_MCAST, 0); + stream_msc_cfg_reg = set_blob_reg_field( + stream_msc_cfg_reg, NO_PREV_PHASE_OUTGOING_DATA_FLUSH_WIDTH, NO_PREV_PHASE_OUTGOING_DATA_FLUSH, 0); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MISC_CFG_REG_INDEX, stream_msc_cfg_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote src + // Remote src noc x/y is based on the update noc (because it sends updates, NOT data, to src, so it needs update + // noc) + uint32_t stream_remote_src_reg = 0; + uint32_t data_noc_in_src_noc_x = + remote_src_update_noc_id == 0 ? remote_src_noc_x : noc_size_x - 1 - remote_src_noc_x; + uint32_t data_noc_in_src_noc_y = + remote_src_update_noc_id == 0 ? remote_src_noc_y : noc_size_y - 1 - remote_src_noc_y; + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, STREAM_REMOTE_SRC_X_WIDTH, STREAM_REMOTE_SRC_X, data_noc_in_src_noc_x); + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, STREAM_REMOTE_SRC_Y_WIDTH, STREAM_REMOTE_SRC_Y, data_noc_in_src_noc_y); + stream_remote_src_reg = set_blob_reg_field( + stream_remote_src_reg, REMOTE_SRC_STREAM_ID_WIDTH, REMOTE_SRC_STREAM_ID, remote_src_stream_id); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_REG_INDEX, stream_remote_src_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote dest + // Remote dest noc x/y is NOT based on the update noc (because it is sending data to the dest, so it needs data noc) + uint32_t stream_remote_dest_reg = 0; + uint32_t data_noc_out_dest_noc_x = remote_dest_noc_id == 0 ? remote_dest_noc_x : noc_size_x - 1 - remote_dest_noc_x; + uint32_t data_noc_out_dest_noc_y = remote_dest_noc_id == 0 ? remote_dest_noc_y : noc_size_y - 1 - remote_dest_noc_y; + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, STREAM_REMOTE_DEST_X_WIDTH, STREAM_REMOTE_DEST_X, data_noc_out_dest_noc_x); + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, STREAM_REMOTE_DEST_Y_WIDTH, STREAM_REMOTE_DEST_Y, data_noc_out_dest_noc_y); + stream_remote_dest_reg = set_blob_reg_field( + stream_remote_dest_reg, + STREAM_REMOTE_DEST_STREAM_ID_WIDTH, + STREAM_REMOTE_DEST_STREAM_ID, + remote_dest_noc_stream_id); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_REG_INDEX, stream_remote_dest_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote_dest buf start + uint32_t stream_remote_dest_buf_start_reg_val = 0; + stream_remote_dest_buf_start_reg_val = set_blob_reg_field( + stream_remote_dest_buf_start_reg_val, + DRAM_WRITES__SCRATCH_1_PTR_LO_WIDTH, + DRAM_WRITES__SCRATCH_1_PTR_LO, + remote_dest_buf_addr >> 4); + *stream_phase_1_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_remote_dest_buf_start_reg_val); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // remote_dest buf size + uint32_t stream_remote_dest_buf_size_reg = 0; + stream_remote_dest_buf_size_reg = set_blob_reg_field( + stream_remote_dest_buf_size_reg, + REMOTE_DEST_BUF_SIZE_WORDS_WIDTH, + REMOTE_DEST_BUF_SIZE_WORDS, + remote_dest_buf_size_4B_words >> 4); + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, stream_remote_dest_buf_size_reg); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_PHASE_REG_INDEX, first_phase_remote_src_phase); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_REG_INDEX, local_first_phase); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + *stream_phase_1_reg_addr = blob_cfg_dw(STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + stream_phase_1_reg_addr++; + *stream_phase_1_reg_addr = 0; + + // + // phase 2 - we're unrolling one iteration of the first phase, so the second phase is mostly identical + // + volatile uint32_t* const stream_phase_2_start = stream_phase_1_reg_addr; + volatile uint32_t* stream_phase_2_stream_reg_addr = stream_phase_2_start + 1; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_BUF_START_REG_INDEX, stream_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // Local stream buffer size register + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_BUF_SIZE_REG_INDEX, stream_buffer_size >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // msg info rdptr + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MSG_INFO_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + // msg info wrptr + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_MSG_INFO_WR_PTR_REG_INDEX, stream_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, remote_dest_tile_header_buffer_addr >> 4); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MISC_CFG_REG_INDEX, stream_msc_cfg_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_REG_INDEX, stream_remote_src_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_DEST_REG_INDEX, stream_remote_dest_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_remote_dest_buf_start_reg_val); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = + blob_cfg_dw(STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, stream_remote_dest_buf_size_reg); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_CURR_PHASE_REG_INDEX, local_second_phase); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_PHASE_AUTO_CFG_PTR_BASE_REG_INDEX, 0); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_REMOTE_SRC_PHASE_REG_INDEX, second_phase_remote_src_phase); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + *stream_phase_2_stream_reg_addr = blob_cfg_dw(STREAM_PHASE_AUTO_CFG_PTR_REG_INDEX, stream_phase_1_start); + stream_phase_2_stream_reg_addr++; + *stream_phase_2_stream_reg_addr = 0; + + const uint32_t phase_1_num_cfg_regs = + ((reinterpret_cast(stream_phase_1_reg_addr) >> 2) - (stream_phase_1_start >> 2)) - 1; + uint32_t phase_2_num_cfg_regs = ((reinterpret_cast(stream_phase_2_stream_reg_addr) >> 2) - + (reinterpret_cast(stream_phase_2_start) >> 2)) - + 1; + + // We're supposed to put the **next** phase num config registers in the **current** phase's blob header. This means + // we need to flip the register counts between the two phases for their headers So in a sequence of 3 phases, the + // header blob on phase 1 would need the #cfg regs for phase 2. Phase 2's cfg header blob would need the #cfg regs + // for phase 3 and for phase 3, the #cfg regs in the header blob would be 0 (since no phase follows it) In our case, + // we just need to point to the opposite phase's #cfg regs + *reinterpret_cast(stream_phase_1_start) = + blob_header_dw(phase_2_num_cfg_regs, stream_tile_header_max_num_messages, 1); + *stream_phase_2_start = blob_header_dw(phase_1_num_cfg_regs, stream_tile_header_max_num_messages, 1); + + // Now kick off the stream + stream_phase_blob_run( + stream_id, + reinterpret_cast(stream_phase_1_start), + stream_tile_header_max_num_messages, + phase_1_num_cfg_regs); + + // Wait for sender and receiver to signal completion + while (*tx_rx_done_semaphore_addr != 2) { + asm volatile("nop"); + } + + // Now teardown the stream + // Unknown if it's safe to reset the stream while it's in a state before active + while ((NOC_STREAM_READ_REG(stream_id, STREAM_DEBUG_STATUS_REG_INDEX + 9) >> MEM_WORD_ADDR_WIDTH) != 0 || + !stream_phase_is_active(stream_id)) { + asm volatile("nop"); + } + + stream_reset(stream_id); + ASSERT(!assert_check(stream_id, false)); + for (auto ptr = reinterpret_cast(stream_phase_1_start); ptr != stream_phase_2_stream_reg_addr; + ptr++) { + *ptr = 0; + } +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp new file mode 100644 index 00000000000..a21474a048a --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp @@ -0,0 +1,272 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +// THESE TWO FUNCTIONS WERE ONLY VALID FOR WORMHOLE_B0 AND MAY NOT WORK WITH BLACKHOLE!!! +// STREAM_RECEIVER_ENDPOINT_MULTI_TILE_CLEAR_REG_INDEX is aliased to STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX for +// whb0 +inline bool is_stream_receiver_endpoint_tile_clearing_finished(uint32_t stream_id) { + return (NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX) == 0); +} +inline void stream_receiver_endpoint_tiles_clear_b0(uint32_t stream_id, uint32_t num_tiles) { + uint32_t clr_val = num_tiles; + clr_val *= 2; + clr_val = (~clr_val) + 1; + NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, clr_val); +} +////////////////////////////////////////////////////////////////////////////////////////// + +uint32_t get_receiver_stream_config_reg(uint32_t data_noc_id, uint32_t update_noc, bool drain_after_phase_send) { + uint32_t stream_cfg_reg = 0; + bool next_phase_src_dest_change = drain_after_phase_send ? 1 : 0; + stream_cfg_reg |= STREAM_CFG(INCOMING_DATA_NOC, data_noc_id) | STREAM_CFG(REMOTE_SRC_UPDATE_NOC, update_noc) | + STREAM_CFG(RECEIVER_ENDPOINT, 1) | STREAM_CFG(REMOTE_SOURCE, 1) | + STREAM_CFG(NEXT_PHASE_SRC_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(NEXT_PHASE_DEST_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(PHASE_AUTO_ADVANCE, 0) | STREAM_CFG(DATA_AUTO_SEND, 0) | + STREAM_CFG(REG_UPDATE_VC_REG, 1); + + return stream_cfg_reg; +} + +FORCE_INLINE bool messages_are_available(uint32_t stream_id, stream_state_t &stream_state) { + uint32_t wrptr = NOC_STREAM_READ_REG(stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX); + uint32_t rdptr = NOC_STREAM_READ_REG(stream_id, STREAM_MSG_INFO_PTR_REG_INDEX); + uint32_t internal_rdptr = stream_state.local_msg_info_ptr >> 4; + bool messages_available = internal_rdptr < wrptr; + return messages_available; +} + +FORCE_INLINE void flush_message_from_stream_buffer( + uint32_t stream_id, stream_state_t &stream_state, uint32_t msg_size_bytes) { + stream_receiver_endpoint_tiles_clear_b0(stream_id, 1); + while (!is_stream_receiver_endpoint_tile_clearing_finished(stream_id)) { + asm volatile(""); + } +} + +FORCE_INLINE uint32_t +get_next_available_stream_message_size_in_bytes(stream_state_t &stream_state, uint32_t stream_id) { + uint32_t msg_info_byte_ptr = stream_state.local_msg_info_ptr; + uint32_t msg_size_bytes = *reinterpret_cast(msg_info_byte_ptr) << 4; + ASSERT(msg_size_bytes > 0); + return msg_size_bytes; +} + +FORCE_INLINE std::tuple get_next_message_info(uint32_t stream_id, stream_state_t &stream_state) { + uint32_t rdptr_offset = NOC_STREAM_READ_REG(stream_id, STREAM_RD_PTR_REG_INDEX) << 4; + uint32_t addr = rdptr_offset + stream_state.local_data_buffer_base_address; + ASSERT((rdptr_offset & 0xF) == 0); + ASSERT((addr & 0xF) == 0); + return {addr, get_next_available_stream_message_size_in_bytes(stream_state, stream_id)}; +} + +FORCE_INLINE void advance_stream_state_struct( + uint32_t stream_id, stream_state_t &stream_state, uint32_t msg_size_bytes) { + uint32_t next_offset = stream_state.local_buffer_read_offset + msg_size_bytes; + if (next_offset >= stream_state.local_buffer_size) { + next_offset -= stream_state.local_buffer_size; + } + stream_state.local_buffer_read_offset = next_offset; + stream_state.local_msg_info_ptr += (1 << 4); +} + +FORCE_INLINE void advance_phase( + noc_endpoint_info_t const &remote_endpoint_info, stream_state_t &state, uint32_t stream_id) { + // This is remote receiver, so it sends messages (updates) to remote source, NOT data, so it uses + // the update noc to communicate to remote src instead of the data noc. Therefore, we need to set remote + // src x/y based on the update noc. + uint32_t translated_remote_noc_x = remote_endpoint_info.update_noc_id == 0 + ? remote_endpoint_info.noc_x + : noc_size_x - 1 - remote_endpoint_info.noc_x; + uint32_t translated_remote_noc_y = remote_endpoint_info.update_noc_id == 0 + ? remote_endpoint_info.noc_y + : noc_size_y - 1 - remote_endpoint_info.noc_y; + + NOC_STREAM_WRITE_REG(stream_id, STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(stream_id, STREAM_CURR_PHASE_REG_INDEX, ((uint32_t)state.local_phase_id)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_BUF_START_REG_INDEX, ((uint32_t)state.local_buffer_base_addr) >> 4); + NOC_STREAM_WRITE_REG(stream_id, STREAM_BUF_SIZE_REG_INDEX, state.local_buffer_size >> 4); + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_REMOTE_SRC_REG_INDEX, + STREAM_REMOTE_SRC(translated_remote_noc_x, translated_remote_noc_y, stream_id)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_SRC_PHASE_REG_INDEX, ((uint32_t)state.remote_phase_id)); + + NOC_STREAM_WRITE_REG(stream_id, STREAM_MEM_BUF_SPACE_AVAILABLE_ACK_THRESHOLD_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(stream_id, STREAM_MSG_INFO_PTR_REG_INDEX, ((uint32_t)state.local_msg_info_ptr) >> 4); + NOC_STREAM_WRITE_REG(stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX, ((uint32_t)state.local_msg_info_ptr) >> 4); + + NOC_STREAM_WRITE_REG( + stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_receiver_stream_config_reg(remote_endpoint_info.data_noc_id, remote_endpoint_info.update_noc_id, true)); + + NOC_STREAM_WRITE_REG( + stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, AUTO_CFG_HEADER(0, state.messages_per_phase, 0)); + NOC_STREAM_WRITE_REG(stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 0x1); +} + +FORCE_INLINE void advance_stream_to_next_message( + noc_endpoint_info_t const &remote_endpoint_info, + stream_state_t &state, + uint32_t stream_id, + uint32_t msg_size_bytes, + phase_iterator_t &local_phase_iterator, + phase_iterator_t &remote_phase_iterator) { + advance_stream_state_struct(stream_id, state, msg_size_bytes); + flush_message_from_stream_buffer(stream_id, state, msg_size_bytes); + + if (state.num_tiles_sent == state.tile_header_num_msgs - 1) { + remote_phase_iterator.increment(); + state.remote_phase_id = remote_phase_iterator.get(); + local_phase_iterator.increment(); + state.local_phase_id = local_phase_iterator.get(); + state.num_tiles_sent = 0; + state.local_msg_info_ptr = state.local_msg_info_ptr_base_address; + + advance_phase(remote_endpoint_info, state, stream_id); + state.local_buffer_read_offset = 0; + } else { + state.num_tiles_sent++; + } +} + +FORCE_INLINE void copy_message_to_cb_blocking( + uint32_t cb, uint32_t msg_addr, uint32_t msg_size_bytes, stream_state_t &stream_state) { + uint32_t cb_write_addr = get_write_ptr(cb); + uint64_t dest_noc_addr = get_noc_addr(cb_write_addr); + ASSERT((dest_noc_addr & 0xF) == 0); + ASSERT((msg_addr & 0xF) == 0); + uint32_t distance_until_end = + stream_state.local_buffer_size - (msg_addr - stream_state.local_data_buffer_base_address); + uint32_t bytes_to_copy = std::min(distance_until_end, msg_size_bytes); + + noc_async_write(msg_addr, dest_noc_addr, bytes_to_copy); + if (bytes_to_copy < msg_size_bytes) { + uint32_t bytes_to_copy_second = msg_size_bytes - bytes_to_copy; + noc_async_write( + stream_state.local_data_buffer_base_address, dest_noc_addr + bytes_to_copy, bytes_to_copy_second); + uint32_t num_words = bytes_to_copy_second >> 2; + } + noc_async_write_barrier(); +} + +void kernel_main() { + uint32_t arg_idx = 0; + + uint32_t num_messages_to_forward = get_arg_val(arg_idx++); + + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_src_noc_x = get_arg_val(arg_idx++); + uint32_t remote_src_noc_y = get_arg_val(arg_idx++); + uint32_t remote_src_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_src_data_noc_id = get_arg_val(arg_idx++); + uint32_t remote_src_buffer_addr = get_arg_val(arg_idx++); + uint32_t remote_src_buffer_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_src_tile_header_buffer_addr = get_arg_val(arg_idx++); + + uint32_t relay_done_semaphore_addr = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_x = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_y = get_arg_val(arg_idx++); + uint32_t other_relay_done_semaphore = get_arg_val(arg_idx++); + + uint32_t sender_noc_x = get_arg_val(arg_idx++); + uint32_t sender_noc_y = get_arg_val(arg_idx++); + uint32_t sender_wait_finish_semaphore = get_arg_val(arg_idx++); + uint32_t remote_src_start_phase_addr = get_arg_val(arg_idx++); + + const uint32_t first_phase_remote_src_phase = + wait_for_remote_source_starting_phase(reinterpret_cast(remote_src_start_phase_addr)); + const uint32_t second_phase_remote_src_phase = first_phase_remote_src_phase + 1; + const uint32_t local_first_phase = get_first_available_phase_out_of_reset(stream_id); + const uint32_t local_second_phase = local_first_phase; + + auto local_phase_iterator = phase_iterator_t(local_first_phase, local_second_phase); + auto remote_phase_iterator = phase_iterator_t(first_phase_remote_src_phase, second_phase_remote_src_phase); + + stream_state_t stream_state{ + stream_buffer_addr, + stream_tile_header_buffer_addr, + + local_phase_iterator.get(), // phase_id + stream_tile_header_max_num_messages, + + stream_tile_header_buffer_addr, // msg_info_wrptr_addr; + + 0, // num_tiles_sent; + stream_tile_header_max_num_messages, // tile_header_num_msgs; + + stream_buffer_addr, // dest_buffer_base_addr; + stream_buffer_size, // dest_buffer_size; + stream_tile_header_buffer_addr, // dest_msg_info_ptr; + + 0, // src_buffer_read_offset; + + remote_src_buffer_addr, // src_buffer_base_addr; + remote_src_buffer_size_4B_words, // src_buffer_size; + remote_src_tile_header_buffer_addr, // src_msg_info_ptr; + + 0, // dest_buffer_write_offset; + remote_phase_iterator.get(), // receiver start phase + }; + + ASSERT((stream_state.local_data_buffer_base_address & 0xf) == 0); + + auto remote_noc_info_desc = + noc_endpoint_info_t{remote_src_data_noc_id, 1 - remote_src_data_noc_id, remote_src_noc_x, remote_src_noc_y}; + + advance_phase(remote_noc_info_desc, stream_state, stream_id); + + auto cb = tt::CB::c_in0; + stream_state.local_buffer_base_addr = stream_buffer_addr; + + for (uint32_t i = 0; i < num_messages_to_forward; i++) { + cb_reserve_back(cb, 1); + + while (!messages_are_available(stream_id, stream_state)) { + asm volatile("nop"); + } + + auto const &[msg_addr, msg_size_bytes] = get_next_message_info(stream_id, stream_state); + ASSERT(msg_size_bytes > 0); + ASSERT(msg_size_bytes <= stream_state.local_buffer_size); + + copy_message_to_cb_blocking(cb, msg_addr, msg_size_bytes, stream_state); + + cb_push_back(cb, 1); + + stream_relay_tiles(stream_id, 1, msg_size_bytes >> 4); + advance_stream_to_next_message( + remote_noc_info_desc, stream_state, stream_id, msg_size_bytes, local_phase_iterator, remote_phase_iterator); + } + + noc_semaphore_inc(get_noc_addr(sender_noc_x, sender_noc_y, sender_wait_finish_semaphore), 1); + + while ((NOC_STREAM_READ_REG(stream_id, STREAM_DEBUG_STATUS_REG_INDEX + 9) >> MEM_WORD_ADDR_WIDTH) != 0) { + asm volatile("nop"); + } + + stream_reset(stream_id); + + noc_semaphore_inc( + get_noc_addr(remote_noc_info_desc.noc_x, remote_noc_info_desc.noc_y, relay_done_semaphore_addr), 1); + noc_semaphore_inc( + get_noc_addr(other_relay_core_to_signal_x, other_relay_core_to_signal_y, other_relay_done_semaphore), 1); + + ASSERT(!assert_check(stream_id, false)); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp new file mode 100644 index 00000000000..470ef6a4264 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + uint32_t arg_idx = 0; + + constexpr uint32_t msg_hdr_size = get_compile_time_arg_val(0); + + uint32_t output_buffer_addr = get_arg_val(arg_idx++); + uint32_t cb_page_size = get_arg_val(arg_idx++); + uint32_t num_pages = get_arg_val(arg_idx++); + + uint32_t write_page_size = cb_page_size - msg_hdr_size; + const InterleavedAddrGen dest_addr_gen = { + .bank_base_address = output_buffer_addr, .page_size = write_page_size}; + + auto cb = tt::CB::c_in0; + for (uint32_t i = 0; i < num_pages; i++) { + cb_wait_front(cb, 1); + // NOTE THAT msg_hdr_size is doubled on host side to maintain alignment for DRAM reads/writes in THIS TEST ONLY + uint32_t src_start = get_read_ptr(cb) + msg_hdr_size; + + uint64_t dst_noc_addr = get_noc_addr(i, dest_addr_gen); + noc_async_write(src_start, dst_noc_addr, write_page_size); + + noc_async_write_barrier(); + cb_pop_front(cb, 1); + } + +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp new file mode 100644 index 00000000000..606930d73ff --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "stream_interface.h" +#include "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_io_kernel_helpers.hpp" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +////////// +/// FUTURE OPTIMIZATIONS +/////////// +// 1) Don't update message info rd/wrptrs. Instead, just write message size into the next corresponding message info +// buffer entry 2) Use stream registers to track # messages sent 3) For contiguous messages, use a single stream phase +// to send them back to back then only do one wait for flush at the end + +////////// +// Q/A W/ Djordje + Extra Notes +// +// 1) DON'T set any of the STREAM_REMOTE_DEST_* registers if NEXT_PHASE_SRC_CHANGE is false +// 2) stream_phase_advance_wait can be used to wait for the current phase to complete +// -> in the scheme for this producer, it'll end up waiting until the message is sent out of L1 +// 3) How does initial stream handshake happen? +// -> Stream has hidden registers: curr_phase_src/dest_change. When comming out of reset, these are set true +// This value is sticky and the next_phase_src/dest_change will override it for the next phase +/////// + +uint32_t get_sender_stream_config_reg(uint32_t tx_noc_id, uint32_t rx_src_update_noc, bool drain_after_phase_send) { + uint32_t stream_cfg_reg = 0; + bool next_phase_src_dest_change = drain_after_phase_send ? 1 : 0; + stream_cfg_reg |= STREAM_CFG(OUTGOING_DATA_NOC, tx_noc_id) | STREAM_CFG(REMOTE_SRC_UPDATE_NOC, rx_src_update_noc) | + STREAM_CFG(SOURCE_ENDPOINT, 1) | STREAM_CFG(REMOTE_RECEIVER, 1) | + STREAM_CFG(NEXT_PHASE_SRC_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(NEXT_PHASE_DEST_CHANGE, next_phase_src_dest_change) | + STREAM_CFG(PHASE_AUTO_ADVANCE, 0) | STREAM_CFG(DATA_AUTO_SEND, 0) | + STREAM_CFG(REG_UPDATE_VC_REG, 1); + + return stream_cfg_reg; +} + +FORCE_INLINE void write_message_size_to_message_info_buffer( + stream_state_t const &stream_state, uint32_t message_size_noc_words) { + ASSERT((message_size_noc_words << 4) <= stream_state.local_buffer_size); + if (!((message_size_noc_words << 4) <= stream_state.local_buffer_size)) { + DPRINT << "YIKES\n"; + } + *reinterpret_cast(stream_state.local_msg_info_ptr) = message_size_noc_words; +} + +FORCE_INLINE void reset_stream_message_info_buffer_rdptr(stream_state_t &stream_state, uint32_t stream_id) { + stream_state.local_msg_info_ptr = stream_state.local_msg_info_ptr_base_address; + NOC_STREAM_WRITE_REG( + stream_id, STREAM_MSG_INFO_PTR_REG_INDEX, ((uint32_t)(stream_state.local_msg_info_ptr_base_address >> 4))); + NOC_STREAM_WRITE_REG( + stream_id, STREAM_MSG_INFO_WR_PTR_REG_INDEX, (((uint32_t)stream_state.local_msg_info_ptr_base_address >> 4))); +} +FORCE_INLINE void advance_stream_message_info_buffer_wrptr( + stream_state_t &stream_state, uint32_t stream_id, uint32_t message_size) { + stream_state.local_msg_info_ptr += (1 << 4); + stream_state.local_buffer_read_offset += message_size; + if (stream_state.local_buffer_read_offset >= stream_state.local_buffer_size) { + stream_state.local_buffer_read_offset -= stream_state.local_buffer_size; + } +} + +FORCE_INLINE void wait_for_stream_write_complete(uint32_t sender_stream_id) { + while (!stream_phase_advance_wait(sender_stream_id)) { + asm volatile("nop"); + } +} + +FORCE_INLINE void copy_from_cb_to_stream_buffer( + stream_state_t &stream_state, uint32_t message_base, uint32_t message_size_noc_words) { + ASSERT((message_size_noc_words << 4) <= stream_state.local_buffer_size); + if (!((message_size_noc_words << 4) <= stream_state.local_buffer_size)) { + DPRINT << "YIKES2\n"; + } + uint32_t message_size_size_in_bytes = message_size_noc_words << 4; + uint32_t bytes_to_copy = + std::min(stream_state.local_buffer_size - stream_state.local_buffer_read_offset, message_size_size_in_bytes); + noc_async_write(message_base, get_noc_addr(stream_state.get_current_local_buffer_address()), bytes_to_copy); + ASSERT(stream_state.local_buffer_size + stream_state.local_buffer_read_offset >= bytes_to_copy); + if (!(stream_state.local_buffer_size + stream_state.local_buffer_read_offset >= bytes_to_copy)) { + DPRINT << "YIKES3\n"; + } + + if (bytes_to_copy < message_size_size_in_bytes) { + uint32_t second_bytes_to_copy = message_size_size_in_bytes - bytes_to_copy; + noc_async_write( + message_base + bytes_to_copy, get_noc_addr(stream_state.local_buffer_base_addr), second_bytes_to_copy); + } + noc_async_write_barrier(); +} + +FORCE_INLINE void hang_toggle(volatile uint32_t *hang_toggle_semaphore) { + return; + while (*hang_toggle_semaphore == 0) { + asm volatile(""); + } + *hang_toggle_semaphore = 0; +} + +FORCE_INLINE void stream_noc_write( + stream_state_t &stream_state, + uint32_t message_base, + uint32_t sender_stream_id, + uint32_t dest_addr, + uint32_t remote_noc_x, + uint32_t remote_noc_y, + uint32_t dest_noc_id, + uint32_t dest_tile_header_buffer_addr, + uint32_t local_start_phase, + bool very_first_message, + volatile uint32_t *hang_toggle_semaphore, + uint32_t message_id) { + const uint32_t tiles_per_phase = stream_state.messages_per_phase; + + uint32_t message_size_noc_words = *reinterpret_cast(message_base); + + uint32_t dest_noc_reg = 0; + uint32_t num_tiles = stream_state.num_tiles_sent; + const bool send_last_message_and_drain = num_tiles == (stream_state.tile_header_num_msgs - 1); + + bool first_message = num_tiles == 0; + + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_CURR_PHASE_BASE_REG_INDEX, 0); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_CURR_PHASE_REG_INDEX, stream_state.local_phase_id); + + if (first_message) { + reset_stream_message_info_buffer_rdptr(stream_state, sender_stream_id); + stream_state.local_buffer_read_offset = 0; + } + copy_from_cb_to_stream_buffer(stream_state, message_base, message_size_noc_words); + + if (message_id < 10) { + hang_toggle(hang_toggle_semaphore); + } + + uint32_t rx_src_update_noc = 1 - dest_noc_id; + if (send_last_message_and_drain) { + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_sender_stream_config_reg(dest_noc_id, rx_src_update_noc, true)); + + } else if (first_message) { + // ASSERT(stream_state.remote_buffer_base_addr + stream_state.local_buffer_size <= + // stream_state.remote_buffer_size || + // stream_state.remote_buffer_size + (stream_state.tile_header_num_msgs << 4) <= + // stream_state.remote_buffer_base_addr); + + uint32_t rx_src_update_noc = 1 - dest_noc_id; + uint32_t translated_remote_noc_x = dest_noc_id == 0 ? remote_noc_x : noc_size_x - 1 - remote_noc_x; + uint32_t translated_remote_noc_y = dest_noc_id == 0 ? remote_noc_y : noc_size_y - 1 - remote_noc_y; + uint32_t dest_stream_id = sender_stream_id; + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_BUF_START_REG_INDEX, + ((uint32_t)stream_state.get_current_local_buffer_address()) >> 4); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_BUF_SIZE_REG_INDEX, stream_state.local_buffer_size >> 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_REMOTE_DEST_REG_INDEX, + STREAM_REMOTE_DEST(translated_remote_noc_x, translated_remote_noc_y, dest_stream_id)); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_HI_REG_INDEX, 0); + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX, stream_state.remote_msg_info_ptr >> 4); + + // DPRINT << "STREAM_REMOTE_DEST_MSG_INFO_WR_PTR_REG_INDEX: " << (uint32_t)(stream_state.remote_msg_info_ptr >> + // 4) << "\n"; + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_REMOTE_DEST_BUF_START_REG_INDEX, stream_state.remote_buffer_base_addr >> 4); + // Inserting an assert here causes test to pass + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_REMOTE_DEST_BUF_START_HI_REG_INDEX, + (stream_state.remote_buffer_base_addr / MEM_WORD_WIDTH) >> MEM_WORD_ADDR_WIDTH); + NOC_STREAM_WRITE_REG_FIELD( + sender_stream_id, + STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, + REMOTE_DEST_BUF_SIZE_WORDS, + stream_state.remote_buffer_size >> 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, + STREAM_MISC_CFG_REG_INDEX, + get_sender_stream_config_reg(dest_noc_id, rx_src_update_noc, false)); + } + + if (first_message) { + // DPRINT << "Msg info ptr: " << (uint32_t)stream_state.local_msg_info_ptr << "\n"; + } + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + write_message_size_to_message_info_buffer(stream_state, message_size_noc_words); + advance_stream_message_info_buffer_wrptr(stream_state, sender_stream_id, message_size_noc_words << 4); + + NOC_STREAM_WRITE_REG( + sender_stream_id, STREAM_PHASE_AUTO_CFG_HEADER_REG_INDEX, AUTO_CFG_HEADER(0, 1 /*tiles_per_phase*/, 1)); + NOC_STREAM_WRITE_REG(sender_stream_id, STREAM_PHASE_ADVANCE_REG_INDEX, 0x1); + + if (first_message) { + // wait for handshake to complete + while (!stream_phase_is_active(sender_stream_id)) { + asm volatile(""); + } + } + + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + if (send_last_message_and_drain) { + // We only wrap around to 0 when the remote receiver relay stream has finished its second phase. We need to do + // this to avoid any handshake bugs we might hit if the second phase of relay must sync with phase 1 of the + // producer (this) since the relay will handshake with phase 1 of the producer (this) stream for relay stream's + // first phase too + num_tiles = 0; + stream_state.remote_phase_id = 3 - stream_state.remote_phase_id; // will alternate between 1 and 2 + // Remote phase was already updated so the condition is inverted + stream_state.local_phase_id = + (stream_state.remote_phase_id == 1) ? local_start_phase : stream_state.local_phase_id + 1; + } else { + num_tiles++; + stream_state.local_phase_id++; + } + + stream_relay_tiles(sender_stream_id, 1, message_size_noc_words); + wait_for_stream_write_complete(sender_stream_id); + + if (very_first_message) { + hang_toggle(hang_toggle_semaphore); + } + + stream_state.num_tiles_sent = num_tiles; +} + +void kernel_main() { + uint32_t arg_idx = 0; + + uint32_t num_messages_to_forward = get_arg_val(arg_idx++); + + uint32_t stream_id = get_arg_val(arg_idx++); + uint32_t stream_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_buffer_size = get_arg_val(arg_idx++); + uint32_t stream_tile_header_buffer_addr = get_arg_val(arg_idx++); + uint32_t stream_tile_header_max_num_messages = get_arg_val(arg_idx++); + + uint32_t remote_dest_noc_x = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_y = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_stream_id = get_arg_val(arg_idx++); + uint32_t remote_dest_noc_id = get_arg_val(arg_idx++); + uint32_t remote_dest_buffer_addr = get_arg_val(arg_idx++); + uint32_t remote_dest_buffer_size_4B_words = get_arg_val(arg_idx++); + uint32_t remote_dest_tile_header_buffer_addr = get_arg_val(arg_idx++); + + uint32_t relay_done_semaphore_addr = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_x = get_arg_val(arg_idx++); + uint32_t other_relay_core_to_signal_y = get_arg_val(arg_idx++); + uint32_t other_relay_done_semaphore = get_arg_val(arg_idx++); + + uint32_t wait_receiver_semaphore = get_arg_val(arg_idx++); + *reinterpret_cast(wait_receiver_semaphore) = 0; + + uint32_t first_relay_remote_src_start_phase_addr = get_arg_val(arg_idx++); + volatile uint32_t *hang_toggle_semaphore = reinterpret_cast(get_arg_val(arg_idx++)); + + uint32_t local_starting_phase = + notify_remote_receiver_of_starting_phase( + stream_id, + stream_buffer_addr, + get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, first_relay_remote_src_start_phase_addr)) - + 1; + + // clear the buffers + for (uint32_t i = 0; i < stream_buffer_size / sizeof(uint32_t); i++) { + reinterpret_cast(stream_buffer_addr)[i] = 0; + } + for (uint32_t i = 0; i < stream_tile_header_max_num_messages * 4; i++) { + reinterpret_cast(stream_tile_header_buffer_addr)[i] = 0; + } + + stream_state_t stream_state{ + stream_buffer_addr, + stream_tile_header_buffer_addr, + + local_starting_phase, // phase_id + stream_tile_header_max_num_messages, // messages_per_phase + + stream_tile_header_buffer_addr, // msg_info_wrptr_addr; + + 0, // num_tiles_sent; + stream_tile_header_max_num_messages, // tile_header_num_msgs; + + stream_buffer_addr, // src_buffer_base_addr; + stream_buffer_size, // src_buffer_size; + stream_tile_header_buffer_addr, // src_msg_info_ptr; + 0, // src_buffer_read_offset; + + remote_dest_buffer_addr, // dest_buffer_base_addr; + remote_dest_buffer_size_4B_words, // dest_buffer_size; + remote_dest_tile_header_buffer_addr, // dest_msg_info_ptr; + 0, // dest_buffer_write_offset; + + 1, // receiver_phase; // receiver start phase // don't need the true value + }; + + DPRINT << "hang_toggle_semaphore: " << (uint32_t)hang_toggle_semaphore << "\n"; + + hang_toggle(hang_toggle_semaphore); + + auto cb = tt::CB::c_in0; + bool very_first_message = true; + + uint32_t message_id = 0; + uint32_t count = 0; + for (uint32_t i = 0; i < num_messages_to_forward; i++) { + cb_wait_front(cb, 1); + uint32_t src_addr = get_read_ptr(cb); + stream_noc_write( + stream_state, + src_addr, + stream_id, + stream_state.remote_buffer_base_addr, + remote_dest_noc_x, + remote_dest_noc_y, + remote_dest_noc_id, + remote_dest_tile_header_buffer_addr, + local_starting_phase, + very_first_message, + hang_toggle_semaphore, + message_id); + + cb_pop_front(cb, 1); + // if (count == 1000) { + // DPRINT << "Sent " << i << " messages\n"; + // count = 0; + // } else { + // count++; + // } + very_first_message = false; + message_id++; + } + + // Reset sequence is that both the remote sender and remote receiver streams of the relay + // should reset first so that no data is in flight. Sender and receiver must ensure that no + // payloads are in flight to the relay stream(s) before sending the reset signal to the relay + // core + noc_semaphore_wait(reinterpret_cast(wait_receiver_semaphore), 1); + + stream_reset(stream_id); + + noc_semaphore_inc( + get_noc_addr(other_relay_core_to_signal_x, other_relay_core_to_signal_y, other_relay_done_semaphore), 1); + noc_semaphore_inc(get_noc_addr(remote_dest_noc_x, remote_dest_noc_y, relay_done_semaphore_addr), 1); + + ASSERT(!assert_check(stream_id, false)); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp new file mode 100644 index 00000000000..2127013baac --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h" + +void kernel_main() { + uint32_t arg_idx = 0; + + constexpr uint32_t msg_hdr_size = get_compile_time_arg_val(0); + constexpr bool enable_page_size_variations = get_compile_time_arg_val(1) == 1; + + const uint32_t input_buffer_addr = get_arg_val(arg_idx++); + const uint32_t cb_page_size = get_arg_val(arg_idx++); + const uint32_t num_pages = get_arg_val(arg_idx++); + + constexpr uint32_t num_sizes = 8; + std::array sub_sizes = {}; + for (uint32_t i = 0; i < num_sizes; i++) { + sub_sizes[i] = get_arg_val(arg_idx++); + } + + const uint32_t read_page_size = cb_page_size - msg_hdr_size; + const InterleavedAddrGen src_addr_gen = {.bank_base_address = input_buffer_addr, .page_size = read_page_size}; + + auto cb = tt::CB::c_in0; + + uint32_t sub_index = 0; + + for (uint32_t i = 0; i < num_pages; i++) { + cb_reserve_back(cb, 1); + volatile uint32_t *page_header_addr = reinterpret_cast(get_write_ptr(cb)); + // NOTE THAT msg_hdr_size is doubled on host side to maintain alignment for the DRAM reads in THIS TEST ONLY + uint32_t data_out_start = reinterpret_cast(page_header_addr) + msg_hdr_size; + uint64_t src_noc_addr = get_noc_addr(i, src_addr_gen); + uint32_t message_header_size = + (read_page_size >> 4) + 2; // one for header one for padding to maintain noc word alignment + if (enable_page_size_variations) { + if (message_header_size < sub_sizes[sub_index] || sub_index >= 8) { + DPRINT << "REMOTE SENDER READER ERROR!\n"; + } + message_header_size -= sub_sizes[sub_index]; + sub_index = sub_index == num_sizes - 1 ? 0 : sub_index + 1; + } + page_header_addr[0] = message_header_size; + page_header_addr[1] = 0; + page_header_addr[2] = 0; + page_header_addr[3] = 0; + page_header_addr[4] = 0; + page_header_addr[5] = 0; + page_header_addr[6] = 0; + page_header_addr[7] = 0; + + noc_async_read(src_noc_addr, data_out_start, read_page_size); + + // TODO: upgrade to look at the writes acked counter instead + noc_async_read_barrier(); + cb_push_back(cb, 1); + } +} diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt index 1dd900cc7d4..dd25ee844ad 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt @@ -10,6 +10,7 @@ set(UNIT_TESTS_FD_SRC ${CMAKE_CURRENT_SOURCE_DIR}/multichip/test_eth_EnqueueProgram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/multichip/test_eth_ring_gather_EnqueueProgram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/pipelining/basic_pipeline.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/streams/test_autonomous_relay_streams.cpp ) add_executable(unit_tests_fast_dispatch ${UNIT_TESTS_FD_SRC} $) diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp new file mode 100644 index 00000000000..2c963a0796d --- /dev/null +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp @@ -0,0 +1,973 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include + +#include "device/tt_arch_types.h" +#include "gtest/gtest.h" +#include "tests/tt_metal/tt_metal/unit_tests_fast_dispatch/common/command_queue_fixture.hpp" +#include "tt_metal/common/logger.hpp" +// #include "impl/device/device.hpp" +#include "impl/kernels/data_types.hpp" +#include "impl/kernels/kernel_types.hpp" +#include "tt_metal/common/core_coord.h" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +// #include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/detail/persistent_kernel_cache.hpp" +#include "tt_metal/test_utils/stimulus.hpp" + +using tt::tt_metal::Device; + +constexpr uint32_t num_sizes = 8; +namespace tt { + +namespace tt_metal { + +struct hop_eth_sockets { + chip_id_t receiver_device_id; + CoreCoord receiver_core; + chip_id_t sender_device_id; + CoreCoord sender_core; +}; + +struct stream_config_t { + uint32_t buffer_addr; + uint32_t buffer_size; // in bytes + uint32_t tile_header_buffer_addr; + uint32_t tile_header_num_msgs; + uint32_t tile_header_buffer_size; // in bytes +}; + +struct stream_builder_spec_t { + uint32_t buffer_size_bytes; + uint32_t tile_header_buffer_size_bytes; +}; + +constexpr uint32_t relay_stream_id = 32; +constexpr uint32_t tile_header_size = 32; // needs to provide noc word alignment +// constexpr uint32_t tile_header_size = 16; +constexpr uint32_t noc_word_size = 16; + +// Reads data from input +std::vector get_sender_reader_rt_args( + Device* device, + uint32_t input_buffer_addr, + uint32_t page_size_plus_header, + uint32_t num_messages_to_read, + std::array const& sub_sizes) { + auto args = std::vector{input_buffer_addr, page_size_plus_header, num_messages_to_read}; + for (auto const& sub_size : sub_sizes) { + args.push_back(sub_size); + } + return args; +} +// sender stream data mover kernel +std::vector get_sender_writer_rt_args( + Device* device, + uint32_t num_messages, + uint32_t relay_done_semaphore, + CoreCoord const& relay_core, + uint32_t sender_noc_id, + stream_config_t const& sender_stream_config, + stream_config_t const& relay_stream_config, + CoreCoord const& other_relay_to_notify_when_done, + uint32_t other_relay_done_semaphore, + uint32_t sender_wait_for_receiver_semaphore, + uint32_t first_relay_remote_src_start_phase_addr, + uint32_t hang_toggle_addr) { + return std::vector{ + num_messages, + + relay_stream_id, + sender_stream_config.buffer_addr, + sender_stream_config.buffer_size, + sender_stream_config.tile_header_buffer_addr, + relay_stream_config.tile_header_num_msgs, + + static_cast(device->worker_core_from_logical_core(relay_core).x), + static_cast(device->worker_core_from_logical_core(relay_core).y), + relay_stream_id, + sender_noc_id, + + relay_stream_config.buffer_addr, + relay_stream_config.buffer_size, + relay_stream_config.tile_header_buffer_addr, + + relay_done_semaphore, + static_cast(device->worker_core_from_logical_core(other_relay_to_notify_when_done).x), + static_cast(device->worker_core_from_logical_core(other_relay_to_notify_when_done).y), + other_relay_done_semaphore, + + static_cast(sender_wait_for_receiver_semaphore), + first_relay_remote_src_start_phase_addr, + hang_toggle_addr}; +} + +std::vector get_relay_rt_args( + Device* device, + uint32_t relay_stream_overlay_blob_addr, + uint32_t relay_done_semaphore, + CoreCoord const& sender_core, + CoreCoord const& receiver_core, + uint32_t sender_noc_id, + uint32_t receiver_noc_id, + // stream_config_t const& sender_stream_config, + stream_config_t const& relay_stream_config, + stream_config_t const& receiver_stream_config, + uint32_t remote_src_start_phase_addr, + uint32_t dest_remote_src_start_phase_addr, + bool is_first_relay_in_chain) { + return std::vector{ + static_cast(relay_stream_overlay_blob_addr), + static_cast(relay_stream_id), + static_cast(relay_stream_config.buffer_addr), + static_cast(relay_stream_config.buffer_size), + static_cast(relay_stream_config.tile_header_buffer_addr), + static_cast(relay_stream_config.tile_header_num_msgs), + + // noc0 address + static_cast(device->worker_core_from_logical_core(sender_core).x), + static_cast(device->worker_core_from_logical_core(sender_core).y), + static_cast(relay_stream_id), + static_cast(sender_noc_id), + + static_cast(device->worker_core_from_logical_core(receiver_core).x), + static_cast(device->worker_core_from_logical_core(receiver_core).y), + static_cast(relay_stream_id), + static_cast(receiver_noc_id), + static_cast(receiver_stream_config.buffer_addr), + static_cast(receiver_stream_config.buffer_size), + static_cast(receiver_stream_config.tile_header_buffer_addr), + + static_cast(relay_done_semaphore), + static_cast(is_first_relay_in_chain ? 1 : 0), + + remote_src_start_phase_addr, + dest_remote_src_start_phase_addr}; +} + +// Receiver stream data mover kernel +std::vector get_receiver_reader_rt_args( + Device* device, + uint32_t num_messages, + uint32_t relay_done_semaphore, + CoreCoord const& relay_core, + uint32_t receiver_noc_id, + stream_config_t const& relay_stream_config, + stream_config_t const& receiver_stream_config, + CoreCoord const& other_relay_core_to_notify_when_done, + uint32_t other_relay_done_semaphore, + CoreCoord const& sender_core, + uint32_t sender_receiver_semaphore, + uint32_t remote_src_start_phase_addr) { + return std::vector{ + static_cast(num_messages), + static_cast(relay_stream_id), + static_cast(receiver_stream_config.buffer_addr), + static_cast(receiver_stream_config.buffer_size), + static_cast(receiver_stream_config.tile_header_buffer_addr), + static_cast(receiver_stream_config.tile_header_num_msgs), + static_cast(device->worker_core_from_logical_core(relay_core).x), + static_cast(device->worker_core_from_logical_core(relay_core).y), + static_cast(relay_stream_id), + static_cast(receiver_noc_id), + static_cast(relay_stream_config.buffer_addr), + static_cast(relay_stream_config.buffer_size), + static_cast(relay_stream_config.tile_header_buffer_addr), + + static_cast(relay_done_semaphore), + static_cast(device->worker_core_from_logical_core(other_relay_core_to_notify_when_done).x), + static_cast(device->worker_core_from_logical_core(other_relay_core_to_notify_when_done).y), + other_relay_done_semaphore, + + static_cast(device->worker_core_from_logical_core(sender_core).x), + static_cast(device->worker_core_from_logical_core(sender_core).y), + sender_receiver_semaphore, + remote_src_start_phase_addr}; +} +std::vector get_receiver_writer_rt_args( + Device* device, uint32_t output_buffer_addr, uint32_t page_size, uint32_t num_messages_to_read) { + return std::vector{output_buffer_addr, page_size, num_messages_to_read}; +} + +// TODO: randomize each noc for testing purposes +void build_and_run_autonomous_stream_test( + std::vector& programs, + std::vector const& devices, + std::size_t num_messages, + std::size_t page_size, + uint32_t tile_header_buffer_num_messages, + stream_builder_spec_t const& sender_stream_spec, + stream_builder_spec_t const& relay_stream_spec, + stream_builder_spec_t const& receiver_stream_spec, + bool enable_page_size_variations, + std::array const& sub_sizes, + std::size_t num_loop_iterations) { + TT_ASSERT(programs.size() == 0); + // Make configurable + const uint32_t read_write_cb_num_pages = 8; + const uint32_t page_size_plus_header = page_size + tile_header_size; + + const uint32_t sender_stream_buffer_num_pages = sender_stream_spec.buffer_size_bytes / page_size; + const uint32_t relay_stream_buffer_num_pages = relay_stream_spec.buffer_size_bytes / page_size; + const uint32_t receiver_stream_buffer_num_pages = receiver_stream_spec.buffer_size_bytes / page_size; + + const uint32_t sender_stream_buffer_size_bytes = sender_stream_buffer_num_pages * page_size_plus_header; + const uint32_t relay_stream_buffer_size_bytes = relay_stream_buffer_num_pages * page_size_plus_header; + const uint32_t receiver_stream_buffer_size_bytes = receiver_stream_buffer_num_pages * page_size_plus_header; + uint32_t stream_tile_header_buffer_size_bytes = tile_header_buffer_num_messages * tile_header_size; + uint32_t relay_stream_overlay_blob_size_bytes = 256; + + programs.emplace_back(); + Device* device = devices.at(0); + Program& program = programs.at(0); + log_trace(tt::LogTest, "Device ID: {}", device->id()); + + CoreCoord sender_core = CoreCoord(0, 0); + CoreCoord first_relay_core = CoreCoord(1, 0); + CoreCoord second_relay_core = CoreCoord(2, 0); + CoreCoord receiver_core = CoreCoord(3, 0); + + log_trace( + tt::LogTest, + "sender_core: x={}, y={}", + device->physical_core_from_logical_core(sender_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(sender_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "first_relay_core: x={}, y={}", + device->physical_core_from_logical_core(first_relay_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(first_relay_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "second_relay_core: x={}, y={}", + device->physical_core_from_logical_core(second_relay_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(second_relay_core, CoreType::WORKER).y); + log_trace( + tt::LogTest, + "receiver_core: x={}, y={}", + device->physical_core_from_logical_core(receiver_core, CoreType::WORKER).x, + device->physical_core_from_logical_core(receiver_core, CoreType::WORKER).y); + + // Input DRAM buffer creation + uint32_t buffer_size_bytes = num_messages * page_size; + auto inputs = test_utils::generate_uniform_random_vector(0, 100, buffer_size_bytes / sizeof(uint32_t)); + std::iota(inputs.begin(), inputs.end(), 1); + // for (auto i = 0; i < inputs.size(); i += page_size) { + // for (auto ii = 0; ii < std::min(page_size, inputs.size() - i); ii++) { + // inputs.at(i + ii) = i + 1; + // } + // } + + auto zeroes_buffer = std::vector(buffer_size_bytes / sizeof(uint32_t), 0); + std::vector outputs(buffer_size_bytes / sizeof(uint32_t), 0); + log_trace(tt::LogTest, "outputs.size(): {}", outputs.size()); + log_trace(tt::LogTest, "inputs.size(): {}", inputs.size()); + auto input_buffer = CreateBuffer( + InterleavedBufferConfig{device, static_cast(num_messages * page_size), page_size, BufferType::DRAM}); + auto output_buffer = CreateBuffer( + InterleavedBufferConfig{device, static_cast(num_messages * page_size), page_size, BufferType::DRAM}); + + tt_metal::EnqueueWriteBuffer(device->command_queue(), input_buffer, inputs, false); + // Explicitly overwrite to 0 in case of left over state from prior run(s) + tt_metal::EnqueueWriteBuffer(device->command_queue(), output_buffer, zeroes_buffer, true); + const uint32_t dram_input_buf_base_addr = input_buffer->address(); + + // For overlay blob on relay core + constexpr uint32_t dummy_cb_index3 = CB::c_in3; + auto const& relay_stream_overlay_blob_buffer_cb_config = + tt_metal::CircularBufferConfig( + relay_stream_overlay_blob_size_bytes, {{dummy_cb_index3, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index3, relay_stream_overlay_blob_size_bytes); + auto first_relay_stream_overlay_blob_cb = + CreateCircularBuffer(program, first_relay_core, relay_stream_overlay_blob_buffer_cb_config); + auto second_relay_stream_overlay_blob_cb = + CreateCircularBuffer(program, second_relay_core, relay_stream_overlay_blob_buffer_cb_config); + + // Sender/Receiver CBs for pulling in/pushing out stimulus data taht we can output compare + constexpr uint32_t cb_index = CB::c_in0; + const uint32_t cb_size = page_size_plus_header * read_write_cb_num_pages; + auto const& cb_config = tt_metal::CircularBufferConfig(cb_size, {{cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(cb_index, page_size_plus_header); + auto sender_cb = CreateCircularBuffer(program, sender_core, cb_config); + auto receiver_cb = CreateCircularBuffer(program, receiver_core, cb_config); + + // Stream Tile Header Buffers + constexpr uint32_t dummy_cb_index2 = CB::c_in2; + auto const& stream_tile_header_buffer_cb_config = + tt_metal::CircularBufferConfig( + stream_tile_header_buffer_size_bytes, {{dummy_cb_index2, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index2, stream_tile_header_buffer_size_bytes); + auto sender_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, sender_core, stream_tile_header_buffer_cb_config); + auto first_relay_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, first_relay_core, stream_tile_header_buffer_cb_config); + auto second_relay_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, second_relay_core, stream_tile_header_buffer_cb_config); + auto receiver_stream_tile_header_buffer_cb = + CreateCircularBuffer(program, receiver_core, stream_tile_header_buffer_cb_config); + + constexpr uint32_t dummy_cb_index = CB::c_in1; + auto const& sender_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(sender_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, sender_stream_buffer_size_bytes); + auto const& relay_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(relay_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, relay_stream_buffer_size_bytes); + auto const& receiver_stream_buffer_cb_config = + tt_metal::CircularBufferConfig(receiver_stream_buffer_size_bytes, {{dummy_cb_index, tt::DataFormat::Float16_b}}) + .set_page_size(dummy_cb_index, receiver_stream_buffer_size_bytes); + auto sender_stream_buffer_cb = CreateCircularBuffer(program, sender_core, sender_stream_buffer_cb_config); + auto first_relay_stream_buffer_cb = CreateCircularBuffer(program, first_relay_core, relay_stream_buffer_cb_config); + auto second_relay_stream_buffer_cb = + CreateCircularBuffer(program, second_relay_core, relay_stream_buffer_cb_config); + auto receiver_stream_buffer_cb = CreateCircularBuffer(program, receiver_core, receiver_stream_buffer_cb_config); + + program.allocate_circular_buffers(); + + uint32_t sender_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, sender_stream_buffer_cb)->address(); + uint32_t first_relay_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_buffer_cb)->address(); + uint32_t second_relay_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_buffer_cb)->address(); + uint32_t receiver_stream_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, receiver_stream_buffer_cb)->address(); + uint32_t sender_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, sender_stream_tile_header_buffer_cb)->address(); + uint32_t first_relay_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_tile_header_buffer_cb)->address(); + uint32_t second_relay_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_tile_header_buffer_cb)->address(); + uint32_t receiver_stream_tile_header_buffer_addr = + tt_metal::detail::GetCircularBuffer(program, receiver_stream_tile_header_buffer_cb)->address(); + uint32_t first_relay_stream_overlay_blob_addr = + tt_metal::detail::GetCircularBuffer(program, first_relay_stream_overlay_blob_cb)->address(); + uint32_t second_relay_stream_overlay_blob_addr = + tt_metal::detail::GetCircularBuffer(program, second_relay_stream_overlay_blob_cb)->address(); + + uint32_t receiver_cb_address = tt_metal::detail::GetCircularBuffer(program, receiver_cb)->address(); + log_trace(tt::LogTest, "receiver_cb_address: {}", receiver_cb_address); + + TT_ASSERT(sender_stream_buffer_size_bytes % page_size_plus_header == 0); + TT_ASSERT(relay_stream_buffer_size_bytes % page_size_plus_header == 0); + TT_ASSERT(receiver_stream_buffer_size_bytes % page_size_plus_header == 0); + log_trace( + tt::LogTest, "first_relay_stream_tile_header_buffer_addr: {}", first_relay_stream_tile_header_buffer_addr); + log_trace( + tt::LogTest, "second_relay_stream_tile_header_buffer_addr: {}", second_relay_stream_tile_header_buffer_addr); + stream_config_t sender_stream_config = stream_config_t{ + sender_stream_buffer_addr, + sender_stream_buffer_size_bytes, + sender_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t first_relay_stream_config = stream_config_t{ + first_relay_stream_buffer_addr, + relay_stream_buffer_size_bytes, + first_relay_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t second_relay_stream_config = stream_config_t{ + second_relay_stream_buffer_addr, + relay_stream_buffer_size_bytes, + second_relay_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + stream_config_t receiver_stream_config = stream_config_t{ + receiver_stream_buffer_addr, + receiver_stream_buffer_size_bytes, + receiver_stream_tile_header_buffer_addr, + tile_header_buffer_num_messages, + stream_tile_header_buffer_size_bytes}; + + uint32_t sender_receiver_semaphore_sender = CreateSemaphore(program, sender_core, 0, CoreType::WORKER); + uint32_t remote_sender_hang_toggle_addr = CreateSemaphore(program, sender_core, 0, CoreType::WORKER); + uint32_t first_relay_done_semaphore = CreateSemaphore(program, first_relay_core, 0, CoreType::WORKER); + uint32_t second_relay_done_semaphore = CreateSemaphore(program, second_relay_core, 0, CoreType::WORKER); + + uint32_t first_relay_remote_src_start_phase_addr = CreateSemaphore(program, first_relay_core, 0, CoreType::WORKER); + uint32_t second_relay_remote_src_start_phase_addr = + CreateSemaphore(program, second_relay_core, 0, CoreType::WORKER); + uint32_t receiver_remote_src_start_phase_addr = CreateSemaphore(program, receiver_core, 0, CoreType::WORKER); + + auto sender_noc_id = tt_metal::NOC::NOC_0; + auto relay_to_relay_data_noc_id = tt_metal::NOC::NOC_0; + // remote deceiver doesn't handshake properly with noc_1 + auto receiver_noc_id = tt_metal::NOC::NOC_0; + std::vector const& sender_reader_rt_args = + get_sender_reader_rt_args(device, input_buffer->address(), page_size_plus_header, num_messages, sub_sizes); + std::vector const& sender_writer_rt_args = get_sender_writer_rt_args( + device, + num_messages, + first_relay_done_semaphore, + first_relay_core, + sender_noc_id, + sender_stream_config, + first_relay_stream_config, + second_relay_core, + second_relay_done_semaphore, + sender_receiver_semaphore_sender, + first_relay_remote_src_start_phase_addr, + remote_sender_hang_toggle_addr); + + log_trace(tt::LogTest, "first_relay_stream_config"); + log_trace(tt::LogTest, "\tfirst_relay_stream_config.buffer_addr: {}", first_relay_stream_config.buffer_addr); + log_trace(tt::LogTest, "\tfirst_relay_stream_config.buffer_size: {}", first_relay_stream_config.buffer_size); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_buffer_addr: {}", + first_relay_stream_config.tile_header_buffer_addr); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_num_msgs: {}", + first_relay_stream_config.tile_header_num_msgs); + log_trace( + tt::LogTest, + "\tfirst_relay_stream_config.tile_header_buffer_size: {}", + first_relay_stream_config.tile_header_buffer_size); + log_trace(tt::LogTest, "second_relay_stream_config"); + log_trace(tt::LogTest, "\tsecond_relay_stream_config.buffer_addr: {}", second_relay_stream_config.buffer_addr); + log_trace(tt::LogTest, "\tsecond_relay_stream_config.buffer_size: {}", second_relay_stream_config.buffer_size); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_buffer_addr: {}", + second_relay_stream_config.tile_header_buffer_addr); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_num_msgs: {}", + second_relay_stream_config.tile_header_num_msgs); + log_trace( + tt::LogTest, + "\tsecond_relay_stream_config.tile_header_buffer_size: {}", + second_relay_stream_config.tile_header_buffer_size); + + // Need to figure out the noc IDs between the first and second relay. Also double check the + std::vector const first_relay_rt_args = get_relay_rt_args( + device, + first_relay_stream_overlay_blob_addr, + first_relay_done_semaphore, + sender_core, + second_relay_core, + sender_noc_id, + relay_to_relay_data_noc_id, + /*sender_stream_config,*/ first_relay_stream_config, + second_relay_stream_config, + first_relay_remote_src_start_phase_addr, + second_relay_remote_src_start_phase_addr, + true); + std::vector const second_relay_rt_args = get_relay_rt_args( + device, + second_relay_stream_overlay_blob_addr, + second_relay_done_semaphore, + first_relay_core, + receiver_core, + relay_to_relay_data_noc_id, + receiver_noc_id, + /*first_relay_stream_config,*/ second_relay_stream_config, + receiver_stream_config, + second_relay_remote_src_start_phase_addr, + receiver_remote_src_start_phase_addr, + false); + + std::vector const& receiver_reader_rt_args = get_receiver_reader_rt_args( + device, + num_messages, + second_relay_done_semaphore, + second_relay_core, + receiver_noc_id, + second_relay_stream_config, + receiver_stream_config, + first_relay_core, + first_relay_done_semaphore, + sender_core, + sender_receiver_semaphore_sender, + receiver_remote_src_start_phase_addr); + std::vector const& receiver_writer_rt_args = + get_receiver_writer_rt_args(device, output_buffer->address(), page_size_plus_header, num_messages); + + auto sender_reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender_reader.cpp", + sender_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::NOC_0, + .compile_args = {tile_header_size, static_cast(enable_page_size_variations ? 1 : 0)}}); + auto sender_writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_sender.cpp", + sender_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::NOC_1, // to keep noc coords simple (no calculating noc1 coords) + .compile_args = {}}); + + auto first_relay_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp", + first_relay_core, + tt_metal::DataMovementConfig{.noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + + auto second_relay_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay.cpp", + second_relay_core, + tt_metal::DataMovementConfig{.noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + + auto receiver_reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver.cpp", + receiver_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::NOC_0, .compile_args = {}}); + auto receiver_writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/streams/stream_relay_remote_receiver_writer.cpp", + receiver_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::NOC_1, // to keep noc coords simple (no calculating noc1 coords) + .compile_args = {tile_header_size}}); + + log_trace(tt::LogTest, "sender_reader_rt_args: "); + for (auto const& arg : sender_reader_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, sender_reader_kernel, sender_core, sender_reader_rt_args); + + log_trace(tt::LogTest, "sender_writer_rt_args: "); + for (auto const& arg : sender_writer_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, sender_writer_kernel, sender_core, sender_writer_rt_args); + + log_trace(tt::LogTest, "first_relay_rt_args: "); + for (auto const& arg : first_relay_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, first_relay_kernel, first_relay_core, first_relay_rt_args); + + log_trace(tt::LogTest, "second_relay_rt_args: "); + for (auto const& arg : second_relay_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, second_relay_kernel, second_relay_core, second_relay_rt_args); + + log_trace(tt::LogTest, "receiver_reader_rt_args: "); + for (auto const& arg : receiver_reader_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, receiver_reader_kernel, receiver_core, receiver_reader_rt_args); + + log_trace(tt::LogTest, "receiver_writer_rt_args: "); + for (auto const& arg : receiver_writer_rt_args) { + log_trace(tt::LogTest, "\t{}", arg); + } + tt_metal::SetRuntimeArgs(program, receiver_writer_kernel, receiver_core, receiver_writer_rt_args); + + tt::tt_metal::detail::CompileProgram(device, program); + for (std::size_t i = 0; i < num_loop_iterations; i++) { + log_debug(tt::LogTest, "Enqueing Program"); + tt_metal::EnqueueProgram(device->command_queue(), program, true); + log_debug(tt::LogTest, "Calling Finish"); + tt_metal::Finish(device->command_queue()); + if (i == 0) { + log_debug(tt::LogTest, "Reading Output Buffer"); + tt_metal::EnqueueReadBuffer(device->command_queue(), output_buffer, outputs, true); + } + } + + log_debug(tt::LogTest, "outputs.size(): {}", outputs.size()); + log_debug(tt::LogTest, "inputs.size(): {}", inputs.size()); + log_debug(tt::LogTest, "Comparing Outputs"); + TT_ASSERT(inputs.size() == outputs.size()); + if (enable_page_size_variations) { + uint32_t page_size_words = page_size / sizeof(uint32_t); + bool matches = true; + std::size_t size = outputs.size(); + uint32_t sub_size_i = 0; + uint32_t page_idx = 0; + for (auto i = 0; i < size; i += page_size_words) { + std::size_t n_elems = page_size_words - (sub_sizes.at(sub_size_i) * noc_word_size / sizeof(uint32_t)); + sub_size_i = (sub_size_i + 1) % num_sizes; + bool printed_page_info = false; + for (auto ii = 0; ii < n_elems; ii++) { + bool match = outputs.at(i + ii) == inputs.at(i + ii); + if (!match) { + if (!printed_page_info) { + printed_page_info = true; + log_error(tt::LogTest, "Output Mismatch"); + } + log_trace( + tt::LogTest, + "Mismatch at index {}: {} (expected) != {} (actual)", + i + ii, + inputs.at(i + ii), + outputs.at(i + ii)); + matches = false; + } + } + page_idx++; + } + TT_ASSERT(matches); + } else { + bool matches = true; + bool printed = false; + TT_ASSERT(inputs.size() == outputs.size()); + for (std::size_t i = 0; i < inputs.size(); i++) { + if (inputs.at(i) != outputs.at(i)) { + if (!printed) { + log_error(tt::LogTest, "Output Mismatch"); + printed = true; + } + matches = false; + log_trace( + tt::LogTest, "Mismatch at index {}: {} (expected) != {} (actual)", i, inputs.at(i), outputs.at(i)); + } + } + TT_ASSERT(matches); + } +} + +} // namespace tt_metal + +} // namespace tt + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreams) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 10; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsSmallPackets) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 10; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 128; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 5, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsLoopingShort) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 50; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingRandomShort) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + // if (num_devices != 8) { + // log_info(tt::LogTest, "Need at least 2 devices to run this test"); + // return; + // } + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 500; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + for (std::size_t i = 0; i < num_loop_iterations; i++) { + std::array sub_sizes = {}; + for (auto i = 0; i < num_sizes; i++) { + sub_sizes.at(i) = std::rand() % (page_size / noc_word_size); + EXPECT_TRUE(sub_sizes.at(i) < (page_size / noc_word_size)); + } + std::vector programs; + log_info(tt::LogTest, "Iteration: {}", i); + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + 1); + } + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingLong) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + // if (num_devices != 8) { + // log_info(tt::LogTest, "Need at least 2 devices to run this test"); + // return; + // } + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + std::srand(0); + + uint32_t num_loop_iterations = 1'000; + uint32_t num_messages_to_send = 1'000'000; + uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; + uint32_t relay_stream_buffer_size_bytes = 16 * 1024; + uint32_t tile_header_buffer_num_messages = 1024; + uint32_t page_size = 4096; + uint32_t enable_variable_sized_messages = 1; + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto relay_stream_spec = + tt::tt_metal::stream_builder_spec_t{relay_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{tx_rx_stream_buffer_size_bytes, tile_header_buffer_num_messages}; + + std::array sub_sizes = std::array{0, 3, 4, 7, 0, 2, 10, 1}; + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages_to_send, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + + return; +} + +// Too long to run in post commit and these kernels are currently only live in these unit tests anyways +// so we just enable a couple of the unit tests to ensure nobody accidentally introduces compile errors +// or anything like that +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsSweep) { + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (arch == tt::ARCH::GRAYSKULL) { + log_info(tt::LogTest, "Test must be run on WH"); + return; + } + + // Create array of size `num_sizes` of random integers using c++ random + std::array sub_sizes_global = {}; + std::srand(0); + for (auto i = 0; i < num_sizes; i++) { + sub_sizes_global.at(i) = std::rand(); + } + + uint32_t num_loop_iterations = 10; + std::vector message_counts = {1'000'000}; + std::vector fw_stream_buffer_sizes = {2 * 1024, 8 * 1024, 16 * 1024, 32 * 1024}; + std::vector relay_stream_buffer_sizes = {8 * 1024, 16 * 1024, 24 * 1024}; + std::vector phase_message_counts = { + // 32, // Hangs on handshake on phase range wrap, or 25th run, whichever comes first + // 64, // Hangs on handshake on phase range wrap, or 25th run, whichever comes first + 128, // works with 16KB buffer + 256, // works with 16KB buffer + 1024 // works with 16KB buffer + }; + // std::vector page_size = {2048, 4096}; + std::vector page_size = {4096}; + for (auto num_messages : message_counts) { + for (auto fw_stream_buffer_size : fw_stream_buffer_sizes) { + for (auto relay_stream_buffer_size : relay_stream_buffer_sizes) { + // auto fw_stream_buffer_size = relay_stream_buffer_size; + for (auto tile_header_buffer_num_messages : phase_message_counts) { + for (auto page_size : page_size) { + if (page_size > fw_stream_buffer_size) { + continue; + } + if (page_size > relay_stream_buffer_size) { + continue; + } + uint32_t enable_variable_sized_messages = 1; + + log_info( + tt::LogTest, + "num_messages: {}, fw_stream_buffer_size: {}, relay_stream_buffer_size: {}, " + "tile_header_buffer_num_messages: {}, page_size: {}, enable_variable_sized_messages: {}", + num_messages, + fw_stream_buffer_size, + relay_stream_buffer_size, + tile_header_buffer_num_messages, + page_size, + enable_variable_sized_messages); + + auto sender_stream_spec = + tt::tt_metal::stream_builder_spec_t{fw_stream_buffer_size, tile_header_buffer_num_messages}; + auto relay_stream_spec = tt::tt_metal::stream_builder_spec_t{ + relay_stream_buffer_size, tile_header_buffer_num_messages}; + auto receiver_stream_spec = + tt::tt_metal::stream_builder_spec_t{fw_stream_buffer_size, tile_header_buffer_num_messages}; + + std::array sub_sizes = {}; + for (auto i = 0; i < num_sizes; i++) { + sub_sizes.at(i) = sub_sizes_global.at(i) % (page_size / noc_word_size); + EXPECT_TRUE(sub_sizes.at(i) < (page_size / noc_word_size)); + } + + std::vector programs; + tt::tt_metal::build_and_run_autonomous_stream_test( + programs, + {device_}, + num_messages, + page_size, + tile_header_buffer_num_messages, + sender_stream_spec, + relay_stream_spec, + receiver_stream_spec, + enable_variable_sized_messages == 1, + sub_sizes, + num_loop_iterations); + } + } + } + } + } + + return; +} diff --git a/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h index 3a70066d9af..4ba2e33be1c 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_overlay_parameters.h @@ -414,6 +414,7 @@ // Set when stream is in data forwarding state. #define MSG_FWD_ONGOING (WAIT_PREV_PHASE_DATA_FLUSH+WAIT_PREV_PHASE_DATA_FLUSH_WIDTH) #define MSG_FWD_ONGOING_WIDTH 1 +// 0 is idle. 1/2 is auto cfg. 3 is waiting for phase advance. 4 is waiting for data send. 5 is phase active #define STREAM_CURR_STATE (MSG_FWD_ONGOING+MSG_FWD_ONGOING_WIDTH) #define STREAM_CURR_STATE_WIDTH 4