Skip to content

Commit

Permalink
#5044: Add optional output to addalpha
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jun 2, 2024
1 parent 354370a commit 9d0c6b9
Show file tree
Hide file tree
Showing 6 changed files with 936 additions and 369 deletions.
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@
"tt_op": tt_lib_ops.eltwise_addalpha,
"pytorch_op": pytorch_ops.addalpha,
},
"eltwise-addalpha-optional": {
"tt_op": tt_lib_ops.eltwise_addalpha_optional,
"pytorch_op": pytorch_ops.addalpha,
},
"lamb-optimizer": {
"tt_op": tt_lib_ops.lamb_optimizer,
"pytorch_op": pytorch_ops.lamb_optimizer,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from functools import partial
import tt_lib as ttl
import numpy as np


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0

shapes = [
[[1, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 32, 32], [32, 1, 32, 32]], # Single core
[[64, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core
]

input_mem_cfgs = generation_funcs.supported_mem_configs
output_mem_cfgs = generation_funcs.supported_mem_configs

if is_wormhole_b0():
shapes = [
shapes[0],
]
input_mem_cfgs = [
input_mem_cfgs[0],
]


@pytest.mark.parametrize(
"input_shapes",
shapes,
)
@pytest.mark.parametrize("input_mem_config", input_mem_cfgs)
@pytest.mark.parametrize("output_mem_config", output_mem_cfgs)
@pytest.mark.parametrize("fn_kind", ["addalpha"])
def test_run_addalpha(
input_shapes,
fn_kind,
input_mem_config,
output_mem_config,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.float32)
] * len(input_shapes)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config],
"output_mem_config": output_mem_config,
"alpha": np.random.randint(1, 100),
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{fn_kind}",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)


shapes_w_output = [
[[1, 1, 32, 32], [1, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 32, 32], [32, 1, 32, 32], [32, 1, 32, 32]], # Single core
[[64, 1, 32, 32], [1, 1, 32, 32], [64, 1, 32, 32]], # Single core
[[1, 1, 320, 384], [1, 1, 320, 384], [1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384], [1, 3, 320, 384], [1, 3, 320, 384]], # Multi core
]


@pytest.mark.parametrize(
"input_shapes",
shapes_w_output,
)
@pytest.mark.parametrize("input_mem_config", input_mem_cfgs)
@pytest.mark.parametrize("output_mem_config", output_mem_cfgs)
@pytest.mark.parametrize("fn_kind", ["addalpha"])
def test_run_addalpha_optional_output(
input_shapes,
fn_kind,
input_mem_config,
output_mem_config,
device,
function_level_defaults,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.float32)
] * len(input_shapes)
datagen_func.append(
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-80, high=80), torch.bfloat16)
)
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update(
{
"input_mem_config": [input_mem_config, input_mem_config, input_mem_config],
"output_mem_config": output_mem_config,
"alpha": np.random.randint(1, 100),
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
f"eltwise-{fn_kind}-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
22 changes: 22 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,28 @@ def eltwise_addalpha(
return tt2torch_tensor(t2)


@setup_host_and_device
def eltwise_addalpha_optional(
x,
y,
z,
*args,
alpha,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[1], input_mem_config[1], dtype[1])
ttl.tensor.addalpha(t0, t1, alpha, output_tensor=t2, output_mem_config=output_mem_config)

return tt2torch_tensor(t2)


@setup_host_and_device
def eltwise_div(
x,
Expand Down
Loading

0 comments on commit 9d0c6b9

Please sign in to comment.