Skip to content

Commit

Permalink
#9217: Add cq_id to add_bw, mul_bw and dependency ops
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jun 16, 2024
1 parent a3c0d04 commit 43b8166
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 325 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_run_eltwise_binary_ops(
{
"dtype": [in0_dtype, in1_dtype, in2_dtype],
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
"queue_id": False,
}
)
comparison_func = comparison_funcs.comp_pcc
Expand Down Expand Up @@ -117,14 +118,12 @@ def test_run_eltwise_binary_bias_ops(
)

@pytest.mark.parametrize("cmp_kind", ["lt", "gt", "lte", "gte", "ne", "eq"])
@pytest.mark.parametrize("pass_queue_id", [True, False])
def test_run_eltwise_binary_cmp_ops(
self,
input_shapes,
input_mem_config,
cmp_kind,
device,
pass_queue_id,
function_level_defaults,
):
datagen_func = [
Expand All @@ -137,15 +136,8 @@ def test_run_eltwise_binary_cmp_ops(
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
"queue_id": "skip",
}
)
if cmp_kind == "eq":
test_args.update(
{
"queue_id": pass_queue_id,
}
)

comparison_func = comparison_funcs.comp_equal
run_single_pytorch_test(
Expand Down
38 changes: 9 additions & 29 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def eltwise_addalpha_optional(
cq_id = 0

if queue_id:
ttl.tensor.addalpha(cq_id, t0, t1, alpha, output_tensor=t2)
ttl.tensor.addalpha(t0, t1, alpha, output_tensor=t2, queue_id=cq_id)
else:
ttl.tensor.addalpha(t0, t1, alpha, output_tensor=t2)

Expand Down Expand Up @@ -1643,7 +1643,8 @@ def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3])
ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3)
cq_id = 0
ttl.tensor.where(t0, t1, t2, output_tensor=t3, queue_id=cq_id)

return tt2torch_tensor(t3)

Expand All @@ -1654,7 +1655,8 @@ def where_scalar_optional(
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1])
ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3)
cq_id = 0
ttl.tensor.where(t0, scalar_true, scalar_false, output_tensor=t3, queue_id=cq_id)

return tt2torch_tensor(t3)

Expand Down Expand Up @@ -2582,7 +2584,9 @@ def binary_op(
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])

ttl_tensor_binop(t0, t1, output_tensor=t2)
cq_id = 0

ttl_tensor_binop(t0, t1, output_tensor=t2, queue_id=cq_id)

return tt2torch_tensor(t2)

Expand All @@ -2595,6 +2599,7 @@ def binary_op(
eltwise_bias_gelu_optional = make_binary_op_optional_output(ttnn.bias_gelu)
eltwise_squared_difference_optional = make_binary_op_optional_output(ttnn.squared_difference)
eltwise_ne_optional = make_binary_op_optional_output(ttnn.ne)
eltwise_eq_optional = make_binary_op_optional_output(ttnn.eq)
eltwise_gt_optional = make_binary_op_optional_output(ttnn.gt)
eltwise_lt_optional = make_binary_op_optional_output(ttnn.lt)
eltwise_gte_optional = make_binary_op_optional_output(ttnn.ge)
Expand All @@ -2606,31 +2611,6 @@ def binary_op(
eltwise_logical_or_optional = make_binary_op_optional_output(ttnn.logical_or)


def eltwise_eq_optional(
x,
y,
z,
*args,
device,
dtype,
layout,
input_mem_config,
queue_id,
**kwargs,
):
cq_id = 0
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])

if queue_id == True:
ttnn.eq(t0, t1, output_tensor=t2, queue_id=cq_id)
else:
ttnn.eq(t0, t1, output_tensor=t2)

return tt2torch_tensor(t2)


################################################
#################### Tensor ####################
################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def test_bw_add_with_opt_output(input_shapes, device, are_required_outputs):
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

tt_output_tensor_on_device = tt_lib.tensor.add_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)

in_data.retain_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ def test_bw_addalpha_with_opt_output(input_shapes, alpha, device, are_required_o
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0
tt_output_tensor_on_device = tt_lib.tensor.addalpha_bw(
grad_tensor,
input_tensor,
other_tensor,
alpha,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)

in_data.retain_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,22 @@ def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)
input_grad = None
other_grad = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

queue_id = 0
cq_id = 0

tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(
queue_id,
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
queue_id=cq_id,
)

in_grad = torch.zeros_like(in_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,40 @@ def test_bw_mul(input_shapes, device):
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_mul_opt_output(input_shapes, device, are_required_outputs):
@pytest.mark.parametrize("pass_queue_id", [True, False])
def test_bw_mul_opt_output(input_shapes, device, are_required_outputs, pass_queue_id):
in_data_a, input_tensor_a = data_gen_with_range(input_shapes, -90, 80, device, True)
in_data_b, input_tensor_b = data_gen_with_range(input_shapes, -70, 90, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -60, 60, device)
input_a_grad = None
input_b_grad = None
tt_output_tensor_on_device = None

if are_required_outputs[0]:
_, input_a_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, input_b_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = tt_lib.tensor.mul_bw(
grad_tensor,
input_tensor_a,
input_tensor_b,
are_required_outputs=are_required_outputs,
input_a_grad=input_a_grad,
input_b_grad=input_b_grad,
)
cq_id = 0
if pass_queue_id:
tt_output_tensor_on_device = tt_lib.tensor.mul_bw(
grad_tensor,
input_tensor_a,
input_tensor_b,
are_required_outputs=are_required_outputs,
input_a_grad=input_a_grad,
input_b_grad=input_b_grad,
queue_id=cq_id,
)
else:
tt_output_tensor_on_device = tt_lib.tensor.mul_bw(
grad_tensor,
input_tensor_a,
input_tensor_b,
are_required_outputs=are_required_outputs,
input_a_grad=input_a_grad,
input_b_grad=input_b_grad,
)

in_data_a.retain_grad()
in_data_b.retain_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,60 @@ def test_bw_where(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_where_output(input_shapes, are_required_outputs, device):
condition_data = torch.zeros(input_shapes, dtype=torch.bool)
condition_data.view(-1)[::2] = True

condition_tensor = (
tt_lib.tensor.Tensor(condition_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -1, 1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -4, 4, device)
input_grad = None
other_grad = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

tt_output_tensor_on_device = tt_lib.tensor.where_bw(
grad_tensor,
condition_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.where(condition_data, in_data, other_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Loading

0 comments on commit 43b8166

Please sign in to comment.