Skip to content

Commit

Permalink
Revert "#12184: Alignment fix for BH in I2S and S2I"
Browse files Browse the repository at this point in the history
This reverts commit cf1c75e.
  • Loading branch information
ttmchiou committed Nov 14, 2024
1 parent ce6ff4c commit 7f6ab69
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 243 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def test_sharded_tile(


# TODO (7735): Switch to new interleaved_to_sharded with sharded_mem_config input and re-enable BLOCK sharded tests
@skip_for_blackhole("WIP")
@pytest.mark.parametrize(
"input_shape, shard_scheme, shard_size, num_cores",
[
Expand Down Expand Up @@ -181,7 +180,7 @@ def test_sharded_rm(
assert passing


@skip_for_blackhole("BH LLK issue with untilize, #14594")
@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("H, num_cores", [[100352, 98], [25088, 98]])
@pytest.mark.parametrize("in_sharded", [True, False])
@pytest.mark.parametrize("out_sharded", [True, False])
Expand Down Expand Up @@ -257,7 +256,7 @@ def test_sharded_untilize(H, num_cores, in_sharded, out_sharded, dtype, device,
assert passing


@skip_for_blackhole("Mismatching on BH, see #14609")
@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("H, num_cores", [[25088, 98]])
@pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_sharded_tilize(H, num_cores, output_dtype, device, function_level_defaults):
Expand Down Expand Up @@ -896,7 +895,6 @@ def test_partial_sharded_op_binary(
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"])
@pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"])
Expand Down Expand Up @@ -1337,7 +1335,6 @@ def test_sharded_matmul_2d_transposed(
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
def test_resharded_binary_to_matmul(device, function_level_defaults):
grid_size_binary = device.compute_with_storage_grid_size()
num_cores_binary = 98
Expand Down Expand Up @@ -1429,7 +1426,6 @@ def test_resharded_binary_to_matmul(device, function_level_defaults):
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1505,7 +1501,6 @@ def test_sharded_untilize_padded_shard(in_sharded, out_sharded, dtype, device, f
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"])
@pytest.mark.parametrize("activations_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1696,7 +1691,6 @@ def test_block_sharded_untilize_with_unpadding(in_sharded, out_sharded, dtype, d
"unbatched_16_shape_out_interleaved",
],
)
@skip_for_blackhole("BH Issue with untilize LLK, see #14594")
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_width_sharded_untilize_with_unpadding(
shape, output_H, in_sharded, out_sharded, dtype, device, function_level_defaults
Expand Down Expand Up @@ -1767,7 +1761,7 @@ def test_width_sharded_untilize_with_unpadding(
assert passing


@skip_for_blackhole("BH LLK Issue with tilize, #14609")
@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("input_shape", [[8, 1, 49, 2048], [1, 1, 8, 2048], [16, 1, 49, 2048], [1, 1, 16, 2048]])
@pytest.mark.parametrize("sharding_config", [(True, True), (False, False)], ids=["both_sharded", "both_interleaved"])
@pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -1839,6 +1833,7 @@ def test_sharded_tilize_with_val_padding(input_shape, sharding_config, output_dt
assert passing


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("N", [8, 16])
@pytest.mark.parametrize("in_sharded", [True], ids=["in0_sharded"])
@pytest.mark.parametrize("out_sharded", [True], ids=["out_sharded"])
Expand Down Expand Up @@ -2069,7 +2064,6 @@ def test_sharded_matmul_1d_in1_wormhole(device, function_level_defaults):
assert passing


@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745")
@pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"])
@pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"])
@pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"])
Expand Down
94 changes: 0 additions & 94 deletions tests/ttnn/unit_tests/operations/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,97 +439,3 @@ def test_create_sharded_memory_config(device, shape, strategy, orientation, core

passing = torch.equal(input_data, output_data)
assert passing


@pytest.mark.parametrize(
"shape, shard_shape, strategy, orientation, core_grid",
[
([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=1, x=1)),
([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
([1, 1, 32, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
([1, 1, 64, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)),
(
[1, 1, 2, 16],
[2, 16],
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)),
}
),
),
(
[1, 1, 5280, 16],
[5280, 16],
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)),
}
),
),
# TODO: Add this test back by checking for core grid size and skipping if we can't do it
# (
# [1, 1, 675840, 16],
# [5280, 16],
# ttnn.ShardStrategy.HEIGHT,
# ttnn.ShardOrientation.ROW_MAJOR,
# ttnn.CoreRangeSet(
# {
# ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(11, 9)), # 120
# ttnn.CoreRange(ttnn.CoreCoord(12, 0), ttnn.CoreCoord(12, 7)), # 8
# }
# ),
# ),
],
)
@pytest.mark.parametrize(
"input_buffer_type",
[
ttnn.L1_MEMORY_CONFIG,
ttnn.DRAM_MEMORY_CONFIG,
],
)
@pytest.mark.parametrize(
"output_buffer_type",
[
ttnn.L1_MEMORY_CONFIG,
ttnn.DRAM_MEMORY_CONFIG,
],
)
def test_bh_alignment_i2s(
device, shape, shard_shape, strategy, orientation, core_grid, input_buffer_type, output_buffer_type
):
torch.manual_seed(0)
input_data = torch.randn(shape, dtype=torch.bfloat16)
if shard_shape == None:
shard_config = ttnn.create_sharded_memory_config(
shape=shape,
core_grid=core_grid,
strategy=strategy,
orientation=orientation,
use_height_and_width_as_shard_shape=False,
)
else:
shard_config = ttnn.create_sharded_memory_config(
shape=shard_shape,
core_grid=core_grid,
strategy=strategy,
orientation=orientation,
use_height_and_width_as_shard_shape=True,
)
x_t = ttnn.from_torch(
input_data,
device=device,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=input_buffer_type,
dtype=ttnn.bfloat16,
)
x_t_sharded = ttnn.to_memory_config(x_t, shard_config)
x_t = ttnn.to_memory_config(x_t_sharded, output_buffer_type)
output_data = ttnn.from_device(x_t)
output_data = ttnn.to_torch(output_data)
passing = torch.equal(input_data, output_data)
assert passing
3 changes: 0 additions & 3 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ def run_max_pool(
output_host = output.cpu()
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
output_pytorch = output_pytorch_padded[:, :, :, :in_c]
torch.set_printoptions(profile="full")
print("output_pytorch" + str(output_pytorch))
torch.set_printoptions(profile="default") # reset

## reference
golden_pytorch = torch.nn.MaxPool2d(
Expand Down
37 changes: 3 additions & 34 deletions ttnn/cpp/ttnn/operations/core/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp"
#include "ttnn/distributed/types.hpp"
#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp"
#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp"

namespace ttnn::operations::core {

Expand Down Expand Up @@ -56,29 +54,12 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) {
}

ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional<MemoryConfig>& memory_config) {
auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
if(mem_config.is_sharded () and (device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG);
return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt);
}
else {
return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}
return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}

ttnn::Tensor to_device(
const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional<MemoryConfig>& memory_config) {

auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
// Currently no direct sharded write support in BLACKHOLE due to alignment issue
if(mem_config.is_sharded () and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG);
return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt);
}
else {
return tensor.to(mesh_device, mem_config);
}

return tensor.to(mesh_device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
}

ttnn::Tensor allocate_tensor_on_device(
Expand All @@ -105,19 +86,7 @@ void copy_host_to_device_tensor(ttnn::Tensor host_tensor, ttnn::Tensor device_te
tt::tt_metal::write_tensor(host_tensor, device_tensor, cq_id);
}

ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) {

// Currently no direct sharded read support in BLACKHOLE due to alignment issue
if(tensor.is_sharded () and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt);
return interleaved_tensor.cpu(blocking, cq_id);
}
else {
return tensor.cpu(blocking, cq_id);

}

}
ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { return tensor.cpu(blocking, cq_id); }

void deallocate(Tensor& tensor, bool force) { tensor.deallocate(force); }

Expand Down
20 changes: 0 additions & 20 deletions ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
#include <stdint.h>
#include "dataflow_api.h"

//#define DEBUG

#ifdef DEBUG
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp"
#endif

void kernel_main() {

const uint32_t src_addr = get_arg_val<uint32_t>(0);
Expand Down Expand Up @@ -44,36 +38,26 @@ void kernel_main() {
uint32_t stick_id = start_id;
cb_reserve_back(cb_id_in0, block_height);
uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
uint32_t l1_write_addr_base = l1_write_addr;
if (aligned) {
for (uint32_t h = 0; h < block_height; ++h) {
uint64_t src_noc_addr = get_noc_addr(stick_id, s0);
noc_async_read(src_noc_addr, l1_write_addr, block_width_bytes);
stick_id++;
#ifdef DEBUG
noc_async_read_barrier();
tt::data_movement::common::print_pages(l1_write_addr, block_width_bytes >> 1, 1);
#endif
l1_write_addr += padded_block_width_bytes;
}
} else {
cb_reserve_back(cb_id_in1, 4);
cb_reserve_back(cb_id_in1, 1);
uint32_t scratch_l1_write_addr = get_write_ptr(cb_id_in1);
uint64_t scratch_l1_noc_read_addr = get_noc_addr(scratch_l1_write_addr + aligned_offset);
for (uint32_t h = 0; h < block_height; ++h) {
uint64_t src_noc_addr = get_noc_addr(stick_id, s0);
noc_async_read(src_noc_addr, scratch_l1_write_addr, aligned_block_width_bytes);
noc_async_read_barrier();
noc_async_read(scratch_l1_noc_read_addr, l1_write_addr, block_width_bytes);
#ifdef DEBUG
noc_async_read_barrier();
tt::data_movement::common::print_pages(l1_write_addr, block_width_bytes >> 1, 1);
#endif
stick_id++;
l1_write_addr += padded_block_width_bytes;
}
}

noc_async_read_barrier();
cb_push_back(cb_id_in0, block_height);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
#include <stdint.h>
#include "dataflow_api.h"

//#define DEBUG

#ifdef DEBUG
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp"
#endif

void kernel_main() {

const uint32_t dst_addr = get_arg_val<uint32_t>(0);
Expand Down Expand Up @@ -40,15 +34,9 @@ void kernel_main() {
uint32_t stick_id = start_id;
cb_wait_front(cb_id_out0, block_height);
uint32_t l1_read_addr = get_read_ptr(cb_id_out0);


for (uint32_t h = 0; h < block_height; ++h) {
uint64_t dst_noc_addr = get_noc_addr(stick_id, s0);
noc_async_write(l1_read_addr, dst_noc_addr, block_width_bytes);
#ifdef DEBUG
noc_async_read_barrier();
tt::data_movement::common::print_pages(l1_read_addr, block_width_bytes >> 1, 1);
#endif
stick_id++;
l1_read_addr += padded_block_width_bytes;
noc_async_write_barrier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,8 @@ std::vector<tt::tt_metal::LegacyShape> InterleavedToShardedDeviceOperation::comp

std::vector<Tensor> InterleavedToShardedDeviceOperation::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
//return operation::generic_create_output_tensors(
// *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config);


auto mem_config = this->output_mem_config;

return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
mem_config
)};
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config);
}

operation::ProgramWithCallbacks InterleavedToShardedDeviceOperation::create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const {
Expand Down
Loading

0 comments on commit 7f6ab69

Please sign in to comment.