From 00cca1f3553dbcd6696da0cc25d966b2d90e4dc9 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Thu, 23 May 2024 22:49:23 +0000 Subject: [PATCH] #5562: add initial reduce scatter implementation (experimental) TODO: add beautiful description about - What is reduce scatter - What are known issues/limitations (what's supported in other words) - New features (EDM) - New infrastructure created for this op (especially infra portable to allgather and other CCL ops) - Description of t-streaming --- tests/scripts/common.py | 6 +- tests/scripts/run_tt_eager.py | 12 +- tests/tt_eager/CMakeLists.txt | 1 + tests/tt_eager/module.mk | 1 + tests/tt_eager/ops/ccl/test_ccl_helpers.cpp | 303 ++++++ .../unit_testing/misc/test_reduce_scatter.py | 193 ++++ .../test_ethernet_hop_latencies_no_edm.cpp | 4 +- tt_eager/tt_dnn/op_library/CMakeLists.txt | 2 + .../op_library/all_gather/all_gather_op.hpp | 92 +- .../dataflow/worker_ring_gather_utils.hpp | 237 ++++- .../multi_core/all_gather_op_multi_core.cpp | 39 +- tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp | 99 +- tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp | 563 +++++++++- .../ccl/ccl_host_datastructures.hpp | 282 ++--- .../ccl/edm/erisc_async_datamover.hpp | 202 ++-- .../op_library/ccl/edm/erisc_datamover.cpp | 40 +- .../ccl/kernel_common/worker_edm_utils.hpp | 22 +- .../host/reduce_scatter_full_worker_grid.cpp | 982 ++++++++++++++++++ ...interleaved_ring_reduce_scatter_reader.cpp | 346 ++++++ ...interleaved_ring_reduce_scatter_sender.cpp | 150 +++ .../ccl/reduce_scatter/reduce_scatter_op.cpp | 133 +++ .../ccl/reduce_scatter/reduce_scatter_op.hpp | 59 ++ .../hetergeneous_data_structs.hpp | 148 +-- .../csrc/tt_lib_bindings_tensor_dm_ops.cpp | 26 +- 24 files changed, 3385 insertions(+), 557 deletions(-) create mode 100644 tests/tt_eager/ops/ccl/test_ccl_helpers.cpp create mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter.py create mode 100644 tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp create mode 100644 tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp create mode 100644 tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp create mode 100644 tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp create mode 100644 tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp diff --git a/tests/scripts/common.py b/tests/scripts/common.py index 151bce24b2b..e5165275e47 100644 --- a/tests/scripts/common.py +++ b/tests/scripts/common.py @@ -14,7 +14,7 @@ from loguru import logger -from models.utility_functions import is_wormhole_b0 +from models.utility_functions import is_wormhole_b0, is_grayskull class TestSuiteType(Enum): @@ -33,6 +33,10 @@ def void_for_whb0(x): return (not is_wormhole_b0()) and x or None +def void_for_gs(x): + return (not is_grayskull()) and x or None + + def filter_empty(fn): @wraps(fn) def __filter_empty(): diff --git a/tests/scripts/run_tt_eager.py b/tests/scripts/run_tt_eager.py index f9a58d5646a..b3968d39eef 100644 --- a/tests/scripts/run_tt_eager.py +++ b/tests/scripts/run_tt_eager.py @@ -21,6 +21,7 @@ get_git_home_dir_str, filter_empty, void_for_whb0, + void_for_gs, ) from tests.scripts.cmdline_args import ( get_tt_metal_arguments_from_cmdline_args, @@ -28,11 +29,14 @@ ) TT_EAGER_COMMON_TEST_ENTRIES = ( - TestEntry("tt_eager/tests/ops/ccl/test_all_gather_utils", "ops/ccl/test_all_gather_utils"), - TestEntry( - "tt_eager/tests/ops/ccl/test_all_gather_sharded_indexing_helpers", - "ops/ccl/test_all_gather_sharded_indexing_helpers", + void_for_gs(TestEntry("tt_eager/tests/ops/ccl/test_all_gather_utils", "ops/ccl/test_all_gather_utils")), + void_for_gs( + TestEntry( + "tt_eager/tests/ops/ccl/test_all_gather_sharded_indexing_helpers", + "ops/ccl/test_all_gather_sharded_indexing_helpers", + ) ), + void_for_gs(TestEntry("tt_eager/tests/ops/ccl/test_ccl_helpers", "ops/ccl/test_ccl_helpers")), TestEntry("tt_eager/tests/ops/test_eltwise_binary_op", "ops/test_eltwise_binary_op"), TestEntry("tt_eager/tests/ops/test_bcast_op", "ops/test_bcast_op"), TestEntry("tt_eager/tests/ops/test_reduce_op", "ops/test_reduce_op"), diff --git a/tests/tt_eager/CMakeLists.txt b/tests/tt_eager/CMakeLists.txt index 6b98d658818..67fd3adf161 100644 --- a/tests/tt_eager/CMakeLists.txt +++ b/tests/tt_eager/CMakeLists.txt @@ -5,6 +5,7 @@ target_link_libraries(test_eager_common_libs INTERFACE tt_eager test_common_libs set(TT_EAGER_TESTS_OPS ops/ccl/test_all_gather_utils ops/ccl/test_all_gather_sharded_indexing_helpers + ops/ccl/test_ccl_helpers ops/test_average_pool ops/test_eltwise_binary_op ops/test_eltwise_unary_op diff --git a/tests/tt_eager/module.mk b/tests/tt_eager/module.mk index efc23ea28e8..00111eae1bd 100644 --- a/tests/tt_eager/module.mk +++ b/tests/tt_eager/module.mk @@ -2,6 +2,7 @@ TT_EAGER_TESTS += \ tests/tt_eager/ops/ccl/test_all_gather_utils \ tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers \ + tests/tt_eager/ops/ccl/test_ccl_helpers \ tests/tt_eager/ops/test_average_pool \ tests/tt_eager/ops/test_eltwise_binary_op \ tests/tt_eager/ops/test_eltwise_unary_op \ diff --git a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp new file mode 100644 index 00000000000..a520543553b --- /dev/null +++ b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp @@ -0,0 +1,303 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "device/tt_xy_pair.h" +#include "gtest/gtest.h" +#include "tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" + +TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingMode) { + std::size_t num_channels = 4; + uint32_t page_size = 2048; + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + ccl::EriscDataMoverTerminationMode termination_mode = ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, buffer_sharing_mode, termination_mode); + std::vector worker_semaphore_addresses = { + 0x1000, + 0x1010, + 0x1020, + 0x1030, + }; + std::vector message_counts = {256, 512, 24, 1}; + std::vector> const& worker_coords = { + {ccl::WorkerXY{1, 1}, ccl::WorkerXY{2, 1}}, + {ccl::WorkerXY{3, 1}}, + {ccl::WorkerXY{4, 1}, ccl::WorkerXY{5, 1}, ccl::WorkerXY{6, 1}}, + {ccl::WorkerXY{1, 2}}, + }; + std::vector is_sender_channel{true, false, true, false}; + + std::vector channel_buffer_interfaces; + channel_buffer_interfaces.reserve(num_channels); + for (std::size_t i = 0; i < num_channels; i++) { + ccl::EriscDatamoverBuilder::ChannelBufferInterface const& channel_buffer_interface = + (is_sender_channel[i]) + ? edm_builder.add_sender_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i]) + : edm_builder.add_receiver_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i]); + channel_buffer_interfaces.push_back(channel_buffer_interface); + ASSERT_TRUE(channel_buffer_interface.eth_buffer_l1_address > 0); + ASSERT_TRUE(channel_buffer_interface.eth_semaphore_l1_address > 0); + } + + auto const& active_channels = edm_builder.get_active_channels(); + ASSERT_EQ(active_channels.size(), num_channels); + for (std::size_t i = 0; i < active_channels.size(); ++i) { + ASSERT_EQ(active_channels[i].channel, i); + ASSERT_EQ(active_channels[i].is_sender, is_sender_channel.at(i)); + ASSERT_EQ(active_channels[i].worker_coords, worker_coords.at(i)); + ASSERT_TRUE(active_channels[i].worker_semaphore_address == worker_semaphore_addresses.at(i)); + ASSERT_TRUE(active_channels[i].num_eth_messages_to_forward == message_counts.at(i)); + } +} + +TEST(CclHelpers, EriscDatamoverConfig_GetEdmHandshakeAddress_GT_0) { + for (std::size_t i = 0; i < 8; i++) { + ASSERT_TRUE(ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0); + } +} +TEST(CclHelpers, EriscDatamoverConfig_GetSemaphoresBaseAddress_GT_0) { + for (std::size_t i = 0; i < 8; i++) { + ASSERT_TRUE( + ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >= + (ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ccl::EriscDatamoverConfig::handshake_location_size + + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + } +} + +TEST(CclHelpers, EriscDatamoverConfig_GetBuffersBaseAddress_GT_0) { + for (std::size_t i = 0; i < 8; i++) { + ASSERT_TRUE( + ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= + (ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ccl::EriscDatamoverConfig::handshake_location_size + + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + } +} + +TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) { + for (std::size_t i = 0; i < 8; i++) { + ASSERT_TRUE( + ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= + (ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ccl::EriscDatamoverConfig::handshake_location_size + + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + } +} + +///////////////////////////////////////// +// TEST AdvanceSliceRowMajor +///////////////////////////////////////// +// x_y x_y x_y +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { + const auto expected = tt::tt_metal::ccl::coord_t(1, 0); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 1); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { + const auto expected = tt::tt_metal::ccl::coord_t(0, 1); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 1); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { + const auto expected = tt::tt_metal::ccl::coord_t(1, 1); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 1); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { + const auto expected = tt::tt_metal::ccl::coord_t(0, 1); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 2); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { + const auto expected = tt::tt_metal::ccl::coord_t(1, 1); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 2); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} + +// Test that we successfully go out of bounds on the last iteration +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 2); + ASSERT_TRUE(result.x >= 2 || result.y >= 2); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 1}, {1, 1}, {2, 2}, 2); + ASSERT_TRUE(result.x >= 2 || result.y >= 2); +} + +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) { + const auto expected = tt::tt_metal::ccl::coord_t(1, 1); + auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 3); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) { + const auto expected = tt::tt_metal::ccl::coord_t(1, 1); + const auto outer_shape = tt::tt_metal::ccl::coord_t(2, 2); + const auto inner_offset = tt::tt_metal::ccl::coord_t(1, 1); + const auto inner_shape = tt::tt_metal::ccl::coord_t(1, 1); + const uint32_t num_parallel_workers = 3; + auto const& result = + tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); + ASSERT_TRUE(result.x >= outer_shape.x || result.y >= outer_shape.y); +} +TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_0__InnerShape_24_0__OuterShape_32_4__NumActiveSlices_4) { + const auto expected = tt::tt_metal::ccl::coord_t(24, 2); + const auto outer_shape = tt::tt_metal::ccl::coord_t(32, 4); + const auto inner_offset = tt::tt_metal::ccl::coord_t(24, 0); + const auto inner_shape = tt::tt_metal::ccl::coord_t(24, 1); + const uint32_t num_parallel_workers = 4; + auto const& result = + tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); + ASSERT_EQ(result.x, expected.x); + ASSERT_EQ(result.y, expected.y); +} + +///////////////////////////////////////// +// Test InterleavedRingReduceScatterTensorSlicer +///////////////////////////////////////// +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_AllWorkersSameRow) { + auto worker_slice_shapes = std::vector(4, {2, 2}); + tt_xy_pair tensor_slice_shape = {8, 4}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4, 0)); + ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6, 0)); +} +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToNextRowAligned) { + auto worker_slice_shapes = std::vector(4, {2, 2}); + tt_xy_pair tensor_slice_shape = {6, 4}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4, 0)); + ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(0, 2)); +} +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToNextRowMisaligned) { + { + auto worker_slice_shapes = std::vector(4, {2, 2}); + tt_xy_pair tensor_slice_shape = {5, 4}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4, 0)); + ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(0, 2)); + } +} +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowAligned) { + auto worker_slice_shapes = std::vector(8, {2, 2}); + tt_xy_pair tensor_slice_shape = {10, 4}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4, 0)); + ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6, 0)); + ASSERT_EQ(worker_slice_offsets.at(4), tt_xy_pair(8, 0)); + ASSERT_EQ(worker_slice_offsets.at(5), tt_xy_pair(0, 2)); + ASSERT_EQ(worker_slice_offsets.at(6), tt_xy_pair(2, 2)); + ASSERT_EQ(worker_slice_offsets.at(7), tt_xy_pair(4, 2)); +} + +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowMisaligned) { + auto worker_slice_shapes = std::vector(8, {2, 2}); + tt_xy_pair tensor_slice_shape = {9, 4}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4, 0)); + ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6, 0)); + ASSERT_EQ(worker_slice_offsets.at(4), tt_xy_pair(8, 0)); + ASSERT_EQ(worker_slice_offsets.at(5), tt_xy_pair(0, 2)); + ASSERT_EQ(worker_slice_offsets.at(6), tt_xy_pair(2, 2)); + ASSERT_EQ(worker_slice_offsets.at(7), tt_xy_pair(4, 2)); +} + +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowAligned) { + auto worker_slice_shapes = std::vector(3, {4, 4}); + tt_xy_pair tensor_slice_shape = {4, 12}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0, 4)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(0, 8)); +} + +TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowMisaligned) { + auto worker_slice_shapes = std::vector(3, {4, 3}); + tt_xy_pair tensor_slice_shape = {3, 12}; + auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets( + worker_slice_shapes, tensor_slice_shape); + ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); + ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0, 3)); + ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(0, 6)); +} + +TEST( + Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, + InnerOffset_0_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { + auto worker_slice = ccl::InterleavedTensorWorkerSlice( + tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result + tt_xy_pair(32, 4), + tt_xy_pair(24, 1), + tt_xy_pair(0, 0)); + uint32_t num_workers = 4; + auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers); + auto expected = 2; + ASSERT_EQ(num_iterations, expected); +} + +TEST( + Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, + InnerOffset_24_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { + auto worker_slice = ccl::InterleavedTensorWorkerSlice( + tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result + tt_xy_pair(32, 4), + tt_xy_pair(24, 1), + tt_xy_pair(24, 0)); + uint32_t num_workers = 4; + auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers); + auto expected = 2; + ASSERT_EQ(num_iterations, expected); +} + +TEST( + Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, + InnerOffset_0_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { + auto worker_slice = ccl::InterleavedTensorWorkerSlice( + tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result + tt_xy_pair(32, 4), + tt_xy_pair(24, 1), + tt_xy_pair(0, 1)); + uint32_t num_workers = 4; + auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers); + auto expected = 2; + ASSERT_EQ(num_iterations, expected); +} + +TEST( + Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, + InnerOffset_24_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { + auto worker_slice = ccl::InterleavedTensorWorkerSlice( + tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result + tt_xy_pair(32, 4), + tt_xy_pair(24, 1), + tt_xy_pair(24, 0)); + uint32_t num_workers = 4; + auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers); + auto expected = 2; + ASSERT_EQ(num_iterations, expected); +} diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter.py new file mode 100644 index 00000000000..0ee9cceb667 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import tt_lib as ttl +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +import itertools + + +def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): + if scatter_dim != 3: + return True, "Only support for scatter_dim=3 is tested so far" + + return False, "" + + +def run_reduce_scatter_test( + all_devices, + num_devices, + per_chip_output_shape, + scatter_dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters=1, +): + if len(all_devices) != 8: + pytest.skip("Not T3000!") + + # if num_devices != 4: + # pytest.skip("Only testing for 4 devices") + + debug = False + logger.info(f"num_devices: {num_devices}") + logger.info(f"per_chip_output_shape: {per_chip_output_shape}") + logger.info(f"scatter_dim: {scatter_dim}") + logger.info(f"num_links: {num_links}") + logger.info(f"math_op: {math_op}") + logger.info(f"input_dtype: {input_dtype}") + logger.info(f"layout: {layout}") + logger.info(f"mem_config: {mem_config}") + + (is_known_failure, message) = is_unsupported_case( + per_chip_output_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout + ) + if is_known_failure: + pytest.skip(f"Skipping unsupported case {message}.") + devices = get_devices_for_t3000(all_devices, num_devices) + + # Generate input tensors + canonical_input_shape = per_chip_output_shape.copy() + canonical_input_shape[scatter_dim] *= num_devices + logger.info(f"per_chip_output_shape: {per_chip_output_shape}") + logger.info(f"canonical_input_tensor_shape: {canonical_input_shape}") + tt_input_tensors = [] + + numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] + input_tensors = [ + # torch.rand(canonical_input_shape).bfloat16() if not debug else torch.arange(numel).reshape(canonical_input_shape).bfloat16() + torch.rand(canonical_input_shape).bfloat16() if not debug else torch.ones(canonical_input_shape).bfloat16() + for _ in range(num_devices) + ] + if debug: + input_tensors[-1] = torch.arange(numel).reshape(canonical_input_shape).bfloat16() + for i, canonical_input_tensor in enumerate(input_tensors): + logger.info(f"input_tensor[{i}].shape: {canonical_input_tensor.data.shape}") + logger.info(f"input_tensor[{i}]: {canonical_input_tensor.data}") + tt_input_tensors.append( + ttl.tensor.Tensor(canonical_input_tensor, input_dtype).to(layout).to(devices[i], mem_config) + ) + + # Run the op + # for i in range(num_iters): + tt_out_tensors = ttl.tensor.reduce_scatter( + tt_input_tensors, + scatter_split_dim=scatter_dim, + reduce_op=math_op, + num_links=num_links, + output_mem_config=mem_config, + ) + + for d in devices: + ttl.device.Synchronize(d) + logger.info(f"Done iteration {i}") + + # Compute golden + # TODO: Make it model how reduce scatter actually works for numerical correctness/ordering + golden_canonical_out_tensor = torch.zeros(canonical_input_shape).bfloat16() + logger.info(f"golden_canonical_out_tensor shape: {golden_canonical_out_tensor.shape}") + logger.info(f"canonical_input_shape: {canonical_input_shape}") + for i, t in enumerate(input_tensors): + logger.info(f"t shape: {t.shape}") + golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() + logger.info(f"golden_canonical_out_tensor[{i}]: {golden_canonical_out_tensor.data}") + + logger.info(f"golden_canonical_out_tensor: {golden_canonical_out_tensor.data}") + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, scatter_dim) + + logger.info(f"Compare") + # Compare + assert len(golden_output_tensors) == len(tt_out_tensors) + mismatch = False + for i, t in enumerate(tt_out_tensors): + tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + logger.info(f"golden_output_tensors[i].shape: {golden_output_tensors[i].shape}") + logger.info(f"tt_output_tensor.shape: {tt_output_tensor.shape}") + eq, output = comp_pcc(tt_output_tensor, golden_output_tensors[i]) + mismatch = mismatch or not eq + if not eq: + logger.error(f"output mismatch for tensor {i}") + else: + logger.info(f"output match for tensor {i}") + assert not mismatch, f"{i} FAILED: {output}" + + +@pytest.mark.timeout(30) +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (4, 1), + (8, 1), + ], +) +@pytest.mark.parametrize( + "per_chip_output_shape, scatter_dim, layout", + [ + ([1, 1, 32, 32], 3, ttl.tensor.Layout.TILE), + ([1, 1, 32, 64], 3, ttl.tensor.Layout.TILE), + ([1, 1, 64, 64], 3, ttl.tensor.Layout.TILE), + ([1, 1, 32, 128], 3, ttl.tensor.Layout.TILE), + ([1, 1, 32, 256], 3, ttl.tensor.Layout.TILE), + # Hangs... for whatever reason. Seems like a noc sem inc from sender -> EDM gets lost + # somehow at some point + # ([1, 1, 32, 512], 3, ttl.tensor.Layout.TILE), + ([1, 1, 32, 1024], 3, ttl.tensor.Layout.TILE), + ([1, 1, 32, 2048], 3, ttl.tensor.Layout.TILE), + ([1, 1, 128, 1024], 3, ttl.tensor.Layout.TILE), + ([1, 1, 128, 8192], 3, ttl.tensor.Layout.TILE), + ([1, 1, 2048, 1024], 3, ttl.tensor.Layout.TILE), + ([1, 1, 2048, 8192], 3, ttl.tensor.Layout.TILE), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttl.tensor.DataType.BFLOAT16, + ttl.tensor.DataType.BFLOAT8_B, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.L1), + ], +) +@pytest.mark.parametrize("math_op", [ttl.tensor.ReduceOpMath.SUM]) +def test_reduce_scatter_post_commit( + all_devices, + num_devices, + per_chip_output_shape, + scatter_dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters=1, +): + run_reduce_scatter_test( + all_devices, + num_devices, + per_chip_output_shape, + scatter_dim, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters, + ) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index 36ce4e7079a..667c95e22a5 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -439,10 +439,10 @@ int main (int argc, char** argv) { return std::vector{all_devices[0], all_devices[1], all_devices[2], all_devices[3]}; case 8: - return std::vector{all_devices[0], all_devices[7], all_devices[6], all_devices[1], all_devices[2], all_devices[5], all_devices[4], all_devices[3]}; + return std::vector{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]}; case 12: // Does an extra loop through the inner ring - return std::vector{all_devices[0], all_devices[7], all_devices[6], all_devices[1], all_devices[2], all_devices[3], all_devices[0], all_devices[1], all_devices[2], all_devices[5], all_devices[4], all_devices[3]}; + return std::vector{all_devices[0], all_devices[4], all_devices[5], all_devices[1], all_devices[2], all_devices[3], all_devices[0], all_devices[1], all_devices[2], all_devices[6], all_devices[7], all_devices[3]}; default: TT_ASSERT("Unsupported hop_count"); diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index 2b22e61f40f..bfa0bcd559c 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -6,6 +6,8 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/layout_conversion/layout_conversion_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/multi_core/all_gather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/reduce_scatter_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/sharded_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/multi_core/sharded_op_multi_core.cpp diff --git a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp index 77cf69c9b96..32debd44f72 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp +++ b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp @@ -12,6 +12,7 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_dnn/op_library/ccl/ccl_common.hpp" #include "tt_dnn/op_library/run_operation.hpp" @@ -318,93 +319,8 @@ struct ShardedAllGatherConfig { }; -class AllGatherOpTensorConfig { - public: - static std::unique_ptr build_all_gather_tensor_config(Tensor const& tensor); - // static AllGatherOpTensorConfig *build_all_gather_tensor_config(Tensor const& tensor); - - AllGatherOpTensorConfig( - Tensor const& tensor - ) : - buffer_start_address(tensor.buffer()->address()), - df(tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) - {} - - virtual uint32_t get_page_size() const = 0; - virtual uint32_t get_unit_size() const = 0; - - uint32_t get_buffer_start_address() const { - return this->buffer_start_address; - } - - virtual ~AllGatherOpTensorConfig() {}; - - protected: - uint32_t buffer_start_address; - DataFormat df; -}; - -class AllGatherOpInterleavedTensorConfig final : public virtual AllGatherOpTensorConfig { - public: - AllGatherOpInterleavedTensorConfig(Tensor const& input_tensor) : - AllGatherOpTensorConfig(input_tensor) { - if (input_tensor.get_layout() == Layout::TILE) { - this->page_size = tt_metal::detail::TileSize(this->df); - } else { - this->page_size = input_tensor.buffer()->page_size(); - } - } - virtual uint32_t get_page_size() const override { - return this->page_size; - } - virtual uint32_t get_unit_size() const override { - return this->page_size; - } - - - private: - uint32_t page_size; - -}; - -class AllGatherOpShardedTensorConfig final : public virtual AllGatherOpTensorConfig { - public: - AllGatherOpShardedTensorConfig(Tensor const& tensor) : - AllGatherOpTensorConfig(tensor), - shard_spec(tensor.shard_spec().value()) { - if (tensor.get_layout() == Layout::TILE) { - this->page_size = tt_metal::detail::TileSize(this->df); - TT_ASSERT(this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) % (TILE_HEIGHT * TILE_WIDTH) == 0); - this->unit_size = (this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) / (TILE_HEIGHT * TILE_WIDTH)) * this->page_size; - } else { - this->page_size = tensor.get_legacy_shape()[-1] * tensor.element_size(); - this->unit_size = (this->page_size * this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1)) / tensor.shard_spec()->num_cores(); - } - } - - virtual uint32_t get_page_size() const override { - return this->page_size; - } - virtual uint32_t get_unit_size() const override { - return this->unit_size; - } - - uint32_t get_shard_size_in_bytes() const { - return this->get_unit_size(); - } - - ShardSpec const& get_shard_spec() const { - return this->shard_spec; - } - - private: - uint32_t page_size; - uint32_t unit_size; - ShardSpec const shard_spec; -}; struct ShardAddrGenArgGenerator { - using shard_cores_t = CoreRangeSet; ShardAddrGenArgGenerator(ccl::ShardAddrGenArgs const& args_struct) : @@ -497,7 +413,7 @@ struct InputTensorShardAddrGenArgGenerator final : public ShardAddrGenArgGenerat } InputTensorShardAddrGenArgGenerator( Device const* device, - AllGatherOpShardedTensorConfig *input_tensor_config, + ccl::CclOpShardedTensorConfig *input_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, @@ -693,8 +609,8 @@ struct OutputTensorShardAddrGenArgGenerator final : ShardAddrGenArgGenerator { OutputTensorShardAddrGenArgGenerator( AllGatherConfig const& all_gather_config, Device const* device, - AllGatherOpShardedTensorConfig *input_tensor_config, - AllGatherOpShardedTensorConfig *output_tensor_config, + ccl::CclOpShardedTensorConfig *input_tensor_config, + ccl::CclOpShardedTensorConfig *output_tensor_config, uint32_t ring_index, uint32_t ring_size, uint32_t num_workers, diff --git a/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp b/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp index 8f241fdeee7..52a8124d82c 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp @@ -4,19 +4,20 @@ #include "dataflow_api.h" #include "debug/assert.h" -#include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" using tt::tt_metal::ccl::ShardType; -using tt::tt_metal::ccl::WorkerXY; -using tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32; using tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16; +using tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32; +using tt::tt_metal::ccl::WorkerXY; // Only workers on local worker core, hence no uint64_t noc addresses template struct FullWorkerGridShardAddrGen { - FullWorkerGridShardAddrGen()=default; - FORCE_INLINE static void build_with_placement_new(FullWorkerGridShardAddrGen* placement_new_address, const uint32_t arg_index) { + FullWorkerGridShardAddrGen() = default; + FORCE_INLINE static void build_with_placement_new( + FullWorkerGridShardAddrGen* placement_new_address, const uint32_t arg_index) { tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs input_args; uint32_t curr_arg_index = arg_index; @@ -53,8 +54,7 @@ struct FullWorkerGridShardAddrGen { } FullWorkerGridShardAddrGen( - uint8_t num_args_consumed, - tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs const& input_args) : + uint8_t num_args_consumed, tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs const& input_args) : dest_cores(input_args.dest_cores), tile_size_in_bytes(input_args.tile_size_in_bytes), shards_start_address(input_args.shards_start_address), @@ -68,8 +68,7 @@ struct FullWorkerGridShardAddrGen { input_shard_num_tiles_y(input_args.input_shard_num_tiles_y), total_shards_x(input_args.total_shards_x), num_args_consumed(num_args_consumed), - is_clockwise(input_args.is_clockwise) - { + is_clockwise(input_args.is_clockwise) { ASSERT(input_shard_num_tiles_x > 0); ASSERT(input_shard_num_tiles_y > 0); ASSERT(total_shards_x > 0); @@ -77,13 +76,14 @@ struct FullWorkerGridShardAddrGen { ASSERT(curr_core_index < total_num_cores); if constexpr (SHARD_TYPE == ShardType::Width) { ASSERT(curr_shard < total_shards_x); - ASSERT(curr_tile_index = curr_shard_tile_x * input_shard_num_tiles_x + (curr_shard_tile_y * total_shards_x * input_shard_num_tiles_x)); + ASSERT( + curr_tile_index = curr_shard_tile_x * input_shard_num_tiles_x + + (curr_shard_tile_y * total_shards_x * input_shard_num_tiles_x)); } else { - ASSERT(false); // Not implemented yet + ASSERT(false); // Not implemented yet } } - [[nodiscard]] FORCE_INLINE WorkerXY get_next_noc_xy_core() const { ASSERT(this->curr_core_index < this->total_num_cores); return this->dest_cores[this->curr_core_index]; @@ -101,7 +101,7 @@ struct FullWorkerGridShardAddrGen { } FORCE_INLINE void advance() { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance ( + tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance( this->curr_shard_tile_x, this->curr_shard_tile_y, this->curr_tile_index, @@ -135,9 +135,7 @@ struct FullWorkerGridShardAddrGen { } } - [[nodiscard]] FORCE_INLINE uint32_t get_tile_size_in_bytes() const { - return this->tile_size_in_bytes; - } + [[nodiscard]] FORCE_INLINE uint32_t get_tile_size_in_bytes() const { return this->tile_size_in_bytes; } [[nodiscard]] FORCE_INLINE uint32_t get_shard_tile_row_size_in_bytes() const { return this->input_shard_num_tiles_x * this->tile_size_in_bytes; @@ -161,11 +159,9 @@ struct FullWorkerGridShardAddrGen { bool is_clockwise; }; - - template struct ShardAddrGen final { - ShardAddrGen()=default; + ShardAddrGen() = default; FORCE_INLINE static void build_with_placement_new(ShardAddrGen* placement_new_address, const uint32_t arg_index) { tt::tt_metal::ccl::ShardAddrGenArgs input_args; @@ -194,17 +190,13 @@ struct ShardAddrGen final { ASSERT(curr_arg_index - arg_index == input_args.get_expected_num_args()); - new (placement_new_address) ShardAddrGen( - curr_arg_index - arg_index, - input_args); + new (placement_new_address) ShardAddrGen(curr_arg_index - arg_index, input_args); } // This addr gen will dump all tiles from an input shard contiguously, and dump the // next input shard contiguously after it. This approach depends on a follow up // - ShardAddrGen( - uint8_t num_args_consumed, - tt::tt_metal::ccl::ShardAddrGenArgs const& input_args) : + ShardAddrGen(uint8_t num_args_consumed, tt::tt_metal::ccl::ShardAddrGenArgs const& input_args) : dest_cores(input_args.dest_cores), shards_start_address(input_args.shards_start_address), shard_size_in_bytes(input_args.shard_size_in_bytes), @@ -218,14 +210,14 @@ struct ShardAddrGen final { num_dest_cores(input_args.num_dest_cores), num_args_consumed(num_args_consumed), - is_clockwise(input_args.is_clockwise) - { - ASSERT(this->contiguous_chunks_before_stride >= 1); - ASSERT(this->intra_core_stride_in_shards >= 1); - ASSERT(input_args.starting_chunk_into_shard <= this->total_chunks_per_core); - }; + is_clockwise(input_args.is_clockwise) { + ASSERT(this->contiguous_chunks_before_stride >= 1); + ASSERT(this->intra_core_stride_in_shards >= 1); + ASSERT(input_args.starting_chunk_into_shard <= this->total_chunks_per_core); + }; - static_assert(TYPE == ShardType::Width || TYPE == ShardType::Height || TYPE == ShardType::Block, "Invalid ShardType"); + static_assert( + TYPE == ShardType::Width || TYPE == ShardType::Height || TYPE == ShardType::Block, "Invalid ShardType"); // Clockwise vs counter clockwise only affects worker core traversal order (relative to canonical order). Since the // dest core list is a configurable list, we will, for now, require the host side kernel config code to produce the @@ -240,8 +232,7 @@ struct ShardAddrGen final { this->num_dest_cores, this->intra_core_stride_in_shards, this->contiguous_chunks_before_stride, - this->is_clockwise - ); + this->is_clockwise); } else { // Unsupported ASSERT(false); @@ -263,7 +254,8 @@ struct ShardAddrGen final { [[nodiscard]] FORCE_INLINE uint64_t get_next_noc_addr_and_advance() { if constexpr (TYPE == ShardType::Width) { WorkerXY dest_worker = this->get_next_noc_xy_core(); - uint32_t curr_address = this->shards_start_address + this->curr_core_chunk_index * this->shard_size_in_bytes; + uint32_t curr_address = + this->shards_start_address + this->curr_core_chunk_index * this->shard_size_in_bytes; ASSERT(this->shards_start_address <= curr_address); this->advance(); return get_noc_addr(dest_worker.x, dest_worker.y, curr_address); @@ -277,10 +269,8 @@ struct ShardAddrGen final { [[nodiscard]] FORCE_INLINE uint32_t get_shard_size_in_bytes() const { return this->shard_size_in_bytes; } [[nodiscard]] FORCE_INLINE uint32_t get_num_dest_cores() const { return this->num_dest_cores; } - [[nodiscard]] FORCE_INLINE uint32_t get_total_chunks_per_core() const { - return this->total_chunks_per_core; - } - [[nodiscard]] FORCE_INLINE uint32_t get_num_args_consumed() const { return this->num_args_consumed;} + [[nodiscard]] FORCE_INLINE uint32_t get_total_chunks_per_core() const { return this->total_chunks_per_core; } + [[nodiscard]] FORCE_INLINE uint32_t get_num_args_consumed() const { return this->num_args_consumed; } WorkerXY* dest_cores; uint32_t shards_start_address; @@ -298,7 +288,11 @@ struct ShardAddrGen final { template FORCE_INLINE void write_and_send_chunk_sharded( - const uint32_t& cb_id, ShardAddrGen& addr_gen, uint32_t const num_pages, uint64_t remote_eth_l1_write_addr, uint64_t eth_l1_sender_semaphore_addr) { + const uint32_t& cb_id, + ShardAddrGen& addr_gen, + uint32_t const num_pages, + uint64_t remote_eth_l1_write_addr, + uint64_t eth_l1_sender_semaphore_addr) { cb_wait_front(cb_id, num_pages); uint32_t l1_read_addr = get_read_ptr(cb_id); uint32_t num_pages_remaining = num_pages; @@ -306,7 +300,8 @@ FORCE_INLINE void write_and_send_chunk_sharded( noc_semaphore_inc(eth_l1_sender_semaphore_addr, 1); while (num_pages_remaining > 0) { uint64_t dest_worker_noc_addr = addr_gen.get_next_noc_addr(); - uint32_t num_shards_to_write = std::min(num_pages_remaining, addr_gen.contiguous_chunks_before_stride); + uint32_t num_shards_to_write = + std::min(num_pages_remaining, addr_gen.contiguous_chunks_before_stride); noc_async_write(l1_read_addr, dest_worker_noc_addr, num_shards_to_write * addr_gen.get_shard_size_in_bytes()); for (uint32_t i = 0; i < num_shards_to_write; i++) { addr_gen.advance(); @@ -318,7 +313,20 @@ FORCE_INLINE void write_and_send_chunk_sharded( cb_pop_front(cb_id, num_pages); } template -FORCE_INLINE void write_and_send_chunk(uint32_t& output_page_idx, uint32_t& col_idx, uint32_t& row_idx, const uint32_t& cb_id, const AddrGen& d, const uint32_t num_cols, const uint32_t num_rows, const uint32_t& col_offset, const uint32_t& row_offset, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr, uint64_t eth_l1_sender_semaphore_addr) { +FORCE_INLINE void write_and_send_chunk( + uint32_t& output_page_idx, + uint32_t& col_idx, + uint32_t& row_idx, + const uint32_t& cb_id, + const AddrGen& d, + const uint32_t num_cols, + const uint32_t num_rows, + const uint32_t& col_offset, + const uint32_t& row_offset, + const uint32_t& num_pages, + const uint32_t& page_size, + uint64_t remote_l1_write_addr, + uint64_t eth_l1_sender_semaphore_addr) { cb_wait_front(cb_id, num_pages); uint32_t l1_read_addr = get_read_ptr(cb_id); noc_async_write(l1_read_addr, remote_l1_write_addr, page_size * num_pages); @@ -374,11 +382,22 @@ FORCE_INLINE void write_chunk_sharded(const uint32_t& cb_id, ShardAddrGen& ad cb_pop_front(cb_id, num_pages); } template -FORCE_INLINE void write_chunk(uint32_t& output_page_idx, uint32_t& col_idx, uint32_t& row_idx, const uint32_t& cb_id, const AddrGen& d, const uint32_t& num_cols, const uint32_t& num_rows, const uint32_t& col_offset, const uint32_t& row_offset, const uint32_t& num_pages, const uint32_t& page_size) { +FORCE_INLINE void write_chunk( + uint32_t& output_page_idx, + uint32_t& col_idx, + uint32_t& row_idx, + const uint32_t& cb_id, + const AddrGen& d, + const uint32_t& num_cols, + const uint32_t& num_rows, + const uint32_t& col_offset, + const uint32_t& row_offset, + const uint32_t& num_pages, + const uint32_t& page_size) { cb_wait_front(cb_id, num_pages); uint32_t l1_read_addr = get_read_ptr(cb_id); for (uint32_t i = 0; i < num_pages; ++i) { - #ifdef RM_INTERLEAVED +#ifdef RM_INTERLEAVED uint64_t dst_noc_addr = get_noc_addr(output_page_idx, d); noc_async_write(l1_read_addr, dst_noc_addr, page_size); output_page_idx++; @@ -387,7 +406,7 @@ FORCE_INLINE void write_chunk(uint32_t& output_page_idx, uint32_t& col_idx, uint row_idx = 0; output_page_idx += row_offset; } - #elif defined TILE_INTERLEAVED +#elif defined TILE_INTERLEAVED noc_async_write_tile(output_page_idx, d, l1_read_addr); output_page_idx++; col_idx++; @@ -400,7 +419,7 @@ FORCE_INLINE void write_chunk(uint32_t& output_page_idx, uint32_t& col_idx, uint output_page_idx += row_offset; } } - #endif +#endif l1_read_addr += page_size; } noc_async_write_barrier(); @@ -422,17 +441,22 @@ FORCE_INLINE void read_shard_from_input_tensor_sharded( } // read chunk from input tensor (local chip) template -FORCE_INLINE void read_chunk_from_input_tensor(uint32_t& input_page_idx, const uint32_t& cb_id, const AddrGen& s, const uint32_t& num_pages, const uint32_t& page_size) { +FORCE_INLINE void read_chunk_from_input_tensor( + uint32_t& input_page_idx, + const uint32_t& cb_id, + const AddrGen& s, + const uint32_t& num_pages, + const uint32_t& page_size) { const uint32_t end_read_idx = input_page_idx + num_pages; cb_reserve_back(cb_id, num_pages); uint32_t local_l1_read_addr = get_write_ptr(cb_id); for (; input_page_idx < end_read_idx; ++input_page_idx) { - #ifdef RM_INTERLEAVED +#ifdef RM_INTERLEAVED uint64_t src_noc_addr = get_noc_addr(input_page_idx, s); noc_async_read(src_noc_addr, local_l1_read_addr, page_size); - #elif defined TILE_INTERLEAVED +#elif defined TILE_INTERLEAVED noc_async_read_tile(input_page_idx, s, local_l1_read_addr); - #endif +#endif local_l1_read_addr += page_size; } noc_async_read_barrier(); @@ -461,11 +485,22 @@ FORCE_INLINE void read_chunk_from_output_tensor_sharded( } // read chunk from output tensor (local chip) template -FORCE_INLINE void read_chunk_from_output_tensor(uint32_t& input_page_idx, uint32_t& col_idx, uint32_t& row_idx, const uint32_t& cb_id, const AddrGen& s, const uint32_t& num_cols, const uint32_t& num_rows, const uint32_t& col_offset, const uint32_t& row_offset, const uint32_t& num_pages, const uint32_t& page_size) { +FORCE_INLINE void read_chunk_from_output_tensor( + uint32_t& input_page_idx, + uint32_t& col_idx, + uint32_t& row_idx, + const uint32_t& cb_id, + const AddrGen& s, + const uint32_t& num_cols, + const uint32_t& num_rows, + const uint32_t& col_offset, + const uint32_t& row_offset, + const uint32_t& num_pages, + const uint32_t& page_size) { cb_reserve_back(cb_id, num_pages); uint32_t local_l1_read_addr = get_write_ptr(cb_id); for (uint32_t i = 0; i < num_pages; ++i) { - #ifdef RM_INTERLEAVED +#ifdef RM_INTERLEAVED uint64_t src_noc_addr = get_noc_addr(input_page_idx, s); noc_async_read(src_noc_addr, local_l1_read_addr, page_size); input_page_idx++; @@ -474,7 +509,7 @@ FORCE_INLINE void read_chunk_from_output_tensor(uint32_t& input_page_idx, uint32 row_idx = 0; input_page_idx += row_offset; } - #elif defined TILE_INTERLEAVED +#elif defined TILE_INTERLEAVED noc_async_read_tile(input_page_idx, s, local_l1_read_addr); input_page_idx++; col_idx++; @@ -487,9 +522,103 @@ FORCE_INLINE void read_chunk_from_output_tensor(uint32_t& input_page_idx, uint32 input_page_idx += row_offset; } } - #endif +#endif + local_l1_read_addr += page_size; + } + noc_async_read_barrier(); + cb_push_back(cb_id, num_pages); +} + +template +FORCE_INLINE void read_chunk_from_output_tensor_v2( + uint32_t& curr_page_idx, + tt::tt_metal::ccl::coord_t& offset_into_worker_slice, + const tt::tt_metal::ccl::coord_t& worker_slice_shape, + + // In tiles for tile layout + const tt::tt_metal::ccl::coord_t& tensor_shape, + const uint32_t cb_id, + const AddrGen& s, + const uint32_t num_pages, + const uint32_t page_size, + bool& last_page_of_worker) { + // we expected caller to reset this and the last curr_page_idx when we set it true + ASSERT(last_page_of_worker == false); + cb_reserve_back(cb_id, num_pages); + uint32_t local_l1_read_addr = get_write_ptr(cb_id); + for (uint32_t i = 0; i < num_pages; ++i) { +#ifdef RM_INTERLEAVED + uint64_t src_noc_addr = get_noc_addr(curr_page_idx, s); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); + ASSERT(false); // unimplemented + +#elif defined TILE_INTERLEAVED + + noc_async_read_tile(curr_page_idx, s, local_l1_read_addr); + // common with `write_chunk_v2` + offset_into_worker_slice.x++; + bool end_of_worker_slice_row = offset_into_worker_slice.x == worker_slice_shape.x; + if (end_of_worker_slice_row) { + offset_into_worker_slice.x = 0; + offset_into_worker_slice.y++; + bool end_of_worker_slice = offset_into_worker_slice.y == worker_slice_shape.y; + if (end_of_worker_slice) { + offset_into_worker_slice.y = 0; + last_page_of_worker = true; + } else { + curr_page_idx += tensor_shape.x - worker_slice_shape.x; + } + } else { + curr_page_idx++; + } +#endif local_l1_read_addr += page_size; } noc_async_read_barrier(); cb_push_back(cb_id, num_pages); } + +template +FORCE_INLINE void write_chunk_v2( + uint32_t& curr_page_idx, + tt::tt_metal::ccl::coord_t& offset_into_worker_slice, + const tt::tt_metal::ccl::coord_t& worker_slice_shape, + + // In tiles for tile layout + const tt::tt_metal::ccl::coord_t& tensor_shape, + uint32_t cb_id, + const AddrGen& d, + const uint32_t num_pages, + const uint32_t page_size, + bool& last_page_of_worker) { + cb_wait_front(cb_id, num_pages); + uint32_t l1_read_addr = get_read_ptr(cb_id); + for (uint32_t i = 0; i < num_pages; ++i) { +#ifdef RM_INTERLEAVED + uint64_t dst_noc_addr = get_noc_addr(curr_page_idx, d); + noc_async_write(l1_read_addr, dst_noc_addr, page_size); + ASSERT(false); // unimplemented +#elif defined TILE_INTERLEAVED + noc_async_write_tile(curr_page_idx, d, l1_read_addr); + // Common with `read_chunk_from_output_tensor_v2` + offset_into_worker_slice.x++; + bool end_of_worker_slice_row = offset_into_worker_slice.x == worker_slice_shape.x; + if (end_of_worker_slice_row) { + offset_into_worker_slice.x = 0; + offset_into_worker_slice.y++; + bool end_of_worker_slice = offset_into_worker_slice.y == worker_slice_shape.y; + if (end_of_worker_slice) { + offset_into_worker_slice.y = 0; + last_page_of_worker = true; + } else { + curr_page_idx += tensor_shape.x - worker_slice_shape.x; + } + } else { + curr_page_idx++; + } +#endif + l1_read_addr += page_size; + } + noc_async_write_barrier(); + cb_pop_front(cb_id, num_pages); +} diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index 9690f1f52f4..2c7f486cd81 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -26,8 +26,9 @@ namespace tt { namespace tt_metal { +using namespace ccl; -std::tuple select_worker_cores(AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t link, uint32_t full_send_direction) { +static std::tuple select_worker_cores(AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t link, uint32_t full_send_direction) { constexpr uint32_t worker_grid_width = 8; const bool fit_sender_and_receiver_workers_on_same_row = (worker_grid_width / 2) >= all_gather_config.get_num_eth_buffers_per_edm(); std::set receiver_worker_cores = {}; @@ -55,13 +56,7 @@ std::tuple select_worker_cores(AllGatherConfig const& return {CoreRangeSet(receiver_worker_cores), CoreRangeSet(sender_worker_cores)}; } -std::unique_ptr AllGatherOpTensorConfig::build_all_gather_tensor_config(Tensor const& tensor) { - if (tensor.is_sharded()) { - return std::make_unique(tensor); - } else { - return std::make_unique(tensor); - } -} + // For ring all-gather, we can send sub-sections of input tensor in opposite directions @@ -71,8 +66,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_FATAL(!(receiver_device_id == std::nullopt && sender_device_id == std::nullopt), "At least one of receiver_device_id or sender_device_id must be specified"); bool is_linear = topology == all_gather_op::Topology::Linear; - std::unique_ptr input_tensor_config = AllGatherOpTensorConfig::build_all_gather_tensor_config(input_tensor); - std::unique_ptr output_tensor_config = AllGatherOpTensorConfig::build_all_gather_tensor_config(output_tensor); + std::unique_ptr input_tensor_config = CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); + std::unique_ptr output_tensor_config = CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); tt_metal::Program program{}; const auto& device = input_tensor.device(); @@ -175,9 +170,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED); + all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); counter_clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED); + all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); } for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { @@ -347,7 +342,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -538,7 +533,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -561,8 +556,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -705,7 +700,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& global_worker_index); auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -720,8 +715,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -904,7 +899,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& CoreCoord const& worker_eth_receiver_core = is_clockwise_direction ? eth_receiver_cores.at(i) : eth_sender_cores.at(i); auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -1059,8 +1054,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator output_tensor_shard_arg_generator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp b/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp index 7aa4e304035..e2bdc79833c 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_common.cpp @@ -2,25 +2,32 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ccl_common.hpp" #include -#include "ccl_common.hpp" + #include "ccl_host_datastructures.hpp" namespace tt { namespace tt_metal { namespace ccl { +std::unique_ptr CclOpTensorConfig::build_all_gather_tensor_config(Tensor const& tensor) { + if (tensor.is_sharded()) { + return std::make_unique(tensor); + } else { + return std::make_unique(tensor); + } +} + void generate_edm_kernels_for_ring_or_linear_topology( - tt_metal::Program &program, + tt_metal::Program& program, Device const* device, RingTopology const& topology_config, std::vector const& clockwise_edm_builders, std::vector const& counter_clockwise_edm_builders, std::optional receiver_device_id, - std::optional sender_device_id - ) { - + std::optional sender_device_id) { auto sender_noc = detail::GetPreferredNOCForDRAMRead(tt::Cluster::instance().arch()); auto receiver_noc = detail::GetPreferredNOCForDRAMWrite(tt::Cluster::instance().arch()); uint32_t sender_socket_idx = 0; @@ -33,17 +40,20 @@ void generate_edm_kernels_for_ring_or_linear_topology( } } for (uint32_t i = 0; i < topology_config.num_links; ++i) { - bool is_clockwise_direction_edm_enabled = !topology_config.is_linear || topology_config.ring_index != topology_config.ring_size - 1; + bool is_clockwise_direction_edm_enabled = + !topology_config.is_linear || topology_config.ring_index != topology_config.ring_size - 1; if (is_clockwise_direction_edm_enabled) { auto eth_sender_core = topology_config.eth_sender_cores.at(i); log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); - auto eth_sender_kernel = ccl::generate_edm_kernel( - program, - device, - clockwise_edm_builders.at(i), - eth_sender_core, - sender_noc); - log_trace(tt::LogOp, "RingIndex: {}. Link {}. Clockwise EDM Core (x={},y={})", topology_config.ring_index, i, eth_sender_core.x, eth_sender_core.y); + auto eth_sender_kernel = + ccl::generate_edm_kernel(program, device, clockwise_edm_builders.at(i), eth_sender_core, sender_noc); + log_trace( + tt::LogOp, + "RingIndex: {}. Link {}. Clockwise EDM Core (x={},y={})", + topology_config.ring_index, + i, + eth_sender_core.x, + eth_sender_core.y); } bool is_counter_clockwise_direction_edm_enabled = !topology_config.is_linear || topology_config.ring_index != 0; @@ -51,19 +61,20 @@ void generate_edm_kernels_for_ring_or_linear_topology( log_trace(tt::LogOp, "EDM COUNTER CLOCKWISE KERNEL RT ARGS: "); auto eth_receiver_core = topology_config.eth_receiver_cores.at(i); auto eth_receiver_kernel = ccl::generate_edm_kernel( - program, - device, - counter_clockwise_edm_builders.at(i), - eth_receiver_core, - receiver_noc); - log_trace(tt::LogOp, "RingIndex: {}. Link {}. Counter-clockwise EDM Core (x={},y={})", topology_config.ring_index, i, eth_receiver_core.x, eth_receiver_core.y); + program, device, counter_clockwise_edm_builders.at(i), eth_receiver_core, receiver_noc); + log_trace( + tt::LogOp, + "RingIndex: {}. Link {}. Counter-clockwise EDM Core (x={},y={})", + topology_config.ring_index, + i, + eth_receiver_core.x, + eth_receiver_core.y); } } - } KernelHandle generate_edm_kernel( - tt_metal::Program &program, + tt_metal::Program& program, Device const* device, ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, @@ -71,27 +82,21 @@ KernelHandle generate_edm_kernel( log_trace(tt::LogOp, "EDM CLOCKWISE KERNEL RT ARGS: "); edm_builder.dump_to_log(); - // auto eth_sender_core = device->get_ethernet_sockets(receiver_device_id.value()).at(sender_socket_idx); - std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); // Ethernet Kernels std::vector eth_sender_ct_args = edm_builder.emit_compile_time_args(); + log_trace(tt::LogOp, "CT ARGS:"); + for (auto const& s : eth_sender_ct_args) { + log_trace(tt::LogOp, "\t{}", s); + } auto eth_sender_kernel = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp", eth_core, - tt_metal::EthernetConfig{.noc=noc_id, .compile_args=eth_sender_ct_args}); - - - tt_metal::SetRuntimeArgs( - program, - eth_sender_kernel, - eth_core, - edm_clockwise_kernel_rt_args); + tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); - // eth_sender_kernels.push_back(eth_sender_kernel); - // log_trace(tt::LogOp, "RingIndex: {}. Link {}. Clockwise EDM Core (x={},y={})", ring_index, i, eth_sender_core.x, eth_sender_core.y); + tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_clockwise_kernel_rt_args); std::stringstream ss; ss << "EDM ARGS:\n"; @@ -103,28 +108,38 @@ KernelHandle generate_edm_kernel( return eth_sender_kernel; } - -ccl::EriscDatamoverBuilder create_erisc_datamover_builder(std::size_t num_channels, uint32_t page_size, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode) { - +ccl::EriscDatamoverBuilder create_erisc_datamover_builder( + std::size_t num_channels, + uint32_t page_size, + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ccl::EriscDataMoverTerminationMode termination_mode) { + TT_ASSERT(num_channels > 0); std::vector edm_sem_addresses(num_channels, 0); std::vector edm_buffer_addresses(num_channels, 0); uint32_t edm_sem_addr = ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); + TT_ASSERT(edm_sem_addr > 0); + TT_ASSERT(edm_buffer_addr > 0); const uint32_t buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); for (std::size_t c = 0; c < num_channels; ++c) { - edm_sem_addresses.push_back(edm_sem_addr); + edm_sem_addresses.at(c) = edm_sem_addr; edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; - edm_buffer_addresses.push_back(edm_buffer_addr); + edm_buffer_addresses.at(c) = edm_buffer_addr; edm_buffer_addr += buffer_size; TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } return ccl::EriscDatamoverBuilder( - buffer_size, ccl::EriscDatamoverConfig::get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode); + buffer_size, + ccl::EriscDatamoverConfig::get_edm_handshake_address(), + edm_sem_addresses, + edm_buffer_addresses, + buffer_sharing_mode, + termination_mode); } -} // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ccl +} // namespace tt_metal +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp b/tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp index eb755a5eee9..df99433a9e7 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp @@ -4,13 +4,14 @@ #pragma once -#include "tt_metal/host_api.hpp" -#include "tt_metal/impl/program/program.hpp" -#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include #include +#include "common/constants.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/program/program.hpp" namespace tt { namespace tt_metal { @@ -44,9 +45,8 @@ struct RingTopology { // Get the cores for the sender and receiver worker cores if (!is_linear || ring_index != ring_size - 1) { uint32_t receiver_device = receiver_device_id.value(); - auto const &sockets = device->get_ethernet_sockets(receiver_device); - auto eth_sender_core = - sockets.at(sender_socket_idx); + auto const& sockets = device->get_ethernet_sockets(receiver_device); + auto eth_sender_core = sockets.at(sender_socket_idx); eth_sender_cores.push_back(eth_sender_core); log_trace( tt::LogOp, "\teth_sender_core on link {}: (x={},y={})", l, eth_sender_core.x, eth_sender_core.y); @@ -54,8 +54,7 @@ struct RingTopology { if (!is_linear || ring_index != 0) { uint32_t sender_device = sender_device_id.value(); auto const& sockets = device->get_ethernet_sockets(sender_device); - auto eth_receiver_core = - sockets.at(receiver_socket_idx); + auto eth_receiver_core = sockets.at(receiver_socket_idx); eth_receiver_cores.push_back(eth_receiver_core); log_trace( tt::LogOp, @@ -84,6 +83,75 @@ struct RingTopology { bool is_linear; }; +class CclOpTensorConfig { + public: + static std::unique_ptr build_all_gather_tensor_config(Tensor const& tensor); + + CclOpTensorConfig(Tensor const& tensor) : + buffer_start_address(tensor.buffer()->address()), + df(tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) {} + + virtual uint32_t get_page_size() const = 0; + virtual uint32_t get_unit_size() const = 0; + + uint32_t get_buffer_start_address() const { return this->buffer_start_address; } + + virtual ~CclOpTensorConfig() {}; + + protected: + uint32_t buffer_start_address; + DataFormat df; +}; + +class CclOpInterleavedTensorConfig final : public virtual CclOpTensorConfig { + public: + CclOpInterleavedTensorConfig(Tensor const& input_tensor) : CclOpTensorConfig(input_tensor) { + if (input_tensor.get_layout() == Layout::TILE) { + this->page_size = tt_metal::detail::TileSize(this->df); + } else { + this->page_size = input_tensor.buffer()->page_size(); + } + } + virtual uint32_t get_page_size() const override { return this->page_size; } + virtual uint32_t get_unit_size() const override { return this->page_size; } + + private: + uint32_t page_size; +}; + +class CclOpShardedTensorConfig final : public virtual CclOpTensorConfig { + public: + CclOpShardedTensorConfig(Tensor const& tensor) : + CclOpTensorConfig(tensor), shard_spec(tensor.shard_spec().value()) { + if (tensor.get_layout() == Layout::TILE) { + this->page_size = tt_metal::detail::TileSize(this->df); + TT_ASSERT( + this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) % + (constants::TILE_HEIGHT * constants::TILE_WIDTH) == + 0); + this->unit_size = (this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) / + (constants::TILE_HEIGHT * constants::TILE_WIDTH)) * + this->page_size; + } else { + this->page_size = tensor.get_legacy_shape()[-1] * tensor.element_size(); + this->unit_size = (this->page_size * this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1)) / + tensor.shard_spec()->num_cores(); + } + } + + virtual uint32_t get_page_size() const override { return this->page_size; } + virtual uint32_t get_unit_size() const override { return this->unit_size; } + + uint32_t get_shard_size_in_bytes() const { return this->get_unit_size(); } + + ShardSpec const& get_shard_spec() const { return this->shard_spec; } + + private: + uint32_t page_size; + uint32_t unit_size; + ShardSpec const shard_spec; +}; + struct CclTensorSlicer { CclTensorSlicer( Shape tensor_shape, @@ -91,28 +159,24 @@ struct CclTensorSlicer { // Shape page_shape, std::size_t num_pages, std::size_t elem_size, - std::size_t page_size_in_bytes - ) : + std::size_t page_size_in_bytes) : tensor_shape(tensor_shape), dim_slice_factors_per_rank(dim_slice_factors), // page_shape(page_shape), num_pages(num_pages), page_size_in_bytes(page_size_in_bytes), - elem_size(elem_size) - { - TT_ASSERT(tensor_shape.rank() == dim_slice_factors.rank(), - "Tensor shape and dim slice factors must have the same size"); - TT_ASSERT(std::all_of(dim_slice_factors.begin(), dim_slice_factors.end(), [](uint32_t factor) { return factor > 0; }), - "All factors must be greater than 0"); - // TT_ASSERT(page_shape.rank() == 2 || page_shape.rank() == tensor_shape.rank(), - // "Page shape must have rank 2 or the same rank as the tensor shape"); - - // // TODO(snijjar) - // rank_slice_shape + elem_size(elem_size) { + TT_ASSERT( + tensor_shape.rank() == dim_slice_factors.rank(), + "Tensor shape and dim slice factors must have the same size"); + TT_ASSERT( + std::all_of(dim_slice_factors.begin(), dim_slice_factors.end(), [](uint32_t factor) { return factor > 0; }), + "All factors must be greater than 0"); } std::size_t get_num_pages_per_slice() const { - std::size_t n = std::accumulate(dim_slice_factors_per_rank.begin(), dim_slice_factors_per_rank.end(), 1, std::multiplies()); + std::size_t n = std::accumulate( + dim_slice_factors_per_rank.begin(), dim_slice_factors_per_rank.end(), 1, std::multiplies()); for (uint32_t i = 0; i < (tensor_shape.rank() - dim_slice_factors_per_rank.rank()); ++i) { n *= tensor_shape[i]; } @@ -130,7 +194,6 @@ struct CclTensorSlicer { std::size_t const elem_size; }; - // To be replaced by the CclTensorSlicer class, which should be reusable between sharded and interleaved // specs and also provides a simpler interface to reason about struct LegacyCclTensorSlicer { @@ -206,36 +269,432 @@ struct LegacyCclTensorSlicer { bool is_sharded; }; -class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { +// Uniform Tensor Worker Slice +struct InterleavedTensorWorkerSlice { + InterleavedTensorWorkerSlice( + tt_xy_pair const& tensor_shape, // Don't _really_ need this + tt_xy_pair const& tensor_slice_shape, + tt_xy_pair const& worker_slice_shape, + tt_xy_pair const& worker_slice_offset) : + tensor_shape(tensor_shape), + tensor_slice_shape(tensor_slice_shape), + worker_slice_shape(worker_slice_shape), + worker_slice_offset(worker_slice_offset) {} + + // Could probably be solved in some closed form + std::size_t compute_num_worker_slice_iterations(std::size_t num_workers) const { + auto slice_offset = coord_t(worker_slice_offset.x, worker_slice_offset.y); + auto const& slice_shape = coord_t(worker_slice_shape.x, worker_slice_shape.y); + auto const& outer_slice_shape = coord_t(tensor_slice_shape.x, tensor_slice_shape.y); + uint32_t num_iterations = 0; + while (slice_offset.y < tensor_slice_shape.y && slice_offset.x < tensor_slice_shape.x) { + slice_offset = + tt::tt_metal::ccl::advance_slice_row_major(slice_offset, slice_shape, outer_slice_shape, num_workers); + num_iterations++; + } + + return num_iterations; + } + + tt_xy_pair tensor_shape; + tt_xy_pair tensor_slice_shape; + tt_xy_pair worker_slice_shape; + tt_xy_pair worker_slice_offset; +}; + +class InterleavedRingReduceScatterTensorSlicer : public LegacyCclTensorSlicer { public: - InterleavedRingAllGatherTensorSlicer ( + InterleavedRingReduceScatterTensorSlicer( Tensor const& input_tensor, Tensor const& output_tensor, int slice_dim, - uint32_t slice_idx - ) : LegacyCclTensorSlicer() { + uint32_t ring_index, + uint32_t ring_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes) : + LegacyCclTensorSlicer() { + TT_ASSERT(max_slice_size_in_bytes > 0); + this->row_major = input_tensor.get_layout() == Layout::ROW_MAJOR; + this->slice_dim_is_width = input_tensor.get_legacy_shape().rank() - 1 == slice_dim; + this->is_sharded = input_tensor.is_sharded(); + + int32_t shard_size_in_bytes = + is_sharded ? (input_tensor.buffer()->page_size() * input_tensor.buffer()->shard_spec().tensor2d_shape[0] * + input_tensor.buffer()->shard_spec().tensor2d_shape[1]) / + input_tensor.shard_spec()->num_cores() + : -1; + this->input_page_size = is_sharded ? shard_size_in_bytes : input_tensor.buffer()->page_size(); + ; + if (row_major) { + this->num_cols = input_tensor.get_legacy_shape()[-1]; + auto input_shape = input_tensor.get_legacy_shape(); + auto output_shape = output_tensor.get_legacy_shape(); + this->num_rows = + std::accumulate(input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()); + this->row_offset = + std::accumulate( + output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - + num_rows; + } else { + const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / tt::constants::TILE_WIDTH; + const uint32_t num_tiles_y = input_tensor.get_legacy_shape()[-2] / tt::constants::TILE_HEIGHT; + TT_ASSERT(num_tiles_x >= ring_size); + this->tensor_slice_shape.x = slice_dim == 3 ? (num_tiles_x / ring_size) : num_tiles_x; + this->tensor_slice_shape.y = slice_dim != 3 ? num_tiles_y / ring_size : num_tiles_y; + } + + // Create the worker schedule + + // The `output_page_offset` will be the starting page offset for this slice index (corresponds to ) + // ring index). Each worker will operate out of that slice and then advance to the next slice for + // for the next ring index/timestep + uint32_t slice_size_in_bytes = std::numeric_limits::max(); + if (row_major) { + if (slice_dim_is_width) { + TT_FATAL(false, "Reduce scatter row-major interleaved does not yet support a width dim"); + this->output_addr_offset = input_page_size; + } else { + this->output_page_offset = num_rows; + } + this->worker_slice_shapes = create_worker_slice_shapes_for_row_major_layout( + this->tensor_slice_shape, total_num_workers, max_slice_size_in_bytes); + } else { + this->worker_slice_shapes = create_worker_slice_shapes_for_tile_layout( + this->tensor_slice_shape, total_num_workers, max_slice_size_in_bytes / input_page_size); + } + + if (row_major) { + this->flattened_tensor_shape = tt_xy_pair{ + input_tensor.get_legacy_shape()[3], + input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * + input_tensor.get_legacy_shape()[2]}; + } else { + this->flattened_tensor_shape = tt_xy_pair{ + input_tensor.get_legacy_shape()[3] / constants::TILE_WIDTH, + (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * + input_tensor.get_legacy_shape()[2]) / + constants::TILE_HEIGHT}; + } + this->worker_slice_offsets = compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); + TT_ASSERT(this->worker_slice_offsets.size() == this->worker_slice_shapes.size()); + } + + ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index) { + return ccl::InterleavedTensorWorkerSlice( + this->flattened_tensor_shape, + this->tensor_slice_shape, + this->worker_slice_shapes.at(global_worker_index), + this->worker_slice_offsets.at(global_worker_index)); + } + + [[deprecated("deprecated code path for reduce scatter. Use nerw get_worker_slice API instead")]] + virtual void increment(uint32_t num_pages) override { + TT_FATAL(false, "deprecated code path for "); + } + + public: + static std::vector compute_worker_slice_offsets( + std::vector const& worker_slice_shapes, tt_xy_pair const& tensor_slice_shape) { + std::vector worker_slice_offsets; + worker_slice_offsets.reserve(worker_slice_shapes.size()); + + std::size_t offset_x = 0; + std::size_t offset_y = 0; + std::size_t last_worker_size_y = worker_slice_shapes.at(0).y; // for validation + bool first_in_row = true; + for (tt_xy_pair const& worker_slice_shape : worker_slice_shapes) { + worker_slice_offsets.emplace_back(offset_x, offset_y); + + TT_ASSERT(offset_y < tensor_slice_shape.y); + offset_x += worker_slice_shape.x; + if (offset_x < tensor_slice_shape.x) { + first_in_row = false; + } else { + offset_x = 0; + first_in_row = true; + offset_y += worker_slice_shape.y; + } + TT_ASSERT(first_in_row || last_worker_size_y == worker_slice_shape.y); + last_worker_size_y = worker_slice_shape.y; + } + + TT_ASSERT(worker_slice_offsets.size() == worker_slice_shapes.size()); + return worker_slice_offsets; + } + + static std::vector create_worker_slice_shapes_for_row_major_layout( + tt_xy_pair const& tensor_slice_shape_in_elems, uint32_t num_workers, uint32_t max_slice_size_in_elements) { + std::vector worker_slice_shapes; + worker_slice_shapes.reserve(num_workers); + if (num_workers > tensor_slice_shape_in_elems.y) { + log_warning( + tt::LogOp, + "Reduce Scatter more workers instantiated than is work to be done. Some workers will be idle and do " + "nothing"); + num_workers = tensor_slice_shape_in_elems.y; + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(tensor_slice_shape_in_elems.x, 1); + } + for (uint32_t w = num_workers; w < tensor_slice_shape_in_elems.x; ++w) { + worker_slice_shapes.emplace_back(0, 0); + } + return worker_slice_shapes; + } + + uint32_t num_elems_accounted_for = 0; + // For now we don't support row splitting but we will in the future + const uint32_t min_rows_per_worker = tensor_slice_shape_in_elems.y / num_workers; + const uint32_t num_workers_with_max_rows = tensor_slice_shape_in_elems.y % num_workers; + const uint32_t max_rows_per_worker = + num_workers_with_max_rows != 0 ? min_rows_per_worker + 1 : min_rows_per_worker; + for (uint32_t w = 0; w < num_workers_with_max_rows; w++) { + worker_slice_shapes.emplace_back(tensor_slice_shape_in_elems.x, max_rows_per_worker); + num_elems_accounted_for += tensor_slice_shape_in_elems.x * max_rows_per_worker; + } + for (uint32_t w = num_workers_with_max_rows; w < num_workers; w++) { + worker_slice_shapes.emplace_back(tensor_slice_shape_in_elems.x, min_rows_per_worker); + num_elems_accounted_for += tensor_slice_shape_in_elems.x * min_rows_per_worker; + } + + TT_ASSERT(num_elems_accounted_for == tensor_slice_shape_in_elems.x * tensor_slice_shape_in_elems.y); + for (auto& worker_slice_shape : worker_slice_shapes) { + TT_ASSERT(max_slice_size_in_elements >= worker_slice_shape.x * worker_slice_shape.y); + TT_ASSERT(worker_slice_shape.x * worker_slice_shape.y > 0); + } + return worker_slice_shapes; + } + + static std::vector create_worker_slice_shapes_for_tile_layout( + tt_xy_pair const& tensor_slice_shape_in_tiles, uint32_t num_workers, uint32_t max_slice_size_in_pages) { + TT_ASSERT(max_slice_size_in_pages > 0); + std::vector worker_slice_shapes; + worker_slice_shapes.reserve(num_workers); + const uint32_t total_num_tiles = tensor_slice_shape_in_tiles.x * tensor_slice_shape_in_tiles.y; + if (num_workers > total_num_tiles) { + log_warning( + tt::LogOp, + "Reduce Scatter more workers instantiated than is work to be done. Some workers will be idle and do " + "nothing"); + num_workers = total_num_tiles; + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(1, 1); + } + for (uint32_t w = num_workers; w < total_num_tiles; ++w) { + worker_slice_shapes.emplace_back(0, 0); + } + return worker_slice_shapes; + } + + std::size_t max_slice_size_in_tiles = max_slice_size_in_pages; + TT_ASSERT(max_slice_size_in_tiles > 0); + std::size_t max_width_in_tiles = std::min(max_slice_size_in_tiles, tensor_slice_shape_in_tiles.x); + std::size_t max_height_in_tiles = std::min(max_slice_size_in_tiles, tensor_slice_shape_in_tiles.y); + + uint32_t num_tiles_accounted_for = 0; // for validation + if (tensor_slice_shape_in_tiles.y >= num_workers) { + // slice into rows + const uint32_t min_rows_per_worker = tensor_slice_shape_in_tiles.y / num_workers; + const uint32_t num_workers_with_max_rows = tensor_slice_shape_in_tiles.y % num_workers; + const uint32_t max_rows_per_worker = + num_workers_with_max_rows != 0 ? min_rows_per_worker + 1 : min_rows_per_worker; + for (uint32_t w = 0; w < num_workers_with_max_rows; w++) { + worker_slice_shapes.emplace_back(tensor_slice_shape_in_tiles.x, max_rows_per_worker); + num_tiles_accounted_for += tensor_slice_shape_in_tiles.x * max_rows_per_worker; + } + for (uint32_t w = num_workers_with_max_rows; w < num_workers; w++) { + worker_slice_shapes.emplace_back(tensor_slice_shape_in_tiles.x, min_rows_per_worker); + num_tiles_accounted_for += tensor_slice_shape_in_tiles.x * min_rows_per_worker; + } + } else if (tensor_slice_shape_in_tiles.x >= num_workers) { + // slice into columns + const uint32_t min_cols_per_worker = tensor_slice_shape_in_tiles.x / num_workers; + const uint32_t num_workers_with_max_cols = tensor_slice_shape_in_tiles.x % num_workers; + const uint32_t max_cols_per_worker = + num_workers_with_max_cols != 0 ? min_cols_per_worker + 1 : min_cols_per_worker; + for (uint32_t w = 0; w < num_workers_with_max_cols; w++) { + worker_slice_shapes.emplace_back(max_cols_per_worker, tensor_slice_shape_in_tiles.y); + num_tiles_accounted_for += max_cols_per_worker * tensor_slice_shape_in_tiles.y; + } + for (uint32_t w = num_workers_with_max_cols; w < num_workers; w++) { + worker_slice_shapes.emplace_back(min_cols_per_worker, tensor_slice_shape_in_tiles.y); + num_tiles_accounted_for += min_cols_per_worker * tensor_slice_shape_in_tiles.y; + } + + } else { + const uint32_t min_num_workers_per_row = num_workers / tensor_slice_shape_in_tiles.y; + const uint32_t num_rows_with_max_workers = tensor_slice_shape_in_tiles.y % num_workers; + const uint32_t max_num_workers_per_row = + num_rows_with_max_workers != 0 ? min_num_workers_per_row + 1 : min_num_workers_per_row; + + // 4 "quadrants" to the worker slicing: + // 1. Row with max num workers and max columns wide per worker (first part of rows with max num workers) + // 2. Row with max num workers and min columns wide per worker (second part of rows with max num workers) + // 3. Row with min num workers and max columns wide per worker (first part of rows with min num workers) + // 4. Row with min num workers and min columns wide per worker (second part of rows with min num workers) + // Depending on specific numbers, some of the above "quadrants" might be 0 sized + const uint32_t max_workers_row_min_cols_per_worker = + tensor_slice_shape_in_tiles.x / max_num_workers_per_row; + const uint32_t max_workers_row_max_col_worker_count = + tensor_slice_shape_in_tiles.x % max_num_workers_per_row; + const uint32_t max_workers_row_max_cols_per_worker = max_workers_row_max_col_worker_count != 0 + ? max_workers_row_min_cols_per_worker + 1 + : max_workers_row_min_cols_per_worker; + TT_ASSERT(max_workers_row_min_cols_per_worker > 0); + TT_ASSERT(max_workers_row_max_cols_per_worker >= max_workers_row_min_cols_per_worker); + for (uint32_t w_r = 0; w_r < num_rows_with_max_workers; w_r++) { + for (uint32_t w_c = 0; w_c < max_workers_row_max_cols_per_worker; w_c++) { + worker_slice_shapes.emplace_back(max_workers_row_max_cols_per_worker, 1); + num_tiles_accounted_for += max_workers_row_max_cols_per_worker; + } + for (uint32_t w_c = max_workers_row_max_col_worker_count; w_c < max_num_workers_per_row; w_c++) { + worker_slice_shapes.emplace_back(max_workers_row_min_cols_per_worker, 1); + num_tiles_accounted_for += max_workers_row_min_cols_per_worker; + } + } + + const uint32_t min_workers_row_min_cols_per_worker = + tensor_slice_shape_in_tiles.x / min_num_workers_per_row; + const uint32_t min_workers_row_max_col_worker_count = + tensor_slice_shape_in_tiles.x % min_num_workers_per_row; + const uint32_t min_workers_row_max_cols_per_worker = min_workers_row_max_col_worker_count != 0 + ? min_workers_row_min_cols_per_worker + 1 + : min_workers_row_min_cols_per_worker; + + for (uint32_t w_r = num_rows_with_max_workers; w_r < tensor_slice_shape_in_tiles.y; w_r++) { + for (uint32_t w_c = 0; w_c < min_workers_row_max_cols_per_worker; w_c++) { + worker_slice_shapes.emplace_back(min_workers_row_max_cols_per_worker, 1); + num_tiles_accounted_for += min_workers_row_max_cols_per_worker; + } + for (uint32_t w_c = min_workers_row_max_col_worker_count; w_c < min_num_workers_per_row; w_c++) { + worker_slice_shapes.emplace_back(min_workers_row_min_cols_per_worker, 1); + num_tiles_accounted_for += min_workers_row_max_cols_per_worker; + } + } + } + + // For now we do something a little naive - since this becomes an optimization problem otherwise, and the + // benefits to nailing it are marginal we expect uniform chunk sizes and just truncate the largest chunk to fit + // the max size and then apply that shape to all workers slice shapes + tt_xy_pair largest_worker_slice_shape = {0, 0}; + for (auto const& worker_slice_shape : worker_slice_shapes) { + if (largest_worker_slice_shape.x * largest_worker_slice_shape.y < + worker_slice_shape.x * worker_slice_shape.y) { + largest_worker_slice_shape = worker_slice_shape; + } + } + bool do_truncation = largest_worker_slice_shape.x * largest_worker_slice_shape.y > max_slice_size_in_tiles; + if (do_truncation) { + log_trace(tt::LogOp, "Truncating worker slice shapes to fit max slice size in tiles"); + } + log_trace( + tt::LogOp, + "largest_worker_slice_shape: x={}, y={}", + largest_worker_slice_shape.x, + largest_worker_slice_shape.y); + log_trace(tt::LogOp, "max_slice_size_in_tiles={}", max_slice_size_in_tiles); + while (largest_worker_slice_shape.x * largest_worker_slice_shape.y > max_slice_size_in_tiles) { + log_trace(tt::LogOp, "Loop Head"); + // truncate the largest dim first + uint32_t delta = (largest_worker_slice_shape.x * largest_worker_slice_shape.y) - max_slice_size_in_tiles; + log_trace(tt::LogOp, "-- delta: {}", delta); + uint32_t cols_removed_if_x_truncated = std::max(1, largest_worker_slice_shape.x / delta); + uint32_t tiles_removed_if_x_truncated = cols_removed_if_x_truncated * largest_worker_slice_shape.y; + uint32_t rows_removed_if_y_truncated = std::max(1, largest_worker_slice_shape.y / delta); + uint32_t tiles_removed_if_y_truncated = rows_removed_if_y_truncated * largest_worker_slice_shape.x; + uint32_t difference_x = tiles_removed_if_x_truncated > delta ? tiles_removed_if_x_truncated - delta + : delta - tiles_removed_if_x_truncated; + uint32_t difference_y = tiles_removed_if_y_truncated > delta ? tiles_removed_if_y_truncated - delta + : delta - tiles_removed_if_y_truncated; + log_trace(tt::LogOp, "-- cols_removed_if_x_truncated: {}", cols_removed_if_x_truncated); + log_trace(tt::LogOp, "-- tiles_removed_if_x_truncated: {}", tiles_removed_if_x_truncated); + log_trace(tt::LogOp, "-- rows_removed_if_y_truncated: {}", rows_removed_if_y_truncated); + log_trace(tt::LogOp, "-- tiles_removed_if_y_truncated: {}", tiles_removed_if_y_truncated); + log_trace(tt::LogOp, "-- difference_x: {}", difference_x); + log_trace(tt::LogOp, "-- difference_y: {}", difference_y); + if (difference_x < difference_y) { + largest_worker_slice_shape.x -= cols_removed_if_x_truncated; + } else { + largest_worker_slice_shape.y -= rows_removed_if_y_truncated; + } + log_trace( + tt::LogOp, + "-- new largest_worker_slice_shape: x={}, y={}", + largest_worker_slice_shape.x, + largest_worker_slice_shape.y); + } + if (do_truncation) { + log_trace( + tt::LogOp, + "Truncated worker slice shape to fit max slice size in tiles: ({},{})", + largest_worker_slice_shape.x, + largest_worker_slice_shape.y); + TT_ASSERT(largest_worker_slice_shape.x * largest_worker_slice_shape.y > 0); + for (auto& worker_slice_shape : worker_slice_shapes) { + worker_slice_shape = largest_worker_slice_shape; + } + } + + TT_ASSERT( + num_tiles_accounted_for == total_num_tiles, "All tiles must be accounted for in the worker slice shapes"); + TT_ASSERT(worker_slice_shapes.size() == num_workers, "Worker slice shapes must match the number of workers"); + return worker_slice_shapes; + } + + void create_worker_slice_shape_for_row_major_layout(tt_xy_pair const& tensor_slice_shape, uint32_t num_workers) { + TT_FATAL("Row major interleaved not supported by Reduce Scatter"); + } + + protected: + tt_xy_pair flattened_tensor_shape; + tt_xy_pair tensor_slice_shape; + std::vector worker_slice_shapes; + // For RowMajor - offset is in elements + // For Tile - offset is in tiles + std::vector worker_slice_offsets; +}; + +class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { + public: + InterleavedRingAllGatherTensorSlicer( + Tensor const& input_tensor, Tensor const& output_tensor, int slice_dim, uint32_t slice_idx) : + LegacyCclTensorSlicer() { this->row_major = input_tensor.get_layout() == Layout::ROW_MAJOR; this->slice_dim_is_width = input_tensor.get_legacy_shape().rank() - 1 == slice_dim; this->is_sharded = input_tensor.is_sharded(); - int32_t shard_size_in_bytes = is_sharded ? - (input_tensor.buffer()->page_size() * input_tensor.buffer()->shard_spec().tensor2d_shape[0] * input_tensor.buffer()->shard_spec().tensor2d_shape[1]) / input_tensor.shard_spec()->num_cores() : - -1; - this->input_page_size = is_sharded ? shard_size_in_bytes : input_tensor.buffer()->page_size();; + int32_t shard_size_in_bytes = + is_sharded ? (input_tensor.buffer()->page_size() * input_tensor.buffer()->shard_spec().tensor2d_shape[0] * + input_tensor.buffer()->shard_spec().tensor2d_shape[1]) / + input_tensor.shard_spec()->num_cores() + : -1; + this->input_page_size = is_sharded ? shard_size_in_bytes : input_tensor.buffer()->page_size(); + ; if (row_major) { this->num_cols = input_tensor.get_legacy_shape()[-1]; auto input_shape = input_tensor.get_legacy_shape(); auto output_shape = output_tensor.get_legacy_shape(); - this->num_rows = std::accumulate(input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()); - this->row_offset = std::accumulate(output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - num_rows; + this->num_rows = + std::accumulate(input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()); + this->row_offset = + std::accumulate( + output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - + num_rows; } else { this->num_cols = input_tensor.get_legacy_shape()[-1] / tt::constants::TILE_WIDTH; auto input_shape = input_tensor.get_legacy_shape(); auto output_shape = output_tensor.get_legacy_shape(); uint32_t num_output_cols = output_tensor.get_legacy_shape()[-1] / tt::constants::TILE_WIDTH; - this->num_rows = std::accumulate(input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()) / tt::constants::TILE_HEIGHT; - this->row_offset = (std::accumulate(output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) / tt::constants::TILE_HEIGHT - num_rows) * num_output_cols; + this->num_rows = + std::accumulate( + input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()) / + tt::constants::TILE_HEIGHT; + this->row_offset = + (std::accumulate( + output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) / + tt::constants::TILE_HEIGHT - + num_rows) * + num_output_cols; this->col_offset = num_output_cols - num_cols; this->num_tiles = num_rows * num_cols; } @@ -253,46 +712,45 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { this->output_page_offset = num_tiles; } } - this->output_start_page_idx = slice_idx/*ring_index*/ * output_page_offset; - this->output_start_addr_offset = slice_idx/*ring_index*/ * output_addr_offset; + this->output_start_page_idx = slice_idx /*ring_index*/ * output_page_offset; + this->output_start_addr_offset = slice_idx /*ring_index*/ * output_addr_offset; } virtual void increment(uint32_t num_pages) override { - // uint32_t pages_per_worker = num_full_chunks_per_worker.at(b) * pages_per_chunk + rem_pages_per_worker.at(b); if (is_sharded) { // nothing to do here - is handled by } else { // Only for interleaved - if (num_pages/*pages_per_worker*/ > 0) { + if (num_pages /*pages_per_worker*/ > 0) { if (row_major) { - uint32_t num_rows_shifted = row_idx + num_pages/*pages_per_worker*/; + uint32_t num_rows_shifted = row_idx + num_pages /*pages_per_worker*/; uint32_t num_blocks_shifted = slice_dim_is_width ? 0 : num_rows_shifted / num_rows; - this->output_start_page_idx += num_pages/*pages_per_worker*/ + num_blocks_shifted * row_offset; + this->output_start_page_idx += num_pages /*pages_per_worker*/ + num_blocks_shifted * row_offset; this->row_idx = slice_dim_is_width ? 0 : num_rows_shifted % num_rows; } else { - uint32_t num_cols_shifted = col_idx + num_pages/*pages_per_worker*/; + uint32_t num_cols_shifted = col_idx + num_pages /*pages_per_worker*/; uint32_t num_rows_shifted = num_cols_shifted / num_cols; uint32_t num_blocks_shifted = slice_dim_is_width ? 0 : num_rows_shifted / num_rows; - this->output_start_page_idx += num_pages/*pages_per_worker*/ + num_rows_shifted * col_offset + num_blocks_shifted * row_offset; + this->output_start_page_idx += num_pages /*pages_per_worker*/ + num_rows_shifted * col_offset + + num_blocks_shifted * row_offset; this->col_idx = num_cols_shifted % num_cols; this->row_idx = slice_dim_is_width ? 0 : num_rows_shifted % num_rows; } } - this->input_start_page_idx += num_pages/*pages_per_worker*/; + this->input_start_page_idx += num_pages /*pages_per_worker*/; } } }; - KernelHandle generate_edm_kernel( - tt_metal::Program &program, + tt_metal::Program& program, Device const* device, ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id); void generate_edm_kernels_for_ring_or_linear_topology( - tt_metal::Program &program, + tt_metal::Program& program, Device const* device, RingTopology const& topology_config, std::vector const& clockwise_edm_builders, @@ -303,8 +761,9 @@ void generate_edm_kernels_for_ring_or_linear_topology( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode); + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + EriscDataMoverTerminationMode termination_mode); -} // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ccl +} // namespace tt_metal +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp index d0bd6858db2..193046b8c54 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp @@ -12,61 +12,78 @@ namespace tt { namespace tt_metal { namespace ccl { -enum Topology { - Ring = 0, - Linear = 1, - Meash = 2 -}; - +enum Topology { Ring = 0, Linear = 1, Meash = 2 }; struct EriscDatamoverConfig { - static constexpr std::size_t total_l1_buffer_space = eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + static constexpr std::size_t total_l1_buffer_space = + eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; static constexpr std::size_t usable_l1_base_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - static constexpr std::size_t semaphore_size = 4; - static constexpr std::size_t handshake_location_size = 16; // ethernet word size + static constexpr std::size_t semaphore_size = 32; + static constexpr std::size_t handshake_location_size = 16; // ethernet word size + // The EDM uses this fixed address as a source for a first level ack sent from receiver -> sender + // side. We have this dedicated source address to avoid a race between first and second level ack + // where second level ack overwrites the first level ack in L1 before the first one is sent out. + // The memory contents in L1 will be {1, 1, x, x}. By having this dedicated source memory, we + // avoid the race + static constexpr std::size_t edm_receiver_first_level_ack_source_word_size = 16; // ethernet word size static constexpr std::size_t eth_word_size_bytes = 16; - static uint32_t get_edm_handshake_address() { - return usable_l1_base_address; - } + static uint32_t get_edm_handshake_address() { return usable_l1_base_address; } static uint32_t get_semaphores_base_address(std::size_t num_edm_channels) { - return usable_l1_base_address + handshake_location_size; + return usable_l1_base_address + handshake_location_size + edm_receiver_first_level_ack_source_word_size; } static uint32_t get_buffers_base_address(std::size_t num_edm_channels) { - uint32_t base_address = round_up(get_semaphores_base_address(num_edm_channels) + num_edm_channels * semaphore_size, eth_word_size_bytes); + uint32_t base_address = round_up( + get_semaphores_base_address(num_edm_channels) + num_edm_channels * semaphore_size, eth_word_size_bytes); TT_ASSERT(base_address % eth_word_size_bytes == 0); return base_address; } static uint32_t compute_buffer_size(std::size_t num_edm_channels, uint32_t page_size = eth_word_size_bytes) { page_size = std::max(page_size, eth_word_size_bytes); - uint32_t buffer_size = round_down((total_l1_buffer_space - get_buffers_base_address(num_edm_channels)) / (num_edm_channels), page_size); + TT_ASSERT(num_edm_channels > 0); + uint32_t buffer_size = round_down( + (total_l1_buffer_space - get_buffers_base_address(num_edm_channels)) / (num_edm_channels), page_size); + log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); + log_trace( + tt::LogOp, "get_buffers_base_address(num_edm_channels): {}", get_buffers_base_address(num_edm_channels)); + log_trace( + tt::LogOp, "usable buffer space: {}", total_l1_buffer_space - get_buffers_base_address(num_edm_channels)); + log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); + log_trace(tt::LogOp, "page_size: {}", page_size); + + log_trace(tt::LogOp, "Buffer size: {}", buffer_size); + TT_ASSERT(buffer_size > 0 && buffer_size % page_size == 0); return buffer_size; } }; - - struct CCLOpConfig { public: - CCLOpConfig(const std::vector& input_tensors, const std::vector& output_tensors, Topology topology) : + CCLOpConfig( + const std::vector& input_tensors, const std::vector& output_tensors, Topology topology) : input_tensors(&input_tensors), output_tensors(&output_tensors), input_sharded(input_tensors.at(0).is_sharded()), output_sharded(output_tensors.at(0).is_sharded()), page_size(input_tensors.at(0).buffer()->page_size()), input_shard_size_bytes( - input_tensors.at(0).is_sharded() ? - static_cast>((input_tensors.at(0).buffer()->page_size() * input_tensors.at(0).buffer()->shard_spec().tensor2d_shape[0] * input_tensors.at(0).buffer()->shard_spec().tensor2d_shape[1]) / input_tensors.at(0).shard_spec()->num_cores()) : - std::nullopt), + input_tensors.at(0).is_sharded() ? static_cast>( + (input_tensors.at(0).buffer()->page_size() * + input_tensors.at(0).buffer()->shard_spec().tensor2d_shape[0] * + input_tensors.at(0).buffer()->shard_spec().tensor2d_shape[1]) / + input_tensors.at(0).shard_spec()->num_cores()) + : std::nullopt), output_shard_size_bytes( - output_tensors.at(0).is_sharded() ? - static_cast>((output_tensors.at(0).buffer()->page_size() * output_tensors.at(0).buffer()->shard_spec().tensor2d_shape[0] * output_tensors.at(0).buffer()->shard_spec().tensor2d_shape[1]) / input_tensors.at(0).shard_spec()->num_cores()) : - std::nullopt), + output_tensors.at(0).is_sharded() ? static_cast>( + (output_tensors.at(0).buffer()->page_size() * + output_tensors.at(0).buffer()->shard_spec().tensor2d_shape[0] * + output_tensors.at(0).buffer()->shard_spec().tensor2d_shape[1]) / + input_tensors.at(0).shard_spec()->num_cores()) + : std::nullopt), shard_grid_size(output_tensors.at(0).is_sharded() ? input_tensors.at(0).shard_spec()->num_cores() : 0), - topology(topology) - { + topology(topology) { TT_ASSERT(!this->is_input_sharded() || input_shard_size_bytes.has_value()); TT_ASSERT(!this->is_output_sharded() || output_shard_size_bytes.has_value()); } @@ -79,33 +96,17 @@ struct CCLOpConfig { TT_ASSERT(output_shard_size_bytes.has_value()); return output_shard_size_bytes.value(); } - uint32_t get_page_size() const { - return this->page_size; - } - Topology get_topology() const { - return this->topology; - } - bool is_input_sharded() const { - return this->input_sharded; - } - bool is_output_sharded() const { - return this->output_sharded; - } - bool get_shard_grid_size() const { - return this->shard_grid_size; - } - Tensor const& get_input_tensor(std::size_t i) const { - return input_tensors->at(i); - } - Tensor const& get_output_tensor(std::size_t i) const { - return output_tensors->at(i); - } + uint32_t get_page_size() const { return this->page_size; } + Topology get_topology() const { return this->topology; } + bool is_input_sharded() const { return this->input_sharded; } + bool is_output_sharded() const { return this->output_sharded; } + bool get_shard_grid_size() const { return this->shard_grid_size; } + Tensor const& get_input_tensor(std::size_t i) const { return input_tensors->at(i); } + Tensor const& get_output_tensor(std::size_t i) const { return output_tensors->at(i); } private: - - - std::optional input_shard_size_bytes; // TODO: split off into CCL op input config () - std::optional output_shard_size_bytes; // TODO: split off into CCL op input config () + std::optional input_shard_size_bytes; + std::optional output_shard_size_bytes; uint32_t page_size; uint32_t shard_grid_size; Topology topology; @@ -117,65 +118,104 @@ struct CCLOpConfig { }; class EriscDatamoverBuilder { + private: + struct ChannelBufferSpec { + ChannelBufferSpec( + bool is_sender, + uint32_t worker_semaphore_address, + uint32_t num_eth_messages_to_forward, + uint32_t channel, + std::vector const& worker_coords) : + worker_coords(worker_coords), + worker_semaphore_address(worker_semaphore_address), + num_eth_messages_to_forward(num_eth_messages_to_forward), + channel(channel), + is_sender(is_sender) {} + + std::vector const worker_coords; + uint32_t worker_semaphore_address; + uint32_t num_eth_messages_to_forward; + uint32_t channel; + bool is_sender; + }; + + void push_back_channel_args(std::vector& args, ChannelBufferSpec const& channel) const { + args.push_back(this->local_buffer_addresses.at(channel.channel)); + args.push_back(channel.num_eth_messages_to_forward); + args.push_back(this->eth_buffer_size_bytes); + args.push_back(this->local_semaphore_addresses.at(channel.channel)); + args.push_back(channel.worker_semaphore_address); + args.push_back(channel.worker_coords.size()); + for (auto const& worker_coord : channel.worker_coords) { + args.push_back(worker_coord.to_uint32()); + } + } + + std::vector active_channels; + std::vector const local_semaphore_addresses; + std::vector const local_buffer_addresses; + uint32_t eth_buffer_size_bytes; + uint32_t handshake_addr; + uint32_t const num_channel_buffers; + ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; + ccl::EriscDataMoverTerminationMode const termination_mode; + uint32_t num_senders; + uint32_t num_receivers; + + bool enable_sender; + bool enable_receiver; + public: struct ChannelBufferInterface { uint32_t eth_buffer_l1_address; uint32_t eth_semaphore_l1_address; }; - EriscDatamoverBuilder(uint32_t eth_buffer_size, uint32_t handshake_addr, std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode) : + EriscDatamoverBuilder( + uint32_t eth_buffer_size, + uint32_t handshake_addr, + std::vector const& local_semaphore_addresses, + std::vector const& local_buffer_addresses, + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, + ccl::EriscDataMoverTerminationMode termination_mode = + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : local_semaphore_addresses(local_semaphore_addresses), local_buffer_addresses(local_buffer_addresses), eth_buffer_size_bytes(eth_buffer_size), handshake_addr(handshake_addr), num_channel_buffers(local_buffer_addresses.size()), buffer_sharing_mode(buffer_sharing_mode), + termination_mode(termination_mode), enable_sender(false), enable_receiver(false), num_senders(0), - num_receivers(0) - { - TT_ASSERT(local_buffer_addresses.size() == local_semaphore_addresses.size()); - active_channels.reserve(num_channel_buffers); - TT_ASSERT(eth_buffer_size_bytes < 163000); - log_trace(tt::LogOp, "EriscDatamoverBuilder:"); - for (auto const& addr : local_semaphore_addresses) { - log_trace(tt::LogOp, "\tsemaphore_address: {}", addr); - } - for (auto const& addr : local_buffer_addresses) { - log_trace(tt::LogOp, "\tbuffer_address: {}", addr); - } + num_receivers(0) { + TT_ASSERT(local_buffer_addresses.size() == local_semaphore_addresses.size()); + active_channels.reserve(num_channel_buffers); + TT_ASSERT(eth_buffer_size_bytes < 163000); + log_trace(tt::LogOp, "EriscDatamoverBuilder:"); + for (auto const& addr : local_semaphore_addresses) { + TT_ASSERT(addr > 0); + TT_ASSERT(addr % 16 == 0); + log_trace(tt::LogOp, "\tsemaphore_address: {}", addr); } - - // EriscDatamoverBuilder(AllGatherConfig const& all_gather_config, std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode) : - // local_semaphore_addresses(local_semaphore_addresses), - // local_buffer_addresses(local_buffer_addresses), - // eth_buffer_size_bytes(all_gather_config.get_eth_buffer_size()), - // handshake_addr(all_gather_config.get_erisc_handshake_address()), - // num_channel_buffers(all_gather_config.get_num_eth_buffers_per_edm()), - // buffer_sharing_mode(buffer_sharing_mode), - // enable_sender(false), - // enable_receiver(false), - // num_senders(0), - // num_receivers(0) - // { - // active_channels.reserve(num_channel_buffers); - // TT_ASSERT(eth_buffer_size_bytes < 163000); - // log_trace(tt::LogOp, "EriscDatamoverBuilder:"); - // for (auto const& addr : local_semaphore_addresses) { - // log_trace(tt::LogOp, "\tsemaphore_address: {}", addr); - // } - // for (auto const& addr : local_buffer_addresses) { - // log_trace(tt::LogOp, "\tbuffer_address: {}", addr); - // } - // } + for (auto const& addr : local_buffer_addresses) { + TT_ASSERT(addr > 0); + TT_ASSERT(addr % 16 == 0); + log_trace(tt::LogOp, "\tbuffer_address: {}", addr); + } + } [[nodiscard]] - ChannelBufferInterface add_sender_channel(uint32_t worker_semaphore_address, uint32_t num_eth_messages_to_forward, std::vector const& worker_coords) { + ChannelBufferInterface add_sender_channel( + uint32_t worker_semaphore_address, + uint32_t num_eth_messages_to_forward, + std::vector const& worker_coords) { this->enable_sender = true; this->num_senders++; auto channel = active_channels.size(); - active_channels.emplace_back(true, worker_semaphore_address, num_eth_messages_to_forward, channel, worker_coords); + active_channels.emplace_back( + true, worker_semaphore_address, num_eth_messages_to_forward, channel, worker_coords); log_trace(tt::LogOp, "Adding sender channel:"); log_trace(tt::LogOp, "\tworker_semaphore_address: {}", active_channels.back().worker_semaphore_address); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -187,11 +227,15 @@ class EriscDatamoverBuilder { return ChannelBufferInterface{local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; } [[nodiscard]] - ChannelBufferInterface add_receiver_channel(uint32_t worker_semaphore_address, uint32_t num_eth_messages_to_forward, std::vector const& worker_coords) { + ChannelBufferInterface add_receiver_channel( + uint32_t worker_semaphore_address, + uint32_t num_eth_messages_to_forward, + std::vector const& worker_coords) { this->enable_receiver = true; this->num_receivers++; auto channel = active_channels.size(); - active_channels.emplace_back(false, worker_semaphore_address, num_eth_messages_to_forward, channel, worker_coords); + active_channels.emplace_back( + false, worker_semaphore_address, num_eth_messages_to_forward, channel, worker_coords); log_trace(tt::LogOp, "Adding receiver channel:"); log_trace(tt::LogOp, "\tworker_semaphore_address: {}", active_channels.back().worker_semaphore_address); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -207,7 +251,8 @@ class EriscDatamoverBuilder { static_cast(this->enable_receiver ? 1 : 0), this->num_senders, this->num_receivers, - this->buffer_sharing_mode}; + this->buffer_sharing_mode, + this->termination_mode}; } [[nodiscard]] @@ -260,54 +305,9 @@ class EriscDatamoverBuilder { return this->eth_buffer_size_bytes; } - private: - struct ChannelBufferSpec { - ChannelBufferSpec( - bool is_sender, - uint32_t worker_semaphore_address, - uint32_t num_eth_messages_to_forward, - uint32_t channel, - std::vector const& worker_coords - ) : - worker_coords(worker_coords), - worker_semaphore_address(worker_semaphore_address), - num_eth_messages_to_forward(num_eth_messages_to_forward), - channel(channel), - is_sender(is_sender) {} - - std::vector const worker_coords; - uint32_t worker_semaphore_address; - uint32_t num_eth_messages_to_forward; - uint32_t channel; - bool is_sender; - }; - - void push_back_channel_args (std::vector &args, ChannelBufferSpec const& channel) const { - args.push_back(this->local_buffer_addresses.at(channel.channel)); - args.push_back(channel.num_eth_messages_to_forward); - args.push_back(this->eth_buffer_size_bytes); - args.push_back(this->local_semaphore_addresses.at(channel.channel)); - args.push_back(channel.worker_semaphore_address); - args.push_back(channel.worker_coords.size()); - for (auto const& worker_coord : channel.worker_coords) { - args.push_back(worker_coord.to_uint32()); - } - } - - std::vector active_channels; - std::vector const local_semaphore_addresses; - std::vector const local_buffer_addresses; - uint32_t eth_buffer_size_bytes; - uint32_t handshake_addr; - uint32_t const num_channel_buffers; - ccl::EriscDataMoverBufferSharingMode const buffer_sharing_mode; - uint32_t num_senders; - uint32_t num_receivers; - - bool enable_sender; - bool enable_receiver; + std::vector const& get_active_channels() const { return this->active_channels; } }; -}; // namespace ccl -}; // namespace tt_metal -}; // namespace tt +}; // namespace ccl +}; // namespace tt_metal +}; // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp index 589324a5fad..be93a6531a0 100644 --- a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp @@ -14,14 +14,21 @@ #include "tt_metal/hw/inc/wormhole/noc/noc.h" using tt::tt_metal::ccl::EriscDataMoverBufferSharingMode; +using tt::tt_metal::ccl::EriscDataMoverTerminationMode; +using tt::tt_metal::ccl::EriscDataMoverWorkerSignal; namespace erisc { namespace datamover { -template -struct edm_worker_index { +template +struct EriscDatamoverConfig { + static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = buffer_sharing_mode; + static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = termination_mode; }; +template +struct edm_worker_index {}; + template <> struct edm_worker_index { uint16_t worker_index = 0; @@ -35,12 +42,16 @@ using tt::tt_metal::ccl::WorkerXY; * state for the transaction channel, holds information such as buffer and semaphore addresses, and has helper * functions to more easily check semaphore and ack statuses and to send/receive data and/or semaphore updates. */ -template +// template +template class ChannelBuffer final { + static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = EDM_CONFIG::BUFFER_SHARING_MODE; + static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = EDM_CONFIG::TERMINATION_MODE; static_assert( BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::NOT_SHARED || - BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN, + BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN, "The only BufferSharding modes supported are NOT_SHARED and ROUND_ROBIN"); + public: enum STATE : uint8_t { DONE = 0, @@ -100,11 +111,12 @@ class ChannelBuffer final { channel_bytes_acked_address(&erisc_info->channels[eth_transaction_channel].receiver_ack), total_num_messages_to_move(total_num_messages_to_move), state(is_sender_side ? STATE::WAITING_FOR_WORKER : STATE::WAITING_FOR_ETH), - is_sender_completion_pending(false) { - + is_sender_completion_pending(false), + is_sender_side(is_sender_side) { clear_local_semaphore(); - if (total_num_messages_to_move != 0) { + if (total_num_messages_to_move != 0 || + TERMINATION_MODE != EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel increment_worker_semaphores(); @@ -130,21 +142,24 @@ class ChannelBuffer final { noc_semaphore_inc(worker_semaphore_address, 1); } } else if (BUFFER_SHARING_MODE == EriscDataMoverBufferSharingMode::ROUND_ROBIN) { - WorkerXY worker_xy = this->worker_coords[this->worker_index.worker_index]; - uint64_t worker_semaphore_address = - get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); - - noc_semaphore_inc(worker_semaphore_address, 1); - this->worker_index.worker_index++; - if (this->worker_index.worker_index >= this->num_workers) { - this->worker_index.worker_index = 0; - } + WorkerXY worker_xy = this->worker_coords[this->worker_index.worker_index]; + uint64_t worker_semaphore_address = + get_noc_addr((uint32_t)worker_xy.x, (uint32_t)worker_xy.y, this->worker_semaphore_l1_address); + + noc_semaphore_inc(worker_semaphore_address, 1); + this->worker_index.worker_index++; + if (this->worker_index.worker_index >= this->num_workers) { + this->worker_index.worker_index = 0; + } } else { - ASSERT(false); // Not implemented + ASSERT(false); // Not implemented } } [[nodiscard]] FORCE_INLINE bool is_local_semaphore_full() const { + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + ASSERT(*(this->local_semaphore_address) <= this->num_workers); + } return *(this->local_semaphore_address) == this->num_workers; } @@ -155,6 +170,7 @@ class ChannelBuffer final { [[nodiscard]] STATE get_state() const { return this->state; } FORCE_INLINE void goto_state(STATE s) { this->state = s; } + [[nodiscard]] FORCE_INLINE bool is_waiting_for_workers_core() const { return this->state == STATE::WAITING_FOR_WORKER; } @@ -173,15 +189,13 @@ class ChannelBuffer final { ASSERT(this->eth_transaction_channel < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); return this->eth_transaction_channel; } - [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { - return this->address; - } + [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->address; } [[nodiscard]] FORCE_INLINE std::size_t get_size_in_bytes() const { return this->size_in_bytes; } [[nodiscard]] FORCE_INLINE std::size_t get_current_payload_size() const { return this->get_size_in_bytes(); } [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { return this->address; } - FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } + FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } FORCE_INLINE void increment_messages_moved() { this->num_messages_moved++; } [[nodiscard]] FORCE_INLINE bool all_messages_moved() { @@ -191,9 +205,9 @@ class ChannelBuffer final { FORCE_INLINE void set_send_completion_pending(bool value) { this->is_sender_completion_pending = value; } [[nodiscard]] FORCE_INLINE bool is_send_completion_pending() const { return this->is_sender_completion_pending; } - FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0;} - FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0;} - FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0;} + FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0; } + FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0; } + FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0; } volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_address; } volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_address; } @@ -213,6 +227,7 @@ class ChannelBuffer final { STATE state; edm_worker_index worker_index; bool is_sender_completion_pending; + bool is_sender_side; }; template @@ -267,22 +282,21 @@ class QueueIndexPointer { }; FORCE_INLINE void eth_setup_handshake(std::uint32_t handshake_register_address, bool is_sender) { - - reinterpret_cast(handshake_register_address)[4] = 1; - reinterpret_cast(handshake_register_address)[5] = 1; - reinterpret_cast(handshake_register_address)[6] = 0x1c0ffee1; - reinterpret_cast(handshake_register_address)[7] = 0x1c0ffee2; - - reinterpret_cast(handshake_register_address)[8] = 0; - reinterpret_cast(handshake_register_address)[9] = 0; - reinterpret_cast(handshake_register_address)[10] = 0x4c0ffee1; - reinterpret_cast(handshake_register_address)[11] = 0x5c0ffee2; + reinterpret_cast(handshake_register_address)[4] = 1; + reinterpret_cast(handshake_register_address)[5] = 1; + reinterpret_cast(handshake_register_address)[6] = 0x1c0ffee1; + reinterpret_cast(handshake_register_address)[7] = 0x1c0ffee2; + + reinterpret_cast(handshake_register_address)[8] = 0; + reinterpret_cast(handshake_register_address)[9] = 0; + reinterpret_cast(handshake_register_address)[10] = 0x4c0ffee1; + reinterpret_cast(handshake_register_address)[11] = 0x5c0ffee2; erisc_info->channels[0].receiver_ack = 0; for (uint32_t i = 1; i < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS; i++) { erisc_info->channels[i].bytes_sent = 0; erisc_info->channels[i].receiver_ack = 0; } - *(volatile tt_l1_ptr uint32_t*)handshake_register_address = 0; + *(volatile tt_l1_ptr uint32_t *)handshake_register_address = 0; if (is_sender) { eth_wait_receiver_done(); eth_send_bytes(handshake_register_address, handshake_register_address, 16); @@ -310,8 +324,8 @@ FORCE_INLINE void initialize_transaction_buffer_addresses( // SENDER SIDE HELPERS ///////////////////////////////////////////// -template -FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { +template +FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { bool did_something = false; if (sender_buffer_channel.eth_is_receiver_channel_send_done()) { bool need_to_send_completion = sender_buffer_channel.is_send_completion_pending(); @@ -331,9 +345,10 @@ FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer::WAITING_FOR_ETH); + sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); did_something = true; } } @@ -341,33 +356,42 @@ FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer +template FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( - ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + bool channel_done = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + channel_done = sender_buffer_channel.all_messages_moved(); + } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + // Nothing to do here because in this termination mode, we must check the signal in a different state + } else { + ASSERT(false); + } + sender_buffer_channel.clear_local_semaphore(); sender_buffer_channel.increment_worker_semaphores(); - if (!sender_buffer_channel.all_messages_moved()) { - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + if (!channel_done) { + sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); } else { - sender_buffer_channel.goto_state(ChannelBuffer::DONE); + sender_buffer_channel.goto_state(ChannelBuffer::DONE); num_senders_complete++; } return true; } -template -FORCE_INLINE bool sender_eth_check_receiver_ack_sequence(ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { +template +FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( + ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { bool did_something = false; - bool transimission_acked_by_receiver = - sender_buffer_channel.eth_is_receiver_channel_send_acked() || - sender_buffer_channel.eth_is_receiver_channel_send_done(); + bool transimission_acked_by_receiver = sender_buffer_channel.eth_is_receiver_channel_send_acked() || + sender_buffer_channel.eth_is_receiver_channel_send_done(); if (transimission_acked_by_receiver) { eth_clear_sender_channel_ack(sender_buffer_channel.get_eth_transaction_channel()); sender_buffer_channel.increment_messages_moved(); - sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); sender_notify_workers_if_buffer_available_sequence(sender_buffer_channel, num_senders_complete); did_something = true; } @@ -378,15 +402,25 @@ FORCE_INLINE bool sender_eth_check_receiver_ack_sequence(ChannelBuffer -FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence(ChannelBuffer &sender_channel_buffer) { +template +FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( + ChannelBuffer &sender_channel_buffer, uint32_t &num_senders_complete) { bool did_something = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + if (*sender_channel_buffer.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY) { + sender_channel_buffer.clear_local_semaphore(); + sender_channel_buffer.goto_state(ChannelBuffer::DONE); + num_senders_complete++; + return true; + } + } + bool read_finished = sender_channel_buffer.is_local_semaphore_full(); if (read_finished) { // We can clear the semaphore, and wait for space on receiver - sender_channel_buffer.clear_local_semaphore(); - sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); + // sender_channel_buffer.clear_local_semaphore(); + sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); did_something = true; erisc::datamover::sender_eth_send_data_sequence(sender_channel_buffer); @@ -399,32 +433,42 @@ FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence(ChannelBuffer -FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { +template +FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { + buffer_channel.clear_local_semaphore(); buffer_channel.increment_worker_semaphores(); - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); return true; } - - /* * If payload received, notify (send ack to) sender so sender knows it can free up its local buffer * */ -template -FORCE_INLINE bool receiver_eth_accept_payload_sequence(ChannelBuffer &buffer_channel, uint32_t eth_transaction_ack_word_addr) { +template +FORCE_INLINE bool receiver_eth_accept_payload_sequence( + ChannelBuffer &buffer_channel, + uint32_t &num_receivers_complete, + uint32_t eth_transaction_ack_word_addr) { bool did_something = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + if (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY) { + buffer_channel.clear_local_semaphore(); + buffer_channel.goto_state(ChannelBuffer::DONE); + num_receivers_complete++; + return true; + } + } + if (buffer_channel.eth_bytes_are_available_on_channel()) { if (!eth_txq_is_busy()) { eth_receiver_channel_ack(buffer_channel.get_eth_transaction_channel(), eth_transaction_ack_word_addr); - buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); did_something = true; // FIXME: Decouple these so we can still signal workers even if eth command queue is busy @@ -443,30 +487,46 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence(ChannelBuffer +template FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( - ChannelBuffer &buffer_channel, uint32_t &num_receivers_complete, uint32_t eth_transaction_complete_addr) { + ChannelBuffer &buffer_channel, + uint32_t &num_receivers_complete, + uint32_t eth_transaction_complete_addr) { bool did_something = false; bool workers_are_finished_reading = buffer_channel.is_local_semaphore_full(); + + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + // May have already gotten final termination signal by this point so check for that too + workers_are_finished_reading = + workers_are_finished_reading || + (*buffer_channel.local_semaphore_address == EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + bool can_notify_sender_of_buffer_available = workers_are_finished_reading; if (can_notify_sender_of_buffer_available) { - if (!eth_txq_is_busy()) { eth_receiver_channel_done(buffer_channel.get_eth_transaction_channel()); buffer_channel.increment_messages_moved(); - buffer_channel.clear_local_semaphore(); - if (!buffer_channel.all_messages_moved()) { - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + bool channel_done = false; + if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { + channel_done = buffer_channel.all_messages_moved(); + } else if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::WORKER_INITIATED) { + // Do nothing } else { - buffer_channel.goto_state(ChannelBuffer::DONE); + ASSERT(false); + } + + if (!channel_done) { + buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + } else { + buffer_channel.goto_state(ChannelBuffer::DONE); num_receivers_complete++; } did_something = true; } - } return did_something; diff --git a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp index 6df809496e7..89e8731aa2f 100644 --- a/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp @@ -116,9 +116,17 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - constexpr tt::tt_metal::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = static_cast(get_compile_time_arg_val(4)); - std::array, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS> buffer_channels; + constexpr tt::tt_metal::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); + + constexpr tt::tt_metal::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); + + constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); + using EDM_CONFIG_T = decltype(EDM_CONFIG); + using ChannelBufferT = erisc::datamover::ChannelBuffer; + std::array buffer_channels; // std::array printed_receiver_done; @@ -143,7 +151,7 @@ void kernel_main() { const uint32_t sender_num_workers = get_arg_val(args_offset++); const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); args_offset += sender_num_workers; - new (&buffer_channels[sender_channels_start + channel]) erisc::datamover::ChannelBuffer( + new (&buffer_channels[sender_channels_start + channel]) ChannelBufferT( sender_channels_start + channel, sender_buffer_address, sender_channel_size, @@ -173,7 +181,7 @@ void kernel_main() { uint32_t const receiver_num_workers = get_arg_val(args_offset++); const uint32_t workers_xy_list_addr = get_arg_addr(args_offset); args_offset += receiver_num_workers; - new (&buffer_channels[receiver_channels_start + channel]) erisc::datamover::ChannelBuffer( + new (&buffer_channels[receiver_channels_start + channel]) ChannelBufferT( receiver_channels_start + channel, receiver_buffers_base_address, receiver_channel_size, @@ -217,24 +225,25 @@ void kernel_main() { ////////////////////////////////////// // SENDER if constexpr (enable_sender_side) { - erisc::datamover::ChannelBuffer ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; + ChannelBufferT ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; switch (current_sender.get_state()) { - case erisc::datamover::ChannelBuffer::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::WAITING_FOR_WORKER: did_something_sender = - erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender); + erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender, num_senders_complete); + senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case erisc::datamover::ChannelBuffer::STATE::READY_FOR_ETH_TRANSFER: + case ChannelBufferT::STATE::READY_FOR_ETH_TRANSFER: did_something_sender = erisc::datamover::sender_eth_send_data_sequence(current_sender); break; - case erisc::datamover::ChannelBuffer::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::SIGNALING_WORKER: did_something_sender = erisc::datamover::sender_notify_workers_if_buffer_available_sequence( current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case erisc::datamover::ChannelBuffer::STATE::WAITING_FOR_ETH: + case ChannelBufferT::STATE::WAITING_FOR_ETH: did_something_sender = erisc::datamover::sender_eth_check_receiver_ack_sequence(current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; @@ -248,19 +257,20 @@ void kernel_main() { ////////////////////////////////////// // RECEIVER if constexpr (enable_receiver_side) { - erisc::datamover::ChannelBuffer ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; + ChannelBufferT ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; switch (current_receiver.get_state()) { - case erisc::datamover::ChannelBuffer::STATE::WAITING_FOR_ETH: - did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, eth_transaction_ack_word_addr); + case ChannelBufferT::STATE::WAITING_FOR_ETH: + did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, num_receivers_complete, eth_transaction_ack_word_addr); + receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; break; - case erisc::datamover::ChannelBuffer::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::SIGNALING_WORKER: did_something_receiver = erisc::datamover::receiver_eth_notify_workers_payload_available_sequence(current_receiver); break; - case erisc::datamover::ChannelBuffer::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::WAITING_FOR_WORKER: did_something_receiver = erisc::datamover::receiver_noc_read_worker_completion_check_sequence( current_receiver, num_receivers_complete, eth_transaction_complete_addr); receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; diff --git a/tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp b/tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp index 155794cbee5..dd5316ffcda 100644 --- a/tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp @@ -6,10 +6,25 @@ #include "dataflow_api.h" #include "debug/assert.h" +#include "debug/dprint.h" #include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" using tt::tt_metal::ccl::ShardType; using tt::tt_metal::ccl::WorkerXY; +// using tt::tt_metal::ccl::coord_t; + +namespace tt { +namespace tt_metal { +namespace ccl { +static FORCE_INLINE coord_t coord_from_args(uint32_t& arg_idx) { + uint32_t x = get_arg_val(arg_idx++); + uint32_t y = get_arg_val(arg_idx++); + return coord_t(x, y); +} + +} // namespace ccl +} // namespace tt_metal +} // namespace tt FORCE_INLINE void push_filler_pages_to_cb(const uint32_t& cb_id, uint32_t num_pages) { ASSERT(num_pages < cb_interface[cb_id].fifo_num_pages); @@ -22,7 +37,6 @@ FORCE_INLINE void pop_filler_pages_from_cb(const uint32_t& cb_id, uint32_t num_p cb_pop_front(cb_id, num_pages); } - FORCE_INLINE void fetch_chunk( const uint32_t& cb_id, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_read_addr) { cb_reserve_back(cb_id, num_pages); @@ -49,7 +63,11 @@ FORCE_INLINE void send_chunk( cb_pop_front(cb_id, num_pages); } FORCE_INLINE void send_chunk_sharded( - const uint32_t& cb_id, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr, uint64_t eth_l1_sender_semaphore_addr) { + const uint32_t& cb_id, + const uint32_t& num_pages, + const uint32_t& page_size, + uint64_t remote_l1_write_addr, + uint64_t eth_l1_sender_semaphore_addr) { cb_wait_front(cb_id, num_pages); uint32_t l1_read_addr = get_read_ptr(cb_id); noc_async_write(l1_read_addr, remote_l1_write_addr, page_size * num_pages); diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp new file mode 100644 index 00000000000..49ca82b47a4 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp @@ -0,0 +1,982 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/// + +#include "common/core_coord.h" +#include "eth_l1_address_map.h" +#include "impl/buffers/buffer.hpp" +#include "impl/kernels/data_types.hpp" +#include "tensor/tensor_impl.hpp" +#include "tt_dnn/op_library/ccl/ccl_common.hpp" +#include "tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/buffers/circular_buffer_types.hpp" + +// Includes that need to be moved to CCL datastructures header +#include + +using namespace tt::constants; + +// Notes on abbreviations: +// cw = clockwise +// ccw = counter-clockwise +// edm = erisc data mover + +// How this reduce_scatter op works: +// For each chip, we have a element range of the input tensor shape that will eventually scatter +// out to it. For all other chunks outside that range, the chip will forward the chunk to the next chip. +// While forwarding the data, the chip will also reduce it with the local input tensor chunk corresponding +// with that received chunk. It will forward the partially reduced chunk. +// Reduces along rank + +namespace tt { + +namespace tt_metal { + +namespace ccl { +namespace reduce_scatter_detail { +struct WorkerTransferInfo { + WorkerTransferInfo( + std::vector pages_per_full_chunk_per_worker, + std::vector num_messages_per_worker, + std::vector remaining_num_pages_per_worker, + uint32_t num_links, + uint32_t num_workers) : + pages_per_full_chunk_per_worker(pages_per_full_chunk_per_worker), + num_messages_per_worker(num_messages_per_worker), + remaining_num_pages_per_worker(remaining_num_pages_per_worker), + num_links(num_links), + num_workers(num_workers) {} + + uint32_t get_num_pages_per_full_chunk(uint32_t link, uint32_t worker_idx) const { + return pages_per_full_chunk_per_worker.at(link * num_workers + worker_idx); + } + uint32_t get_num_remaining_pages(uint32_t link, uint32_t worker_idx) const { + return remaining_num_pages_per_worker.at(link * num_workers + worker_idx); + } + uint32_t get_num_full_chunks_per_transfer(uint32_t link, uint32_t worker_idx) const { + return num_messages_per_worker.at(link * num_workers + worker_idx) - + (get_num_remaining_pages(link, worker_idx) > 0 ? 1 : 0); + } + uint32_t get_num_pages_per_ring_index(uint32_t link, uint32_t worker_idx) const { + return get_num_full_chunks_per_transfer(link, worker_idx) * get_num_pages_per_full_chunk(link, worker_idx) + + get_num_remaining_pages(link, worker_idx); + } + + std::vector pages_per_full_chunk_per_worker; + std::vector num_messages_per_worker; + std::vector remaining_num_pages_per_worker; + uint32_t num_links; + uint32_t num_workers; +}; + +static std::size_t decide_number_of_edm_channels( + ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { + return ccl_op_config.is_input_sharded() ? std::min( + ccl_op_config.get_shard_grid_size(), + std::min(max_num_workers, enable_bidirectional ? 8 : 4)) + : std::min(max_num_workers, enable_bidirectional ? 8 : 4); +} + +struct ReduceScatterWorkerArgBuilder { + ReduceScatterWorkerArgBuilder( + ccl::CCLOpConfig const& op_config, + ccl::RingTopology const& topology_config, + ccl::InterleavedTensorWorkerSlice const& worker_input_slice, + WorkerTransferInfo const& worker_transfer_info, + uint32_t worker_idx, + uint32_t cb_num_pages_per_packet, + uint32_t worker_receiver_semaphore_address, + uint32_t worker_sender_semaphore_address) : + op_config(op_config), + topology_config(topology_config), + worker_input_slice(worker_input_slice), + worker_transfer_info(worker_transfer_info), + cb_num_pages_per_packet(cb_num_pages_per_packet), + worker_receiver_semaphore_address(worker_receiver_semaphore_address), + worker_sender_semaphore_address(worker_sender_semaphore_address) {} + + std::vector generate_reduce_op_kernel_ct_args() const { + log_trace(tt::LogOp, "Reduce Scatter Worker CT Args: None"); + return {}; + } + + std::vector generate_reduce_op_kernel_rt_args( + uint32_t link, uint32_t worker_index, uint32_t ring_size) const { + uint32_t num_pages_per_ring_index_slice = + this->worker_transfer_info.get_num_pages_per_ring_index(link, worker_index); + if (this->worker_transfer_info.get_num_remaining_pages(link, worker_index) > 0) { + // Add the filler pages + uint32_t num_padded_pages = this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index) - + this->worker_transfer_info.get_num_remaining_pages(link, worker_index); + num_pages_per_ring_index_slice += num_padded_pages; + } + + auto num_iterations = + this->worker_input_slice.compute_num_worker_slice_iterations(this->worker_transfer_info.num_workers); + auto const& args = std::vector{ + static_cast(num_pages_per_ring_index_slice * (ring_size - 1) * num_iterations), + // TODO: update to half-cb size + 1}; // this field is supposed to be # pages from looking at the kernel code// + // this->tensor_slicer.input_page_size}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Worker RT Args:"); + log_trace(tt::LogOp, "\tnum_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); + + return args; + } + + std::vector generate_receiver_kernel_ct_args() const { + auto const& args = std::vector{ + static_cast(this->op_config.is_input_sharded() ? 1 : 0), + static_cast( + this->op_config.get_input_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Receiver Worker CT Args:"); + log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); + log_trace(tt::LogOp, "\tsrc_is_dram: {}", args.at(i++)); + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_receiver_kernel_rt_args( + ccl::WorkerXY edm_core, + uint32_t edm_core_semaphore_address, + uint32_t edm_core_buffer_address, + uint32_t link, + uint32_t worker_index, + bool is_in_clockwise_direction) const { + TT_ASSERT(edm_core_semaphore_address > 0); + TT_ASSERT(edm_core_buffer_address > 0); + auto const& local_input_tensor = this->op_config.get_input_tensor(0); + uint32_t starting_ring_index = + is_in_clockwise_direction ? (this->topology_config.ring_index == 0 ? this->topology_config.ring_size - 1 + : this->topology_config.ring_index - 1) + : (this->topology_config.ring_index == this->topology_config.ring_size - 1 + ? 0 + : this->topology_config.ring_index + 1); + auto args = std::vector{ + static_cast(local_input_tensor.buffer()->address()), + static_cast(this->topology_config.ring_size), // num_transfers + static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), + static_cast(this->op_config.get_page_size()), + static_cast(starting_ring_index), + static_cast(this->topology_config.ring_size), + static_cast(this->worker_receiver_semaphore_address), + static_cast(is_in_clockwise_direction ? 1 : 0), + static_cast(this->cb_num_pages_per_packet), + static_cast(edm_core.x), + static_cast(edm_core.y), + static_cast(edm_core_semaphore_address), + static_cast(edm_core_buffer_address), + + static_cast(worker_transfer_info.num_workers), + + static_cast(this->worker_input_slice.tensor_shape.x), + static_cast(this->worker_input_slice.tensor_shape.y), + + static_cast(this->worker_input_slice.tensor_slice_shape.x), + static_cast(this->worker_input_slice.tensor_slice_shape.y), + + static_cast(this->worker_input_slice.worker_slice_shape.x), + static_cast(this->worker_input_slice.worker_slice_shape.y), + + static_cast(this->worker_input_slice.worker_slice_offset.x), + static_cast(this->worker_input_slice.worker_slice_offset.y), + + // How many messages does the eltwise kernel expect? Use this as a kludge for now until we can + // elegently compute exactly how many tiles the math kernel will need + generate_reduce_op_kernel_rt_args(link, worker_index, this->topology_config.ring_size).at(0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Receiver Worker RT Args:"); + log_trace(tt::LogOp, "\tsrc_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); + log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tmy_ring_idx: {}", args.at(i++)); + log_trace(tt::LogOp, "\tring_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tsem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\tis_clockwise_direction: {}", args.at(i++)); + log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); + + log_trace(tt::LogOp, "\tedm_core_noc0_core_x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_noc0_core_y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_semaphore_address: {}", args.at(i++)); + log_trace(tt::LogOp, "\tedm_core_buffer_address: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); + + log_trace(tt::LogOp, "\tinput_tensor_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tinput_tensor_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\ttensor_slice_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\ttensor_slice_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.y={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.x={}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.y={}", args.at(i++)); + + log_trace(tt::LogOp, "\ttotal_eltwise_kernel_num_pages={}", args.at(i++)); + + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_sender_kernel_ct_args() const { + auto const& args = std::vector{ + static_cast(this->op_config.is_input_sharded() ? 1 : 0), + static_cast( + this->op_config.get_output_tensor(0).memory_config().buffer_type == BufferType::DRAM ? 1 : 0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Sender Worker CT Args:"); + log_trace(tt::LogOp, "\tis_sharded: {}", args.at(i++)); + log_trace(tt::LogOp, "\tdst_is_dram: {}", args.at(i++)); + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + std::vector generate_sender_kernel_rt_args( + ccl::WorkerXY edm_core, + uint32_t edm_core_semaphore_address, + uint32_t edm_core_buffer_address, + uint32_t link, + uint32_t worker_index, + bool is_clockwise) const { + TT_ASSERT(edm_core_semaphore_address > 0); + TT_ASSERT(edm_core_buffer_address > 0); + auto const& local_output_tensor = this->op_config.get_output_tensor(0); + auto const& args = std::vector{ + static_cast(local_output_tensor.buffer()->address()), + static_cast(edm_core_buffer_address), + static_cast(edm_core_semaphore_address), + static_cast(edm_core.x), + static_cast(edm_core.y), + static_cast(this->topology_config.ring_size - 1), // num_transfers), + + static_cast(this->op_config.get_page_size()), + static_cast(this->worker_transfer_info.get_num_pages_per_full_chunk(link, worker_index)), + + static_cast(this->worker_sender_semaphore_address), + static_cast(this->cb_num_pages_per_packet), + + static_cast(worker_transfer_info.num_workers), + + // For sender side, all worker slice info is the same except for the tensor shape + // and for sender side specifically, there is only one tensor_slice_shape for the output + // tensor (as opposed to `ring_size` tensor_slice_shapes for the input tensor), so we can + // directly use it as the output tensor shape + static_cast(this->worker_input_slice.tensor_slice_shape.x), + static_cast(this->worker_input_slice.tensor_slice_shape.y), + static_cast(this->worker_input_slice.worker_slice_shape.x), + static_cast(this->worker_input_slice.worker_slice_shape.y), + static_cast(this->worker_input_slice.worker_slice_offset.x), + static_cast(this->worker_input_slice.worker_slice_offset.y), + + // How many messages does the eltwise kernel expect? Use this as a kludge for now until we can + // elegently compute exactly how many tiles the math kernel will need + generate_reduce_op_kernel_rt_args(link, worker_index, this->topology_config.ring_size).at(0)}; + + std::size_t i = 0; + log_trace(tt::LogOp, "Reduce Scatter Sender Worker RT Args:"); + log_trace(tt::LogOp, "\tdst_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_l1_base_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_l1_sem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_noc_x: {}", args.at(i++)); + log_trace(tt::LogOp, "\teth_sender_noc_y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_transfers: {}", args.at(i++)); + log_trace(tt::LogOp, "\tpage_size: {}", args.at(i++)); + log_trace(tt::LogOp, "\tfull_chunk_num_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\twriter_send_sem_addr: {}", args.at(i++)); + log_trace(tt::LogOp, "\thalf_cb_n_pages: {}", args.at(i++)); + log_trace(tt::LogOp, "\tnum_concurrent_workers: {}", args.at(i++)); + + log_trace(tt::LogOp, "\toutput_tensor_shape.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\toutput_tensor_shape.y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_shape.y: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.x: {}", args.at(i++)); + log_trace(tt::LogOp, "\tworker_slice_offset.y: {}", args.at(i++)); + + log_trace(tt::LogOp, "\ttotal_eltwise_kernel_num_pages={}", args.at(i++)); + + TT_ASSERT(args.size() == i, "Missed some args"); + + return args; + } + + ccl::RingTopology const topology_config; + ccl::CCLOpConfig const op_config; + ccl::InterleavedTensorWorkerSlice const worker_input_slice; + WorkerTransferInfo const worker_transfer_info; + uint32_t cb_num_pages_per_packet; + uint32_t worker_receiver_semaphore_address; + uint32_t worker_sender_semaphore_address; + bool src_is_dram; + bool dst_is_dram; +}; + +struct EdmInterfaceAddresses { + std::unordered_map worker_sender_edm_semaphore_addresses; + std::unordered_map worker_sender_edm_buffer_addresses; + std::unordered_map worker_receiver_edm_semaphore_addresses; + std::unordered_map worker_receiver_edm_buffer_addresses; +}; + +// Future work: split this up further: +// 1) assign workers to EDM channel (with buffer sharing mode specified too) +// 2) Compute the semaphore and buffer addresses (for each EDM channel and worker) +// For now - the mapping between workers and EDM channels is 1:1 +static void add_worker_config_to_edm_builders( + Device* device, + ccl::CCLOpConfig const& op_config, + std::vector const& worker_cores, + uint32_t num_channels_per_edm, + + std::vector& clockwise_edm_builders, + std::vector& counter_clockwise_edm_builders, + + std::vector const& cw_edm_channel_num_messages_to_send_per_transfer, + std::vector const& ccw_edm_channel_num_messages_to_send_per_transfer, + + uint32_t worker_sender_semaphore_address, + uint32_t worker_receiver_semaphore_address, + uint32_t link, + uint32_t ring_size, + std::function is_buffer_in_clockwise_direction_fn, + + EdmInterfaceAddresses& edm_interface_addresses) { + for (uint32_t c = 0; c < num_channels_per_edm; ++c) { + uint32_t global_worker_idx = c + num_channels_per_edm * link; + uint32_t num_workers_per_eth_buffer = 1; // std::min(workers_per_link, num_channels_per_edm ); + + std::vector sender_worker_coords; + std::vector receiver_worker_coords; + for (uint32_t w = c * num_workers_per_eth_buffer; w < (c + 1) * num_workers_per_eth_buffer; ++w) { + sender_worker_coords.push_back(ccl::WorkerXY( + device->worker_core_from_logical_core(worker_cores.at(w)).x, + device->worker_core_from_logical_core(worker_cores.at(w)).y)); + receiver_worker_coords.push_back(ccl::WorkerXY( + device->worker_core_from_logical_core(worker_cores.at(w)).x, + device->worker_core_from_logical_core(worker_cores.at(w)).y)); + } + + bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear + if (sender_enabled) { + auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) + : counter_clockwise_edm_builders.at(link); + log_trace(tt::LogOp, "Adding sender EDM channel"); + ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = + sender_edm_builder.add_sender_channel( + worker_sender_semaphore_address, + cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), + sender_worker_coords); + edm_interface_addresses.worker_sender_edm_semaphore_addresses[global_worker_idx] = + sender_channel_buffer_info.eth_semaphore_l1_address; + edm_interface_addresses.worker_sender_edm_buffer_addresses[global_worker_idx] = + sender_channel_buffer_info.eth_buffer_l1_address; + } + + bool receiver_enabled = true; //(!is_linear || !is_first_chip_in_chain); + if (receiver_enabled) { + auto& receiver_edm_builder = is_buffer_in_clockwise_direction_fn(c) + ? counter_clockwise_edm_builders.at(link) + : clockwise_edm_builders.at(link); + log_trace(tt::LogOp, "Adding receiver EDM channel"); + ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = + receiver_edm_builder.add_receiver_channel( + worker_receiver_semaphore_address, + ccw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), + receiver_worker_coords); + edm_interface_addresses.worker_receiver_edm_semaphore_addresses[global_worker_idx] = + receiver_channel_buffer_info.eth_semaphore_l1_address; + edm_interface_addresses.worker_receiver_edm_buffer_addresses[global_worker_idx] = + receiver_channel_buffer_info.eth_buffer_l1_address; + } + } +} + +static std::tuple build_reduce_scatter_worker( + tt_metal::Program& program, + Device const* device, + ccl::RingTopology const& topology_config, + ccl::CCLOpConfig const& op_config, + ReduceScatterWorkerArgBuilder const& worker_arg_builder, + std::vector& cw_edm_builders, + std::vector& ccw_edm_builders, + EdmInterfaceAddresses const& edm_interface_addresses, + CoreCoord const& worker_core, + uint32_t num_edm_channels, + uint32_t link, + uint32_t ring_size, + uint32_t worker_index, + std::map const& worker_defines, + BinaryOpType binary_math_op) { + TT_ASSERT(worker_defines.size() > 0); + for (auto const& [key, value] : worker_defines) { + log_trace(tt::LogOp, "Worker Define: {} = {}", key, value); + } + static std::string const& receiver_kernel_path = + "tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp"; + static std::string const& sender_kernel_path = + "tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp"; + + // This will be configurable by sharded/non-sharded but present the same arg builder + KernelHandle worker_receiver_kernel_id, worker_sender_kernel_id; + + bool is_in_clockwise_direction = true; + uint32_t global_worker_index = link * num_edm_channels + worker_index; + { + CoreCoord const& receiver_edm = is_in_clockwise_direction ? topology_config.eth_receiver_cores.at(link) + : topology_config.eth_sender_cores.at(link); + ccl::WorkerXY receiver_edm_noc_coord = ccl::WorkerXY( + device->ethernet_core_from_logical_core(receiver_edm).x, + device->ethernet_core_from_logical_core(receiver_edm).y); + const uint32_t edm_core_semaphore_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index) + : edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index); + const uint32_t edm_core_buffer_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index) + : edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index); + worker_receiver_kernel_id = tt_metal::CreateKernel( + program, + receiver_kernel_path, + worker_core, + tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_receiver_kernel_ct_args(), worker_defines)); + + tt_metal::SetRuntimeArgs( + program, + worker_receiver_kernel_id, + worker_core, + worker_arg_builder.generate_receiver_kernel_rt_args( + receiver_edm_noc_coord, + edm_core_semaphore_address, + edm_core_buffer_address, + link, + worker_index, + is_in_clockwise_direction)); + } + + { + vector compute_kernel_args = {}; + constexpr bool fp32_dest_acc_en = false; + constexpr bool math_approx_mode = false; + std::map eltwise_defines = eltwise_binary_op_utils::get_defines(binary_math_op, std::nullopt); + KernelHandle worker_reduce_kernel_id = tt_metal::CreateKernel( + program, + "tt_metal/kernels/compute/eltwise_binary.cpp", + worker_core, + tt_metal::ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_kernel_args, + .defines = eltwise_defines}); + + tt_metal::SetRuntimeArgs( + program, + worker_reduce_kernel_id, + worker_core, + worker_arg_builder.generate_reduce_op_kernel_rt_args(link, worker_index, ring_size)); + } + + { + CoreCoord sender_edm = is_in_clockwise_direction ? topology_config.eth_sender_cores.at(link) + : topology_config.eth_receiver_cores.at(link); + ccl::WorkerXY const sender_edm_noc_coord = ccl::WorkerXY( + device->ethernet_core_from_logical_core(sender_edm).x, + device->ethernet_core_from_logical_core(sender_edm).y); + TT_ASSERT(sender_edm_noc_coord.y == 0 || sender_edm_noc_coord.y == 6); + const uint32_t edm_core_semaphore_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_sender_edm_semaphore_addresses.at(global_worker_index) + : edm_interface_addresses.worker_receiver_edm_semaphore_addresses.at(global_worker_index); + const uint32_t edm_core_buffer_address = + is_in_clockwise_direction + ? edm_interface_addresses.worker_sender_edm_buffer_addresses.at(global_worker_index) + : edm_interface_addresses.worker_receiver_edm_buffer_addresses.at(global_worker_index); + worker_sender_kernel_id = tt_metal::CreateKernel( + program, + sender_kernel_path, + worker_core, + tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_kernel_ct_args(), worker_defines)); + + tt_metal::SetRuntimeArgs( + program, + worker_sender_kernel_id, + worker_core, + worker_arg_builder.generate_sender_kernel_rt_args( + sender_edm_noc_coord, + edm_core_semaphore_address, + edm_core_buffer_address, + link, + worker_index, + is_in_clockwise_direction)); + } + + return {worker_receiver_kernel_id, worker_sender_kernel_id}; +} + +static CoreRangeSet select_worker_cores( + ccl::CCLOpConfig const& op_config, std::size_t num_links, std::size_t num_edm_channels) { + switch (op_config.get_topology()) { + case tt::tt_metal::ccl::Topology::Linear: + return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); + case tt::tt_metal::ccl::Topology::Ring: + return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); + default: TT_ASSERT(false, "Unsupported topology"); return CoreRangeSet({}); + }; +} + +// map: (CW) link -> (CW) edm num messages to send per channel +// map: (CCW) link -> (CCW) edm num messages to send per channel +// There's a bit of a mutual dependence here between the number of workers and the number of channels, +// and the number of channels and the channel buffer size and the buffer size and the number of transfers +static WorkerTransferInfo compute_num_edm_messages_per_channel( + ccl::CCLOpConfig const& op_config, + uint32_t const page_size_in_bytes, + uint32_t const pages_per_slice, + + std::vector const& cw_per_link_edm_builders, + std::vector const& ccw_per_link_edm_builders, + std::size_t const num_edm_channels, + std::size_t const num_links, + std::size_t const ring_size) { + TT_ASSERT(num_edm_channels > 0); + TT_ASSERT(num_links > 0); + TT_ASSERT(page_size_in_bytes > 0); + TT_ASSERT(pages_per_slice > 0); + log_trace(tt::LogOp, "WorkerTransferInfo"); + + auto get_iter_begin = [num_edm_channels]( + std::vector& vec, std::size_t link) -> std::vector::iterator { + return vec.begin() + (link * num_edm_channels); + }; + + auto get_iter_end = [num_edm_channels, num_links]( + std::vector& vec, std::size_t link) -> std::vector::iterator { + bool last_link = link == num_links - 1; + TT_ASSERT( + (!last_link && ((link + 1) * num_edm_channels < vec.size())) || + (last_link && ((link + 1) * num_edm_channels == vec.size()))); + return last_link ? vec.end() : vec.begin() + ((link + 1) * num_edm_channels); + }; + + std::unordered_map> cw_edm_channel_num_messages_to_send; + std::unordered_map> ccw_edm_channel_num_messages_to_send; + + std::size_t const total_num_pages = pages_per_slice; + std::vector pages_per_link(num_links, total_num_pages / num_links); + for (std::size_t i = 0; i < total_num_pages % num_links; i++) { + pages_per_link.at(i)++; + } + log_trace(tt::LogOp, "pages_per_link"); + for (std::size_t i = 0; i < num_links; i++) { + log_trace(tt::LogOp, "\tpages_per_link[{}]: {}", i, pages_per_link.at(i)); + } + + // Pages per EDM channel + std::size_t total_num_edm_channels = num_links * num_edm_channels; + log_trace(tt::LogOp, "total_num_edm_channels: {}", total_num_edm_channels); + std::vector num_pages_per_edm_channel(total_num_edm_channels, 0); + + for (std::size_t link = 0; link < num_links; link++) { + std::fill( + get_iter_begin(num_pages_per_edm_channel, link), + get_iter_end(num_pages_per_edm_channel, link), + pages_per_link.at(link) / num_edm_channels); + for (std::size_t i = 0; i < pages_per_link.at(link) % num_edm_channels; i++) { + num_pages_per_edm_channel.at(link * num_edm_channels + i)++; + } + } + + std::vector num_messages_per_edm_channel; + std::vector num_pages_per_full_chunk(num_pages_per_edm_channel.size(), 0); + std::vector remaining_num_pages_per_edm_channel; + num_messages_per_edm_channel.reserve(num_pages_per_edm_channel.size()); + remaining_num_pages_per_edm_channel.reserve(num_pages_per_edm_channel.size()); + for (std::size_t link = 0; link < num_links; link++) { + std::size_t edm_channel_size_in_bytes = cw_per_link_edm_builders.at(link).get_eth_buffer_size_bytes(); + std::size_t num_pages_per_edm_buffer = edm_channel_size_in_bytes / page_size_in_bytes; + log_trace( + tt::LogOp, + "link {}, edm_channel_size_in_bytes: {}, page_size_in_bytes: {}, num_pages_per_edm_buffer: {}", + link, + edm_channel_size_in_bytes, + page_size_in_bytes, + num_pages_per_edm_buffer); + + std::transform( + get_iter_begin(num_pages_per_edm_channel, link), + get_iter_end(num_pages_per_edm_channel, link), + std::back_inserter(num_messages_per_edm_channel), + [num_pages_per_edm_buffer](uint32_t num_pages) { + return (((num_pages - 1) / num_pages_per_edm_buffer) + 1); + }); + std::transform( + get_iter_begin(num_pages_per_edm_channel, link), + get_iter_end(num_pages_per_edm_channel, link), + std::back_inserter(remaining_num_pages_per_edm_channel), + [num_pages_per_edm_buffer](uint32_t num_pages) { return num_pages % num_pages_per_edm_buffer; }); + std::fill( + get_iter_begin(num_pages_per_full_chunk, link), + get_iter_end(num_pages_per_full_chunk, link), + num_pages_per_edm_buffer); + } + + log_trace(tt::LogOp, "-- num_pages_per_edm_channel:"); + for (std::size_t link = 0; link < num_links; link++) { + for (std::size_t c = 0; c < num_edm_channels; c++) { + log_trace( + tt::LogOp, + "-- num pages for link: {}, channel: {}: {}", + link, + c, + num_pages_per_edm_channel.at(link * num_edm_channels + c)); + } + } + + log_trace(tt::LogOp, "-- num_pages_per_full_chunk:"); + for (std::size_t l = 0; l < num_links; l++) { + for (std::size_t w = 0; w < num_edm_channels; w++) { + log_trace( + tt::LogOp, "\t\t(link={},worker={}): {}", l, w, num_pages_per_full_chunk.at(l * num_edm_channels + w)); + } + } + log_trace(tt::LogOp, "-- num_messages_per_edm_channel:"); + for (std::size_t l = 0; l < num_links; l++) { + for (std::size_t w = 0; w < num_edm_channels; w++) { + log_trace( + tt::LogOp, + "\t\t(link={},worker={}): {}", + l, + w, + num_messages_per_edm_channel.at(l * num_edm_channels + w)); + } + } + log_trace(tt::LogOp, "-- remaining_num_pages_per_edm_channel:"); + for (std::size_t l = 0; l < num_links; l++) { + for (std::size_t w = 0; w < num_edm_channels; w++) { + log_trace( + tt::LogOp, + "\t\t(link={},worker={}): {}", + l, + w, + remaining_num_pages_per_edm_channel.at(l * num_edm_channels + w)); + } + } + + return WorkerTransferInfo( + num_pages_per_full_chunk, + num_messages_per_edm_channel, + remaining_num_pages_per_edm_channel, + num_links, + num_edm_channels); +} + +static uint32_t compute_maximum_worker_slice_in_bytes( + uint32_t cb_src0_size_pages, uint32_t cb_dst0_size_pages, std::size_t edm_channel_buffer_size, uint32_t page_size) { + return (cb_src0_size_pages + cb_dst0_size_pages) * page_size + edm_channel_buffer_size; +} + +static bool is_cb_buffering_sufficient_to_avoid_deadlock( + ccl::InterleavedTensorWorkerSlice const& worker_slice, + uint32_t cb_src0_size_pages, + uint32_t cb_dst0_size_pages, + std::size_t edm_channel_buffer_size, + uint32_t page_size) { + uint32_t worker_size_pages_rounded_up = + round_up(worker_slice.worker_slice_shape.x * worker_slice.worker_slice_shape.y, cb_src0_size_pages); + uint32_t worker_slice_size_bytes = worker_size_pages_rounded_up * page_size; + uint32_t available_buffering_capacity = compute_maximum_worker_slice_in_bytes( + cb_src0_size_pages, cb_dst0_size_pages, edm_channel_buffer_size, page_size); + log_trace(tt::LogOp, "worker_slice.worker_slice_shape.x: {}", worker_slice.worker_slice_shape.x); + log_trace(tt::LogOp, "worker_slice.worker_slice_shape.y: {}", worker_slice.worker_slice_shape.y); + log_trace(tt::LogOp, "worker_slice_size_bytes: {}", worker_slice_size_bytes); + log_trace(tt::LogOp, "worker_size_pages_rounded_up: {}", worker_size_pages_rounded_up); + log_trace(tt::LogOp, "cb_src0_size_pages: {}", cb_src0_size_pages); + log_trace(tt::LogOp, "cb_dst0_size_pages: {}", cb_dst0_size_pages); + log_trace(tt::LogOp, "page_size: {}", page_size); + log_trace(tt::LogOp, "edm_channel_buffer_size: {}", edm_channel_buffer_size); + log_trace(tt::LogOp, "available_buffering_capacity: {}", available_buffering_capacity); + + return available_buffering_capacity >= worker_slice_size_bytes; +} + +static std::tuple create_worker_circular_buffers( + Tensor const& input_tensor, + ccl::CCLOpConfig const& op_config, + CoreRangeSet const& worker_core_range, + uint32_t worker_pages_per_transfer, + tt_metal::Program& program) { + tt::DataFormat df = tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t page_size_bytes = op_config.get_page_size(); + + // Input 0 CB + uint32_t src0_cb_index = CB::c_in0; + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, worker_core_range, cb_src0_config); + + // Input 1 CB + uint32_t src1_cb_index = CB::c_in1; + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{src1_cb_index, df}}) + .set_page_size(src1_cb_index, page_size_bytes); + CBHandle cb_src1_workers = CreateCircularBuffer(program, worker_core_range, cb_src1_config); + + // Dataflow Writer Kernel input CB + uint32_t cb_dst0_index = CB::c_out0; + tt_metal::CircularBufferConfig cb_dst0_config = + tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_dst0_index, df}}) + .set_page_size(cb_dst0_index, page_size_bytes); + CBHandle cb_dst0_sender_workers = CreateCircularBuffer(program, worker_core_range, cb_dst0_config); + + // From reader -> writer kernel (I think I need this because sharing the cb_dst0_sender_workers as output + // of reader kernel (first output) and math kernel (all subsequent outputs) doesn't seem to work because + // it seems like the math kernels hold some of the CB state in local variables) + uint32_t cb_short_circuit_index = CB::c_out1; + tt_metal::CircularBufferConfig cb_short_circuit_config = + tt_metal::CircularBufferConfig(worker_pages_per_transfer * page_size_bytes, {{cb_short_circuit_index, df}}) + .set_page_size(cb_short_circuit_index, page_size_bytes); + CBHandle cb_short_circuit_sender_workers = + CreateCircularBuffer(program, worker_core_range, cb_short_circuit_config); + + return {cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers}; +} + +operation::ProgramWithCallbacks reduce_scatter_with_workers( + const std::vector& input_tensors, + const std::vector& output_tensors, + BinaryOpType reduce_op, + const uint32_t scatter_split_dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, + ccl::Topology topology) { + log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); + TT_ASSERT( + input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == + output_tensors.at(0).get_legacy_shape()[scatter_split_dim] * ring_size, + "Input and output tensor shapes must match"); + TT_ASSERT( + input_tensors.at(0).buffer()->num_pages() % ring_size == 0, + "Reduce scatter current only supports even divisibility of input tensor(s) across ranks"); + + /////////////// Constants/Configuration + /// Constants/Configuration + ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + auto const& op_config = ccl::CCLOpConfig(input_tensors, output_tensors, topology); + std::unique_ptr input_tensor_config = + CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); + std::unique_ptr output_tensor_config = + CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); + uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; + uint32_t input_tensor_num_units_per_scatter_dim = + per_step_dim_size / constants::TILE_WIDTH; // TODO: find the divisibility based on layout + TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); + uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); + auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, false); + log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); + auto edm_termination_mode = ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + auto const& edm_builder = create_erisc_datamover_builder( + num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); + TT_ASSERT(num_edm_channels > 0); + + Tensor const& local_chip_tensor = input_tensors.at(0); + Tensor const& local_chip_output_tensor = output_tensors.at(0); + + std::map worker_defines; + std::vector worker_receiver_kernels; + std::vector worker_sender_kernels; + std::vector cw_per_link_edm_builders(num_links, edm_builder); + std::vector ccw_per_link_edm_builders(num_links, edm_builder); + + bool rm = local_chip_tensor.get_layout() == Layout::ROW_MAJOR; + if (rm) { + worker_defines["RM_INTERLEAVED"] = "1"; + } else { + worker_defines["TILE_INTERLEAVED"] = "1"; + } + + ////////////////// + tt_metal::Program program{}; + const auto& device = local_chip_tensor.device(); + + auto const& topology_config = + ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); + + auto dim_slice_factors = Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); + dim_slice_factors[-1] = ring_size; + + // Not per buffer because the buffer sharing mode may cause some buffers to share EDM transfers + WorkerTransferInfo const& worker_transfer_info = compute_num_edm_messages_per_channel( + op_config, + op_config.get_page_size(), + local_chip_tensor.buffer()->num_pages() / ring_size, // pages_per_slice, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + num_edm_channels, + num_links, + ring_size); + + CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); + auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); + + // Semaphores && CBs + auto worker_receiver_semaphore_address = tt_metal::CreateSemaphore(program, worker_core_range, 0); + auto worker_sender_semaphore_address = tt_metal::CreateSemaphore(program, worker_core_range, 0); + + uint32_t cb_num_pages = + (cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes() / op_config.get_page_size()) * 2; + uint32_t cb_num_pages_per_packet = cb_num_pages / 2; + log_trace(tt::LogOp, "cb_num_pages: {}", cb_num_pages); + auto const& [cb_src0_workers, cb_src1_workers, cb_dst0_sender_workers, cb_short_circuit_sender_workers] = + create_worker_circular_buffers(local_chip_tensor, op_config, worker_core_range, cb_num_pages, program); + + uint32_t max_worker_slice_in_bytes = compute_maximum_worker_slice_in_bytes( + cb_num_pages, + cb_num_pages, + cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), + op_config.get_page_size()); + auto tensor_slicer = ccl::InterleavedRingReduceScatterTensorSlicer( + local_chip_tensor, + local_chip_output_tensor, + scatter_split_dim, + ring_index, + ring_size, + num_edm_channels * num_links, + cb_num_pages * 2 * op_config.get_page_size()); + + // Configure the EDM builders + EdmInterfaceAddresses edm_interface_addresses; + for (std::size_t link = 0; link < num_links; link++) { + TT_ASSERT(((link + 1) * num_edm_channels) <= worker_transfer_info.num_messages_per_worker.size()); + add_worker_config_to_edm_builders( + device, + op_config, + worker_cores, + num_edm_channels, + + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + + std::vector( + worker_transfer_info.num_messages_per_worker.begin() + link * num_edm_channels, + worker_transfer_info.num_messages_per_worker.begin() + (link + 1) * num_edm_channels), + std::vector( + worker_transfer_info.num_messages_per_worker.begin() + link * num_edm_channels, + worker_transfer_info.num_messages_per_worker.begin() + (link + 1) * num_edm_channels), + + worker_sender_semaphore_address, + worker_receiver_semaphore_address, + link, + ring_size, + [](uint32_t x) { return true; }, // std::function is_buffer_in_clockwise_direction_fn + + edm_interface_addresses); + } + + // build the worker kernels + tt_metal::ComputeConfig compute_config; + for (std::size_t link = 0; link < num_links; link++) { + uint32_t global_worker_index = link * num_edm_channels; + log_trace(tt::LogOp, "=============================================="); + log_trace(tt::LogOp, "------------------ Link: {} ------------------", link); + for (std::size_t worker = 0; worker < num_edm_channels; worker++) { + std::size_t global_worker_index = worker + link * num_edm_channels; + log_trace(tt::LogOp, "------ Worker: {} (global ID={})", worker, global_worker_index); + // This will be configurable by sharded/non-sharded but present the same arg builder + auto const& worker_slice = tensor_slicer.get_worker_slice(global_worker_index); + auto worker_arg_builder = ReduceScatterWorkerArgBuilder( + op_config, + topology_config, + worker_slice, + worker_transfer_info, + worker, + cb_num_pages_per_packet, + worker_receiver_semaphore_address, + worker_sender_semaphore_address); + + log_trace(tt::LogOp, "worker_cores.at(global_worker_index): {}", worker_cores.at(global_worker_index)); + auto [receiver_kernel_id, sender_kernel_id] = build_reduce_scatter_worker( + program, + device, + topology_config, + op_config, + worker_arg_builder, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + edm_interface_addresses, + worker_cores.at(global_worker_index), + num_edm_channels, + link, + ring_size, + worker, + worker_defines, + reduce_op); + worker_receiver_kernels.push_back(receiver_kernel_id); + worker_sender_kernels.push_back(sender_kernel_id); + + TT_ASSERT(is_cb_buffering_sufficient_to_avoid_deadlock( + worker_slice, + cb_num_pages, + cb_num_pages, + cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), + op_config.get_page_size())); + } + } + + // Generate the EDM kernels + ccl::generate_edm_kernels_for_ring_or_linear_topology( + program, + device, + topology_config, + cw_per_link_edm_builders, + ccw_per_link_edm_builders, + receiver_device_id, + sender_device_id); + + uint32_t total_num_workers = worker_cores.size(); + auto override_runtime_arguments_callback = + [topology_config, worker_receiver_kernels, worker_sender_kernels, worker_cores, total_num_workers, ring_index]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors.at(0); + const auto& output = output_tensors.at(0); + TT_ASSERT(worker_sender_kernels.size() == worker_receiver_kernels.size()); + for (uint32_t i = 0; i < worker_sender_kernels.size(); ++i) { + auto& worker_receiver_runtime_args = + GetRuntimeArgs(program, worker_receiver_kernels.at(i), worker_cores.at(i)); + worker_receiver_runtime_args.at(0) = input.buffer()->address(); + + auto& worker_sender_runtime_args = + GetRuntimeArgs(program, worker_sender_kernels.at(i), worker_cores.at(i)); + worker_sender_runtime_args.at(0) = output.buffer()->address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace reduce_scatter_detail +} // namespace ccl +} // namespace tt_metal +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp new file mode 100644 index 00000000000..3587e453553 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "debug/assert.h" +#include "tensix_types.h" +#include "tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +using tt::tt_metal::ccl::coord_t; +using tt::tt_metal::ccl::WorkerXY; + +struct reduce_scatter_reader_common_args_t { + reduce_scatter_reader_common_args_t(uint32_t& arg_idx) : + src_addr(get_arg_val(arg_idx++)), + num_transfers(get_arg_val(arg_idx++)), + full_chunk_num_pages(get_arg_val(arg_idx++)), + page_size(get_arg_val(arg_idx++)), + + my_ring_idx(get_arg_val(arg_idx++)), + ring_size(get_arg_val(arg_idx++)), + sem_addr(get_arg_val(arg_idx++)), + + is_clockwise_direction(get_arg_val(arg_idx++) == 1), + half_cb_n_pages(get_arg_val(arg_idx++)), + edm_core_noc0_core_x(get_arg_val(arg_idx++)), + edm_core_noc0_core_y(get_arg_val(arg_idx++)), + edm_core_semaphore_address(get_arg_val(arg_idx++)), + edm_core_buffer_address(get_arg_val(arg_idx++)), + num_concurrent_workers(get_arg_val(arg_idx++)), + + input_tensor_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), + tensor_slice_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), + worker_slice_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), + worker_slice_offset(tt::tt_metal::ccl::coord_from_args(arg_idx)) { + ASSERT(full_chunk_num_pages > 0); + ASSERT(page_size > 0); + ASSERT(ring_size > 0); + ASSERT(half_cb_n_pages > 0); + } + + const uint32_t src_addr; + const uint32_t num_transfers; + const uint32_t full_chunk_num_pages; + const uint32_t page_size; + uint32_t my_ring_idx; + const uint32_t ring_size; + const uint32_t sem_addr; + + const bool is_clockwise_direction; + + const uint32_t half_cb_n_pages; + const uint32_t edm_core_noc0_core_x; + const uint32_t edm_core_noc0_core_y; + const uint32_t edm_core_semaphore_address; + const uint32_t edm_core_buffer_address; + const uint32_t num_concurrent_workers; + + coord_t input_tensor_shape; + coord_t tensor_slice_shape; + coord_t worker_slice_shape; + coord_t worker_slice_offset; +}; +#ifdef RM_INTERLEAVED +constexpr bool rm_interleaved_addr_gen_mode = true; +#else +constexpr bool rm_interleaved_addr_gen_mode = false; +#endif + +template +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGen; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGenFast; +}; +template <> +struct interleaved_addr_gen_t { + using type = InterleavedAddrGenFast; +}; + +template +struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { + using src_addr_gen_t = typename interleaved_addr_gen_t::type; + + reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : + reduce_scatter_reader_common_args_t(arg_idx) { + this->s = { + .bank_base_address = this->src_addr, + .page_size = page_size +#if defined TILE_INTERLEAVED + , + .data_format = in0_df +#endif + }; + } + + src_addr_gen_t s; + + void dprint() const { + DPRINT << "RSR args:" + << "\n\tsrc_addr=" << src_addr << "\n\tnum_transfers=" << num_transfers << "\n\tpage_size=" << page_size + << "\n\tfull_chunk_num_pages=" << full_chunk_num_pages << "\n\tmy_ring_idx=" << my_ring_idx + << "\n\tsem_addr=" << sem_addr << "\n\tis_clockwise_direction=" << (uint32_t)is_clockwise_direction + << "\n\thalf_cb_n_pages=" << half_cb_n_pages << "\n\tring_size=" << ring_size + << "\n\tedm_core_noc0_core_x=" << edm_core_noc0_core_x + << "\n\tedm_core_noc0_core_y=" << edm_core_noc0_core_y + << "\n\tedm_core_semaphore_address=" << edm_core_semaphore_address + << "\n\tedm_core_buffer_address=" << edm_core_buffer_address << "\n"; + } +}; + +template +struct reduce_scatter_reader_unique_args_t : public reduce_scatter_reader_common_args_t { + reduce_scatter_reader_unique_args_t(uint32_t& arg_idx, const DataFormat in0_df) : + reduce_scatter_reader_common_args_t(arg_idx), + shard_num_pages(get_arg_val(arg_idx++)), + num_l1_cores(get_arg_val(arg_idx++)), + l1_cores_ptr(reinterpret_cast(get_arg_addr(arg_idx))) { + arg_idx += this->num_l1_cores; + } + + const uint32_t shard_num_pages; + const uint32_t num_l1_cores; + const WorkerXY* const l1_cores_ptr; + + void dprint() const {} +}; + +using advance_to_next_transfer_slice_result_t = std::tuple< + uint32_t, // ring_index + uint32_t // slice_base_page_offset + >; +template +advance_to_next_transfer_slice_result_t advance_to_next_transfer_slice( + uint32_t const ring_size, + uint32_t const curr_ring_idx, + uint32_t const slice_base_page_offset, + uint32_t const bank_base_address, + coord_t const& input_tensor_shape, + coord_t const& tensor_slice_shape, + bool const is_clockwise_direction) { + bool const sliced_on_width = tensor_slice_shape.x < input_tensor_shape.x; + uint32_t single_ring_idx_stride = + sliced_on_width ? tensor_slice_shape.x : tensor_slice_shape.y * input_tensor_shape.x; + uint32_t n_minus_one_ring_indices_stride = sliced_on_width + ? tensor_slice_shape.x * (ring_size - 1) + : tensor_slice_shape.y * input_tensor_shape.x * (ring_size - 1); + + if constexpr (!is_sharded) { + if (is_clockwise_direction) { + if (curr_ring_idx == 0) { + return advance_to_next_transfer_slice_result_t{ + ring_size - 1, + slice_base_page_offset + n_minus_one_ring_indices_stride, + }; + } else { + return advance_to_next_transfer_slice_result_t{ + curr_ring_idx - 1, + slice_base_page_offset - single_ring_idx_stride, + }; + } + } else { + if (curr_ring_idx == ring_size - 1) { + return advance_to_next_transfer_slice_result_t{ + 0, + slice_base_page_offset - n_minus_one_ring_indices_stride, + }; + } else { + return advance_to_next_transfer_slice_result_t{ + curr_ring_idx + 1, + slice_base_page_offset + single_ring_idx_stride, + }; + } + } + } +} + +void kernel_main() { + constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; + + // Currently meaningless when `is_sharded=true` + constexpr bool src_is_dram = get_compile_time_arg_val(1) == 1; + + uint32_t arg_idx = 0; + + constexpr uint32_t to_dm_sender_short_circuit_cb = tt::CB::c_out1; + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + constexpr uint32_t cb_id_in1 = tt::CB::c_in1; + const DataFormat in0_df = get_dataformat(cb_id_in0); + auto args = reduce_scatter_reader_unique_args_t(arg_idx, in0_df); + uint32_t total_eltwise_kernel_num_pages = get_arg_val(arg_idx++); + + ASSERT(args.half_cb_n_pages >= args.full_chunk_num_pages); + + bool width_sliced = args.tensor_slice_shape.x <= args.input_tensor_shape.x; + + volatile tt_l1_ptr uint32_t* receiver_read_semaphore_addr_ptr = + reinterpret_cast(args.sem_addr); + const uint64_t eth_receiver_l1_base_noc_addr = + get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_buffer_address); + const uint64_t eth_receiver_l1_semaphore_noc_addr = + get_noc_addr(args.edm_core_noc0_core_x, args.edm_core_noc0_core_y, args.edm_core_semaphore_address); + + uint32_t total_cb_pages_pushed = 0; + uint32_t total_cb_pages_pushed_to_math = 0; + + // For the first timestep, there is no other input to reduce with, so we just send it straight to the input CB + // of the output data movement kernel - short-circuiting past the (reducer) math kernel + // For tile => shape in tiles + // For RM => shape in elements + while (args.worker_slice_offset.x < args.tensor_slice_shape.x && + args.worker_slice_offset.y < args.tensor_slice_shape.y) { + uint32_t curr_ring_slice_start_page_offset = + width_sliced ? args.tensor_slice_shape.x * args.my_ring_idx + : args.tensor_slice_shape.y * args.my_ring_idx * args.input_tensor_shape.x; + + const uint32_t worker_relative_start_offset_into_slice = + args.worker_slice_offset.x + (args.worker_slice_offset.y * args.input_tensor_shape.x); + const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + uint32_t curr_tile_id = starting_tile_id; + + coord_t valid_worker_slice_shape = coord_t( + std::min(args.worker_slice_shape.x, args.tensor_slice_shape.x - args.worker_slice_offset.x), + std::min(args.worker_slice_shape.y, args.tensor_slice_shape.y - args.worker_slice_offset.y)); + + bool last_page_of_worker = false; + uint32_t const worker_slice_n_pages = valid_worker_slice_shape.x * valid_worker_slice_shape.y; + ASSERT( + (args.num_transfers - 1) * worker_slice_n_pages + total_cb_pages_pushed_to_math <= + total_eltwise_kernel_num_pages); + { + coord_t offset_into_worker_slice = {0, 0}; + for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { + uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); + ASSERT(!last_page_of_worker); + read_chunk_from_output_tensor_v2( + curr_tile_id, + offset_into_worker_slice, + valid_worker_slice_shape, + // In tiles for tile layout + args.input_tensor_shape, + to_dm_sender_short_circuit_cb, + args.s, + n_pages, + args.page_size, + last_page_of_worker); + total_cb_pages_pushed += n_pages; + if (n_pages < args.half_cb_n_pages) { + push_filler_pages_to_cb(to_dm_sender_short_circuit_cb, args.half_cb_n_pages - n_pages); + ASSERT(args.half_cb_n_pages > n_pages); + ASSERT(p + n_pages == worker_slice_n_pages); + total_cb_pages_pushed += (args.half_cb_n_pages - n_pages); + } + } + } + + for (uint32_t i = 1; i < args.num_transfers; ++i) { + coord_t offset_into_worker_slice = {0, 0}; + std::tie(args.my_ring_idx, curr_ring_slice_start_page_offset) = advance_to_next_transfer_slice( + args.ring_size, + args.my_ring_idx, + curr_ring_slice_start_page_offset, + args.s.bank_base_address, + args.input_tensor_shape, + args.tensor_slice_shape, + args.is_clockwise_direction); + ASSERT(last_page_of_worker); + last_page_of_worker = false; + curr_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + + for (uint32_t p = 0; p < worker_slice_n_pages; p += args.full_chunk_num_pages) { + uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); + ASSERT(n_pages > 0); + // Fetch from input tensor + read_chunk_from_output_tensor_v2( + curr_tile_id, + offset_into_worker_slice, + valid_worker_slice_shape, + // In tiles for tile layout + args.input_tensor_shape, + cb_id_in1, + args.s, + n_pages, + args.page_size, + last_page_of_worker); + uint64_t eth_receiver_l1_curr_noc_addr = eth_receiver_l1_base_noc_addr; + + // Fetch from EDM + noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1); + noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0); + fetch_chunk(cb_id_in0, n_pages, args.page_size, eth_receiver_l1_base_noc_addr); + total_cb_pages_pushed_to_math += n_pages; + total_cb_pages_pushed += n_pages; + noc_semaphore_inc( + eth_receiver_l1_semaphore_noc_addr, + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + if (n_pages < args.half_cb_n_pages) { + uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; + push_filler_pages_to_cb(cb_id_in0, num_filler_pages); + push_filler_pages_to_cb(cb_id_in1, num_filler_pages); + total_cb_pages_pushed_to_math += num_filler_pages; + total_cb_pages_pushed += num_filler_pages; + } + } + ASSERT(last_page_of_worker); + } + + args.worker_slice_offset = advance_slice_row_major( + args.worker_slice_offset, args.worker_slice_shape, args.tensor_slice_shape, args.num_concurrent_workers); + } + + ASSERT(total_eltwise_kernel_num_pages >= total_cb_pages_pushed_to_math); + DEBUG_STATUS("DRN1"); + // The host code currently doesn't know how ton accuractly count the exact number of pages pushed through the + // math reduce op so it instead provides a known safe lower bound which may be more than actually required by the + // op. It passes this number to sender and receiver, who will push/pop junk pages to/from the math op to ensure + // it will complete + for (; total_cb_pages_pushed_to_math < total_eltwise_kernel_num_pages; total_cb_pages_pushed_to_math++) { + push_filler_pages_to_cb(cb_id_in0, 1); + push_filler_pages_to_cb(cb_id_in1, 1); + } + + static_assert( + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY > + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + noc_semaphore_inc( + eth_receiver_l1_semaphore_noc_addr, + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY - + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + DEBUG_STATUS("DONE"); +} diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp new file mode 100644 index 00000000000..0ac160a6be6 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" + +using tt::tt_metal::ccl::coord_t; + +void kernel_main() { + constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + + uint32_t arg_idx = 0; + uint32_t const dst_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_l1_base_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_l1_sem_addr = get_arg_val(arg_idx++); + uint32_t const eth_sender_noc_x = get_arg_val(arg_idx++); + uint32_t const eth_sender_noc_y = get_arg_val(arg_idx++); + uint32_t const num_transfers = get_arg_val(arg_idx++); + uint32_t const page_size = get_arg_val(arg_idx++); + uint32_t const full_chunk_num_pages = get_arg_val(arg_idx++); + uint32_t const writer_send_sem_addr = get_arg_val(arg_idx++); + uint32_t const half_cb_n_pages = get_arg_val(arg_idx++); + uint32_t const num_concurrent_workers = get_arg_val(arg_idx++); + + coord_t const& output_tensor_shape = tt::tt_metal::ccl::coord_from_args(arg_idx); + coord_t const& worker_slice_shape = tt::tt_metal::ccl::coord_from_args(arg_idx); + coord_t worker_slice_base_offset = tt::tt_metal::ccl::coord_from_args(arg_idx); + + uint32_t total_eltwise_kernel_num_pages = get_arg_val(arg_idx++); + + // Argument validation + ASSERT(half_cb_n_pages >= full_chunk_num_pages); + ASSERT(full_chunk_num_pages > 0); + ASSERT(page_size > 0); + ASSERT(half_cb_n_pages > 0); + + constexpr uint32_t cb_id_in0 = tt::CB::c_out0; + constexpr uint32_t cb_id_in_short_circuit = tt::CB::c_out1; + const DataFormat in0_df = get_dataformat(cb_id_in0); +#ifdef RM_INTERLEAVED + InterleavedAddrGen d = { + .bank_base_address = dst_addr + output_start_addr_offset, .page_size = page_size}; +#elif defined TILE_INTERLEAVED + + InterleavedAddrGenFast d = { + .bank_base_address = dst_addr, .page_size = page_size, .data_format = in0_df}; +#endif + + // Used to wait until eth sender has space available + volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr = + reinterpret_cast(writer_send_sem_addr); + // This is different per writer core + const uint64_t eth_l1_sender_base_noc_addr = + get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_base_addr); + // Used to signal eth sender that data is available. This is different per writer core + const uint64_t eth_l1_sender_semaphore_addr = + get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr); + + uint32_t total_lifetime_cb_pages_popped_from_math = 0; + uint32_t total_cb_pages_popped = 0; // DEBUG ONLY + while (worker_slice_base_offset.x < output_tensor_shape.x && worker_slice_base_offset.y < output_tensor_shape.y) { + // First phase - we only forward messages to EDM + coord_t valid_worker_slice_shape = coord_t( + std::min(worker_slice_shape.x, output_tensor_shape.x - worker_slice_base_offset.x), + std::min(worker_slice_shape.y, output_tensor_shape.y - worker_slice_base_offset.y)); + uint32_t const num_pages_to_write = valid_worker_slice_shape.x * valid_worker_slice_shape.y; + + ASSERT(total_lifetime_cb_pages_popped_from_math + num_pages_to_write <= total_eltwise_kernel_num_pages); + for (uint32_t i = 0; i < num_transfers; ++i) { + const uint32_t cb_in = i == 0 ? cb_id_in_short_circuit : cb_id_in0; + for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { + uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); + ASSERT(n_pages > 0); + noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); + noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); + send_chunk(cb_in, n_pages, page_size, eth_l1_sender_base_noc_addr); + total_cb_pages_popped += n_pages; // DEBUG ONLY + noc_semaphore_inc( + eth_l1_sender_semaphore_addr, + tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + if (i != 0) { + total_lifetime_cb_pages_popped_from_math += n_pages; + } + if (n_pages < half_cb_n_pages) { + uint32_t num_filler_pages = half_cb_n_pages - n_pages; + + ASSERT(p + n_pages == num_pages_to_write); + pop_filler_pages_from_cb(cb_in, num_filler_pages); + total_cb_pages_popped += num_filler_pages; // DEBUG ONLY + if (i != 0) { + total_lifetime_cb_pages_popped_from_math += num_filler_pages; + } + } + } + } + + // write the final reduced chunk for this chip out to the output tensor + // Second phase - Dump the local output to the output tensor + uint32_t curr_ring_slice_start_page_offset = 0; + const uint32_t worker_relative_start_offset_into_slice = + worker_slice_base_offset.x + (worker_slice_base_offset.y * output_tensor_shape.x); + auto current_worker_slice_offset = worker_slice_base_offset; + const uint32_t starting_tile_id = curr_ring_slice_start_page_offset + worker_relative_start_offset_into_slice; + uint32_t curr_tile_id = starting_tile_id; + + bool last_page_of_worker = false; + for (uint32_t p = 0; p < num_pages_to_write; p += full_chunk_num_pages) { + ASSERT(curr_tile_id < output_tensor_shape.x * output_tensor_shape.y); + ASSERT(!last_page_of_worker); + uint32_t n_pages = std::min(full_chunk_num_pages, num_pages_to_write - p); + ASSERT(n_pages <= half_cb_n_pages); + ASSERT(full_chunk_num_pages <= half_cb_n_pages); + write_chunk_v2( + curr_tile_id, + current_worker_slice_offset, + valid_worker_slice_shape, + output_tensor_shape, // In tiles for tile layout + cb_id_in0, + d, + n_pages, + page_size, + last_page_of_worker); + total_lifetime_cb_pages_popped_from_math += n_pages; + if (n_pages < half_cb_n_pages) { + uint32_t num_filler_pages = half_cb_n_pages - n_pages; + ASSERT(p + n_pages == num_pages_to_write); + pop_filler_pages_from_cb(cb_id_in0, num_filler_pages); + total_lifetime_cb_pages_popped_from_math += num_filler_pages; + } + } + + worker_slice_base_offset = advance_slice_row_major( + worker_slice_base_offset, worker_slice_shape, output_tensor_shape, num_concurrent_workers); + } + + for (; total_lifetime_cb_pages_popped_from_math < total_eltwise_kernel_num_pages; + total_lifetime_cb_pages_popped_from_math++) { + pop_filler_pages_from_cb(cb_id_in0, 1); + } + + noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); + noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); + noc_semaphore_inc( + eth_l1_sender_semaphore_addr, tt::tt_metal::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); +} diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp new file mode 100644 index 00000000000..7a3d39df2f3 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" + +#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { +namespace tt_metal { + +void ReduceScatter::validate(const std::vector& input_tensors) const { + for (auto const& t : input_tensors) { + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] / this->ring_size > 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); + TT_FATAL( + t.get_legacy_shape()[this->scatter_dim] % this->ring_size == 0, + "Reduce scatter input tensor shape on dim {} must be divisible by ring size"); + } +} + +std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_legacy_shape(); + TT_ASSERT( + shape[this->scatter_dim] % this->ring_size == 0, + "The size of the scatter dimension must be a multiple of the ring size"); + shape[this->scatter_dim] /= this->ring_size; + return std::vector(input_tensors.size(), shape); +} + +std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + if (this->output_mem_config.is_sharded()) { + TT_FATAL(false, "Sharded output is not supported for ReduceScatter"); + } else { + return operation::generic_create_output_tensors( + *this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config); + } +} + +tt::stl::reflection::Attributes ReduceScatter::attributes() const { + return { + {"scatter_dim", this->scatter_dim}, + {"num_links", this->num_links}, + {"ring_size", this->ring_size}, + {"ring_index", this->ring_index}, + {"receiver_device_id", this->receiver_device_id}, + {"sender_device_id", this->sender_device_id}, + {"output_mem_config", this->output_mem_config}, + }; +} + +operation::ProgramWithCallbacks ReduceScatter::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + return ccl::reduce_scatter_detail::reduce_scatter_with_workers( + input_tensors, + output_tensors, + this->binary_op_type, + this->scatter_dim, + this->num_links, + this->ring_size, + this->ring_index, + this->receiver_device_id, + this->sender_device_id, + this->topology); +} + +std::vector reduce_scatter_impl( + const std::vector& input_tensors, + const BinaryOpType binary_op_type, + const uint32_t scatter_dim, + const uint32_t num_links, + const MemoryConfig& output_mem_config, + const ccl::Topology topology) { + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + + std::vector output_tensors; + output_tensors.reserve(input_tensors.size()); + std::vector ops; + ops.reserve(input_tensors.size()); + bool is_ring = topology == ccl::Topology::Ring; + for (uint32_t i = 0; i < input_tensors.size(); ++i) { + bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); + bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; + + std::optional receiver_device_id = + is_last_chip_in_clockwise_direction + ? std::nullopt + : std::optional(input_tensors[(i + 1) % input_tensors.size()].device()->id()); + std::optional sender_device_id = + is_last_chip_in_counter_clockwise_direction + ? std::nullopt + : std::optional(input_tensors[i == 0 ? input_tensors.size() - 1 : i - 1].device()->id()); + ops.emplace_back(ReduceScatter{ + binary_op_type, + scatter_dim, + num_links, + static_cast(input_tensors.size()), + i, + receiver_device_id, + sender_device_id, + output_mem_config, + topology}); + output_tensors.push_back(operation::run(ops[i], {input_tensors.at(i)}).at(0)); + } + return output_tensors; +} + +static BinaryOpType convert_reduce_type_to_eltwise_type(ReduceOpMath reduce_op) { + switch (reduce_op) { + case ReduceOpMath::SUM: return BinaryOpType::ADD; + + default: TT_FATAL("Reduce scatter only support reduce_op_type SUM"); return BinaryOpType::ADD; + } +} + +std::vector reduce_scatter( + const std::vector& input_tensors, + const uint32_t scatter_dim, + ReduceOpMath math_op, + const uint32_t num_links, + const MemoryConfig& output_mem_config) { + BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); + return reduce_scatter_impl( + input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config, ccl::Topology::Ring); +} + +}; // namespace tt_metal +}; // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp new file mode 100644 index 00000000000..cc27cd65a88 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp" +#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "tt_eager/tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" + +namespace tt { +namespace tt_metal { + +struct ReduceScatter { + const BinaryOpType binary_op_type; + const uint32_t scatter_dim; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const std::optional receiver_device_id; + const std::optional sender_device_id; + const MemoryConfig output_mem_config; + const ccl::Topology topology; + + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; + tt::stl::reflection::Attributes attributes() const; +}; + +std::vector reduce_scatter( + const std::vector &input_tensors, + const uint32_t scatter_split_dim, + ReduceOpMath reduce_op = ReduceOpMath::SUM, + const uint32_t num_links = 1, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +namespace ccl { +namespace reduce_scatter_detail { +operation::ProgramWithCallbacks reduce_scatter_with_workers( + const std::vector& input_tensors, + const std::vector& output_tensors, + BinaryOpType reduce_op, + const uint32_t scatter_split_dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, + ccl::Topology topology); +} +}; // namespace ccl + +}; // namespace tt_metal +}; // namespace tt diff --git a/tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp b/tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp index d1affcef787..87799e8f4a9 100644 --- a/tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -6,20 +6,28 @@ #include // #include -#include #include +#include namespace tt { namespace tt_metal { namespace ccl { -enum EriscDataMoverBufferSharingMode: uint32_t { +enum EriscDataMoverBufferSharingMode : uint32_t { NOT_SHARED = 0, ROUND_ROBIN = 1, SHARED = 2, ROUND_ROBIN_AND_SHARED = 3 }; +enum EriscDataMoverTerminationMode : uint32_t { MESSAGE_COUNT_REACHED = 0, WORKER_INITIATED = 1 }; + +enum EriscDataMoverWorkerSignal : uint32_t { + NEXT_MESSAGE_AVAILABLE = 1, + NEXT_MESSAGE_IS_LAST = 2, + TERMINATE_IMMEDIATELY = 3, +}; + // TODO: let the kernel runtime args enum ShardType : uint8_t { Width = 0, Height = 1, Block = 2 }; @@ -33,17 +41,49 @@ struct WorkerXY { WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} - uint32_t to_uint32() const { - return (y << 16) | x; - } + uint32_t to_uint32() const { return (y << 16) | x; } - bool operator==(const WorkerXY& rhs) const { - return x == rhs.x && y == rhs.y; + bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } + bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } +}; + +struct coord_t { + coord_t(uint32_t x, uint32_t y) : x(x), y(y) {} + uint32_t x; + uint32_t y; +}; + +// Advances relative to immediate outer slice. There is no notion of global offset here and the caller would be expected +// to add any additional offsets required. Consider templatizing this to conditionally implement the divide as a shift +inline coord_t advance_slice_row_major( + coord_t const &inner_slice_offset, + coord_t const &inner_slice_shape, + coord_t const &outer_slice_shape, + uint32_t num_active_slices) { + auto slice_mod_x = outer_slice_shape.x % inner_slice_shape.x; + bool needs_padding = slice_mod_x != 0; + coord_t padded_outer_slice_shape = + needs_padding ? coord_t(outer_slice_shape.x + (inner_slice_shape.x - slice_mod_x), outer_slice_shape.y) + : outer_slice_shape; + uint32_t advance_x = inner_slice_shape.x * num_active_slices; + uint32_t next_offset_x = inner_slice_offset.x + advance_x; + if (next_offset_x < padded_outer_slice_shape.x) { + return coord_t(next_offset_x, inner_slice_offset.y); } - bool operator!=(const WorkerXY& rhs) const { - return !(*this == rhs); + + uint32_t advance_x_from_0_offset_x = next_offset_x - padded_outer_slice_shape.x; + uint32_t next_offset_y = inner_slice_offset.y + inner_slice_shape.y; + // Measure perf impact of early exit vs the division + if (advance_x_from_0_offset_x < padded_outer_slice_shape.x) { + return coord_t(advance_x_from_0_offset_x, inner_slice_offset.y + inner_slice_shape.y); } -}; + + uint32_t slice_rows_advanced = advance_x_from_0_offset_x / padded_outer_slice_shape.x; + next_offset_x = advance_x_from_0_offset_x - (slice_rows_advanced * padded_outer_slice_shape.x); + next_offset_y += slice_rows_advanced * inner_slice_shape.y; + + return coord_t(next_offset_x, next_offset_y); +} static constexpr uint32_t UNINITIALIZED_VALUE_U32 = std::numeric_limits::max(); static constexpr uint16_t UNINITIALIZED_VALUE_U16 = std::numeric_limits::max(); @@ -59,11 +99,10 @@ struct ArchDependentTypes { template <> struct ArchDependentTypes { - using workers_list_t = ccl::WorkerXY*; + using workers_list_t = ccl::WorkerXY *; static const workers_list_t WORKERS_LIST_UNINITIALIZED_VALUE; }; - template struct FullWorkerGridShardAddrGenArgs final { typename ArchDependentTypes::workers_list_t dest_cores; @@ -91,7 +130,6 @@ struct FullWorkerGridShardAddrGenArgs final { template struct ShardAddrGenArgs final { - uint32_t shards_start_address = UNINITIALIZED_VALUE_U32; uint32_t shard_size_in_bytes = UNINITIALIZED_VALUE_U32; uint16_t total_chunks_per_core = UNINITIALIZED_VALUE_U16; @@ -116,58 +154,36 @@ struct ShardAddrGenArgs final { } }; -// uint16_t &curr_shard_tile_x, -// uint16_t &curr_shard_tile_y, -// uint16_t &curr_tile_index, -// uint16_t &curr_shard, -// uint16_t const input_shard_num_tiles_x, -// uint16_t const input_shard_num_tiles_y, -// uint16_t const total_shards_x, -// bool is_clockwise) { - namespace all_gather { inline void addr_gen_advance_width_sharded( - // uint16_t& curr_core_chunk_index, - // uint16_t& curr_worker_index, - // uint16_t& contiguous_chunk_count, - // // uint16_t& current_core_chunks_visited, - // const uint16_t& total_chunks_per_core, - // const uint16_t& num_dest_cores, - // const uint16_t& intra_core_stride_in_shards, - // const uint16_t& contiguous_chunks_before_stride, - // bool is_clockwise - uint16_t& curr_core_tile_index, - uint16_t& curr_worker_index, - uint16_t& contiguous_tile_count, + uint16_t &curr_core_tile_index, + uint16_t &curr_worker_index, + uint16_t &contiguous_tile_count, // uint16_t& current_core_chunks_visited, - const uint16_t& total_chunks_per_core, - const uint16_t& num_dest_cores, - const uint16_t& intra_core_stride_in_shards, - const uint16_t& contiguous_chunks_before_stride, - bool is_clockwise -) { + const uint16_t &total_chunks_per_core, + const uint16_t &num_dest_cores, + const uint16_t &intra_core_stride_in_shards, + const uint16_t &contiguous_chunks_before_stride, + bool is_clockwise) { if (is_clockwise) { bool do_stride = contiguous_tile_count == contiguous_chunks_before_stride; - bool stride_induced_chunk_wraparound = (do_stride && curr_core_tile_index < (intra_core_stride_in_shards + contiguous_chunks_before_stride - 1)); + bool stride_induced_chunk_wraparound = + (do_stride && curr_core_tile_index < (intra_core_stride_in_shards + contiguous_chunks_before_stride - 1)); bool do_chunk_wrap = curr_core_tile_index >= total_chunks_per_core || stride_induced_chunk_wraparound; - // current_core_chunks_visited++; if (do_chunk_wrap) { bool do_core_wrap = curr_worker_index == 0; - uint32_t past_end_index = (total_chunks_per_core + curr_core_tile_index + 1 - contiguous_chunks_before_stride); + uint32_t past_end_index = + (total_chunks_per_core + curr_core_tile_index + 1 - contiguous_chunks_before_stride); uint32_t backward_step_amount = (intra_core_stride_in_shards + contiguous_chunks_before_stride - 1); - // ASSERT(past_end_index >= backward_step_amount); curr_core_tile_index = past_end_index - backward_step_amount; - // curr_core_tile_index = (total_chunks_per_core + curr_core_tile_index - contiguous_chunks_before_stride) - (intra_core_stride_in_shards + contiguous_chunks_before_stride); contiguous_tile_count = 1; if (do_core_wrap) { curr_worker_index = num_dest_cores - 1; - // current_core_chunks_visited=0; } else { curr_worker_index--; } } else { - if (do_stride) { contiguous_tile_count = 1; curr_core_tile_index -= (intra_core_stride_in_shards + contiguous_chunks_before_stride - 1); @@ -178,10 +194,8 @@ inline void addr_gen_advance_width_sharded( } } else { - // current_core_chunks_visited++; if (contiguous_tile_count == contiguous_chunks_before_stride) { contiguous_tile_count = 1; - // TT_ASSERT(curr_core_chunk_index >= intra_core_stride_in_shards); curr_core_tile_index += intra_core_stride_in_shards; } else { contiguous_tile_count++; @@ -190,7 +204,6 @@ inline void addr_gen_advance_width_sharded( bool do_chunk_wrap = curr_core_tile_index >= total_chunks_per_core; if (do_chunk_wrap) { - // current_core_chunks_visited = 0; curr_core_tile_index = curr_core_tile_index - total_chunks_per_core; curr_worker_index++; bool do_core_wrap = curr_worker_index == num_dest_cores; @@ -210,12 +223,10 @@ inline void full_worker_grid_addr_gen_width_sharded_advance_shard_impl( uint16_t const input_shard_num_tiles_x, uint16_t const total_shards_x, uint16_t const shard_offset, - bool is_clockwise -) { + bool is_clockwise) { bool wrap_around = is_clockwise ? curr_core_index == 0 : curr_core_index == total_num_cores - 1; - curr_core_index = wrap_around ? - (is_clockwise ? total_num_cores - 1 : 0) : - (is_clockwise ? curr_core_index - 1 : curr_core_index + 1); + curr_core_index = wrap_around ? (is_clockwise ? total_num_cores - 1 : 0) + : (is_clockwise ? curr_core_index - 1 : curr_core_index + 1); curr_tile_index = input_shard_num_tiles_x * shard_offset; curr_shard_tile_x = 0; curr_shard_tile_y = 0; @@ -233,12 +244,19 @@ inline void full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( uint16_t const total_shards_x, uint16_t const shard_offset, bool is_clockwise) { - // Keep it verbose for now. we can reduce to a flat index later bool is_last_row = curr_shard_tile_y == input_shard_num_tiles_y - 1; if (is_last_row) { full_worker_grid_addr_gen_width_sharded_advance_shard_impl( - curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, total_shards_x, shard_offset, is_clockwise); + curr_shard_tile_x, + curr_shard_tile_y, + curr_tile_index, + curr_core_index, + total_num_cores, + input_shard_num_tiles_x, + total_shards_x, + shard_offset, + is_clockwise); } else { curr_tile_index += total_shards_x * input_shard_num_tiles_x - curr_shard_tile_x; @@ -247,7 +265,7 @@ inline void full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( } } -inline void full_worker_grid_addr_gen_width_sharded_advance ( +inline void full_worker_grid_addr_gen_width_sharded_advance( uint16_t &curr_shard_tile_x, uint16_t &curr_shard_tile_y, uint16_t &curr_tile_index, @@ -258,13 +276,20 @@ inline void full_worker_grid_addr_gen_width_sharded_advance ( uint16_t const total_shards_x, uint16_t const shard_offset, bool is_clockwise) { - // Keep it verbose for now. we can reduce to a flat index later bool last_tile_in_row = curr_shard_tile_x == input_shard_num_tiles_x - 1; bool last_tile_in_col = curr_shard_tile_y == input_shard_num_tiles_y - 1; if (last_tile_in_row && last_tile_in_col) { full_worker_grid_addr_gen_width_sharded_advance_shard_impl( - curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, total_shards_x, shard_offset, is_clockwise); + curr_shard_tile_x, + curr_shard_tile_y, + curr_tile_index, + curr_core_index, + total_num_cores, + input_shard_num_tiles_x, + total_shards_x, + shard_offset, + is_clockwise); } else if (last_tile_in_row) { curr_tile_index += total_shards_x * input_shard_num_tiles_x - curr_shard_tile_x; @@ -276,8 +301,7 @@ inline void full_worker_grid_addr_gen_width_sharded_advance ( } } - -}; // namespace all_gather +}; // namespace all_gather } // namespace ccl } // namespace tt_metal diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index dea7965862c..c4857479370 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -24,6 +24,7 @@ #include "tt_dnn/op_library/sharded/sharded_op.hpp" #include "tt_dnn/op_library/sharded_partial/sharded_op_partial.hpp" #include "tt_dnn/op_library/all_gather/all_gather_op.hpp" +#include "tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" namespace tt::tt_metal::detail{ @@ -525,7 +526,8 @@ namespace tt::tt_metal::detail{ R"doc(Converts a partial tensor from sharded_to_interleaved memory layout)doc" ); - // Multi-Device ops + // ---------- Multi-Device ops ---------- + // All Gather m_tensor.def("all_gather", &all_gather, py::arg("input_tensors"), py::arg("dim"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(Performs all gather on a list of tensors that form one tensor that is distributed across devices. The output is a list of a tensor which has been duplciated across the input devices.)doc" @@ -534,6 +536,28 @@ namespace tt::tt_metal::detail{ py::arg("input_tensors"), py::arg("dim"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(Performs all gather on a list of tensors that form one tensor that is distributed across devices. The output is a list of a tensor which has been duplciated across the input devices.)doc" ); + + // Reduce Scatter + m_tensor.def("reduce_scatter", &reduce_scatter, + // py::arg("input_tensors"), py::arg("scatter_split_dim"), py::arg("reduce_op") = ReduceOpMath::SUM, py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("input_tensors"), py::arg("scatter_split_dim"), py::arg("reduce_op"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + R"doc( + Performs reduce scatter across chips, where the input tensors are sliced along the scatter dim, and pairwise reduced as they propagate and reduce through the cluster. + + For example, a reduce scatter on a ring of rank 8 and input tensor shapes (per rank) of [1,1,1024,8096] and scatter_dim=3, will split each input tensor + on width into 8 parts of size [1,1,1024,1024]. Each of those parts will reduce with the corresponding chunk from the other ranks. All chips will collectively + reduce the first incoming [1,1,1024,1024] chunk with their local first [1,1,1024,1024] chunk and be forwarded. The second incoming [1,1,1024,1024] chunk will + be reduced with the second local [1,1,1024,1024] chunk and be forwarded and so on. Each rank in the ring will start on a different offset into the chunk such + that by the end, they will finish with a different reduced chunk offset from the original tensor shape. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "scatter_split_dim", "Dimension to evenly slice input tensor along for each rank", "int", "0..3", "Yes" + "reduce_op", "reduction math operation", " ReduceOpMath", "SUM", "No" + "num_links", "Number of ethernet links to allow the op to use to send data chip to chip for the operation. Default=1", "int", "1..max_num_links", "No" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); } }