diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp index 0e62d41c47b..2ba2048de9e 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp @@ -1,34 +1,18 @@ // SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // SPDX-License-Identifier: Apache-2.0 -#include -#include - #include "common_tensor_test_utils.hpp" #include "gtest/gtest.h" +#include "host_api.hpp" #include "tt_metal/common/logger.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/async_runtime.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" +#include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" +#include "ttnn_test_fixtures.hpp" namespace { -struct Inputs { - SimpleShape shape; - Size shard_shape; - Alignment shard_alignment; - std::vector data; -}; - -struct Expected { - Size physical_shard_shape; - Size physical_size; - std::vector physical_data; -}; - -struct ShardWithAlignmentParams { - Inputs inputs; - Expected expected; -}; - // Helpers that potentially need to be moved into TensorLayout Size flatten_to_2D(const ttnn::SimpleShape& shape) { const int rank = static_cast(shape.rank()); @@ -242,6 +226,27 @@ std::vector convert_fp32_physical_data_to_logical_data( } // namespace +namespace { +struct ShardWithAlignmentInputs { + SimpleShape shape; + Size shard_shape; + Alignment shard_alignment; + std::vector data; +}; + +struct ShardWithAlignmentExpected { + Size physical_shard_shape; + Size physical_size; + std::vector physical_data; +}; + +struct ShardWithAlignmentParams { + ShardWithAlignmentInputs inputs; + ShardWithAlignmentExpected expected; +}; +} // namespace +// namespace + class ShardWithAlignmentTests : public ::testing::TestWithParam {}; TEST_P(ShardWithAlignmentTests, LogicalToPhysical) { @@ -307,7 +312,7 @@ INSTANTIATE_TEST_SUITE_P( // TILE interleaved is equivalent to setting logical shard size to full height and width // NOTE: This can also be interpreted as height sharded where we don't break apart height ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 15, 20}, .shard_shape = {15, 20}, .shard_alignment = Alignment({16, 16}), @@ -343,7 +348,7 @@ INSTANTIATE_TEST_SUITE_P( 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 32}, .physical_size = {32, 32}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -384,7 +389,7 @@ INSTANTIATE_TEST_SUITE_P( // TILE height sharded is equivalent to setting logical shard width to full width // NOTE: This also supports logical shard height that breaks the height logically ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 1, 15, 15}, .shard_shape = {5, 15}, .shard_alignment = Alignment({16, 16}), @@ -406,7 +411,7 @@ INSTANTIATE_TEST_SUITE_P( 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 16}, .physical_size = {48, 16}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, @@ -463,7 +468,7 @@ INSTANTIATE_TEST_SUITE_P( }, // TILE width sharded is equivalent to setting logical shard height to full flattened tensor height ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 20}, .shard_shape = {10, 10}, .shard_alignment = Alignment({16, 16}), @@ -478,7 +483,7 @@ INSTANTIATE_TEST_SUITE_P( 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, /**/ 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, /**/ 190, 191, 192, 193, 194, 195, 196, 197, 198, 199} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 16}, .physical_size = {16, 32}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, /**/ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, @@ -501,7 +506,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM interleaved is equivalent to setting logical shard size to 1 by 1 ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 1}, .shard_shape = {1, 1}, .shard_alignment = Alignment({1, 4}), @@ -525,7 +530,7 @@ INSTANTIATE_TEST_SUITE_P( 9} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {1, 4}, .physical_size = {10, 4}, .physical_data = { 0, 0, 0, 0, @@ -551,7 +556,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM height sharded with padding along width to align shards ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 1}, .shard_shape = {3, 1}, .shard_alignment = Alignment({1, 4}), @@ -569,7 +574,7 @@ INSTANTIATE_TEST_SUITE_P( 9} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {3, 4}, .physical_size = {12, 4}, .physical_data = { 0, 0, 0, 0, @@ -591,7 +596,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM width sharded with padding along width to align shards ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 10}, .shard_shape = {10, 3}, .shard_alignment = Alignment({1, 4}), @@ -606,7 +611,7 @@ INSTANTIATE_TEST_SUITE_P( 80, 81, 82, /**/ 83, 84, 85, /**/ 86, 87, 88, /**/ 89, 90, 91, 92, /**/ 93, 94, 95, /**/ 96, 97, 98, /**/ 99} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {10, 4}, .physical_size = {10, 16}, .physical_data = { 0, 1, 2, 0, /**/ 3, 4, 5, 0, /**/ 6, 7, 8, 0, /**/ 9, 0, 0, 0, @@ -623,7 +628,7 @@ INSTANTIATE_TEST_SUITE_P( }, // Arbitrary logical shard shape and alignment to stress test edges with padding ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 10, 10}, .shard_shape = {3, 4}, .shard_alignment = Alignment({5, 7}), @@ -654,7 +659,7 @@ INSTANTIATE_TEST_SUITE_P( 180, 181, 182, 183, /**/ 184, 185, 186, 187, /**/ 188, 189, 190, 191, 192, 193, /**/ 194, 195, 196, 197, /**/ 198, 199} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {5, 7}, .physical_size = {35, 21}, .physical_data = { 0, 1, 2, 3, 0, 0, 0, /**/ 4, 5, 6, 7, 0, 0, 0, /**/ 8, 9, 0, 0, 0, 0, 0, @@ -703,3 +708,263 @@ INSTANTIATE_TEST_SUITE_P( ) // Values // clang-format on ); + +namespace { +const CoreCoord grid_size{8, 7}; + +struct CreateShardedTensorWithAlignmentInputs { + SimpleShape shape; + std::optional shard_shape; + Alignment shard_alignment; + DataType data_type; + PageConfig page_config; + MemoryConfig memory_config; +}; + +struct CreateShardedTensorWithAlignmentExpected { + std::optional logical_shard_shape; + Size physical_size; +}; + +struct CreateShardedTensorWithAlignmentParams { + CreateShardedTensorWithAlignmentInputs inputs; + CreateShardedTensorWithAlignmentExpected expected; +}; +} // namespace + +class CreateShardedTensorWithAlignmentTests + : public ttnn::TTNNFixtureWithDevice, + public ::testing::WithParamInterface {}; + +TEST_P(CreateShardedTensorWithAlignmentTests, Generic) { + const auto& params = GetParam(); + const auto& input_shape = params.inputs.shape; + + TensorLayout layout( + params.inputs.data_type, + params.inputs.page_config, + params.inputs.memory_config, + params.inputs.shard_shape, + params.inputs.shard_alignment); + + test_utils::test_tensor_on_device(input_shape, layout, device_); + + if (params.expected.logical_shard_shape.has_value()) { + ASSERT_TRUE(layout.get_shard_shape().has_value()); + EXPECT_EQ(layout.get_shard_shape().value(), params.expected.logical_shard_shape.value()); + } + EXPECT_EQ(layout.compute_physical_shape(input_shape), params.expected.physical_size); +} + +INSTANTIATE_TEST_SUITE_P( + TensorShardingTests, + CreateShardedTensorWithAlignmentTests, + // clang-format off + ::testing::Values( + ////////////////////////////////////////////////////////////////////////////////////////// + // EXAMPLE 1: TILE tensor with different representation for height sharded / interleaved + ////////////////////////////////////////////////////////////////////////////////////////// + // Example 1a: Logical shard shape + alignment after + // - Along height: 48 * 56 / 48 is 56 shards; 56 * 64 = 3584 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_shape = Size{48, 32}, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::HEIGHT_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(48 * 56, 48), grid_size, /*row_wise=*/true), + {64, 32}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{48, 32}, + .physical_size = Size{3584, 32} + } + }, + // Example 1b: No logical shard shape, so we treat shard shape in shard spec as logical shard shape + // - Along height: 48 * 56 / 64 is 42 shards; 42 * 64 = 2688 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_shape = std::nullopt, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::HEIGHT_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(48 * 56, 64), grid_size, /*row_wise=*/true), + {64, 32}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{64, 32}, + .physical_size = Size{2688, 32} + } + }, + // Example 1c: For interleaved, we treat entire height/width as "logical shard shape" for calculations + // 48 "shards" with 56 aligned to 32 for tile alignment; 48 * 64 = 3072 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_shape = std::nullopt, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::INTERLEAVED, + .buffer_type = BufferType::DRAM, + .shard_spec = std::nullopt + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = std::nullopt, + .physical_size = Size{3072, 32} + } + }, + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // EXAMPLE 2: ROW_MAJOR tensor with different representation for width sharded / interleaved + // - In this example, (shard) width alignment is 4 because UINT8 = 1 bytes and we pack with uint32_t + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // Example 2a: Logical shard shape + alignment after + // - Along width: 5 / 1 is 5 shards; 5 * 4 = 20 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_shape = Size{20, 1}, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::WIDTH_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(5, 1), grid_size, /*row_wise=*/true), + {20, 4}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{20, 1}, + .physical_size = Size{20, 20} + } + }, + // Example 2b: No logical shard shape, so we treat shard shape in shard spec as logical shard shape + // - Along width: 5 / 4 is 2 shards; 2 * 4 = 8 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_shape = std::nullopt, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::WIDTH_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(5, 4), grid_size, /*row_wise=*/true), + {20, 4}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{20, 4}, + .physical_size = Size{20, 8} + } + }, + // Example 2c: For interleaved, we treat entire height/width as "logical shard shape" for calculations + // 20 "shards" with 5 aligned to 4 for uint32_t alignment + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_shape = std::nullopt, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::INTERLEAVED, + .buffer_type = BufferType::L1, + .shard_spec = std::nullopt + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = std::nullopt, + .physical_size = Size{20, 8} + } + }, + //////////////////////////////////////////////////////////////////// + // EXAMPLE 3: Interesting cases with custom (legal) shard alignment + //////////////////////////////////////////////////////////////////// + // Example 3: TILE block sharded tensor with shard alignment of 3 * 16 along the width + // - Along height: 8 * 36 / 48 is 6 shards; 6 * 64 = 384 + // - Along width: 32 / 10 is 4 shards; 4 * custom alignment 48 = 192 (48 % 16 == 0, so it is legal) + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 8, 36, 32}, + .shard_shape = Size{48, 10}, + .shard_alignment = Alignment{32, 48}, + .data_type = DataType::BFLOAT8_B, + .page_config = PageConfig(Layout::TILE, Tile({32, 16})), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::BLOCK_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(8 * 36, 48) * tt::div_up(32, 10), grid_size, /*row_wise=*/true), + {64, 48}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{48, 10}, + .physical_size = Size{384, 192} + } + }, + // Example 3: ROW_MAJOR block sharded tensor with 2 and 1 extra rows and col per shard, respectively + // - Along height: 2 * 10 / 5 is 4 shards; 4 * custom alignment 7 = 28 (no restriction on height alignment for ROW_MAJOR) + // - Along width: 5 / 2 is 3 shards; 3 * custom alignment 3 = 9 (alignment on width can be arbitrary because UINT32 is already 4-byte aligned) + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_shape = Size{5, 2}, + .shard_alignment = Alignment{7, 3}, + .data_type = DataType::UINT32, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::BLOCK_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerange_set(tt::div_up(2 * 10, 5) * tt::div_up(5, 2), grid_size, /*row_wise=*/true), + {7, 3}, + ShardOrientation::ROW_MAJOR, + false} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .logical_shard_shape = Size{5, 2}, + .physical_size = Size{28, 9} + } + } + ) // Values + // clang-format on +); diff --git a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp index 23f745b3a95..5f4c8f57922 100644 --- a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp @@ -37,4 +37,4 @@ class Alignment final : protected ShapeBase { std::ostream &operator<<(std::ostream &os, const tt::tt_metal::Alignment &shape); -} // namespace ttnn +} // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index b8af239c7e1..11df2c5954f 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -42,7 +42,31 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape) { } // namespace CMAKE_UNIQUE_NAMESPACE TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config) - : TensorLayout(dtype, page_config, memory_config, {}) { + : TensorLayout(dtype, page_config, memory_config, std::nullopt, {}) { +} + +TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const std::optional& shard_shape, const Alignment& shard_alignment) + : dtype_(dtype), + page_config_(page_config), + memory_config_(memory_config), + alignment_(shard_alignment) { + + // TODO: If we remove alignment, we should lower this into validate_alignment + TT_FATAL(shard_alignment.size() <= 2, "Shard alignment {} exceeds maximum rank of 2", shard_alignment); + + if (shard_shape.has_value()) { + TT_FATAL(memory_config.is_sharded(), "Shard shape should only be used if memory config is sharded"); + this->shard_shape_ = shard_shape.value(); + } else if (memory_config.shard_spec.has_value()) { + // Treat shard shape from shard spec as physical shard shape and use here as logical shard shape for this->shard_shape_ + // When computing ShardSpecBuffer, double check that the computed physical shard shape is the same as shard shape from shard spec + this->shard_shape_ = Size(memory_config.shard_spec.value().shape); + } else { + this->shard_shape_ = std::nullopt; + } + + initialize_alignment(); + validate_alignment(); } // Private @@ -50,7 +74,8 @@ TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const : dtype_(dtype), page_config_(page_config), memory_config_(memory_config), - alignment_(alignment) { + alignment_(alignment), + shard_shape_(std::nullopt) { initialize_alignment(); validate_alignment(); @@ -61,7 +86,7 @@ TensorLayout TensorLayout::fromLegacyPaddedShape(DataType dtype, const PageConfi } void TensorLayout::initialize_alignment() { - if(!alignment_.empty()) { + if (!alignment_.empty()) { return; } @@ -81,6 +106,10 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn TT_FATAL(memory_config_.shard_spec.has_value(), "MemoryConfig must have Shard Spec specified for sharded memory layout"); auto& shard_spec = memory_config_.shard_spec.value(); + + const auto& physical_shard_shape = compute_physical_shard_shape(shape); + TT_FATAL(physical_shard_shape == (Size) shard_spec.shape, "Shard shape in shard spec {} must be same as physical shard shape {}", shard_spec.shape, physical_shard_shape); + const Size physical_size = compute_physical_shape(shape); const Size page_shape = compute_page_shape(physical_size); @@ -120,29 +149,82 @@ size_t TensorLayout::compute_page_size_bytes(const Size& page_size) const { return page_config_.get_page_size_bytes(page_size, dtype_); } +Size TensorLayout::compute_logical_shard_shape(const ttnn::SimpleShape& shape) const { + if (shard_shape_.has_value()) { + return shard_shape_.value(); + } + const int rank = static_cast(shape.rank()); + + size_t shard_width = 1; + size_t shard_height = 1; + + // Iterate dims in reverse order + // Even tensor of rank 0 or 1 + for (int i = -1; i >= -2; --i) { + if (i >= -rank) { + auto& dim = i == -1 ? shard_width : shard_height; + dim *= shape[i]; + } + } + + return Size{shard_height, shard_width}; +} + +Size TensorLayout::compute_physical_shard_shape(const ttnn::SimpleShape& shape) const { + const auto& logical_shard_shape = compute_logical_shard_shape(shape); + return Size{round_up(logical_shard_shape.height(), alignment_[-2]), round_up(logical_shard_shape.width(), alignment_[-1])}; +} + Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const { const int rank = static_cast(shape.rank()); const int alignment_rank = static_cast(alignment_.size()); - const int max_rank = std::max(rank, alignment_rank); + size_t width = 1; size_t height = 1; - // Iterate dims in reverse order and ensure alignment - // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) - for (int i = -1; i >= -max_rank; --i) { - auto& dim = i == -1 ? width : height; - if(i >= -rank) { - dim *= shape[i]; + if (alignment_rank > 2) { + const int max_rank = std::max(rank, alignment_rank); + + // Iterate dims in reverse order and ensure alignment + // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) + for (int i = -1; i >= -max_rank; --i) { + auto& dim = i == -1 ? width : height; + if(i >= -rank) { + dim *= shape[i]; + } + + // Align the current dimension if alignment is available + if (i >= -alignment_rank) { + dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); + } } - // Align the current dimension if alignment is available - if (i >= -alignment_rank) { - dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); + Size size{height, width}; + return size; + } else { + // Iterate dims in reverse order + // Even tensor of rank 0 or 1 + for (int i = -1; i >= -rank; --i) { + auto& dim = i == -1 ? width : height; + dim *= shape[i]; } - } - Size size{height, width}; - return size; + auto get_physical_size = [](auto original_size, auto logical_shard_size, auto physical_shard_size) { + if (logical_shard_size == 0) { + return (uint32_t) 0; + } + auto num_shards = tt::div_up(original_size, logical_shard_size); + return (uint32_t) physical_shard_size * num_shards; + }; + + const auto& logical_shard_shape = compute_logical_shard_shape(shape); + const auto& physical_shard_shape = compute_physical_shard_shape(shape); + auto physical_height = get_physical_size(height, logical_shard_shape.height(), physical_shard_shape.height()); + auto physical_width = get_physical_size(width, logical_shard_shape.width(), physical_shard_shape.width()); + + Size size{physical_height, physical_width}; + return size; + } } Size TensorLayout::compute_page_shape(const Size& physical_size) const { diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index 487534fc125..c0a4736f5b2 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -24,6 +24,8 @@ class TensorLayout { public: TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config); + TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const std::optional& shard_shape, const Alignment& shard_alignment); + // static method makes it easy to find and remove all of its usages in the codebase - thats why it is not a constructor [[deprecated("Use of Legacy Padded Shape is deprecated")]] static TensorLayout fromLegacyPaddedShape(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape); @@ -33,6 +35,7 @@ class TensorLayout { DataType get_data_type() const { return dtype_; } const MemoryConfig& get_memory_config() const { return memory_config_; } const Alignment& get_alignment() const { return alignment_; } + std::optional get_shard_shape() const { return shard_shape_; } Strides compute_strides(const ttnn::SimpleShape& shape) const; @@ -65,6 +68,10 @@ class TensorLayout { PageConfig page_config_; MemoryConfig memory_config_; Alignment alignment_; + + std::optional shard_shape_; + Size compute_logical_shard_shape(const ttnn::SimpleShape& shape) const; + Size compute_physical_shard_shape(const ttnn::SimpleShape& shape) const; }; } // tt::tt_metal