Skip to content

Commit

Permalink
#8282: modified the callback test to receive random input every time
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Jun 3, 2024
1 parent a02c2a2 commit 2cffe57
Showing 1 changed file with 57 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def get_compute_kernel_options(fp32_dest_acc_en):


def get_torch_tensors(shape):
torch.manual_seed(0)

C = shape[1]
target_shape = shape[:1] + shape[2:]

Expand Down Expand Up @@ -121,6 +119,8 @@ def get_tt_tensors(torch_input, torch_target, torch_weight, torch_divisor, torch
@pytest.mark.parametrize("none_weight", [True, False])
@pytest.mark.parametrize("fp32_dest_acc_en", fp32_dest_acc_en, ids=fp32_dest_acc_en_ids)
def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, fp32_dest_acc_en, device, use_program_cache):
torch.manual_seed(0)

compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en)

(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
Expand Down Expand Up @@ -167,17 +167,17 @@ def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, fp32_dest_a
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_program_cache):
ignore_index = 1
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)

if none_weight:
torch_weight = None

nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction=reduction)
torch_loss = torch.tensor([nll_loss(torch_input, torch_target)])
torch.manual_seed(0)

ignore_index = 1
reduction_mean = reduction == "mean"

# run TT
for _ in range(2):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
Expand All @@ -193,6 +193,12 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog
)

tt_loss_to_cpu = tt_loss.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile([1, 1]).to_torch().reshape([1])

# run torch
nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction=reduction)
torch_loss = torch.tensor([nll_loss(torch_input, torch_target)])

# compare result
rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_loss, tt_loss_to_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing (param)={passing}")
Expand All @@ -217,6 +223,8 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog
def test_moreh_nll_loss_backward(
shape, ignore_index, reduction_mean, none_weight, fp32_dest_acc_en, device, use_program_cache
):
torch.manual_seed(0)

compute_kernel_config = get_compute_kernel_options(fp32_dest_acc_en)

(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
Expand Down Expand Up @@ -249,7 +257,7 @@ def test_moreh_nll_loss_backward(
torch_loss.backward(output_grad)

tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1, 1, 1), ttl.tensor.DataType.BFLOAT16)
ttl.tensor.Tensor(output_grad.reshape(1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
Expand Down Expand Up @@ -293,44 +301,41 @@ def test_moreh_nll_loss_backward(
@pytest.mark.parametrize("reduction_mean", [True, False])
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weight, device, use_program_cache):
torch.manual_seed(0)

ignore_index = 0

(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None
# run TT
for _ in range(2):
(torch_input, torch_target, torch_weight, torch_divisor, torch_output) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

nll_loss = torch.nn.NLLLoss(
weight=torch_weight, ignore_index=ignore_index, reduction="mean" if reduction_mean else "sum"
)
torch_loss = nll_loss(torch_input, torch_target)
(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
if reduction_mean == False:
tt_divisor = None
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input, tt_target, tt_weight, tt_divisor, tt_output, ignore_index, reduction_mean
)

(tt_input, tt_target, tt_weight, tt_divisor, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_divisor, torch_output, device
)
if reduction_mean == False:
tt_divisor = None
tt_loss = ttl.operations.primary.moreh_nll_loss(
tt_input, tt_target, tt_weight, tt_divisor, tt_output, ignore_index, reduction_mean
)
output_grad = torch.rand([])

# run backward
output_grad = torch.randn_like(torch_loss)
torch_loss.backward(output_grad)
tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)

tt_output_grad = (
ttl.tensor.Tensor(output_grad.reshape(1, 1, 1, 1), ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)
tt_input_grad = (
ttl.tensor.Tensor(torch_input, ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)
tt_input_grad = (
ttl.tensor.Tensor(torch_input, ttl.tensor.DataType.BFLOAT16)
.pad_to_tile(float("nan"))
.to(ttl.tensor.Layout.TILE)
.to(device)
)

for _ in range(2):
tt_input_grad = ttl.operations.primary.moreh_nll_loss_backward(
tt_target,
tt_weight,
Expand All @@ -340,8 +345,18 @@ def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weigh
ignore_index,
reduction_mean,
)

tt_input_grad_to_cpu = tt_input_grad.cpu().to(ttl.tensor.Layout.ROW_MAJOR).unpad_from_tile(shape).to_torch()

# run torch
nll_loss = torch.nn.NLLLoss(
weight=torch_weight, ignore_index=ignore_index, reduction="mean" if reduction_mean else "sum"
)
torch_loss = nll_loss(torch_input, torch_target)

torch_loss.backward(output_grad)

# compare result
rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_to_cpu, pcc=0.999, rtol=rtol, atol=atol)

Expand Down

0 comments on commit 2cffe57

Please sign in to comment.