diff --git a/tests/sweep_framework/sweeps/eltwise/unary/hardtanh/hardtanh_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/hardtanh/hardtanh_pytorch2.py index 2f4735e0944..9446ba5e1be 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/hardtanh/hardtanh_pytorch2.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/hardtanh/hardtanh_pytorch2.py @@ -6,7 +6,6 @@ from functools import partial import torch -import random import ttnn from tests.sweep_framework.sweep_utils.utils import gen_shapes, gen_low_high_scalars from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt @@ -14,10 +13,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) parameters = { "nightly": { @@ -152,8 +147,7 @@ def run( *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) torch_input_tensor_a = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype @@ -163,7 +157,7 @@ def run( max_val = input_specs.get("max_val") golden_function = ttnn.get_golden_function(ttnn.hardtanh) - torch_output_tensor = golden_function(torch_input_tensor_a, min=min_val, max=max_val) + torch_output_tensor = golden_function(torch_input_tensor_a, min_val=min_val, max_val=max_val) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, @@ -174,7 +168,7 @@ def run( ) start_time = start_measuring_time() - result = ttnn.hardtanh(input_tensor_a, min=min_val, max=max_val, memory_config=output_memory_config) + result = ttnn.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val, memory_config=output_memory_config) output_tensor = ttnn.to_torch(result) e2e_perf = stop_measuring_time(start_time) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 4f6bc6ce792..b7341753cc4 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -235,7 +235,7 @@ def eltwise_hardtanh( **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttnn.hardtanh(t0, min=low, max=high, memory_config=output_mem_config) + t1 = ttnn.hardtanh(t0, min_val=low, max_val=high, memory_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 188845637f5..04011c73f1d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -1845,8 +1845,8 @@ void py_module(py::module& module) { detail::bind_unary_composite_floats_with_default( module, ttnn::hardtanh, - "min", "min value", -1.0f, - "max", "max value", 1.0f); + "min_val", "min value", -1.0f, + "max_val", "max value", 1.0f); detail::bind_unary_composite_optional_floats_with_default( module, diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index ef032bb695c..b6cecf2b1e1 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -78,7 +78,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): "digamma": torch.digamma, "hardswish": torch.nn.functional.hardswish, "hardsigmoid": torch.nn.functional.hardsigmoid, - "hardtanh": torch.nn.functional.hardtanh, "lgamma": torch.lgamma, "log1p": torch.log1p, "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), @@ -163,7 +162,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn.digamma, ttnn.hardswish, ttnn.hardsigmoid, - ttnn.hardtanh, ttnn.lgamma, ttnn.log1p, ttnn.mish, @@ -241,6 +239,15 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs): ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu) +def _golden_function_hardtanh(input_tensor_a, *args, min_val=-1.0, max_val=1.0, **kwargs): + import torch + + return torch.nn.functional.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val) + + +ttnn.attach_golden_function(ttnn.hardtanh, golden_function=_golden_function_hardtanh) + + def _golden_function_leaky_relu(input_tensor_a, *args, negative_slope=0.01, **kwargs): import torch