Skip to content

Commit

Permalink
#14734: fix hardtanh golden function (#14736)
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored Nov 6, 2024
1 parent d400890 commit caaf3b9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@
from functools import partial

import torch
import random
import ttnn
from tests.sweep_framework.sweep_utils.utils import gen_shapes, gen_low_high_scalars
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)

parameters = {
"nightly": {
Expand Down Expand Up @@ -152,8 +147,7 @@ def run(
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)
torch.manual_seed(0)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
Expand All @@ -163,7 +157,7 @@ def run(
max_val = input_specs.get("max_val")

golden_function = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_function(torch_input_tensor_a, min=min_val, max=max_val)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val=min_val, max_val=max_val)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand All @@ -174,7 +168,7 @@ def run(
)

start_time = start_measuring_time()
result = ttnn.hardtanh(input_tensor_a, min=min_val, max=max_val, memory_config=output_memory_config)
result = ttnn.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def eltwise_hardtanh(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.hardtanh(t0, min=low, max=high, memory_config=output_mem_config)
t1 = ttnn.hardtanh(t0, min_val=low, max_val=high, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1845,8 +1845,8 @@ void py_module(py::module& module) {
detail::bind_unary_composite_floats_with_default(
module,
ttnn::hardtanh,
"min", "min value", -1.0f,
"max", "max value", 1.0f);
"min_val", "min value", -1.0f,
"max_val", "max value", 1.0f);

detail::bind_unary_composite_optional_floats_with_default(
module,
Expand Down
11 changes: 9 additions & 2 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
"digamma": torch.digamma,
"hardswish": torch.nn.functional.hardswish,
"hardsigmoid": torch.nn.functional.hardsigmoid,
"hardtanh": torch.nn.functional.hardtanh,
"lgamma": torch.lgamma,
"log1p": torch.log1p,
"mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)),
Expand Down Expand Up @@ -163,7 +162,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
ttnn.digamma,
ttnn.hardswish,
ttnn.hardsigmoid,
ttnn.hardtanh,
ttnn.lgamma,
ttnn.log1p,
ttnn.mish,
Expand Down Expand Up @@ -241,6 +239,15 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs):
ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu)


def _golden_function_hardtanh(input_tensor_a, *args, min_val=-1.0, max_val=1.0, **kwargs):
import torch

return torch.nn.functional.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val)


ttnn.attach_golden_function(ttnn.hardtanh, golden_function=_golden_function_hardtanh)


def _golden_function_leaky_relu(input_tensor_a, *args, negative_slope=0.01, **kwargs):
import torch

Expand Down

0 comments on commit caaf3b9

Please sign in to comment.