Skip to content

Commit

Permalink
#0: cleanup mm code, fix trid start from 1, rm in1 padding
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed May 30, 2024
1 parent 448f451 commit 00f0bfd
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ def test_matmul_in1_dram_sharded(

in0_shape = [1, 1, M, K]
in1_shape = [1, 1, K, N]
in1_shape_padded = [1, 1, K, N_padded]
in1_shard_shape = [K, N_padded // num_banks]
bias_shape = [1, 1, 1, N]
bias_shape_padded = [1, 1, 32, N_padded]
bias_shape = [1, 1, 32, N]
bias_shard_shape = [32, N_padded // num_banks]
num_cores = grid_size[0] * grid_size[1]

Expand Down Expand Up @@ -152,19 +150,11 @@ def test_matmul_in1_dram_sharded(
logger.debug("in1_shard_grid " + str(in1_shard_grid))

in0 = torch.randn(in0_shape).bfloat16().float()
# step = K // num_cores
# in0 = torch.ones(in0_shape).bfloat16().float()
# for i in range(num_cores): # since 32768 / 16 = 2048
# in0[:, :, :, i * step : (i + 1) * step] = i + 1
in1 = torch.randn(in1_shape).bfloat16().float()
# in1 = torch.ones(in1_shape).bfloat16().float()
bias = torch.randn(bias_shape).bfloat16().float()
# bias = torch.ones(bias_shape).bfloat16().float() * 10

in0_t = torch2tt_tensor(in0, device, tt_memory_config=interleaved_mem_config, tt_dtype=in0_dtype)
in1_t = ttl.tensor.Tensor(in1.flatten().tolist(), in1_shape, in1_dtype, ttl.tensor.Layout.ROW_MAJOR)
in1_t = in1_t.pad(in1_shape_padded, (0, 0, 0, 0), 0).to(ttl.tensor.Layout.TILE)
in1_t = in1_t.to(device, in1_mem_config)
in1_t = torch2tt_tensor(in1, device, tt_memory_config=in1_mem_config, tt_dtype=in1_dtype)

if has_bias:
bias_shard_grid = ttl.tensor.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1)
Expand All @@ -175,12 +165,7 @@ def test_matmul_in1_dram_sharded(
bias_mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.DRAM, bias_shard_spec
)

bias_t = ttl.tensor.Tensor(
bias.flatten().tolist(), bias_shape, ttl.tensor.DataType.BFLOAT16, ttl.tensor.Layout.ROW_MAJOR
)
bias_t = bias_t.pad(bias_shape_padded, (0, 0, 0, 0), 0).to(ttl.tensor.Layout.TILE)
bias_t = bias_t.to(device, bias_mem_config)
bias_t = torch2tt_tensor(bias, device, tt_memory_config=bias_mem_config, tt_dtype=ttl.tensor.DataType.BFLOAT16)

in0_t = ttl.tensor.interleaved_to_sharded(
in0_t,
Expand All @@ -198,9 +183,6 @@ def test_matmul_in1_dram_sharded(
per_core_N=out_block_w,
fuse_batch=True,
fused_activation=None,
skip_compute=False,
skip_in0_mcast=False,
skip_write_back=False,
)

if is_grayskull():
Expand All @@ -217,7 +199,7 @@ def test_matmul_in1_dram_sharded(
)

if has_bias:
output_t = ttl.operations.primary.matmul_dram_sharded(
output_t = ttl.operations.primary.matmul(
in0_t,
in1_t,
bias=bias_t,
Expand All @@ -227,7 +209,7 @@ def test_matmul_in1_dram_sharded(
compute_kernel_config=compute_kernel_config,
)
else:
output_t = ttl.operations.primary.matmul_dram_sharded(
output_t = ttl.operations.primary.matmul(
in0_t,
in1_t,
program_config=program_config,
Expand All @@ -243,11 +225,6 @@ def test_matmul_in1_dram_sharded(

tt_out = tt2torch_tensor(output_t)

print(pt_out)
print(tt_out)

pt_out_unpad = pt_out[:, :, :, 0:N]
tt_out_unpad = tt_out[:, :, :, 0:N]
passing, output = comp_pcc(pt_out_unpad, tt_out_unpad)
passing, output = comp_pcc(pt_out, tt_out)
logger.info(output)
assert passing
34 changes: 9 additions & 25 deletions tests/ttnn/unit_tests/operations/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def test_ttnn_experimental_operations_primary_matmul_1d(
@pytest.mark.parametrize("m_size", [32])
@pytest.mark.parametrize("k_size", [8192])
@pytest.mark.parametrize("n_size", [1024])
@pytest.mark.parametrize("n_padded_size", [1152])
def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size, k_size, n_size, n_padded_size):
def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size, k_size, n_size):
torch.manual_seed(0)

grid_size = ttnn.CoreGrid(y=1, x=8)
Expand All @@ -200,21 +199,6 @@ def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size
)
input_tensor_in0 = ttnn.to_memory_config(input_tensor_in0, sharded_mem_config)

# in1 ttnn tensor, for now need to bring tensor to device to pad first, then bring back, and send sharded tensor to dram!
input_tensor_in1 = ttnn.from_torch(torch_input_tensor_in1, layout=ttnn.TILE_LAYOUT)
input_tensor_in1 = ttnn.to_device(input_tensor_in1, device)
input_tensor_in1 = ttnn.pad(input_tensor_in1, ((0, 0), (0, 0), (0, 0), (0, n_padded_size - n_size)), 0)
input_tensor_in1 = ttnn.from_device(input_tensor_in1)
input_tensor_in1 = ttnn.to_torch(input_tensor_in1)

# in1 host padding cause seg faults!
# input_tensor_in1 = ttnn.Tensor(torch_input_tensor_in1.flatten().tolist(), in1_shape, ttnn.bfloat8_b, ttnn.ROW_MAJOR_LAYOUT)
# input_tensor_in1 = input_tensor_in1.pad([1, 1, k_size, n_padded_size], (0, 0, 0, 0), 0)
# in1_shape = (1, 1, m_size, k_size)
# in1 = torch.randn(in1_shape).bfloat16().float()
# in1_t = ttnn.experimental.tensor.Tensor(in1.flatten().tolist(), in1_shape, ttnn.experimental.tensor.DataType.BFLOAT16, ttnn.experimental.tensor.Layout.ROW_MAJOR)
# in1_t = in1_t.pad([1, 1, k_size, n_padded_size], (0, 0, 0, 0), 0).to(ttnn.experimental.tensor.Layout.TILE)

# in1 shard config
in1_shard_grid = ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1)
in1_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), in1_shard_grid)})
Expand All @@ -224,7 +208,11 @@ def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size
ttnn.types.TensorMemoryLayout.WIDTH_SHARDED, ttnn.types.BufferType.DRAM, in1_shard_spec
)
input_tensor_in1 = ttnn.from_torch(
input_tensor_in1, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat8_b, memory_config=in1_mem_config
torch_input_tensor_in1,
layout=ttnn.TILE_LAYOUT,
device=device,
dtype=ttnn.bfloat8_b,
memory_config=in1_mem_config,
)

program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
Expand All @@ -235,9 +223,6 @@ def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size
per_core_N=4,
fuse_batch=True,
fused_activation=None,
skip_compute=False,
skip_in0_mcast=False,
skip_write_back=False,
)

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
Expand All @@ -247,19 +232,18 @@ def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size
packer_l1_acc=True,
)

output_tensor = ttnn.experimental.operations.primary.matmul_dram_sharded(
output_tensor = ttnn.matmul(
input_tensor_in0,
input_tensor_in1,
program_config=program_config,
output_mem_config=sharded_mem_config,
output_dtype=ttnn.bfloat16,
memory_config=sharded_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)

output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
output_tensor = output_tensor[:, :, :, 0:n_size]
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.9999)


Expand Down
66 changes: 5 additions & 61 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,6 @@ void Matmul::validate(
TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
TT_FATAL(K % program_config.in0_block_w == 0);
TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);
// TT_FATAL(div_up(N, per_core_N) == input_tensor_a.shard_spec().value().grid.num_cores());

// subbblock constraint
TT_FATAL(program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1);
Expand Down Expand Up @@ -1195,22 +1194,6 @@ std::vector<Shape> Matmul::compute_output_shapes(const std::vector<Tensor>& inpu
return {Shape(output_shape, padding)};
}

std::vector<Shape> Matmul::compute_output_shapes_dram_sharded(
const std::vector<Tensor>& input_tensors, uint32_t N_unpadded) const {
const auto input_shape_a = input_tensors.at(0).get_legacy_shape();
const auto input_shape_b = input_tensors.at(1).get_legacy_shape();

auto output_shape = input_shape_a;
output_shape[-1] = N_unpadded;
auto dimensions_pads = std::vector<Padding::PadDimension>();
for (auto index = 0; index < input_shape_a.rank() - 1; index++) {
dimensions_pads.push_back(input_shape_a.padding()[index]);
}
dimensions_pads.push_back(input_shape_b.padding()[input_shape_b.rank() - 1]);
const auto padding = Padding(dimensions_pads, Padding::PadValue::Any);
return {Shape(output_shape, padding)};
};

std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
Expand Down Expand Up @@ -1250,12 +1233,9 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) {
uint32_t M =
(program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]
: input_tensor_a.get_legacy_shape()[-2]) /
TILE_HEIGHT;
: input_tensor_a.get_legacy_shape()[-2]) / TILE_HEIGHT;
uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH;
auto input_tensor_b_shape = input_tensor_b.get_legacy_shape();
uint32_t N_unpaddded = input_tensor_b.get_legacy_shape()[-1] -
input_tensor_b_shape.padding()[input_tensor_b_shape.rank() - 1].back;

uint32_t per_core_M = program_config.per_core_M;
uint32_t per_core_N = program_config.per_core_N;
Expand All @@ -1266,7 +1246,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_device_tensor(
this->compute_output_shapes_dram_sharded(input_tensors, N_unpaddded).at(0),
this->compute_output_shapes(input_tensors).at(0),
this->output_dtype,
output_layout,
input_tensor_a.device(),
Expand Down Expand Up @@ -1454,9 +1434,9 @@ operation::ProgramWithCallbacks Matmul::create_program(
program_config.fuse_batch,
program_config.fused_activation,
this->untilize_out,
program_config.skip_compute,
program_config.skip_in0_mcast,
program_config.skip_write_back);
false,
false,
false);
} else if constexpr (std::is_same_v<ProgramConfigType, MatmulMultiCoreNonOptimizedReuseProgramConfig>) {
TT_FATAL(!bias.has_value(), "Bias is not supported for matmul multi core non-optimized reuse");
return matmul_multi_core_reuse(input_tensor_a, input_tensor_b, output_tensor, broadcast_batch);
Expand Down Expand Up @@ -1517,42 +1497,6 @@ Tensor matmul_1d(
return output_tensors.at(0);
}

Tensor matmul_dram_sharded(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
std::optional<const Tensor> bias,
std::optional<MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig> program_config,
const MemoryConfig& mem_config,
std::optional<const DataType> output_dtype,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config,
bool untilize_out) {
std::vector<Tensor> output_tensors = {
Tensor(operation::get_workers_for_op_output({input_tensor_a, input_tensor_b}, {bias}))};
operation::launch_op(
[program_config, mem_config, output_dtype, compute_kernel_config, untilize_out](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);

auto kernel_config_val =
init_device_compute_kernel_config(input_tensor_a.device()->arch(), compute_kernel_config);
return {operations::primary::matmul(
input_tensor_a,
input_tensor_b,
optional_input_tensors.at(0),
program_config.value(),
mem_config,
output_dtype,
kernel_config_val,
untilize_out)};
},
{input_tensor_a, input_tensor_b},
output_tensors,
{bias});
return output_tensors.at(0);
}

operation::OpPerformanceModel Matmul::create_op_performance_model(
const std::vector<Tensor>& input_tensors,
Expand Down
14 changes: 2 additions & 12 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ struct MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig {
std::size_t per_core_N;
bool fuse_batch;
std::optional<UnaryWithParam> fused_activation;
bool skip_compute;
bool skip_in0_mcast;
bool skip_write_back;

static constexpr auto attribute_names = std::make_tuple(
"in0_block_w",
Expand All @@ -235,10 +232,7 @@ struct MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig {
"per_core_M",
"per_core_N",
"fuse_batch",
"fused_activation",
"skip_compute",
"skip_in0_mcast",
"skip_write_back");
"fused_activation");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->in0_block_w),
Expand All @@ -247,10 +241,7 @@ struct MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig {
std::cref(this->per_core_M),
std::cref(this->per_core_N),
std::cref(this->fuse_batch),
std::cref(this->fused_activation),
std::cref(this->skip_compute),
std::cref(this->skip_in0_mcast),
std::cref(this->skip_write_back));
std::cref(this->fused_activation));
}
};

Expand Down Expand Up @@ -401,7 +392,6 @@ inline Tensor matmul(
}

Tensor matmul_1d(const Tensor &input_tensor_a, const Tensor &input_tensor_b, std::optional<const Tensor> bias, std::optional<MatmulMultiCoreReuseMultiCast1DProgramConfig> program_config = std::nullopt, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional<const DataType> output_dtype=std::nullopt, std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt, bool untilize_out = false);
Tensor matmul_dram_sharded(const Tensor &input_tensor_a, const Tensor &input_tensor_b, std::optional<const Tensor> bias, std::optional<MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig> program_config = std::nullopt, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional<const DataType> output_dtype=std::nullopt, std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt, bool untilize_out = false);

MatmulProgramConfig generate_matmul_program_config(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const MemoryConfig &mem_config, const std::optional<const DeviceComputeKernelConfig> compute_kernel_config, const std::optional<const CoreCoord> user_core_coord, const std::optional<const UnaryWithParam> user_fused_activation, const std::optional<const bool> user_run_batched);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h"

// #include "debug/dprint.h"


namespace NAMESPACE {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "dataflow_api.h"
#include "hostdevcommon/common_values.hpp"

// #include "debug/dprint.h"

void kernel_main() {
// COMPILE TIME ARGS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "dataflow_api.h"
#include "hostdevcommon/common_values.hpp"

// #include "debug/dprint.h"

void kernel_main() {
// COMPILE TIME ARGS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include "dataflow_api.h"
#include "hostdevcommon/common_values.hpp"

// #include "debug/dprint.h"


void kernel_main() {
// RUNTIME ARGS
Expand Down Expand Up @@ -76,9 +74,10 @@ void kernel_main() {
}
#else
constexpr uint32_t total_num_blocks_in_buffer = 3;
constexpr uint32_t total_num_trid = 4;
uint32_t num_free_blocks_in_buffer = total_num_blocks_in_buffer;
uint32_t curr_block_trid = 0;
uint32_t block_trid_to_wait = 0;
uint32_t curr_block_trid = 1;
uint32_t block_trid_to_wait = 1;

cb_reserve_back(cb_id_in1, in1_block_num_tiles);
uint32_t l1_write_addr_in1_offset = 0;
Expand All @@ -88,7 +87,7 @@ void kernel_main() {
noc_async_read_tile_dram_sharded_set_trid(curr_block_trid);

for(uint32_t h = 0; h < in1_num_pages; ++h) {
noc_async_read_tile_dram_sharded_with_state_with_trid(in1_base_addr, l1_read_addr_in1, l1_write_addr_in1);
noc_async_read_tile_dram_sharded_with_state_with_trid(in1_base_addr, l1_read_addr_in1, l1_write_addr_in1, curr_block_trid);
l1_read_addr_in1 += in1_page_size;
l1_write_addr_in1 += in1_page_size;
}
Expand All @@ -97,16 +96,16 @@ void kernel_main() {
noc_async_read_barrier_with_trid(block_trid_to_wait);
cb_push_back(cb_id_in1, in1_block_num_tiles);
// wait for next block trid
block_trid_to_wait = (block_trid_to_wait + 1) % total_num_blocks_in_buffer;
block_trid_to_wait = block_trid_to_wait == 3 ? 1 : (block_trid_to_wait + 1);
// reserve for next block
cb_reserve_back(cb_id_in1, in1_block_num_tiles * 2);
} else {
num_free_blocks_in_buffer -= 1;
}

if (curr_block_trid == total_num_blocks_in_buffer - 1) {
if (curr_block_trid == total_num_blocks_in_buffer) {
l1_write_addr_in1_offset = 0;
curr_block_trid = 0;
curr_block_trid = 1;
} else {
l1_write_addr_in1_offset += in1_block_size_bytes;
curr_block_trid += 1;
Expand Down
Loading

0 comments on commit 00f0bfd

Please sign in to comment.