Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8282: Support non-4d tensor and fp32_dest_acc_en for moreh nllloss backward #8966

Merged
merged 15 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ def get_compute_kernel_options(fp32_dest_acc_en):
return compute_kernel_config


torch.set_printoptions(threshold=1000000, linewidth=100000000, sci_mode=False)


def get_torch_tensors(shape):
torch.manual_seed(0)

C = shape[1]
target_shape = shape[:1] + shape[2:]

Expand Down Expand Up @@ -109,84 +104,6 @@ def get_tt_tensors(torch_input, torch_target, torch_weight, torch_divisor, torch
return tt_input, tt_target, tt_weight, tt_divisor, tt_output


def get_tt_tensors_4d(torch_input, torch_target, torch_weight, torch_divisor, torch_output, device):
torch.manual_seed(0)

N = torch_input.shape[0]
C = torch_input.shape[1]
H = torch_input.shape[2]
W = torch_input.shape[3]

npu_dtype = ttl.tensor.DataType.BFLOAT16
npu_index_dtype = ttl.tensor.DataType.UINT32
npu_layout = ttl.tensor.Layout.TILE
npu_weight_layout = ttl.tensor.Layout.ROW_MAJOR

tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_target = (
ttl.tensor.Tensor(torch_target, npu_index_dtype).reshape(N, 1, H, W).pad_to_tile(C).to(npu_layout).to(device)
)
tt_weight = ttl.tensor.Tensor(torch_weight, npu_dtype).to(npu_weight_layout).to(device)
tt_divisor = (
ttl.tensor.Tensor(torch_divisor, npu_dtype)
.reshape(1, 1, 1, 1)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)
tt_output = (
ttl.tensor.Tensor(torch_output, npu_dtype)
.reshape(1, 1, 1, 1)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)

return tt_input, tt_target, tt_weight, tt_divisor, tt_output


def get_tt_tensors_2d(torch_input, torch_target, torch_weight, torch_divisor, torch_output, device):
torch.manual_seed(0)

N = torch_input.shape[0]
C = torch_input.shape[1]
H = 1
W = 1

npu_dtype = ttl.tensor.DataType.BFLOAT16
npu_index_dtype = ttl.tensor.DataType.UINT32
npu_layout = ttl.tensor.Layout.TILE
npu_weight_layout = ttl.tensor.Layout.ROW_MAJOR

tt_input = (
ttl.tensor.Tensor(torch_input, npu_dtype)
.reshape(N, C, H, W)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)
tt_target = (
ttl.tensor.Tensor(torch_target, npu_index_dtype).reshape(N, 1, H, W).pad_to_tile(C).to(npu_layout).to(device)
)
tt_weight = ttl.tensor.Tensor(torch_weight, npu_dtype).to(npu_weight_layout).to(device)
tt_divisor = (
ttl.tensor.Tensor(torch_divisor, npu_dtype)
.reshape(1, 1, 1, 1)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)
tt_output = (
ttl.tensor.Tensor(torch_output, npu_dtype)
.reshape(1, 1, 1, 1)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)

return tt_input, tt_target, tt_weight, tt_divisor, tt_output


@pytest.mark.parametrize(
"shape",
[
Expand All @@ -197,11 +114,13 @@ def get_tt_tensors_2d(torch_input, torch_target, torch_weight, torch_divisor, to
(5, 100, 2, 7, 50, 70),
],
)
@pytest.mark.parametrize("ignore_index", [-1, 5])
@pytest.mark.parametrize("ignore_index", [1])
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("none_weight", [True, False])
@pytest.mark.parametrize("fp32_dest_acc_en", fp32_dest_acc_en, ids=fp32_dest_acc_en_ids)
def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, fp32_dest_acc_en, device, use_program_cache):
torch.manual_seed(0)

compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en)

(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
Expand Down Expand Up @@ -245,24 +164,24 @@ def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, fp32_dest_a
(5, 10, 10, 20),
],
)
@pytest.mark.parametrize("ignore_index", [-1])
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_callback(shape, ignore_index, reduction, none_weight, device, use_program_cache):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_program_cache):
torch.manual_seed(0)

if none_weight:
torch_weight = None
ignore_index = 1
reduction_mean = reduction == "mean"

nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction=reduction)
torch_loss = torch.tensor([nll_loss(torch_input, torch_target)])
# run TT
for _ in range(2):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)

reduction_mean = reduction == "mean"
for _ in range(2):
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input,
tt_target,
Expand All @@ -274,6 +193,12 @@ def test_moreh_nll_loss_callback(shape, ignore_index, reduction, none_weight, de
)

tt_loss_to_cpu = tt_loss.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile([1, 1]).to_torch().reshape([1])

# run torch
nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction=reduction)
torch_loss = torch.tensor([nll_loss(torch_input, torch_target)])

# compare result
rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_loss, tt_loss_to_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing (param)={passing}")
Expand All @@ -284,17 +209,28 @@ def test_moreh_nll_loss_callback(shape, ignore_index, reduction, none_weight, de

@pytest.mark.parametrize(
"shape",
(
[1, 2, 32, 32],
[1, 2, 32, 32],
[3, 4, 32 * 5, 32 * 6],
),
[
(400, 300),
(20, 300, 320),
(3, 4, 32 * 5, 32 * 6),
(5, 2, 5, 40, 70),
],
)
@pytest.mark.parametrize("ignore_index", [0, -1])
@pytest.mark.parametrize("ignore_index", [1])
@pytest.mark.parametrize("reduction_mean", [True, False])
@pytest.mark.parametrize("has_output", [True, False])
def test_moreh_nll_loss_4d_backward(shape, ignore_index, reduction_mean, has_output, device, use_program_cache):
@pytest.mark.parametrize("none_weight", [True, False])
@pytest.mark.parametrize("fp32_dest_acc_en", fp32_dest_acc_en, ids=fp32_dest_acc_en_ids)
def test_moreh_nll_loss_backward(
shape, ignore_index, reduction_mean, none_weight, fp32_dest_acc_en, device, use_program_cache
):
torch.manual_seed(0)

compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en)

(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

nll_loss = torch.nn.NLLLoss(
weight=torch_weight, ignore_index=ignore_index, reduction="mean" if reduction_mean else "sum"
)
Expand All @@ -306,18 +242,22 @@ def test_moreh_nll_loss_4d_backward(shape, ignore_index, reduction_mean, has_out
if reduction_mean == False:
tt_divisor = None
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input, tt_target, tt_weight, tt_divisor, tt_output, ignore_index, reduction_mean
tt_input,
tt_target,
tt_weight,
tt_divisor,
tt_output,
ignore_index,
reduction_mean,
compute_kernel_config=compute_kernel_config,
)

# run backward
(tt_input, tt_target, tt_weight, _, tt_output) = get_tt_tensors_4d(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
output_grad = torch.randn_like(torch_loss)
torch_loss.backward(output_grad)

tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1, 1, 1), ttl.tensor.DataType.BFLOAT16)
ttl.tensor.Tensor(output_grad.reshape(1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
Expand All @@ -330,14 +270,14 @@ def test_moreh_nll_loss_4d_backward(shape, ignore_index, reduction_mean, has_out
)

tt_input_grad = ttl.operations.primary.moreh_nll_loss_backward(
tt_input,
tt_target,
tt_weight,
tt_divisor,
tt_output_grad,
tt_input_grad if has_output else None,
tt_input_grad,
ignore_index,
reduction_mean,
compute_kernel_config=compute_kernel_config,
)
tt_input_grad_to_cpu = tt_input_grad.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile(shape).to_torch()

Expand All @@ -350,59 +290,73 @@ def test_moreh_nll_loss_4d_backward(shape, ignore_index, reduction_mean, has_out
assert passing


@pytest.mark.parametrize("shape", ([1, 2], [3, 4], [12, 6]))
@pytest.mark.parametrize("ignore_index", [0, -1])
@pytest.mark.parametrize(
"shape",
[
(2, 3),
(2, 3, 4),
(2, 3, 5, 4),
],
)
@pytest.mark.parametrize("reduction_mean", [True, False])
def test_moreh_nll_loss_2d_backward(shape, ignore_index, reduction_mean, device, use_program_cache):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
nll_loss = torch.nn.NLLLoss(
weight=torch_weight, ignore_index=ignore_index, reduction="mean" if reduction_mean else "sum"
)
torch_loss = nll_loss(torch_input, torch_target)
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weight, device, use_program_cache):
torch.manual_seed(0)

(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
ignore_index = 0

if reduction_mean == False:
tt_divisor = None
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input, tt_target, tt_weight, tt_divisor, tt_output, ignore_index, reduction_mean
)
tt_loss_to_cpu = tt_loss.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile([1, 1, 1, 1]).to_torch().reshape([1])
# run TT
for _ in range(2):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

# run backward
(tt_input, tt_target, tt_weight, _, tt_output) = get_tt_tensors_2d(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
if reduction_mean == False:
tt_divisor = None
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input, tt_target, tt_weight, tt_divisor, tt_output, ignore_index, reduction_mean
)

output_grad = torch.randn_like(torch_loss)
torch_loss.backward(output_grad)
output_grad = torch.rand([])

tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1, 1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)
tt_input_grad = (
ttl.tensor.Tensor(torch_input.unsqueeze(-1).unsqueeze(-1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)
tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)

ttl.operations.primary.moreh_nll_loss_backward(
tt_input, tt_target, tt_weight, tt_divisor, tt_output_grad, tt_input_grad, ignore_index, reduction_mean
)
tt_input_grad_to_cpu = (
tt_input_grad.cpu()
.to(ttl.tensor.Layout.ROW_MAJOR)
.unpad_from_tile(tt_input_grad.shape_without_padding())
.to_torch()
.reshape(shape)
tt_input_grad = (
ttl.tensor.Tensor(torch_input, ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)

tt_input_grad = ttl.operations.primary.moreh_nll_loss_backward(
tt_target,
tt_weight,
tt_divisor,
tt_output_grad,
tt_input_grad,
ignore_index,
reduction_mean,
)

tt_input_grad_to_cpu = tt_input_grad.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile(shape).to_torch()

# run torch
nll_loss = torch.nn.NLLLoss(
weight=torch_weight, ignore_index=ignore_index, reduction="mean" if reduction_mean else "sum"
)
torch_loss = nll_loss(torch_input, torch_target)

torch_loss.backward(output_grad)

# compare result
rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_to_cpu, pcc=0.999, rtol=rtol, atol=atol)

Expand Down
11 changes: 10 additions & 1 deletion tt_eager/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/tile_move_copy.h"


// Deprecated
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }

Expand All @@ -44,7 +46,7 @@ ALWI void copy_tile_init_with_dt(uint32_t icb)
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format_srca(icb);
#endif
copy_tile_init();
copy_tile_to_dst_init_short(icb);
}

ALWI void add_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
Expand All @@ -68,6 +70,13 @@ ALWI void mul_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
mul_tiles_init(icb0, icb1);
}

ALWI void mul_tiles_bcast_scalar_init_short_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_tiles_bcast_scalar_init_short(icb0, icb1);
}

class ArgFetcher {
private:
int arg_idx = 0;
Expand Down
Loading
Loading