Skip to content

Commit

Permalink
#5044: add optional output to BW ops EQ, add, addalpha, mul
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jun 2, 2024
1 parent 354370a commit 6f8479b
Show file tree
Hide file tree
Showing 7 changed files with 2,216 additions and 957 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,49 @@ def test_bw_add(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_add_with_opt_output(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device)
input_grad = None
other_grad = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = tt_lib.tensor.add_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.add(in_data, other_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,51 @@ def test_bw_addalpha(input_shapes, alpha, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("alpha", [0.05, 2.0, 1.5, 0.12])
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_addalpha_with_opt_output(input_shapes, alpha, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device)
input_grad = None
other_grad = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = tt_lib.tensor.addalpha_bw(
grad_tensor,
input_tensor,
other_tensor,
alpha,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.add(in_data, other_data, alpha=alpha)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,54 @@
)
def test_bw_binary_eq(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)

tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(grad_tensor, input_tensor)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y, pt_y]
tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(grad_tensor, input_tensor, other_tensor)
in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

golden_tensor = [in_grad, other_grad]
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
input_grad = None
other_grad = None
if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
)

in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

golden_tensor = [in_grad, other_grad]

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,49 @@ def test_bw_mul(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_mul_opt_output(input_shapes, device, are_required_outputs):
in_data_a, input_tensor_a = data_gen_with_range(input_shapes, -90, 80, device, True)
in_data_b, input_tensor_b = data_gen_with_range(input_shapes, -70, 90, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -60, 60, device)
input_a_grad = None
input_b_grad = None

if are_required_outputs[0]:
_, input_a_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, input_b_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = tt_lib.tensor.mul_bw(
grad_tensor,
input_tensor_a,
input_tensor_b,
are_required_outputs=are_required_outputs,
input_a_grad=input_a_grad,
input_b_grad=input_b_grad,
)

in_data_a.retain_grad()
in_data_b.retain_grad()

pyt_y = torch.mul(in_data_a, in_data_b)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data_a.grad, in_data_b.grad]

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Loading

0 comments on commit 6f8479b

Please sign in to comment.