Skip to content

Commit

Permalink
#4904: Add support for 1d width sharded LN
Browse files Browse the repository at this point in the history
Refactored out code for creating tiles for bcast and reduce into common header files
  • Loading branch information
tt-aho committed Jan 25, 2024
1 parent d34df49 commit 66d2a9b
Show file tree
Hide file tree
Showing 30 changed files with 1,873 additions and 1,619 deletions.
243 changes: 211 additions & 32 deletions tests/tt_eager/python_api_testing/unit_testing/test_layernorm_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,8 @@
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=[
"in0_DRAM",
"in0_L1",
],
)
@pytest.mark.parametrize(
"in0_mem_config",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=[
"in0_DRAM",
"in0_L1",
"gb_DRAM",
"gb_L1",
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -78,10 +67,9 @@
"LN_GB",
],
)
def test_layernorm_sharded_rm(
test_id, in_dtype, out_dtype, cb_dtype, in0_mem_config, gamma_beta_mem_config, out_mem_config, device
):
def test_layernorm_sharded_rm(test_id, in_dtype, out_dtype, cb_dtype, gamma_beta_mem_config, out_mem_config, device):
torch.manual_seed(1234)
in0_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM)

grid_size = (12, 8)
fidelity = ttl.tensor.MathFidelity.HiFi4
Expand Down Expand Up @@ -147,7 +135,7 @@ def test_layernorm_sharded_rm(
math_fidelity=fidelity,
im_data_format=cb_dtype,
out_data_format=out_dtype,
inplace=True,
inplace=in_dtype == out_dtype,
)

if test_id == 0:
Expand Down Expand Up @@ -205,19 +193,8 @@ def test_layernorm_sharded_rm(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=[
"in0_DRAM",
"in0_L1",
],
)
@pytest.mark.parametrize(
"in0_mem_config",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=[
"in0_DRAM",
"in0_L1",
"gb_DRAM",
"gb_L1",
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -251,9 +228,10 @@ def test_layernorm_sharded_rm(
],
)
def test_layernorm_sharded_mix_precision_rm(
test_id, in_dtype, out_dtype, cb_dtype, in0_mem_config, gamma_beta_mem_config, out_mem_config, device
test_id, in_dtype, out_dtype, cb_dtype, gamma_beta_mem_config, out_mem_config, device
):
torch.manual_seed(1234)
in0_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM)

grid_size = (12, 8)
fidelity = ttl.tensor.MathFidelity.HiFi4
Expand Down Expand Up @@ -319,7 +297,7 @@ def test_layernorm_sharded_mix_precision_rm(
math_fidelity=fidelity,
im_data_format=cb_dtype,
out_data_format=out_dtype,
inplace=True,
inplace=in_dtype == out_dtype,
)

if test_id == 0:
Expand Down Expand Up @@ -365,3 +343,204 @@ def test_layernorm_sharded_mix_precision_rm(
passing, output = comp_pcc(tt_got_back, ref_lnorm, 0.999)
logger.info(output)
assert passing


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"shard_orientation",
(ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.ShardOrientation.COL_MAJOR),
ids=["RM", "CM"],
)
@pytest.mark.parametrize(
"out_mem_config",
(ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, ttl.tensor.BufferType.L1),),
ids=["out_L1"],
)
@pytest.mark.parametrize(
"gamma_beta_mem_config",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=[
"gb_DRAM",
"gb_L1",
],
)
@pytest.mark.parametrize(
"out_dtype",
(
ttl.tensor.DataType.BFLOAT16,
ttl.tensor.DataType.BFLOAT8_B,
),
ids=["BFLOAT16", "BFLOAT8_B"],
)
@pytest.mark.parametrize(
"cb_dtype",
(ttl.tensor.DataType.BFLOAT16,),
ids=["BFLOAT16"],
)
@pytest.mark.parametrize(
"in_dtype",
(
ttl.tensor.DataType.BFLOAT16,
ttl.tensor.DataType.BFLOAT8_B,
),
ids=["BFLOAT16", "BFLOAT8_B"],
)
@pytest.mark.parametrize(
"test_id",
(0, 1, 2, 3, 4, 5),
ids=[
"add_LN",
"add_LN_G",
"add_LN_GB",
"LN",
"LN_G",
"LN_GB",
],
)
def test_layernorm_1d_sharded_mix_precision_rm(
test_id, in_dtype, out_dtype, cb_dtype, gamma_beta_mem_config, out_mem_config, shard_orientation, device
):
torch.manual_seed(1234)
in0_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM)

grid_size = (8, 8)
fidelity = ttl.tensor.MathFidelity.HiFi4

epsf = 1e-2

in0_shape = torch.Size([1, 1, 32, 8192])
M = in0_shape.numel() // in0_shape[3]
K = in0_shape[3]

in0 = torch.rand(in0_shape) * 2 - 0.95
in0_t = torch2tt_tensor(in0, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype)
in0_t_shard = ttl.tensor.interleaved_to_sharded(
in0_t,
grid_size,
[M, K // (grid_size[0] * grid_size[1])],
ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED,
shard_orientation,
)

if test_id <= 2:
in1 = torch.rand(in0_shape) * 2 - 0.8
in1_t = torch2tt_tensor(in1, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype)
in1_t_shard = ttl.tensor.interleaved_to_sharded(
in1_t,
grid_size,
[M, K // (grid_size[0] * grid_size[1])],
ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED,
shard_orientation,
)

if test_id == 0 or test_id == 3:
gamma = torch.ones(in0_shape[3])
beta = torch.zeros(in0_shape[3])
if test_id == 1 or test_id == 4:
gamma = torch.rand(in0_shape[3]) * 2 - 1
beta = torch.zeros(in0_shape[3])
if test_id == 2 or test_id == 5:
gamma = torch.rand(in0_shape[3]) * 2 - 1
beta = torch.rand(in0_shape[3]) * 2.0 - 1.1

gamma = gamma.reshape(1, 1, -1, 32)
gamma_t = ttl.tensor.Tensor(
gamma.reshape(-1).tolist(),
gamma.shape,
cb_dtype,
ttl.tensor.Layout.ROW_MAJOR,
).to(device, gamma_beta_mem_config)

beta = beta.reshape(1, 1, -1, 32)
beta_t = ttl.tensor.Tensor(
beta.reshape(-1).tolist(),
beta.shape,
cb_dtype,
ttl.tensor.Layout.ROW_MAJOR,
).to(device, gamma_beta_mem_config)

program_config = ttl.operations.primary.LayerNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=4,
block_h=M // 32,
block_w=K // (grid_size[0] * grid_size[1]) // 32,
math_fidelity=fidelity,
im_data_format=cb_dtype,
out_data_format=out_dtype,
inplace=in_dtype == out_dtype,
)

if test_id == 0:
logger.info("Running add_LN")
ttz = ttl.operations.primary.add_layernorm(
in0_t_shard,
in1_t_shard,
epsf,
output_mem_config=out_mem_config,
program_config=program_config,
)
if test_id == 1:
logger.info("Running add_LN_G")
ttz = ttl.operations.primary.add_layernorm(
in0_t_shard,
in1_t_shard,
epsf,
gamma_t,
output_mem_config=out_mem_config,
program_config=program_config,
)
if test_id == 2:
logger.info("Running add_LN_GB")
ttz = ttl.operations.primary.add_layernorm(
in0_t_shard,
in1_t_shard,
epsf,
gamma_t,
beta_t,
output_mem_config=out_mem_config,
program_config=program_config,
)
if test_id == 3:
logger.info("Running LN")
ttz = ttl.operations.primary.layernorm(
in0_t_shard,
epsf,
output_mem_config=out_mem_config,
program_config=program_config,
)
if test_id == 4:
logger.info("Running LN_G")
ttz = ttl.operations.primary.layernorm(
in0_t_shard,
epsf,
gamma_t,
output_mem_config=out_mem_config,
program_config=program_config,
)
if test_id == 5:
logger.info("Running LN_GB")
ttz = ttl.operations.primary.layernorm(
in0_t_shard,
epsf,
gamma_t,
beta_t,
output_mem_config=out_mem_config,
program_config=program_config,
)

logger.info("Done")

ttz = ttl.tensor.sharded_to_interleaved(ttz, in0_mem_config)
t2_data = ttz.cpu().to_torch().float()
tt_got_back = torch.Tensor(t2_data).reshape(in0_shape)
tt_got_back = untilize(tt_got_back)

pt_in = in0 + in1 if test_id <= 2 else in0
ref_lnorm = torch.nn.functional.layer_norm(pt_in, in0.shape[-1:], gamma.flatten(), beta.flatten(), epsf)

passing, output = comp_pcc(tt_got_back, ref_lnorm, 0.999)
logger.info(output)
assert passing
44 changes: 44 additions & 0 deletions tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "dataflow_api.h"

// W-bcast scalar
FORCE_INLINE void generate_bcast_col_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint16_t scalar_val = scalar>>16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(get_write_ptr(cb_id));
for (int k = 0; k < 4; k+=2) {
uint32_t idx = k << 8;
for (int j = 0; j < 256; j+=16) {
ptr[idx + j] = scalar_val;
}
}
cb_push_back(cb_id, 1);
}

// H-bcast scalar
FORCE_INLINE void generate_bcast_row_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint32_t scalar_val = scalar>>16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id));
for (int k = 0; k < 2; ++k) {
uint32_t idx = k << 7;
for (int j = 0; j < 8; ++j) {
ptr[idx + j] = scalar_val;
}
}
cb_push_back(cb_id, 1);
}

// HW-bcast scalar
FORCE_INLINE void generate_bcast_unary_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint32_t scalar_val = scalar>>16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id));
ptr[0] = scalar>>16;
cb_push_back(cb_id, 1);
}
33 changes: 33 additions & 0 deletions tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "dataflow_api.h"

FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t scaler) {
cb_reserve_back(cb_id, 1);

constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
uint32_t write_addr = get_write_ptr(cb_id);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(write_addr);

// Fill tile with zeros
for (uint32_t i = 0; i < num_zeros_reads; ++i) {
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE);
write_addr += MEM_ZEROS_SIZE;
}
noc_async_read_barrier();

if (scaler != 0) {
for (int k = 0; k < 4; ++k) {
uint32_t idx = k << 7;
for (int j = 0; j < 8; ++j) {
ptr[idx + j] = scaler;
}
}
}
cb_push_back(cb_id, 1);
}
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.cpp \
tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp \
tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp \
Expand Down
4 changes: 3 additions & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,15 @@ operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(const std
const operation::Hash EltwiseBinaryBroadcast::compute_program_hash(
const std::vector<Tensor> &input_tensors) const {
auto parallelization_strategy = this->get_parallelization_strategy(input_tensors);
bool bcast_scalar = (input_tensors.at(1).shape()[-2] * input_tensors.at(1).shape()[-1] == 1) && this->dim == BcastOpDim::HW;
return operation::hash_operation<EltwiseBinaryBroadcast>(
*this,
parallelization_strategy,
input_tensors.at(0).memory_config(),
input_tensors.at(0).dtype(),
input_tensors.at(1).memory_config(),
input_tensors.at(1).dtype());
input_tensors.at(1).dtype(),
bcast_scalar);
}

BcastOpParallelizationStrategy EltwiseBinaryBroadcast::get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const {
Expand Down
Loading

0 comments on commit 66d2a9b

Please sign in to comment.