Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8632: Support fp32 dest acc en in moreh_sum and moreh_sum_backward #8724

Merged
merged 8 commits into from
Jun 1, 2024
200 changes: 169 additions & 31 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@
from loguru import logger

import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc, skip_for_wormhole_b0
from models.utility_functions import comp_allclose_and_pcc
from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
)

TILE_HEIGHT = 32
TILE_WIDTH = 32


def get_tensors(input_shape, output_shape, device, *, with_padding=True):
def get_tensors(input_shape, output_shape, device, *, with_padding=True, use_randint=True):
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
if use_randint:
torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
else:
torch_input = torch.rand(input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.rand(output_shape, dtype=cpu_dtype)

if with_padding:
tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
Expand All @@ -31,13 +40,17 @@ def get_tensors(input_shape, output_shape, device, *, with_padding=True):
return tt_input, tt_output, torch_input


def get_backward_tensors(output_grad_shape, input_grad_shape, device, *, with_padding=True):
def get_backward_tensors(output_grad_shape, input_grad_shape, device, *, with_padding=True, use_randint=True):
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_output_grad = torch.randint(-2, 3, output_grad_shape, dtype=cpu_dtype, requires_grad=True)
torch_input_grad = torch.randint(-2, 3, input_grad_shape, dtype=cpu_dtype)
if use_randint:
torch_output_grad = torch.randint(-2, 3, output_grad_shape, dtype=cpu_dtype, requires_grad=True)
torch_input_grad = torch.randint(-2, 3, input_grad_shape, dtype=cpu_dtype)
else:
torch_output_grad = torch.rand(output_grad_shape, dtype=cpu_dtype, requires_grad=True)
torch_input_grad = torch.rand(input_grad_shape, dtype=cpu_dtype)

if with_padding:
tt_output_grad = (
Expand All @@ -55,9 +68,9 @@ def get_backward_tensors(output_grad_shape, input_grad_shape, device, *, with_pa

@pytest.mark.parametrize(
"input_shape",
(([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1]),),
(([3, 2, TILE_HEIGHT * 10 - 1, TILE_WIDTH * 10 - 1]),),
ids=[
"4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1",
"3, 2, TILE_HEIGHT * 10 - 1, TILE_WIDTH * 10 - 1",
],
)
@pytest.mark.parametrize(
Expand All @@ -79,8 +92,9 @@ def get_backward_tensors(output_grad_shape, input_grad_shape, device, *, with_pa
),
ids=["0", "0,1", "0,1,2", "0,1,2,3", "0,1,3", "0,2,3", "1", "1,2", "1,2,3", "1,3", "2", "2,3", "3"],
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
@pytest.mark.parametrize("use_provide_output", (True, False), ids=["True", "False"])
def test_moreh_sum(input_shape, dims, use_provide_output, device):
def test_moreh_sum(input_shape, dims, compute_kernel_options, use_provide_output, device):
torch.manual_seed(2023)
output_shape = input_shape.copy()

Expand All @@ -94,9 +108,12 @@ def test_moreh_sum(input_shape, dims, use_provide_output, device):

torch_output = torch.sum(torch_input, dims, True)

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.moreh_sum(tt_input, dims=dims, output=tt_output)
ttl.operations.primary.moreh_sum(
tt_input, dims=dims, output=tt_output, compute_kernel_config=compute_kernel_config
)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
Expand All @@ -123,29 +140,31 @@ def reduce_rows(x, dims):
"input_shape",
(
([TILE_HEIGHT, TILE_WIDTH]),
([TILE_HEIGHT * 3, TILE_WIDTH * 3]),
([4, TILE_HEIGHT * 2, TILE_WIDTH * 2]),
([TILE_HEIGHT - 1, TILE_WIDTH - 1]),
([2, 3, 2, 4, TILE_HEIGHT * 4, TILE_WIDTH * 4]),
([3, 2, 4, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1]),
),
ids=[
"TILE_HEIGHT, TILE_WIDTH",
"TILE_HEIGHT * 3, TILE_WIDTH * 3",
"4, TILE_HEIGHT * 2, TILE_WIDTH * 2",
"TILE_HEIGHT - 1, TILE_WIDTH - 1",
"2, 3, 2, 4, TILE_HEIGHT * 4, TILE_WIDTH * 4",
"3, 2, 4, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1",
],
)
@pytest.mark.parametrize(
"dims",
(
[0],
[0, 1],
[0, 1, 2],
[0, 2],
[1],
[1, 2],
[2],
[3],
[4],
[5],
),
ids=["0", "0,1", "0,1,2", "0, 2", "1", "1,2", "2"],
ids=["0", "1", "2", "3", "4", "5"],
)
def test_moreh_sum_non_4d(input_shape, dims, device):
@pytest.mark.parametrize("use_provide_output", (True, False), ids=["True", "False"])
def test_moreh_sum_non_4d(input_shape, dims, use_provide_output, device):
torch.manual_seed(2023)
output_shape = input_shape.copy()

Expand All @@ -154,13 +173,26 @@ def test_moreh_sum_non_4d(input_shape, dims, device):
if dim >= input_rank:
pytest.skip(f"input dim {dim} exceeds the dims of input tensor {len(input_shape)}.")

(tt_input, _, torch_input) = get_tensors(input_shape, output_shape, device, with_padding=False)
for dim in dims:
output_shape[dim] = 1

(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device)
if not use_provide_output:
tt_output = None

compute_kernel_config = get_compute_kernel_options(False)

torch_output = torch.sum(torch_input, dims, True)
cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = ttl.operations.primary.moreh_sum(tt_input, dims=dims, output=None).cpu().to(cpu_layout).to_torch()

tt_output_cpu = reduce_rows(tt_output_cpu, dims)
tt_output_cpu = (
ttl.operations.primary.moreh_sum(
tt_input, dims=dims, output=tt_output, compute_kernel_config=compute_kernel_config
)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
.to_torch()
)
rtol = atol = 0.12
passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)

Expand All @@ -173,12 +205,60 @@ def test_moreh_sum_non_4d(input_shape, dims, device):
@pytest.mark.parametrize(
"input_shape",
(
([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]),
([4, 4, TILE_HEIGHT * 20 - 1, TILE_WIDTH * 20 - 1]),
[10, TILE_HEIGHT * 12, TILE_WIDTH * 12],
[10, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1],
),
ids=[
"1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1",
"4, 4, TILE_HEIGHT * 20 - 1, TILE_WIDTH * 20 - 1",
"10, TILE_HEIGHT * 12, TILE_WIDTH * 12",
"10, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1",
],
)
@pytest.mark.parametrize(
"dims",
([0], [1], [2]),
ids=["dim-n", "dim-h", "dim-w"],
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_sum_fp32_dest_acc(input_shape, dims, compute_kernel_options, device):
torch.manual_seed(2023)
output_shape = input_shape.copy()

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

for dim in dims:
output_shape[dim] = 1

(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device, use_randint=False)
torch_input = torch_input.float()
torch_output = torch.sum(torch_input, dims, True)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.moreh_sum(
tt_input, dims=dims, output=tt_output, compute_kernel_config=compute_kernel_config
)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
.to_torch()
)

rtol = atol = 0.1
passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")
logger.debug(f"std={torch.std(torch.abs(torch_output - tt_output_cpu))}")
logger.debug(f"mean={torch.abs(torch_output - tt_output_cpu).mean()}")

# TODO
# assert passing


@pytest.mark.parametrize(
"input_shape",
(([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1]),),
ids=[
"4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1",
],
)
@pytest.mark.parametrize(
Expand All @@ -200,14 +280,17 @@ def test_moreh_sum_non_4d(input_shape, dims, device):
),
ids=["0", "0,1", "0,1,2", "0,1,2,3", "0,1,3", "0,2,3", "1", "1,2", "1,2,3", "1,3", "2", "2,3", "3"],
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
@pytest.mark.parametrize("use_provide_input_grad", (True, False), ids=["True", "False"])
def test_moreh_sum_backward(input_shape, dims, use_provide_input_grad, device):
def test_moreh_sum_backward(input_shape, dims, compute_kernel_options, use_provide_input_grad, device):
torch.manual_seed(2023)
output_shape = input_shape.copy()

for dim in dims:
output_shape[dim] = 1

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

(tt_input, _, torch_input) = get_tensors(input_shape, output_shape, device)
(tt_output_grad, tt_input_grad, torch_output_grad) = get_backward_tensors(output_shape, input_shape, device)

Expand All @@ -219,7 +302,9 @@ def test_moreh_sum_backward(input_shape, dims, use_provide_input_grad, device):

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_input_grad_cpu = (
ttl.operations.primary.moreh_sum_backward(tt_output_grad, tt_input, dims=dims, input_grad=tt_input_grad)
ttl.operations.primary.moreh_sum_backward(
tt_output_grad, tt_input, dims=dims, input_grad=tt_input_grad, compute_kernel_config=compute_kernel_config
)
.cpu()
.to(cpu_layout)
.unpad_from_tile(input_shape)
Expand All @@ -234,3 +319,56 @@ def test_moreh_sum_backward(input_shape, dims, use_provide_input_grad, device):
logger.debug(f"Output pcc={output_pcc}")

assert passing


@pytest.mark.parametrize(
"input_shape",
([2, 3, 2, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1],),
ids=[
"2, 3, 2, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 12 - 1",
],
)
@pytest.mark.parametrize(
"dims",
([0], [4], [5], [4, 5], [1, 4, 5]),
ids=["dim-n", "dim-h", "dim-w", "dim-hw", "dim-nhw"],
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_sum_backward_fp32_dest_acc(input_shape, dims, compute_kernel_options, device):
torch.manual_seed(2023)
output_shape = input_shape.copy()

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

for dim in dims:
output_shape[dim] = 1

(tt_input, _, torch_input) = get_tensors(input_shape, output_shape, device, use_randint=False)
(tt_output_grad, tt_input_grad, torch_output_grad) = get_backward_tensors(
output_shape, input_shape, device, use_randint=False
)

# convert torch_input to float32 dtype
torch_input = torch_input.detach().clone().to(dtype=torch.float32).requires_grad_(True)
torch_output_grad = torch_output_grad.float()
torch_output = torch.sum(torch_input, dims, True)
torch_output.backward(torch_output_grad)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_input_grad_cpu = (
ttl.operations.primary.moreh_sum_backward(
tt_output_grad, tt_input, dims=dims, input_grad=tt_input_grad, compute_kernel_config=compute_kernel_config
)
.cpu()
.to(cpu_layout)
.unpad_from_tile(input_shape)
.to_torch()
)

rtol = atol = 0.1
passing, output_pcc = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")
logger.debug(f"std={torch.std(torch.abs(torch_input.grad- tt_input_grad_cpu))}")
logger.debug(f"mean={torch.abs(torch_input.grad - tt_input_grad_cpu).mean()}")
assert passing
33 changes: 33 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import tt_lib as ttl
from models.utility_functions import is_wormhole_b0

compute_kernel_options = [
False, # for grayskull
]
compute_kernel_ids = ["fp32_dest_acc_en=False"]
if is_wormhole_b0:
compute_kernel_options.append(True)
compute_kernel_ids.append("fp32_dest_acc_en=True")


def get_compute_kernel_options(compute_kernel_options):
if is_wormhole_b0():
fp32_dest_acc_en = compute_kernel_options
packer_l1_acc = False
compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=fp32_dest_acc_en,
packer_l1_acc=packer_l1_acc,
)
else:
# Grayskull doesn't support fp32 but test passing a GS config is ok
compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.HiFi4,
math_approx_mode=True,
)
return compute_kernel_config
37 changes: 27 additions & 10 deletions tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,36 @@ FORCE_INLINE void generate_bcast_scaler(uint32_t cb_scaler, uint32_t scaler) {
cb_push_back(cb_scaler, 1);
}

template <typename T>
FORCE_INLINE void process_data(int cb_id, uint32_t value, int32_t num_of_elems) {
T* ptr = reinterpret_cast<T*>(get_write_ptr(cb_id));
for (int j = 0; j < num_of_elems; j++)
{
ptr[j] = static_cast<T>(value);
}
}

template <>
FORCE_INLINE void process_data<uint16_t>(int cb_id, uint32_t value, int32_t num_of_elems) {
uint16_t* ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_id));
for (int j = 0; j < num_of_elems; j++)
{
ptr[j] = static_cast<uint16_t>(value >> 16);
}
}

FORCE_INLINE void fill_cb_with_value(uint32_t cb_id, uint32_t value, int32_t num_of_elems = 1024) {
cb_reserve_back(cb_id, 1);
#if defined FP32_DEST_ACC_EN
auto ptr = reinterpret_cast<uint32_t *>(get_write_ptr(cb_id));
for (int j = 0; j < 1024; j++) {
ptr[j] = value;
}
#else
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id));
for (int j = 0; j < 1024; j++) {
ptr[j] = uint16_t(value >> 16);
const DataFormat data_format = get_dataformat(cb_id);
switch((uint)data_format & 0x1F) {
case ((uint8_t)DataFormat::Float32):
process_data<uint32_t>(cb_id, value, num_of_elems);
break;
case ((uint8_t)DataFormat::Float16_b):
default:
process_data<uint16_t>(cb_id, value, num_of_elems);
break;
}
#endif
cb_push_back(cb_id, 1);
}

Expand Down
Loading
Loading