Skip to content

Commit

Permalink
#5562: add initial reduce scatter implementation (experimental)
Browse files Browse the repository at this point in the history
TODO: add beautiful description about
- What is reduce scatter
- What are known issues/limitations (what's supported in other words)
- New features (EDM)
- New infrastructure created for this op (especially infra portable to
allgather and other CCL ops)
- Description of t-streaming
  • Loading branch information
SeanNijjar committed May 23, 2024
1 parent 354a834 commit d4f80be
Show file tree
Hide file tree
Showing 24 changed files with 3,369 additions and 557 deletions.
6 changes: 5 additions & 1 deletion tests/scripts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from loguru import logger

from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, is_grayskull


class TestSuiteType(Enum):
Expand All @@ -33,6 +33,10 @@ def void_for_whb0(x):
return (not is_wormhole_b0()) and x or None


def void_for_gs(x):
return (not is_grayskull()) and x or None


def filter_empty(fn):
@wraps(fn)
def __filter_empty():
Expand Down
12 changes: 8 additions & 4 deletions tests/scripts/run_tt_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,22 @@
get_git_home_dir_str,
filter_empty,
void_for_whb0,
void_for_gs,
)
from tests.scripts.cmdline_args import (
get_tt_metal_arguments_from_cmdline_args,
get_cmdline_args,
)

TT_EAGER_COMMON_TEST_ENTRIES = (
TestEntry("tt_eager/tests/ops/ccl/test_all_gather_utils", "ops/ccl/test_all_gather_utils"),
TestEntry(
"tt_eager/tests/ops/ccl/test_all_gather_sharded_indexing_helpers",
"ops/ccl/test_all_gather_sharded_indexing_helpers",
void_for_gs(TestEntry("tt_eager/tests/ops/ccl/test_all_gather_utils", "ops/ccl/test_all_gather_utils")),
void_for_gs(
TestEntry(
"tt_eager/tests/ops/ccl/test_all_gather_sharded_indexing_helpers",
"ops/ccl/test_all_gather_sharded_indexing_helpers",
)
),
void_for_gs(TestEntry("tt_eager/tests/ops/ccl/test_ccl_helpers", "ops/ccl/test_ccl_helpers")),
TestEntry("tt_eager/tests/ops/test_eltwise_binary_op", "ops/test_eltwise_binary_op"),
TestEntry("tt_eager/tests/ops/test_bcast_op", "ops/test_bcast_op"),
TestEntry("tt_eager/tests/ops/test_reduce_op", "ops/test_reduce_op"),
Expand Down
1 change: 1 addition & 0 deletions tests/tt_eager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ target_link_libraries(test_eager_common_libs INTERFACE tt_eager test_common_libs
set(TT_EAGER_TESTS_OPS
ops/ccl/test_all_gather_utils
ops/ccl/test_all_gather_sharded_indexing_helpers
ops/ccl/test_ccl_helpers
ops/test_average_pool
ops/test_eltwise_binary_op
ops/test_eltwise_unary_op
Expand Down
1 change: 1 addition & 0 deletions tests/tt_eager/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TT_EAGER_TESTS += \
tests/tt_eager/ops/ccl/test_all_gather_utils \
tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers \
tests/tt_eager/ops/ccl/test_ccl_helpers \
tests/tt_eager/ops/test_average_pool \
tests/tt_eager/ops/test_eltwise_binary_op \
tests/tt_eager/ops/test_eltwise_unary_op \
Expand Down
291 changes: 291 additions & 0 deletions tests/tt_eager/ops/ccl/test_ccl_helpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "device/tt_xy_pair.h"
#include "gtest/gtest.h"
#include "tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include "tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp"
#include "tt_eager/tt_dnn/op_library/ccl/ccl_common.hpp"


TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingMode) {
std::size_t num_channels = 4;
uint32_t page_size = 2048;
ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN;
ccl::EriscDataMoverTerminationMode termination_mode = ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED;

auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, buffer_sharing_mode, termination_mode);
std::vector<uint32_t> worker_semaphore_addresses = {
0x1000,
0x1010,
0x1020,
0x1030,
};
std::vector<uint32_t> message_counts = {
256,
512,
24,
1
};
std::vector<std::vector<ccl::WorkerXY>> const& worker_coords = {
{ccl::WorkerXY{1,1}, ccl::WorkerXY{2,1}},
{ccl::WorkerXY{3,1}},
{ccl::WorkerXY{4,1}, ccl::WorkerXY{5,1}, ccl::WorkerXY{6,1}},
{ccl::WorkerXY{1,2}},
};
std::vector<bool> is_sender_channel {
true,
false,
true,
false
};

std::vector<ccl::EriscDatamoverBuilder::ChannelBufferInterface> channel_buffer_interfaces;
channel_buffer_interfaces.reserve(num_channels);
for (std::size_t i = 0; i < num_channels; i++) {
ccl::EriscDatamoverBuilder::ChannelBufferInterface const& channel_buffer_interface =
(is_sender_channel[i])
? edm_builder.add_sender_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i])
: edm_builder.add_receiver_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i]);
channel_buffer_interfaces.push_back(channel_buffer_interface);
ASSERT_TRUE(channel_buffer_interface.eth_buffer_l1_address > 0);
ASSERT_TRUE(channel_buffer_interface.eth_semaphore_l1_address > 0);
}

auto const& active_channels = edm_builder.get_active_channels();
ASSERT_EQ(active_channels.size(), num_channels);
for (std::size_t i = 0; i < active_channels.size(); ++i) {
ASSERT_EQ(active_channels[i].channel, i);
ASSERT_EQ(active_channels[i].is_sender, is_sender_channel.at(i));
ASSERT_EQ(active_channels[i].worker_coords, worker_coords.at(i));
ASSERT_TRUE(active_channels[i].worker_semaphore_address == worker_semaphore_addresses.at(i));
ASSERT_TRUE(active_channels[i].num_eth_messages_to_forward > message_counts.at(i));
}
}

TEST(CclHelpers, EriscDatamoverConfig_GetEdmHandshakeAddress_GT_0) {
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0);
}
}
TEST(CclHelpers, EriscDatamoverConfig_GetSemaphoresBaseAddress_GT_0) {
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >= (ccl::EriscDatamoverConfig::get_edm_handshake_address() + ccl::EriscDatamoverConfig::handshake_location_size + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
}
}

TEST(CclHelpers, EriscDatamoverConfig_GetBuffersBaseAddress_GT_0) {
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= (ccl::EriscDatamoverConfig::get_edm_handshake_address() + ccl::EriscDatamoverConfig::handshake_location_size + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
}
}

TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) {
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= (ccl::EriscDatamoverConfig::get_edm_handshake_address() + ccl::EriscDatamoverConfig::handshake_location_size + ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
}
}


/////////////////////////////////////////
// TEST AdvanceSliceRowMajor
/////////////////////////////////////////
// x_y x_y x_y
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) {
const auto expected = tt::tt_metal::ccl::coord_t(1, 0);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 1);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) {
const auto expected = tt::tt_metal::ccl::coord_t(0, 1);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 1);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) {
const auto expected = tt::tt_metal::ccl::coord_t(1, 1);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 1);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) {
const auto expected = tt::tt_metal::ccl::coord_t(0, 1);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 2);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) {
const auto expected = tt::tt_metal::ccl::coord_t(1, 1);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 2);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}

// Test that we successfully go out of bounds on the last iteration
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) {
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 2);
ASSERT_TRUE(result.x >= 2 || result.y >= 2);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) {
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 1}, {1, 1}, {2, 2}, 2);
ASSERT_TRUE(result.x >= 2 || result.y >= 2);
}

TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) {
const auto expected = tt::tt_metal::ccl::coord_t(1, 1);
auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 3);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) {
const auto expected = tt::tt_metal::ccl::coord_t(1, 1);
const auto outer_shape = tt::tt_metal::ccl::coord_t(2, 2);
const auto inner_offset = tt::tt_metal::ccl::coord_t(1, 1);
const auto inner_shape = tt::tt_metal::ccl::coord_t(1, 1);
const uint32_t num_parallel_workers = 3;
auto const& result = tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers);
ASSERT_TRUE(result.x >= outer_shape.x || result.y >= outer_shape.y);
}
TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_0__InnerShape_24_0__OuterShape_32_4__NumActiveSlices_4) {
const auto expected = tt::tt_metal::ccl::coord_t(24, 2);
const auto outer_shape = tt::tt_metal::ccl::coord_t(32, 4);
const auto inner_offset = tt::tt_metal::ccl::coord_t(24, 0);
const auto inner_shape = tt::tt_metal::ccl::coord_t(24, 1);
const uint32_t num_parallel_workers = 4;
auto const& result = tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers);
ASSERT_EQ(result.x, expected.x);
ASSERT_EQ(result.y, expected.y);
}

/////////////////////////////////////////
// Test InterleavedRingReduceScatterTensorSlicer
/////////////////////////////////////////
TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_AllWorkersSameRow) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(4, {2, 2});
tt_xy_pair tensor_slice_shape = {8, 4};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2,0));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4,0));
ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6,0));
}
TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToNextRowAligned) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(4, {2, 2});
tt_xy_pair tensor_slice_shape = {6, 4};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2,0));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4,0));
ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(0,2));
}
TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToNextRowMisaligned) {
{
auto worker_slice_shapes = std::vector<tt_xy_pair>(4, {2, 2});
tt_xy_pair tensor_slice_shape = {5, 4};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2,0));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4,0));
ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(0,2));
}
}
TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowAligned) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(8, {2, 2});
tt_xy_pair tensor_slice_shape = {10, 4};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2,0));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4,0));
ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6,0));
ASSERT_EQ(worker_slice_offsets.at(4), tt_xy_pair(8,0));
ASSERT_EQ(worker_slice_offsets.at(5), tt_xy_pair(0,2));
ASSERT_EQ(worker_slice_offsets.at(6), tt_xy_pair(2,2));
ASSERT_EQ(worker_slice_offsets.at(7), tt_xy_pair(4,2));
}

TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowMisaligned) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(8, {2, 2});
tt_xy_pair tensor_slice_shape = {9, 4};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2,0));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(4,0));
ASSERT_EQ(worker_slice_offsets.at(3), tt_xy_pair(6,0));
ASSERT_EQ(worker_slice_offsets.at(4), tt_xy_pair(8,0));
ASSERT_EQ(worker_slice_offsets.at(5), tt_xy_pair(0,2));
ASSERT_EQ(worker_slice_offsets.at(6), tt_xy_pair(2,2));
ASSERT_EQ(worker_slice_offsets.at(7), tt_xy_pair(4,2));
}

TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowAligned) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(3, {4, 4});
tt_xy_pair tensor_slice_shape = {4, 12};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0,4));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(0,8));
}

TEST(Ccl_InterleavedRingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowMisaligned) {
auto worker_slice_shapes = std::vector<tt_xy_pair>(3, {4, 3});
tt_xy_pair tensor_slice_shape = {3, 12};
auto const& worker_slice_offsets = ccl::InterleavedRingReduceScatterTensorSlicer::compute_worker_slice_offsets(worker_slice_shapes, tensor_slice_shape);
ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0,0));
ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0,3));
ASSERT_EQ(worker_slice_offsets.at(2), tt_xy_pair(0,6));
}


TEST(Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_0_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) {
auto worker_slice = ccl::InterleavedTensorWorkerSlice(
tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result
tt_xy_pair(32, 4),
tt_xy_pair(24, 1),
tt_xy_pair(0, 0)
);
uint32_t num_workers = 4;
auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers);
auto expected = 2;
ASSERT_EQ(num_iterations, expected);
}

TEST(Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_24_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) {
auto worker_slice = ccl::InterleavedTensorWorkerSlice(
tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result
tt_xy_pair(32, 4),
tt_xy_pair(24, 1),
tt_xy_pair(24, 0)
);
uint32_t num_workers = 4;
auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers);
auto expected = 2;
ASSERT_EQ(num_iterations, expected);
}

TEST(Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_0_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) {
auto worker_slice = ccl::InterleavedTensorWorkerSlice(
tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result
tt_xy_pair(32, 4),
tt_xy_pair(24, 1),
tt_xy_pair(0, 1)
);
uint32_t num_workers = 4;
auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers);
auto expected = 2;
ASSERT_EQ(num_iterations, expected);
}

TEST(Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_24_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) {
auto worker_slice = ccl::InterleavedTensorWorkerSlice(
tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result
tt_xy_pair(32, 4),
tt_xy_pair(24, 1),
tt_xy_pair(24, 0)
);
uint32_t num_workers = 4;
auto num_iterations = worker_slice.compute_num_worker_slice_iterations(num_workers);
auto expected = 2;
ASSERT_EQ(num_iterations, expected);
}
Loading

0 comments on commit d4f80be

Please sign in to comment.