From f7354bfa31cabf59df6b599eaaef3a6d42408148 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Mon, 4 Nov 2024 18:30:28 +0000 Subject: [PATCH 1/2] #13127: Add support for new logical sharding + alignment in TensorLayout and tensor creation - Add ShardMode enum to specify shard shape in shard spec as either physical or logical * ShardMode::PHYSICAL: This is current behaviour that we will deprecate! ** It is less expressive than using shard shape as logical (ie. must be tile aligned for TILE layout etc...) ** It fundamentally operates on padded shape and is confusing and incompatible with logical shape * ShardMode::LOGICAL: Shard shape cuts 2D logical shape and each shard is aligned after ** Without alignment restrictions, you can cut 2D logical shape more arbitrarily ** Existing sharding can be switched over to this entirely (just need codeowners to help out and flip...) * Default everywhere will be ShardMode::PHYSICAL with a warning message - Switch tests/ttnn/unit_tests/operations/test_paged_update_cache.py to use logical shard shape as an example * Introduce tensor.logical_volume() (as opposed to tensor.volume() which returns physical volume based on padded shape) * TODO: Rename volume() -> physical_volume() and logical_volume() -> volume() - Add new c++ tests to test tensor creation with logical shard shape + alignment * IMPORTANT: Need to update host data manipulation to be aware of new logical sharding for use from python! To support these changes, some changes to TensorLayout: - Make private TensorLayout constructor with alignment public with these changes: * legacyShapeToAlignment will try to return 2D alignment if possible (ie. only padding on height/width) ** Goal is to transition alignment to be 2D only if we remove poor use cases of padding on non-height/width dims * legacyShapeToAlignment is only expected to be used for ShardMode::PHYSICAL and uses default alignment for sharded tensors ** Before interleaved or sharded will just use padded shape for alignment ** One exception is for row major sharded tensors where we use shard width if shape is padded; otherwise, we only take shard width for BLOCK/WIDTH sharded cases and original physical shape for HEIGHT sharded * legacyShapeToAlignment (and alignment in general) will work iff there is only padding on height and/or width ** IMPORTANT: This means we are expecting tensors with arbitrary padding along non-height/width to be interleaved only! - If ShardMode::LOGICAL: * In TensorLayout::compute_shard_spec_buffer, calculate physical shard shape based on shard shape + alignment * In TensorLayout::compute_physical_shape, calculate physical shape based on number of logical shards - Clean up handling of sharded tensors and error messages in ttnn/cpp/ttnn/tensor/layout/page_config.cpp - Add Size constructor for std::array --- .../tensor/test_sharding_with_alignment.cpp | 320 ++++++++++++++++-- .../operations/test_paged_update_cache.py | 28 +- tt_metal/impl/buffers/buffer.cpp | 3 +- tt_metal/impl/buffers/buffer.hpp | 16 +- tt_metal/impl/buffers/buffer_constants.hpp | 5 + ttnn/cpp/pybind11/pytensor.cpp | 9 + ttnn/cpp/pybind11/tensor.cpp | 9 +- ttnn/cpp/ttnn/tensor/layout/alignment.hpp | 2 +- ttnn/cpp/ttnn/tensor/layout/page_config.cpp | 49 ++- ttnn/cpp/ttnn/tensor/layout/page_config.hpp | 6 +- ttnn/cpp/ttnn/tensor/layout/size.cpp | 3 + ttnn/cpp/ttnn/tensor/layout/size.hpp | 1 + ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 170 ++++++++-- ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp | 7 +- ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/types.py | 1 + 16 files changed, 513 insertions(+), 117 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp index 0e62d41c47b..145e46035c9 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp @@ -1,34 +1,18 @@ // SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // SPDX-License-Identifier: Apache-2.0 -#include -#include - #include "common_tensor_test_utils.hpp" #include "gtest/gtest.h" +#include "host_api.hpp" #include "tt_metal/common/logger.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/async_runtime.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" +#include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" +#include "ttnn_test_fixtures.hpp" namespace { -struct Inputs { - SimpleShape shape; - Size shard_shape; - Alignment shard_alignment; - std::vector data; -}; - -struct Expected { - Size physical_shard_shape; - Size physical_size; - std::vector physical_data; -}; - -struct ShardWithAlignmentParams { - Inputs inputs; - Expected expected; -}; - // Helpers that potentially need to be moved into TensorLayout Size flatten_to_2D(const ttnn::SimpleShape& shape) { const int rank = static_cast(shape.rank()); @@ -242,6 +226,27 @@ std::vector convert_fp32_physical_data_to_logical_data( } // namespace +namespace { +struct ShardWithAlignmentInputs { + SimpleShape shape; + Size shard_shape; + Alignment shard_alignment; + std::vector data; +}; + +struct ShardWithAlignmentExpected { + Size physical_shard_shape; + Size physical_size; + std::vector physical_data; +}; + +struct ShardWithAlignmentParams { + ShardWithAlignmentInputs inputs; + ShardWithAlignmentExpected expected; +}; +} // namespace +// namespace + class ShardWithAlignmentTests : public ::testing::TestWithParam {}; TEST_P(ShardWithAlignmentTests, LogicalToPhysical) { @@ -307,7 +312,7 @@ INSTANTIATE_TEST_SUITE_P( // TILE interleaved is equivalent to setting logical shard size to full height and width // NOTE: This can also be interpreted as height sharded where we don't break apart height ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 15, 20}, .shard_shape = {15, 20}, .shard_alignment = Alignment({16, 16}), @@ -343,7 +348,7 @@ INSTANTIATE_TEST_SUITE_P( 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 32}, .physical_size = {32, 32}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -384,7 +389,7 @@ INSTANTIATE_TEST_SUITE_P( // TILE height sharded is equivalent to setting logical shard width to full width // NOTE: This also supports logical shard height that breaks the height logically ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 1, 15, 15}, .shard_shape = {5, 15}, .shard_alignment = Alignment({16, 16}), @@ -406,7 +411,7 @@ INSTANTIATE_TEST_SUITE_P( 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 16}, .physical_size = {48, 16}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, @@ -463,7 +468,7 @@ INSTANTIATE_TEST_SUITE_P( }, // TILE width sharded is equivalent to setting logical shard height to full flattened tensor height ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 20}, .shard_shape = {10, 10}, .shard_alignment = Alignment({16, 16}), @@ -478,7 +483,7 @@ INSTANTIATE_TEST_SUITE_P( 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, /**/ 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, /**/ 190, 191, 192, 193, 194, 195, 196, 197, 198, 199} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {16, 16}, .physical_size = {16, 32}, .physical_data = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, /**/ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, @@ -501,7 +506,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM interleaved is equivalent to setting logical shard size to 1 by 1 ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 1}, .shard_shape = {1, 1}, .shard_alignment = Alignment({1, 4}), @@ -525,7 +530,7 @@ INSTANTIATE_TEST_SUITE_P( 9} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {1, 4}, .physical_size = {10, 4}, .physical_data = { 0, 0, 0, 0, @@ -551,7 +556,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM height sharded with padding along width to align shards ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 1}, .shard_shape = {3, 1}, .shard_alignment = Alignment({1, 4}), @@ -569,7 +574,7 @@ INSTANTIATE_TEST_SUITE_P( 9} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {3, 4}, .physical_size = {12, 4}, .physical_data = { 0, 0, 0, 0, @@ -591,7 +596,7 @@ INSTANTIATE_TEST_SUITE_P( }, // RM width sharded with padding along width to align shards ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 5, 10}, .shard_shape = {10, 3}, .shard_alignment = Alignment({1, 4}), @@ -606,7 +611,7 @@ INSTANTIATE_TEST_SUITE_P( 80, 81, 82, /**/ 83, 84, 85, /**/ 86, 87, 88, /**/ 89, 90, 91, 92, /**/ 93, 94, 95, /**/ 96, 97, 98, /**/ 99} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {10, 4}, .physical_size = {10, 16}, .physical_data = { 0, 1, 2, 0, /**/ 3, 4, 5, 0, /**/ 6, 7, 8, 0, /**/ 9, 0, 0, 0, @@ -623,7 +628,7 @@ INSTANTIATE_TEST_SUITE_P( }, // Arbitrary logical shard shape and alignment to stress test edges with padding ShardWithAlignmentParams{ - Inputs{ + ShardWithAlignmentInputs{ .shape = SimpleShape{1, 2, 10, 10}, .shard_shape = {3, 4}, .shard_alignment = Alignment({5, 7}), @@ -654,7 +659,7 @@ INSTANTIATE_TEST_SUITE_P( 180, 181, 182, 183, /**/ 184, 185, 186, 187, /**/ 188, 189, 190, 191, 192, 193, /**/ 194, 195, 196, 197, /**/ 198, 199} }, - Expected{ + ShardWithAlignmentExpected{ .physical_shard_shape = {5, 7}, .physical_size = {35, 21}, .physical_data = { 0, 1, 2, 3, 0, 0, 0, /**/ 4, 5, 6, 7, 0, 0, 0, /**/ 8, 9, 0, 0, 0, 0, 0, @@ -703,3 +708,248 @@ INSTANTIATE_TEST_SUITE_P( ) // Values // clang-format on ); + +namespace { +const CoreCoord grid_size{8, 7}; + +struct CreateShardedTensorWithAlignmentInputs { + SimpleShape shape; + Alignment shard_alignment; + DataType data_type; + PageConfig page_config; + MemoryConfig memory_config; +}; + +struct CreateShardedTensorWithAlignmentExpected { + Size physical_size; +}; + +struct CreateShardedTensorWithAlignmentParams { + CreateShardedTensorWithAlignmentInputs inputs; + CreateShardedTensorWithAlignmentExpected expected; +}; +} // namespace + +class CreateShardedTensorWithAlignmentTests + : public ttnn::TTNNFixtureWithDevice, + public ::testing::WithParamInterface {}; + +TEST_P(CreateShardedTensorWithAlignmentTests, Generic) { + const auto& params = GetParam(); + const auto& input_shape = params.inputs.shape; + + TensorLayout layout( + params.inputs.data_type, + params.inputs.page_config, + params.inputs.memory_config, + params.inputs.shard_alignment); + + test_utils::test_tensor_on_device(input_shape, layout, device_); + + EXPECT_EQ(layout.compute_physical_shape(input_shape), params.expected.physical_size); +} + +INSTANTIATE_TEST_SUITE_P( + TensorShardingTests, + CreateShardedTensorWithAlignmentTests, + // clang-format off + ::testing::Values( + ////////////////////////////////////////////////////////////////////////////////////////// + // EXAMPLE 1: TILE tensor with different representation for height sharded / interleaved + ////////////////////////////////////////////////////////////////////////////////////////// + // Example 1a: Logical shard shape + alignment after + // - Along height: 48 * 56 / 48 is 56 shards; 56 * 64 = 3584 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::HEIGHT_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(48 * 56, 48), grid_size, /*row_wise=*/true), + {48, 32}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{3584, 32} + } + }, + // Example 1b: Logical shard shape that is already aligned + // NOTE: If ShardMode::PHYSICAL, it expects height 56 to be padded up to 64 + // - Along height: 48 * 56 / 64 is 42 shards; 42 * 64 = 2688 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::HEIGHT_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(48 * 56, 64), grid_size, /*row_wise=*/true), + {64, 32}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{2688, 32} + } + }, + // Example 1c: For interleaved, we treat entire height/width as "logical shard shape" for calculations + // 48 "shards" with 56 aligned to 32 for tile alignment; 48 * 64 = 3072 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 48, 56, 32}, + .shard_alignment = Alignment{}, + .data_type = DataType::BFLOAT16, + .page_config = PageConfig(Layout::TILE), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::INTERLEAVED, + .buffer_type = BufferType::DRAM, + .shard_spec = std::nullopt + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{3072, 32} + } + }, + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // EXAMPLE 2: ROW_MAJOR tensor with different representation for width sharded / interleaved + // - In this example, (shard) width alignment is 4 because UINT8 = 1 bytes and we pack with uint32_t + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // Example 2a: Logical shard shape + alignment after + // - Along width: 5 / 1 is 5 shards; 5 * 4 = 20 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::WIDTH_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(5, 1), grid_size, /*row_wise=*/true), + {20, 1}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{20, 20} + } + }, + // Example 2b: Logical shard shape that is already aligned + // NOTE: ShardMode::PHYSICAL is equivalent in this case + // - Along width: 5 / 4 is 2 shards; 2 * 4 = 8 + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::WIDTH_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(5, 4), grid_size, /*row_wise=*/true), + {20, 4}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{20, 8} + } + }, + // Example 2c: For interleaved, we treat entire height/width as "logical shard shape" for calculations + // 20 "shards" with 5 aligned to 4 for uint32_t alignment + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_alignment = Alignment{}, + .data_type = DataType::UINT8, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::INTERLEAVED, + .buffer_type = BufferType::L1, + .shard_spec = std::nullopt + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{20, 8} + } + }, + //////////////////////////////////////////////////////////////////// + // EXAMPLE 3: Interesting cases with custom (legal) shard alignment + //////////////////////////////////////////////////////////////////// + // Example 3a: TILE block sharded tensor with shard alignment of 3 * 16 along the width + // - Along height: 8 * 36 / 48 is 6 shards; 6 * 64 = 384 + // - Along width: 32 / 10 is 4 shards; 4 * custom alignment 48 = 192 (48 % 16 == 0, so it is legal) + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 8, 36, 32}, + .shard_alignment = Alignment{32, 48}, + .data_type = DataType::BFLOAT8_B, + .page_config = PageConfig(Layout::TILE, Tile({32, 16})), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::BLOCK_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(8 * 36, 48) * tt::div_up(32, 10), grid_size, /*row_wise=*/true), + {48, 10}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{384, 192} + } + }, + // Example 3b: ROW_MAJOR block sharded tensor with 2 and 1 extra rows and col per shard, respectively + // - Along height: 2 * 10 / 5 is 4 shards; 4 * custom alignment 7 = 28 (no restriction on height alignment for ROW_MAJOR) + // - Along width: 5 / 2 is 3 shards; 3 * custom alignment 3 = 9 (alignment on width can be arbitrary because UINT32 is already 4-byte aligned) + CreateShardedTensorWithAlignmentParams{ + CreateShardedTensorWithAlignmentInputs{ + .shape = SimpleShape{1, 2, 10, 5}, + .shard_alignment = Alignment{7, 3}, + .data_type = DataType::UINT32, + .page_config = PageConfig(Layout::ROW_MAJOR), + .memory_config = + MemoryConfig{ + .memory_layout = TensorMemoryLayout::BLOCK_SHARDED, + .buffer_type = BufferType::L1, + .shard_spec = ShardSpec{ + num_cores_to_corerangeset(tt::div_up(2 * 10, 5) * tt::div_up(5, 2), grid_size, /*row_wise=*/true), + {5, 2}, + ShardOrientation::ROW_MAJOR, + false, + ShardMode::LOGICAL} + } + }, + CreateShardedTensorWithAlignmentExpected{ + .physical_size = Size{28, 9} + } + } + ) // Values + // clang-format on +); diff --git a/tests/ttnn/unit_tests/operations/test_paged_update_cache.py b/tests/ttnn/unit_tests/operations/test_paged_update_cache.py index 90f4f9f2798..bfeccc29cfc 100644 --- a/tests/ttnn/unit_tests/operations/test_paged_update_cache.py +++ b/tests/ttnn/unit_tests/operations/test_paged_update_cache.py @@ -41,11 +41,12 @@ def run_test_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.logical_volume() // xt.shape[-1] // num_cores, xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -151,11 +152,12 @@ def test_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, - xt.shape.with_tile_padding()[-1], + xt.logical_volume() // xt.shape[-1] // num_cores, + xt.shape[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec @@ -234,11 +236,12 @@ def test_update_cache_decode_program_cache( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, - xt.shape.with_tile_padding()[-1], + xt.logical_volume() // xt.shape[-1] // num_cores, + xt.shape[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec @@ -276,11 +279,12 @@ def run_test_tensor_index_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, - xt.shape.with_tile_padding()[-1], + xt.logical_volume() // xt.shape[-1] // num_cores, + xt.shape[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -414,11 +418,12 @@ def run_test_paged_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, - xt.shape.with_tile_padding()[-1], + xt.logical_volume() // xt.shape[-1] // num_cores, + xt.shape[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -543,11 +548,12 @@ def test_paged_update_cache_decode_program_caching( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, - xt.shape.with_tile_padding()[-1], + xt.logical_volume() // xt.shape[-1] // num_cores, + xt.shape[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, + ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index fdfa57a79a3..62ba46bb665 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -526,6 +526,7 @@ tt_metal::ShardSpec from_json_t::operator()(const nlohmann: from_json(json_object.at("grid")), from_json>(json_object.at("shape")), from_json(json_object.at("orientation")), - from_json(json_object.at("halo"))}; + from_json(json_object.at("halo")), + from_json(json_object.at("mode"))}; } } diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index 31c1e3b73d2..2275ee1d47d 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -49,12 +49,20 @@ struct ShardSpec { ShardOrientation orientation = ShardOrientation::ROW_MAJOR; bool halo = false; + ShardMode mode = ShardMode::PHYSICAL; + ShardSpec( const CoreRangeSet &core_sets_, const std::array &shard_shape_, const ShardOrientation &shard_orientation_ = ShardOrientation::ROW_MAJOR, - const bool &halo_ = false) : - grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_) { + const bool &halo_ = false, + const ShardMode &shard_mode_ = ShardMode::PHYSICAL) : + grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_), mode(shard_mode_) { + if (shard_mode_ == ShardMode::PHYSICAL) { + tt::log_warning( + tt::LogOp, + "ShardMode::PHYSICAL will be deprecated soon! Please switch to equivalent representation with ShardMode::LOGICAL"); + } } const uint32_t num_cores() const { return this->grid.num_cores(); } @@ -63,9 +71,9 @@ struct ShardSpec { bool operator==(const ShardSpec& other) const; bool operator!=(const ShardSpec& other) const; - static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo"); + static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo", "mode"); constexpr auto attribute_values() const { - return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo); + return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo, this->mode); } }; diff --git a/tt_metal/impl/buffers/buffer_constants.hpp b/tt_metal/impl/buffers/buffer_constants.hpp index 115bd9d8517..7d1c70c6457 100644 --- a/tt_metal/impl/buffers/buffer_constants.hpp +++ b/tt_metal/impl/buffers/buffer_constants.hpp @@ -22,6 +22,11 @@ enum class ShardOrientation { COL_MAJOR, }; +enum class ShardMode { + PHYSICAL, // TODO: Deprecate this option to treat shard shape as physical + LOGICAL, +}; + enum class BufferType { DRAM, L1, diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 487869e15ef..9d8f8c6099d 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -1512,6 +1512,15 @@ void pytensor_module(py::module &m_tensor) { volume = tt_tensor.volume() + )doc") + .def( + "logical_volume", [](const Tensor &self) { return self.get_logical_volume(); }, R"doc( + Get the logical volume of the tensor. + + .. code-block:: python + + volume = tt_tensor.get_logical_volume() + )doc") .def( "storage_type", [](const Tensor &self) { return self.storage_type(); }, R"doc( diff --git a/ttnn/cpp/pybind11/tensor.cpp b/ttnn/cpp/pybind11/tensor.cpp index 8624e002513..467a83bc030 100644 --- a/ttnn/cpp/pybind11/tensor.cpp +++ b/ttnn/cpp/pybind11/tensor.cpp @@ -72,6 +72,7 @@ void tensor_mem_config_module_types(py::module& m_tensor) { export_enum(m_tensor); export_enum(m_tensor); export_enum(m_tensor); + export_enum(m_tensor); py::enum_(m_tensor, "BufferType") .value("DRAM", BufferType::DRAM) @@ -266,10 +267,16 @@ void tensor_mem_config_module(py::module& m_tensor) { .def(py::init<>([](const CoreRangeSet& core_sets, const std::array& shard_shape, const ShardOrientation& shard_orientation, - const bool& halo) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo); })) + const bool& halo) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, ShardMode::PHYSICAL); })) + .def(py::init<>([](const CoreRangeSet& core_sets, + const std::array& shard_shape, + const ShardOrientation& shard_orientation, + const bool& halo, + const ShardMode& shard_mode) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, shard_mode); })) .def_readwrite("shape", &ShardSpec::shape, "Shape of shard.") .def_readwrite("grid", &ShardSpec::grid, "Grid to layout shards.") .def_readwrite("orientation", &ShardSpec::orientation, "Orientation of cores to read shards") + .def_readwrite("mode", &ShardSpec::mode, "Treat shard shape as physical (default) or logical") .def("num_cores", &ShardSpec::num_cores, "Number of cores") .def(py::self == py::self) .def(py::self != py::self); diff --git a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp index 23f745b3a95..5f4c8f57922 100644 --- a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp @@ -37,4 +37,4 @@ class Alignment final : protected ShapeBase { std::ostream &operator<<(std::ostream &os, const tt::tt_metal::Alignment &shape); -} // namespace ttnn +} // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/layout/page_config.cpp b/ttnn/cpp/ttnn/tensor/layout/page_config.cpp index b66a4ca8af1..67042d76e63 100644 --- a/ttnn/cpp/ttnn/tensor/layout/page_config.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/page_config.cpp @@ -52,8 +52,8 @@ void PageConfig::validate_alignment(const Alignment& alignment, DataType dtype, std::visit([&](const auto& config) constexpr { config.validate_alignment(alignment, dtype, memory_config); }, config_); } -Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const { - return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config); }, config_); +Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const { + return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config, physical_shard_size); }, config_); } size_t PageConfig::get_page_size_bytes(const Size& page_shape, DataType dtype) const { @@ -92,7 +92,7 @@ void TilePageConfig::validate_alignment(const Alignment& alignment, DataType dty "Wrong custom Tensor Layout alignment {}. For Tile layout second innermost dimension should be multiple of tile height {}.", alignment, tile_.get_height()); } -Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const { +Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional&) const { if(memory_config.memory_layout == TensorMemoryLayout::SINGLE_BANK && physical_size.width() != 0 && physical_size.height() != 0) { return physical_size; } @@ -116,20 +116,17 @@ Alignment RowMajorPageConfig::create_default_alignment(DataType dtype, const Mem const auto element_size = CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype); auto width_alignment = sizeof(uint32_t) / element_size; - if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { - const auto& shard_spec = memory_config.shard_spec.value(); - const auto& shard_shape = shard_spec.shape; - const auto shard_width = shard_shape[1]; + if (memory_config.shard_spec.has_value() && memory_config.shard_spec.value().mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { + const auto& physical_shard_shape = memory_config.shard_spec.value().shape; + const auto physical_shard_width = physical_shard_shape[1]; TT_FATAL( - (shard_width % width_alignment) == 0, - "Invalid sharding configuration: For Row Major layout with element size of {} bytes, the innermost dimension must align to {} bytes. " - "Buffer data is packed as uint32_t (4 bytes), so the provided shard shape {} does not meet alignment requirements.", - element_size, width_alignment, shard_shape - ); + (physical_shard_width % width_alignment) == 0, + "For Row Major layout and shard mode {}, the width of shard shape {} is treated as physical shard width and must be aligned to {} since we pack buffer data as uint32_t.", + memory_config.shard_spec.value().mode, physical_shard_shape, width_alignment + ); - width_alignment = shard_width; + width_alignment = physical_shard_width; } - return Alignment({width_alignment});} } @@ -140,21 +137,20 @@ void RowMajorPageConfig::validate_alignment(const Alignment& alignment, DataType const uint32_t page_alignment = sizeof(uint32_t) / element_size; TT_FATAL((width_alignment % page_alignment) == 0, - "Incorrect alignment configuration for Row Major layout: alignment {} requires innermost dimension alignment of {} bytes due to uint32_t (4-byte) packing, but the current alignment size is {}.", - alignment, element_size, page_alignment); + "Incorrect alignment configuration for Row Major layout: innermost dimension alignment must be aligned to {} bytes since we pack buffer data as uint32_t. With element size of {} byte(s), alignment {} must be a multiple of alignment {}.", + sizeof(uint32_t), element_size, alignment, page_alignment); - if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { - const auto& shard_spec = memory_config.shard_spec.value(); - const auto& shard_shape = shard_spec.shape; - const auto shard_width = shard_shape[1]; + if (memory_config.shard_spec.has_value() && memory_config.shard_spec.value().mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { + const auto& physical_shard_shape = memory_config.shard_spec.value().shape; + const auto physical_shard_width = physical_shard_shape[1]; TT_FATAL( - width_alignment % shard_width == 0, - "Alignment mismatch for sharded tensor: Expected alignment width of {} to match shard shape {} for Row Major layout.", - width_alignment, shard_shape); + physical_shard_width % width_alignment == 0, + "Alignment mismatch for sharded tensor: Expected physical shard shape {} to be aligned to {} along the width for Row Major layout.", + physical_shard_width, width_alignment); } } -Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const { +Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const { if (physical_size.height() == 0 || physical_size.width() == 0) { return Size(1, sizeof(uint32_t) / CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype)); } @@ -164,10 +160,9 @@ Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtyp } if (memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { - const auto& shard_spec = memory_config.shard_spec.value(); - const auto& shard_shape = shard_spec.shape; + TT_FATAL(physical_shard_size.has_value(), "For width or block sharded tensors, Row Major page width comes from physical shard size so it must be provided!"); - return Size(1, shard_shape[1]); + return Size(1, physical_shard_size.value().width()); } return Size(1, physical_size.width()); diff --git a/ttnn/cpp/ttnn/tensor/layout/page_config.hpp b/ttnn/cpp/ttnn/tensor/layout/page_config.hpp index 615e4eef96a..0338722bc03 100644 --- a/ttnn/cpp/ttnn/tensor/layout/page_config.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/page_config.hpp @@ -23,7 +23,7 @@ class RowMajorPageConfig { Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const; void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const; - Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const; + Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const; size_t get_page_size_bytes(const Size& page_size, DataType dtype) const; }; @@ -34,7 +34,7 @@ class TilePageConfig { Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const; void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const; - Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const; + Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const; size_t get_page_size_bytes(const Size& page_size, DataType dtype) const; const Tile& get_tile() const; @@ -54,7 +54,7 @@ class PageConfig { Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const; void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const; - Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const; + Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const; size_t get_page_size_bytes(const Size& page_size, DataType dtype) const; std::optional get_tile() const; diff --git a/ttnn/cpp/ttnn/tensor/layout/size.cpp b/ttnn/cpp/ttnn/tensor/layout/size.cpp index d1939eba5e9..ff26c5f68a1 100644 --- a/ttnn/cpp/ttnn/tensor/layout/size.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/size.cpp @@ -16,6 +16,9 @@ Size::Size(const std::pair& size) Size::Size(const std::array& size) : Size(size[0], size[1]) {} +Size::Size(const std::array& size) + : Size(size[0], size[1]) {} + Size Size::operator*(size_t scalar) const { return Size(height_ * scalar, width_ * scalar); } diff --git a/ttnn/cpp/ttnn/tensor/layout/size.hpp b/ttnn/cpp/ttnn/tensor/layout/size.hpp index ed3c0590888..622a270752c 100644 --- a/ttnn/cpp/ttnn/tensor/layout/size.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/size.hpp @@ -13,6 +13,7 @@ class Size final { Size(size_t height, size_t width); Size(const std::pair& size); Size(const std::array& size); + Size(const std::array& size); operator std::pair() const; operator std::array() const; diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index fb007da544c..7b8dffb2f75 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -17,29 +17,56 @@ size_t round_up(size_t value, size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; }; -Alignment legacyShapeToAlignment(const ttnn::Shape& shape) { - auto legacy_padded_shape = shape.padded_shape(); - if (shape.logical_shape() == legacy_padded_shape) { +Alignment legacyShapeToAlignment(const ttnn::Shape& shape, const PageConfig& page_config, const MemoryConfig& memory_config) { + const auto& logical_shape = shape.logical_shape(); + const auto& legacy_padded_shape = shape.padded_shape(); + if (logical_shape == legacy_padded_shape) { return Alignment{}; } const auto rank = legacy_padded_shape.rank(); - ttnn::SmallVector values(rank); - - if(rank >= 1) { - values[rank - 1] = legacy_padded_shape[rank - 1]; - } - if(rank >= 2) { - values[rank - 2] = legacy_padded_shape[rank - 2]; - } + bool alignment_can_be_2D = true; for (int i = rank - 3; i >= 0; i--) { - values[i] = legacy_padded_shape[i] * values[i + 1]; + alignment_can_be_2D &= logical_shape[i] == legacy_padded_shape[i]; } - Alignment result(std::move(values)); - return result; + if (memory_config.shard_spec.has_value()) { + TT_FATAL(alignment_can_be_2D, "Tensor with shape {} cannot be sharded because alignment will have rank greater than 2!", shape); + const auto& shard_spec = memory_config.shard_spec.value(); + TT_FATAL(shard_spec.mode == ShardMode::PHYSICAL, "Shard mode {} has to be ShardMode::PHYSICAL for legacyShapeToAlignment!", shard_spec.mode); + if (page_config.is_row_major()) { + return Alignment{shard_spec.shape[1]}; + } + return Alignment{}; + } else { + if (alignment_can_be_2D) { + ttnn::SmallVector values(std::min((int) rank, 2)); + const auto alignment_size = values.size(); + if (alignment_size >= 1) { + values[alignment_size - 1] = legacy_padded_shape[-1]; + } + if (alignment_size == 2) { + values[alignment_size - 2] = legacy_padded_shape[-2]; + } + Alignment result(std::move(values)); + return result; + } else { + // NOTE: Rank > 2 is guaranteed in this case + ttnn::SmallVector values(rank); + values[rank - 1] = legacy_padded_shape[-1]; + values[rank - 2] = legacy_padded_shape[-2]; + + for (int i = rank - 3; i >= 0; i--) { + values[i] = legacy_padded_shape[i] * values[i + 1]; + } + + Alignment result(std::move(values)); + return result; + } + } } + } // namespace CMAKE_UNIQUE_NAMESPACE } @@ -47,7 +74,6 @@ TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const : TensorLayout(dtype, page_config, memory_config, {}) { } -// Private TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const Alignment& alignment) : dtype_(dtype), page_config_(page_config), @@ -59,11 +85,11 @@ TensorLayout::TensorLayout(DataType dtype, const PageConfig& page_config, const } TensorLayout TensorLayout::fromLegacyPaddedShape(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape) { - return TensorLayout(dtype, page_config, memory_config, CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(legacy_shape)); + return TensorLayout(dtype, page_config, memory_config, CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(legacy_shape, page_config, memory_config)); } void TensorLayout::initialize_alignment() { - if(!alignment_.empty()) { + if (!alignment_.empty()) { return; } @@ -72,6 +98,7 @@ void TensorLayout::initialize_alignment() { void TensorLayout::validate_alignment() const { + TT_FATAL(alignment_.size() <= 2 || !memory_config_.is_sharded(), "Tensor must be interleaved if alignment has rank greater than 2!"); return page_config_.validate_alignment(alignment_, dtype_, memory_config_); } @@ -82,7 +109,17 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn TT_FATAL(memory_config_.shard_spec.has_value(), "MemoryConfig must have Shard Spec specified for sharded memory layout"); - auto& shard_spec = memory_config_.shard_spec.value(); + auto shard_spec = memory_config_.shard_spec.value(); + + const auto& physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); + if (shard_spec.mode == ShardMode::PHYSICAL) { + TT_FATAL(shard_spec.shape == physical_shard_shape, "In shard mode {}, shard shape {} is not compatible with alignment {}!", shard_spec.mode, shard_spec.shape, alignment_); + } else if (shard_spec.mode == ShardMode::LOGICAL) { + shard_spec.shape = physical_shard_shape; + } else { + TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); + } + const Size physical_size = compute_physical_shape(shape); const Size page_shape = compute_page_shape(physical_size); @@ -122,33 +159,104 @@ size_t TensorLayout::compute_page_size_bytes(const Size& page_size) const { return page_config_.get_page_size_bytes(page_size, dtype_); } +Size TensorLayout::compute_physical_shard_shape(const Size& logical_shard_shape) const { + // TODO: Alignment before sharding can describe padding along non-height or width dims + // However, this is not compatible with logical sharding and we just leave it as nullopt + const int alignment_rank = static_cast(alignment_.size()); + TT_FATAL(alignment_rank <= 2, "Alignment {} must be rank 2 or less to compute physical shard shape", alignment_); + + auto physical_shard_height = CMAKE_UNIQUE_NAMESPACE::round_up(logical_shard_shape.height(), alignment_[-2]); + auto physical_shard_width = CMAKE_UNIQUE_NAMESPACE::round_up(logical_shard_shape.width(), alignment_[-1]); + return Size{physical_shard_height, physical_shard_width}; +} + Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const { const int rank = static_cast(shape.rank()); const int alignment_rank = static_cast(alignment_.size()); - const int max_rank = std::max(rank, alignment_rank); + size_t width = 1; size_t height = 1; - // Iterate dims in reverse order and ensure alignment - // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) - for (int i = -1; i >= -max_rank; --i) { - auto& dim = i == -1 ? width : height; - if(i >= -rank) { + if (memory_config_.shard_spec.has_value() and memory_config_.shard_spec.value().mode == ShardMode::LOGICAL) { + // Iterate dims in reverse order + for (int i = -1; i >= -rank; --i) { + auto& dim = i == -1 ? width : height; dim *= shape[i]; } - // Align the current dimension if alignment is available - if (i >= -alignment_rank) { - dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); + auto compute_default_logical_shard_shape = [&shape, &rank]() { + size_t shard_width = 1; + size_t shard_height = 1; + + if (rank >= 1) { + shard_width = shape[-1]; + } + if (rank >= 2) { + shard_height = shape[-2]; + } + + return Size{shard_height, shard_width}; + }; + + const auto& logical_shard_shape = memory_config_.shard_spec.has_value() ? Size(memory_config_.shard_spec.value().shape) : compute_default_logical_shard_shape(); + const auto& physical_shard_shape = compute_physical_shard_shape(logical_shard_shape); + + auto get_physical_size = [](auto original_size, auto logical_shard_size, auto physical_shard_size, auto alignment) -> uint32_t { + if (logical_shard_size == 0) { + return 0; + } + // If we always pad to full shards, then return: + // auto num_shards = tt::div_up(original_size, logical_shard_size); + // return (uint32_t) physical_shard_size * num_shards; + + // But host physical data is only padded up to nearest alignment (and no padding in between shards) + // So last shard is only up to nearest alignment + auto num_full_shards = original_size / logical_shard_size; + auto last_physical_shard_size = CMAKE_UNIQUE_NAMESPACE::round_up(original_size % logical_shard_size, alignment); + return (physical_shard_size * num_full_shards + last_physical_shard_size); + }; + + auto physical_height = get_physical_size(height, logical_shard_shape.height(), physical_shard_shape.height(), alignment_[-2]); + auto physical_width = get_physical_size(width, logical_shard_shape.width(), physical_shard_shape.width(), alignment_[-1]); + + Size size{physical_height, physical_width}; + return size; + } else { + const int max_rank = std::max(rank, alignment_rank); + + // Iterate dims in reverse order and ensure alignment + // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) + for (int i = -1; i >= -max_rank; --i) { + auto& dim = i == -1 ? width : height; + if(i >= -rank) { + dim *= shape[i]; + } + + // Align the current dimension if alignment is available + if (i >= -alignment_rank) { + dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); + } } - } - Size size{height, width}; - return size; + Size size{height, width}; + return size; + } } Size TensorLayout::compute_page_shape(const Size& physical_size) const { - return page_config_.get_page_shape(physical_size, dtype_, memory_config_); + std::optional physical_shard_shape = std::nullopt; + if (memory_config_.shard_spec.has_value()) { + const auto& shard_spec = memory_config_.shard_spec.value(); + if (shard_spec.mode == ShardMode::PHYSICAL) { + physical_shard_shape = shard_spec.shape; + } else if (shard_spec.mode == ShardMode::LOGICAL) { + physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); + } else { + TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); + } + } + + return page_config_.get_page_shape(physical_size, dtype_, memory_config_, physical_shard_shape); } Strides TensorLayout::compute_strides(const ttnn::SimpleShape& shape) const { diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index 487534fc125..6560ac96470 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -24,6 +24,8 @@ class TensorLayout { public: TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config); + TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const Alignment& alignment); + // static method makes it easy to find and remove all of its usages in the codebase - thats why it is not a constructor [[deprecated("Use of Legacy Padded Shape is deprecated")]] static TensorLayout fromLegacyPaddedShape(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape); @@ -52,9 +54,6 @@ class TensorLayout { Size compute_physical_shape(const ttnn::SimpleShape& shape) const; private: - // Private to not expose alignment parameter to the public API - TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const Alignment& alignment); - void initialize_alignment(); void validate_alignment() const; @@ -65,6 +64,8 @@ class TensorLayout { PageConfig page_config_; MemoryConfig memory_config_; Alignment alignment_; + + Size compute_physical_shard_shape(const Size& logical_shard_shape) const; }; } // tt::tt_metal diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba363c3971..ec6b174739b 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -124,6 +124,7 @@ def manage_config(name, value): L1_WIDTH_SHARDED_MEMORY_CONFIG, ShardStrategy, ShardOrientation, + ShardMode, ShardSpec, CoreRangeSet, CoreRange, diff --git a/ttnn/ttnn/types.py b/ttnn/ttnn/types.py index 0c95d333381..0fd3f775313 100644 --- a/ttnn/ttnn/types.py +++ b/ttnn/ttnn/types.py @@ -78,6 +78,7 @@ class ShardStrategy(Enum): ShardOrientation = ttnn._ttnn.tensor.ShardOrientation +ShardMode = ttnn._ttnn.tensor.ShardMode ShardSpec = ttnn._ttnn.tensor.ShardSpec CoreRangeSet = ttnn._ttnn.tensor.CoreRangeSet CoreRange = ttnn._ttnn.tensor.CoreRange From 17af0d5815523044d724e61177e40d5af16f9c25 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Thu, 14 Nov 2024 21:01:28 +0000 Subject: [PATCH 2/2] #0: clean up --- tt_metal/impl/buffers/buffer.hpp | 5 - ttnn/cpp/pybind11/pytensor.cpp | 4 +- ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 135 +++++++++--------- ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp | 3 +- 4 files changed, 69 insertions(+), 78 deletions(-) diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index 2275ee1d47d..5f4aefc03ce 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -58,11 +58,6 @@ struct ShardSpec { const bool &halo_ = false, const ShardMode &shard_mode_ = ShardMode::PHYSICAL) : grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_), mode(shard_mode_) { - if (shard_mode_ == ShardMode::PHYSICAL) { - tt::log_warning( - tt::LogOp, - "ShardMode::PHYSICAL will be deprecated soon! Please switch to equivalent representation with ShardMode::LOGICAL"); - } } const uint32_t num_cores() const { return this->grid.num_cores(); } diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 9d8f8c6099d..1bd5e3ea56e 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -1505,6 +1505,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( + // TODO: Rename to physical_volume "volume", [](const Tensor &self) { return self.volume(); }, R"doc( Get the volume of the tensor. @@ -1514,12 +1515,13 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( + // TODO: Rename to volume "logical_volume", [](const Tensor &self) { return self.get_logical_volume(); }, R"doc( Get the logical volume of the tensor. .. code-block:: python - volume = tt_tensor.get_logical_volume() + volume = tt_tensor.logical_volume() )doc") .def( diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index 7b8dffb2f75..b0eda548cbd 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -30,40 +30,41 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape, const PageConfig& pag alignment_can_be_2D &= logical_shape[i] == legacy_padded_shape[i]; } + // SHARDED if (memory_config.shard_spec.has_value()) { TT_FATAL(alignment_can_be_2D, "Tensor with shape {} cannot be sharded because alignment will have rank greater than 2!", shape); - const auto& shard_spec = memory_config.shard_spec.value(); - TT_FATAL(shard_spec.mode == ShardMode::PHYSICAL, "Shard mode {} has to be ShardMode::PHYSICAL for legacyShapeToAlignment!", shard_spec.mode); if (page_config.is_row_major()) { - return Alignment{shard_spec.shape[1]}; + return Alignment{memory_config.shard_spec.value().shape[1]}; } return Alignment{}; - } else { - if (alignment_can_be_2D) { - ttnn::SmallVector values(std::min((int) rank, 2)); - const auto alignment_size = values.size(); - if (alignment_size >= 1) { - values[alignment_size - 1] = legacy_padded_shape[-1]; - } - if (alignment_size == 2) { - values[alignment_size - 2] = legacy_padded_shape[-2]; - } - Alignment result(std::move(values)); - return result; - } else { - // NOTE: Rank > 2 is guaranteed in this case - ttnn::SmallVector values(rank); - values[rank - 1] = legacy_padded_shape[-1]; - values[rank - 2] = legacy_padded_shape[-2]; - - for (int i = rank - 3; i >= 0; i--) { - values[i] = legacy_padded_shape[i] * values[i + 1]; - } + } - Alignment result(std::move(values)); - return result; + // INTERLEAVED with only height/width padding + if (alignment_can_be_2D) { + ttnn::SmallVector values(std::min((int) rank, 2)); + const auto alignment_size = values.size(); + if (alignment_size >= 1) { + values[alignment_size - 1] = legacy_padded_shape[-1]; + } + if (alignment_size == 2) { + values[alignment_size - 2] = legacy_padded_shape[-2]; } + Alignment result(std::move(values)); + return result; } + + // INTERLEAVED with (deprecated) non-height/width padding + // NOTE: Rank > 2 is guaranteed in this case + ttnn::SmallVector values(rank); + values[rank - 1] = legacy_padded_shape[-1]; + values[rank - 2] = legacy_padded_shape[-2]; + + for (int i = rank - 3; i >= 0; i--) { + values[i] = legacy_padded_shape[i] * values[i + 1]; + } + + Alignment result(std::move(values)); + return result; } @@ -112,12 +113,15 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn auto shard_spec = memory_config_.shard_spec.value(); const auto& physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); - if (shard_spec.mode == ShardMode::PHYSICAL) { - TT_FATAL(shard_spec.shape == physical_shard_shape, "In shard mode {}, shard shape {} is not compatible with alignment {}!", shard_spec.mode, shard_spec.shape, alignment_); - } else if (shard_spec.mode == ShardMode::LOGICAL) { - shard_spec.shape = physical_shard_shape; - } else { - TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); + switch (shard_spec.mode) { + case ShardMode::PHYSICAL: + TT_FATAL(shard_spec.shape == physical_shard_shape, "In shard mode {}, shard shape {} is not compatible with alignment {}!", shard_spec.mode, shard_spec.shape, alignment_); + break; + case ShardMode::LOGICAL: + shard_spec.shape = physical_shard_shape; + break; + default: + TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); } const Size physical_size = compute_physical_shape(shape); @@ -177,6 +181,7 @@ Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const size_t width = 1; size_t height = 1; + // LOGICAL SHARDING if (memory_config_.shard_spec.has_value() and memory_config_.shard_spec.value().mode == ShardMode::LOGICAL) { // Iterate dims in reverse order for (int i = -1; i >= -rank; --i) { @@ -184,21 +189,7 @@ Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const dim *= shape[i]; } - auto compute_default_logical_shard_shape = [&shape, &rank]() { - size_t shard_width = 1; - size_t shard_height = 1; - - if (rank >= 1) { - shard_width = shape[-1]; - } - if (rank >= 2) { - shard_height = shape[-2]; - } - - return Size{shard_height, shard_width}; - }; - - const auto& logical_shard_shape = memory_config_.shard_spec.has_value() ? Size(memory_config_.shard_spec.value().shape) : compute_default_logical_shard_shape(); + const auto& logical_shard_shape = Size(memory_config_.shard_spec.value().shape); const auto& physical_shard_shape = compute_physical_shard_shape(logical_shard_shape); auto get_physical_size = [](auto original_size, auto logical_shard_size, auto physical_shard_size, auto alignment) -> uint32_t { @@ -209,8 +200,8 @@ Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const // auto num_shards = tt::div_up(original_size, logical_shard_size); // return (uint32_t) physical_shard_size * num_shards; - // But host physical data is only padded up to nearest alignment (and no padding in between shards) - // So last shard is only up to nearest alignment + // If we pad all shards except last shard up to physical size and last one only up to nearest alignment, then return this: + // NOTE: This matches existing physical sharding where physical host data can be sharded with partial shards auto num_full_shards = original_size / logical_shard_size; auto last_physical_shard_size = CMAKE_UNIQUE_NAMESPACE::round_up(original_size % logical_shard_size, alignment); return (physical_shard_size * num_full_shards + last_physical_shard_size); @@ -221,38 +212,42 @@ Size TensorLayout::compute_physical_shape(const ttnn::SimpleShape& shape) const Size size{physical_height, physical_width}; return size; - } else { - const int max_rank = std::max(rank, alignment_rank); + } - // Iterate dims in reverse order and ensure alignment - // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) - for (int i = -1; i >= -max_rank; --i) { - auto& dim = i == -1 ? width : height; - if(i >= -rank) { - dim *= shape[i]; - } + // INTERLEAVED or deprecated PHYSICAL SHARDING + const int max_rank = std::max(rank, alignment_rank); - // Align the current dimension if alignment is available - if (i >= -alignment_rank) { - dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); - } + // Iterate dims in reverse order and ensure alignment + // Even tensor of rank 0 or 1 must be aligned (to Tile / Page / Shard) + for (int i = -1; i >= -max_rank; --i) { + auto& dim = i == -1 ? width : height; + if(i >= -rank) { + dim *= shape[i]; } - Size size{height, width}; - return size; + // Align the current dimension if alignment is available + if (i >= -alignment_rank) { + dim = CMAKE_UNIQUE_NAMESPACE::round_up(dim, alignment_[i]); + } } + + Size size{height, width}; + return size; } Size TensorLayout::compute_page_shape(const Size& physical_size) const { std::optional physical_shard_shape = std::nullopt; if (memory_config_.shard_spec.has_value()) { const auto& shard_spec = memory_config_.shard_spec.value(); - if (shard_spec.mode == ShardMode::PHYSICAL) { - physical_shard_shape = shard_spec.shape; - } else if (shard_spec.mode == ShardMode::LOGICAL) { - physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); - } else { - TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); + switch (shard_spec.mode) { + case ShardMode::PHYSICAL: + physical_shard_shape = shard_spec.shape; + break; + case ShardMode::LOGICAL: + physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); + break; + default: + TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); } } diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index 6560ac96470..5eb1e9d3148 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -59,13 +59,12 @@ class TensorLayout { Size compute_page_shape(const Size& physical_size) const; size_t compute_page_size_bytes(const Size& page_size) const; + Size compute_physical_shard_shape(const Size& logical_shard_shape) const; DataType dtype_ = DataType::BFLOAT16; PageConfig page_config_; MemoryConfig memory_config_; Alignment alignment_; - - Size compute_physical_shard_shape(const Size& logical_shard_shape) const; }; } // tt::tt_metal