Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for new logical sharding + alignment in TensorLayout and tensor creation #14771

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 285 additions & 35 deletions tests/ttnn/unit_tests/gtests/tensor/test_sharding_with_alignment.cpp

Large diffs are not rendered by default.

28 changes: 17 additions & 11 deletions tests/ttnn/unit_tests/operations/test_paged_update_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def run_test_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -151,11 +152,12 @@ def test_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down Expand Up @@ -234,11 +236,12 @@ def test_update_cache_decode_program_cache(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down Expand Up @@ -276,11 +279,12 @@ def run_test_tensor_index_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -414,11 +418,12 @@ def run_test_paged_update_cache_decode(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)
xt = xt.to(device, input_mem_config)
Expand Down Expand Up @@ -543,11 +548,12 @@ def test_paged_update_cache_decode_program_caching(
input_shard_spec = ttnn.ShardSpec(
shard_grid,
[
xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores,
xt.shape.with_tile_padding()[-1],
xt.logical_volume() // xt.shape[-1] // num_cores,
xt.shape[-1],
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
ttnn.ShardMode.LOGICAL,
)
input_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ tt_metal::ShardSpec from_json_t<tt_metal::ShardSpec>::operator()(const nlohmann:
from_json<CoreRangeSet>(json_object.at("grid")),
from_json<std::array<uint32_t, 2>>(json_object.at("shape")),
from_json<tt_metal::ShardOrientation>(json_object.at("orientation")),
from_json<bool>(json_object.at("halo"))};
from_json<bool>(json_object.at("halo")),
from_json<tt_metal::ShardMode>(json_object.at("mode"))};
}
}
16 changes: 12 additions & 4 deletions tt_metal/impl/buffers/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,20 @@ struct ShardSpec {
ShardOrientation orientation = ShardOrientation::ROW_MAJOR;
bool halo = false;

ShardMode mode = ShardMode::PHYSICAL;

ShardSpec(
const CoreRangeSet &core_sets_,
const std::array<uint32_t, 2> &shard_shape_,
const ShardOrientation &shard_orientation_ = ShardOrientation::ROW_MAJOR,
const bool &halo_ = false) :
grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_) {
const bool &halo_ = false,
const ShardMode &shard_mode_ = ShardMode::PHYSICAL) :
grid(core_sets_), shape(shard_shape_), orientation(shard_orientation_), halo(halo_), mode(shard_mode_) {
if (shard_mode_ == ShardMode::PHYSICAL) {
tt::log_warning(
tt::LogOp,
"ShardMode::PHYSICAL will be deprecated soon! Please switch to equivalent representation with ShardMode::LOGICAL");
}
}

const uint32_t num_cores() const { return this->grid.num_cores(); }
Expand All @@ -63,9 +71,9 @@ struct ShardSpec {
bool operator==(const ShardSpec& other) const;
bool operator!=(const ShardSpec& other) const;

static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo");
static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo", "mode");
constexpr auto attribute_values() const {
return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo);
return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo, this->mode);
}
};

Expand Down
5 changes: 5 additions & 0 deletions tt_metal/impl/buffers/buffer_constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ enum class ShardOrientation {
COL_MAJOR,
};

enum class ShardMode {
PHYSICAL, // TODO: Deprecate this option to treat shard shape as physical
LOGICAL,
};

enum class BufferType {
DRAM,
L1,
Expand Down
9 changes: 9 additions & 0 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,15 @@ void pytensor_module(py::module &m_tensor) {
volume = tt_tensor.volume()
)doc")
.def(
"logical_volume", [](const Tensor &self) { return self.get_logical_volume(); }, R"doc(
TT-BrianLiu marked this conversation as resolved.
Show resolved Hide resolved
Get the logical volume of the tensor.
.. code-block:: python
volume = tt_tensor.get_logical_volume()
)doc")
.def(
"storage_type", [](const Tensor &self) { return self.storage_type(); }, R"doc(
Expand Down
9 changes: 8 additions & 1 deletion ttnn/cpp/pybind11/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void tensor_mem_config_module_types(py::module& m_tensor) {
export_enum<MathFidelity>(m_tensor);
export_enum<TensorMemoryLayout>(m_tensor);
export_enum<ShardOrientation>(m_tensor);
export_enum<ShardMode>(m_tensor);

py::enum_<tt::tt_metal::BufferType>(m_tensor, "BufferType")
.value("DRAM", BufferType::DRAM)
Expand Down Expand Up @@ -266,10 +267,16 @@ void tensor_mem_config_module(py::module& m_tensor) {
.def(py::init<>([](const CoreRangeSet& core_sets,
const std::array<uint32_t, 2>& shard_shape,
const ShardOrientation& shard_orientation,
const bool& halo) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo); }))
const bool& halo) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, ShardMode::PHYSICAL); }))
.def(py::init<>([](const CoreRangeSet& core_sets,
const std::array<uint32_t, 2>& shard_shape,
const ShardOrientation& shard_orientation,
const bool& halo,
const ShardMode& shard_mode) { return ShardSpec(core_sets, shard_shape, shard_orientation, halo, shard_mode); }))
.def_readwrite("shape", &ShardSpec::shape, "Shape of shard.")
.def_readwrite("grid", &ShardSpec::grid, "Grid to layout shards.")
.def_readwrite("orientation", &ShardSpec::orientation, "Orientation of cores to read shards")
.def_readwrite("mode", &ShardSpec::mode, "Treat shard shape as physical (default) or logical")
.def("num_cores", &ShardSpec::num_cores, "Number of cores")
.def(py::self == py::self)
.def(py::self != py::self);
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/layout/alignment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ class Alignment final : protected ShapeBase {

std::ostream &operator<<(std::ostream &os, const tt::tt_metal::Alignment &shape);

} // namespace ttnn
} // namespace tt::tt_metal
49 changes: 22 additions & 27 deletions ttnn/cpp/ttnn/tensor/layout/page_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void PageConfig::validate_alignment(const Alignment& alignment, DataType dtype,
std::visit([&](const auto& config) constexpr { config.validate_alignment(alignment, dtype, memory_config); }, config_);
}

Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config); }, config_);
Size PageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const {
return std::visit([&](const auto& config) constexpr { return config.get_page_shape(physical_size, dtype, memory_config, physical_shard_size); }, config_);
}

size_t PageConfig::get_page_size_bytes(const Size& page_shape, DataType dtype) const {
Expand Down Expand Up @@ -92,7 +92,7 @@ void TilePageConfig::validate_alignment(const Alignment& alignment, DataType dty
"Wrong custom Tensor Layout alignment {}. For Tile layout second innermost dimension should be multiple of tile height {}.", alignment, tile_.get_height());
}

Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
Size TilePageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>&) const {
if(memory_config.memory_layout == TensorMemoryLayout::SINGLE_BANK && physical_size.width() != 0 && physical_size.height() != 0) {
return physical_size;
}
Expand All @@ -116,20 +116,17 @@ Alignment RowMajorPageConfig::create_default_alignment(DataType dtype, const Mem
const auto element_size = CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype);
auto width_alignment = sizeof(uint32_t) / element_size;

if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
const auto shard_width = shard_shape[1];
if (memory_config.shard_spec.has_value() && memory_config.shard_spec.value().mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& physical_shard_shape = memory_config.shard_spec.value().shape;
const auto physical_shard_width = physical_shard_shape[1];
TT_FATAL(
(shard_width % width_alignment) == 0,
"Invalid sharding configuration: For Row Major layout with element size of {} bytes, the innermost dimension must align to {} bytes. "
"Buffer data is packed as uint32_t (4 bytes), so the provided shard shape {} does not meet alignment requirements.",
element_size, width_alignment, shard_shape
);
(physical_shard_width % width_alignment) == 0,
"For Row Major layout and shard mode {}, the width of shard shape {} is treated as physical shard width and must be aligned to {} since we pack buffer data as uint32_t.",
memory_config.shard_spec.value().mode, physical_shard_shape, width_alignment
);

width_alignment = shard_width;
width_alignment = physical_shard_width;
}

return Alignment({width_alignment});}
}

Expand All @@ -140,21 +137,20 @@ void RowMajorPageConfig::validate_alignment(const Alignment& alignment, DataType
const uint32_t page_alignment = sizeof(uint32_t) / element_size;

TT_FATAL((width_alignment % page_alignment) == 0,
"Incorrect alignment configuration for Row Major layout: alignment {} requires innermost dimension alignment of {} bytes due to uint32_t (4-byte) packing, but the current alignment size is {}.",
alignment, element_size, page_alignment);
"Incorrect alignment configuration for Row Major layout: innermost dimension alignment must be aligned to {} bytes since we pack buffer data as uint32_t. With element size of {} byte(s), alignment {} must be a multiple of alignment {}.",
sizeof(uint32_t), element_size, alignment, page_alignment);

if(memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
const auto shard_width = shard_shape[1];
if (memory_config.shard_spec.has_value() && memory_config.shard_spec.value().mode == ShardMode::PHYSICAL && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& physical_shard_shape = memory_config.shard_spec.value().shape;
const auto physical_shard_width = physical_shard_shape[1];
TT_FATAL(
width_alignment % shard_width == 0,
"Alignment mismatch for sharded tensor: Expected alignment width of {} to match shard shape {} for Row Major layout.",
width_alignment, shard_shape);
physical_shard_width % width_alignment == 0,
"Alignment mismatch for sharded tensor: Expected physical shard shape {} to be aligned to {} along the width for Row Major layout.",
physical_shard_width, width_alignment);
}
}

Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const {
Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const {
if (physical_size.height() == 0 || physical_size.width() == 0) {
return Size(1, sizeof(uint32_t) / CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype));
}
Expand All @@ -164,10 +160,9 @@ Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtyp
}

if (memory_config.shard_spec.has_value() && memory_config.memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
const auto& shard_spec = memory_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
TT_FATAL(physical_shard_size.has_value(), "For width or block sharded tensors, Row Major page width comes from physical shard size so it must be provided!");

return Size(1, shard_shape[1]);
return Size(1, physical_shard_size.value().width());
}

return Size(1, physical_size.width());
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/layout/page_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RowMajorPageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;
};

Expand All @@ -34,7 +34,7 @@ class TilePageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;

const Tile& get_tile() const;
Expand All @@ -54,7 +54,7 @@ class PageConfig {
Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const;
void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const;

Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config) const;
Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional<Size>& physical_shard_size) const;
size_t get_page_size_bytes(const Size& page_size, DataType dtype) const;

std::optional<Tile> get_tile() const;
Expand Down
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/tensor/layout/size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Size::Size(const std::pair<size_t, size_t>& size)
Size::Size(const std::array<size_t, 2>& size)
: Size(size[0], size[1]) {}

Size::Size(const std::array<uint32_t, 2>& size)
: Size(size[0], size[1]) {}

Size Size::operator*(size_t scalar) const {
return Size(height_ * scalar, width_ * scalar);
}
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/tensor/layout/size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Size final {
Size(size_t height, size_t width);
Size(const std::pair<size_t, size_t>& size);
Size(const std::array<size_t, 2>& size);
Size(const std::array<uint32_t, 2>& size);

operator std::pair<size_t, size_t>() const;
operator std::array<size_t, 2>() const;
Expand Down
Loading
Loading