Skip to content

Commit

Permalink
#5044: Add optional output tensor and remove autoformat in eltwise bi…
Browse files Browse the repository at this point in the history
…nary ops
  • Loading branch information
KalaivaniMCW committed May 30, 2024
1 parent e75540b commit a9872fc
Show file tree
Hide file tree
Showing 34 changed files with 1,855 additions and 204 deletions.
65 changes: 65 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,71 @@
"tt_op": tt_lib_ops.eltwise_isclose,
"pytorch_op": pytorch_ops.isclose,
},
# Eltwise binary with optional output
"eltwise-ne-optional": {
"tt_op": tt_lib_ops.eltwise_ne_optional,
"pytorch_op": pytorch_ops.ne,
},
"eltwise-bias_gelu-optional": {
"tt_op": tt_lib_ops.eltwise_bias_gelu_optional,
"pytorch_op": pytorch_ops.bias_gelu,
},
"eltwise-eq-optional": {
"tt_op": tt_lib_ops.eltwise_eq_optional,
"pytorch_op": pytorch_ops.eq,
},
"eltwise-lt-optional": {
"tt_op": tt_lib_ops.eltwise_lt_optional,
"pytorch_op": pytorch_ops.lt,
},
"eltwise-gt-optional": {
"tt_op": tt_lib_ops.eltwise_gt_optional,
"pytorch_op": pytorch_ops.gt,
},
"eltwise-gte-optional": {
"tt_op": tt_lib_ops.eltwise_gte_optional,
"pytorch_op": pytorch_ops.gte,
},
"eltwise-lte-optional": {
"tt_op": tt_lib_ops.eltwise_lte_optional,
"pytorch_op": pytorch_ops.lte,
},
"eltwise-add-optional": {
"tt_op": tt_lib_ops.eltwise_add_optional,
"pytorch_op": pytorch_ops.add,
},
"eltwise-sub-optional": {
"tt_op": tt_lib_ops.eltwise_sub_optional,
"pytorch_op": pytorch_ops.sub,
},
"eltwise-mul-optional": {
"tt_op": tt_lib_ops.eltwise_mul_optional,
"pytorch_op": pytorch_ops.mul,
},
"eltwise-squared_difference-optional": {
"tt_op": tt_lib_ops.eltwise_squared_difference_optional,
"pytorch_op": pytorch_ops.squared_difference,
},
"eltwise-ldexp-optional": {
"tt_op": tt_lib_ops.eltwise_ldexp_optional,
"pytorch_op": pytorch_ops.ldexp,
},
"eltwise-logaddexp-optional": {
"tt_op": tt_lib_ops.eltwise_logaddexp_optional,
"pytorch_op": pytorch_ops.logaddexp,
},
"eltwise-logaddexp2-optional": {
"tt_op": tt_lib_ops.eltwise_logaddexp2_optional,
"pytorch_op": pytorch_ops.logaddexp2,
},
"eltwise-logical_or-optional": {
"tt_op": tt_lib_ops.eltwise_logical_or_optional,
"pytorch_op": pytorch_ops.logical_or,
},
"eltwise-logical_and-optional": {
"tt_op": tt_lib_ops.eltwise_logical_and_optional,
"pytorch_op": pytorch_ops.logical_and,
},
# Eltwise ternary
"eltwise-arange": {
"tt_op": tt_lib_ops.arange,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from functools import partial
import tt_lib as ttl

from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0

shapes = [
[[1, 1, 32, 32], [1, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 32, 32], [32, 1, 32, 32], [32, 1, 32, 32]], # Single core
[[64, 1, 32, 32], [1, 1, 32, 32], [64, 1, 32, 32]], # Single core
[[1, 1, 320, 384], [1, 1, 320, 384], [1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384], [1, 3, 320, 384], [1, 3, 320, 384]], # Multi core
]

input_mem_cfgs = generation_funcs.supported_mem_configs

if is_wormhole_b0():
shapes = [
shapes[0],
]
input_mem_cfgs = [
input_mem_cfgs[0],
]


@pytest.mark.parametrize(
"input_shapes",
shapes,
)
@pytest.mark.parametrize("input_mem_config", input_mem_cfgs)
class TestEltwiseBinary:
@pytest.mark.parametrize("fn_kind", ["add", "sub", "mul", "squared_difference"])
@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in2_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
def test_run_eltwise_binary_ops(
self,
input_shapes,
fn_kind,
in0_dtype,
in1_dtype,
in2_dtype,
input_mem_config,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.float32)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"dtype": [in0_dtype, in1_dtype, in2_dtype],
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{fn_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize(
"fn_kind",
[
"bias_gelu",
],
)
def test_run_eltwise_binary_bias_ops(
self,
input_shapes,
fn_kind,
input_mem_config,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)

test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{fn_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize("cmp_kind", ["lt", "gt", "lte", "gte", "ne", "eq"])
def test_run_eltwise_binary_cmp_ops(
self,
input_shapes,
input_mem_config,
cmp_kind,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_equal
run_single_pytorch_test(
f"eltwise-{cmp_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize(
"log_kind, input_range",
(
("logaddexp", {"low": -80, "high": 80}),
("ldexp", {"low": -60, "high": 60}),
("logaddexp2", {"low": -60, "high": 100}),
),
)
def test_run_eltwise_binary_log_ops(
self, input_shapes, input_mem_config, log_kind, input_range, device, function_level_defaults
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, **input_range), torch.bfloat16)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{log_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize("logical_kind", ["logical_and", "logical_or"])
def test_run_eltwise_binary_logical_ops(
self,
input_shapes,
input_mem_config,
logical_kind,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.int32)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_equal
run_single_pytorch_test(
f"eltwise-{logical_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)

@pytest.mark.parametrize(
"log_kind, input_range",
(
("logaddexp", {"low": -80, "high": 80}),
("ldexp", {"low": -60, "high": 60}),
("logaddexp2", {"low": -60, "high": 100}),
),
)
def test_run_eltwise_binary_log_ops(
self, input_shapes, input_mem_config, log_kind, input_range, device, function_level_defaults
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, **input_range), torch.bfloat16)
] * (len(input_shapes) - 1)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=10), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{log_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
42 changes: 42 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,48 @@ def binary_op(
eltwise_logical_not_unary = make_unary_op(ttl.tensor.logical_not_unary)
eltwise_i0 = make_unary_op(ttl.tensor.i0)


def make_binary_op_optional_output(ttl_tensor_binop):
@setup_host_and_device
def binary_op(
x,
y,
z,
*args,
device,
dtype,
layout,
input_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
ttl_tensor_binop(t0, t1, output_tensor=t2)

return tt2torch_tensor(t2)

return binary_op


eltwise_add_optional = make_binary_op_optional_output(ttl.tensor.add)
eltwise_sub_optional = make_binary_op_optional_output(ttl.tensor.sub)
eltwise_mul_optional = make_binary_op_optional_output(ttl.tensor.mul)
eltwise_bias_gelu_optional = make_binary_op_optional_output(ttl.tensor.bias_gelu)
eltwise_squared_difference_optional = make_binary_op_optional_output(ttl.tensor.squared_difference)
eltwise_ne_optional = make_binary_op_optional_output(ttl.tensor.ne)
eltwise_eq_optional = make_binary_op_optional_output(ttl.tensor.eq)
eltwise_gt_optional = make_binary_op_optional_output(ttl.tensor.gt)
eltwise_lt_optional = make_binary_op_optional_output(ttl.tensor.lt)
eltwise_gte_optional = make_binary_op_optional_output(ttl.tensor.gte)
eltwise_lte_optional = make_binary_op_optional_output(ttl.tensor.lte)
eltwise_ldexp_optional = make_binary_op_optional_output(ttl.tensor.ldexp)
eltwise_logaddexp_optional = make_binary_op_optional_output(ttl.tensor.logaddexp)
eltwise_logaddexp2_optional = make_binary_op_optional_output(ttl.tensor.logaddexp2)
eltwise_logical_and_optional = make_binary_op_optional_output(ttl.tensor.logical_and)
eltwise_logical_or_optional = make_binary_op_optional_output(ttl.tensor.logical_or)


################################################
#################### Tensor ####################
################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
(torch.Size([4, 33, 32, 32])), # 20
(torch.Size([4, 63, 32, 32])), # 21
(torch.Size([4, 64, 32, 32])), # 22
(torch.Size([32, 64, 32, 32])), # 23
(torch.Size([32, 64, 64, 64])), # 23
),
)
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit a9872fc

Please sign in to comment.