Skip to content

Commit

Permalink
#12184: Alignment fix for BH on I2S and S2I (fix after revert) (#15055)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#12184 (comment))

### Problem description
Alignment issues for BH when going from DRAM to L1 in blackhole as a lot
of alignment issues were hardcoded for WH case


### What's changed
Added extra logic to handle alignment on i2s and s2i side. 


### Checklist
- [x] Post commit CI passes
(https://github.com/tenstorrent/tt-metal/actions/runs/11846373612)
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
ntarafdar authored Nov 15, 2024
1 parent 1620ba8 commit 4f565bd
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_sharded_tile(


# TODO (7735): Switch to new interleaved_to_sharded with sharded_mem_config input and re-enable BLOCK sharded tests
@skip_for_blackhole("WIP")
@pytest.mark.parametrize(
"input_shape, shard_scheme, shard_size, num_cores",
[
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_sharded_rm(
assert passing


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_blackhole("BH LLK issue with untilize, #14594")
@pytest.mark.parametrize("H, num_cores", [[100352, 98], [25088, 98]])
@pytest.mark.parametrize("in_sharded", [True, False])
@pytest.mark.parametrize("out_sharded", [True, False])
Expand Down Expand Up @@ -256,7 +257,7 @@ def test_sharded_untilize(H, num_cores, in_sharded, out_sharded, dtype, device,
assert passing


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_blackhole("Mismatching on BH, see #14609")
@pytest.mark.parametrize("H, num_cores", [[25088, 98]])
@pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_sharded_tilize(H, num_cores, output_dtype, device, function_level_defaults):
Expand Down Expand Up @@ -895,6 +896,7 @@ def test_partial_sharded_op_binary(
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"])
@pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"])
Expand Down Expand Up @@ -1335,6 +1337,7 @@ def test_sharded_matmul_2d_transposed(
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
def test_resharded_binary_to_matmul(device, function_level_defaults):
grid_size_binary = device.compute_with_storage_grid_size()
num_cores_binary = 98
Expand Down Expand Up @@ -1426,6 +1429,7 @@ def test_resharded_binary_to_matmul(device, function_level_defaults):
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1501,6 +1505,7 @@ def test_sharded_untilize_padded_shard(in_sharded, out_sharded, dtype, device, f
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"])
@pytest.mark.parametrize("activations_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1691,6 +1696,7 @@ def test_block_sharded_untilize_with_unpadding(in_sharded, out_sharded, dtype, d
"unbatched_16_shape_out_interleaved",
],
)
@skip_for_blackhole("BH Issue with untilize LLK, see #14594")
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_width_sharded_untilize_with_unpadding(
shape, output_H, in_sharded, out_sharded, dtype, device, function_level_defaults
Expand Down Expand Up @@ -1761,7 +1767,7 @@ def test_width_sharded_untilize_with_unpadding(
assert passing


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_blackhole("BH LLK Issue with tilize, #14609")
@pytest.mark.parametrize("input_shape", [[8, 1, 49, 2048], [1, 1, 8, 2048], [16, 1, 49, 2048], [1, 1, 16, 2048]])
@pytest.mark.parametrize("sharding_config", [(True, True), (False, False)], ids=["both_sharded", "both_interleaved"])
@pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1833,7 +1839,6 @@ def test_sharded_tilize_with_val_padding(input_shape, sharding_config, output_dt
assert passing


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("N", [8, 16])
@pytest.mark.parametrize("in_sharded", [True], ids=["in0_sharded"])
@pytest.mark.parametrize("out_sharded", [True], ids=["out_sharded"])
Expand Down Expand Up @@ -2064,6 +2069,7 @@ def test_sharded_matmul_1d_in1_wormhole(device, function_level_defaults):
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"])
@pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"])
Expand Down
94 changes: 94 additions & 0 deletions tests/ttnn/unit_tests/operations/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,97 @@ def test_create_sharded_memory_config(device, shape, strategy, orientation, core

passing = torch.equal(input_data, output_data)
assert passing


@pytest.mark.parametrize(
"shape, shard_shape, strategy, orientation, core_grid",
[
([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=1, x=1)),
([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
([1, 1, 32, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
([1, 1, 64, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
(
[1, 1, 2, 16],
[2, 16],
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)),
}
),
),
(
[1, 1, 5280, 16],
[5280, 16],
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)),
}
),
),
# TODO: Add this test back by checking for core grid size and skipping if we can't do it
# (
# [1, 1, 675840, 16],
# [5280, 16],
# ttnn.ShardStrategy.HEIGHT,
# ttnn.ShardOrientation.ROW_MAJOR,
# ttnn.CoreRangeSet(
# {
# ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(11, 9)), # 120
# ttnn.CoreRange(ttnn.CoreCoord(12, 0), ttnn.CoreCoord(12, 7)), # 8
# }
# ),
# ),
],
)
@pytest.mark.parametrize(
"input_buffer_type",
[
ttnn.L1_MEMORY_CONFIG,
ttnn.DRAM_MEMORY_CONFIG,
],
)
@pytest.mark.parametrize(
"output_buffer_type",
[
ttnn.L1_MEMORY_CONFIG,
ttnn.DRAM_MEMORY_CONFIG,
],
)
def test_bh_alignment_i2s(
device, shape, shard_shape, strategy, orientation, core_grid, input_buffer_type, output_buffer_type
):
torch.manual_seed(0)
input_data = torch.randn(shape, dtype=torch.bfloat16)
if shard_shape == None:
shard_config = ttnn.create_sharded_memory_config(
shape=shape,
core_grid=core_grid,
strategy=strategy,
orientation=orientation,
use_height_and_width_as_shard_shape=False,
)
else:
shard_config = ttnn.create_sharded_memory_config(
shape=shard_shape,
core_grid=core_grid,
strategy=strategy,
orientation=orientation,
use_height_and_width_as_shard_shape=True,
)
x_t = ttnn.from_torch(
input_data,
device=device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=input_buffer_type,
dtype=ttnn.bfloat16,
)
x_t_sharded = ttnn.to_memory_config(x_t, shard_config)
x_t = ttnn.to_memory_config(x_t_sharded, output_buffer_type)
output_data = ttnn.from_device(x_t)
output_data = ttnn.to_torch(output_data)
passing = torch.equal(input_data, output_data)
assert passing
39 changes: 36 additions & 3 deletions ttnn/cpp/ttnn/operations/core/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp"
#include "ttnn/distributed/types.hpp"
#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp"
#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp"

namespace ttnn::operations::core {

Expand Down Expand Up @@ -54,12 +56,30 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) {
}

ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional<MemoryConfig>& memory_config) {
return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
if(mem_config.is_sharded () and (device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG);
return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt);
}
else {
return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}

}

ttnn::Tensor to_device(
const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional<MemoryConfig>& memory_config) {
return tensor.to(mesh_device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
// Currently no direct sharded write support in BLACKHOLE due to alignment issue
if(mem_config.is_sharded () and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG);
return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt);
}
else {
return tensor.to(mesh_device, mem_config);
}


}

ttnn::Tensor allocate_tensor_on_device(
Expand All @@ -86,7 +106,20 @@ void copy_host_to_device_tensor(ttnn::Tensor host_tensor, ttnn::Tensor device_te
tt::tt_metal::write_tensor(host_tensor, device_tensor, cq_id);
}

ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { return tensor.cpu(blocking, cq_id); }

ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) {

// Currently no direct sharded read support in BLACKHOLE due to alignment issue
if(tensor.is_sharded () and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt);
return interleaved_tensor.cpu(blocking, cq_id);
}
else {
return tensor.cpu(blocking, cq_id);

}

}

void deallocate(Tensor& tensor, bool force) { tensor.deallocate(force); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
bool rm_orientation = shard_spec.orientation == ShardOrientation::ROW_MAJOR;

CoreCoord end_core = (*shard_spec.grid.ranges().rbegin()).end_coord;

bool convert_df = input_cb_data_format != output_cb_data_format;
auto src_buffer = input.buffer();
auto dst_buffer = output.buffer();
bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE);

if (input.get_layout() == Layout::TILE) {
num_units = input.volume() / TILE_HW;
input_unit_size = tt::tt_metal::detail::TileSize(input_cb_data_format);
Expand Down Expand Up @@ -66,13 +73,6 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
padded_offset_bytes = align(input_unit_size, input.buffer()->alignment());
}

bool convert_df = input_cb_data_format != output_cb_data_format;

auto src_buffer = input.buffer();

auto dst_buffer = output.buffer();

bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;

auto all_cores = shard_spec.grid;
uint32_t input_cb_index = tt::CB::c_in0;
Expand All @@ -94,10 +94,17 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
.set_globally_allocated_address(*output.buffer());
auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_out_config);
uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM);
if (src_is_dram && input_unit_size % dram_alignment != 0) {
uint32_t scratch_cb_page_size = align(input_unit_size, dram_alignment);
if (src_is_dram && input_unit_size % dram_alignment != 0 or is_blackhole) {
uint32_t scratch_cb_page_size;
//scratchpad going to be used to align DRAM (64B) to L1 (16B)
if (is_blackhole) {
scratch_cb_page_size = align(input_unit_size, hal.get_alignment(HalMemType::L1));
}
else {
scratch_cb_page_size = align(input_unit_size, dram_alignment);
}
tt::tt_metal::CircularBufferConfig scratch_cb_out_config =
tt::tt_metal::CircularBufferConfig(1 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}})
tt::tt_metal::CircularBufferConfig(4 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}})
.set_page_size(scratch_cb_index, scratch_cb_page_size);
auto cb_scratch = tt::tt_metal::CreateCircularBuffer(program, all_cores, scratch_cb_out_config);
}
Expand Down Expand Up @@ -236,10 +243,23 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
}

uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM);
bool aligned = src_is_dram ? curr_idx_w % dram_alignment == 0 : true;
uint32_t l1_alignment = hal.get_alignment(HalMemType::L1);
bool aligned = (src_is_dram ? curr_idx_w % dram_alignment == 0 : true);
aligned = aligned and !(is_blackhole);
uint32_t aligned_width_offset, aligned_shard_width, aligned_offset;
if (!aligned) {
aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment);
//TODO: is this right, leaving non BH case the same for now, should investigate
if(!is_blackhole) {
aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment);
}
else {
if(src_is_dram) {
aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment);
}
else {
aligned_width_offset = tt::round_down(curr_idx_w, l1_alignment);
}
}
aligned_offset = curr_idx_w - aligned_width_offset;
aligned_shard_width = aligned_offset + shard_width;
} else {
Expand All @@ -256,7 +276,7 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
num_units_per_row,
shard_height,
shard_width,
padded_offset_bytes,
(is_blackhole) ? shard_width : padded_offset_bytes,
static_cast<uint32_t>(aligned),
aligned_width_offset,
aligned_shard_width,
Expand Down Expand Up @@ -305,6 +325,4 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ void ShardedToInterleavedDeviceOperation::validate(const std::vector<Tensor>& in
TT_FATAL(input_tensor.memory_config().buffer_type == BufferType::L1, "Input tensor must be in L1");
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Output memory config must be Interleaved");
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM);
uint32_t l1_alignment = hal.get_alignment(HalMemType::L1);
TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (this->output_mem_config.buffer_type == BufferType::DRAM ? dram_alignment : l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor, or {}B for DRAM tensor", l1_alignment, dram_alignment);
TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor", l1_alignment);
}
if (input_tensor.get_dtype() != this->output_dtype) {
TT_FATAL(input_tensor.get_layout() == Layout::TILE, "If diff output type, tensor must be TILED");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
tt_metal::ReaderDataMovementConfig(reader_compile_time_args));

bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE);

tt_metal::KernelHandle unary_writer_kernel_id;
if (input.get_layout() == Layout::TILE) {
Expand Down Expand Up @@ -141,7 +142,8 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
uint32_t curr_idx_w = 0;

const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation);
uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment());
uint32_t padded_offset_bytes;

for (const auto& core : cores) {
if (input.get_layout() == Layout::TILE) {
uint32_t shard_height = num_units_per_shard_height;
Expand Down Expand Up @@ -217,6 +219,13 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
}
}
}
uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM);
uint32_t l1_alignment = hal.get_alignment(HalMemType::L1);
uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment());
if(is_blackhole) {
if(!dst_is_dram)
padded_shard_width = align(output_unit_size, l1_alignment);
}
tt_metal::SetRuntimeArgs(
program,
unary_writer_kernel_id,
Expand All @@ -225,7 +234,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
num_units_per_row,
shard_height,
shard_width,
padded_shard_width,
(is_blackhole) ? shard_width : padded_shard_width,
curr_idx_w,
curr_idx_h});
curr_idx_w += output_unit_size;
Expand Down

0 comments on commit 4f565bd

Please sign in to comment.