Skip to content

Commit

Permalink
#14863: Align ttnn.hardtanh arguments with pytorch (#14889)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #14863

### Problem description
Align ttnn.hardtanh arguments with pytorch

### What's changed
Changed kw_args to positional args in pybind and golden function

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/11744299952
https://github.com/tenstorrent/tt-metal/actions/runs/11767253102
- [ ] Nightly fd
https://github.com/tenstorrent/tt-metal/actions/runs/11767253759
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Nov 10, 2024
1 parent 486862f commit c8f7883
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run(
max_val = input_specs.get("max_val")

golden_function = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val=min_val, max_val=max_val)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val, max_val)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand All @@ -168,7 +168,7 @@ def run(
)

start_time = start_measuring_time()
result = ttnn.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val, memory_config=output_memory_config)
result = ttnn.hardtanh(input_tensor_a, min_val, max_val, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def eltwise_hardtanh(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.hardtanh(t0, min_val=low, max_val=high, memory_config=output_mem_config)
t1 = ttnn.hardtanh(t0, low, high, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
62 changes: 62 additions & 0 deletions tests/ttnn/unit_tests/operations/test_hardtanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
"shapes",
[[1, 1, 32, 32], [64, 64], [2, 2, 3, 256, 256]],
)
def test_hardtanh_default(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.randn(shapes[0], dtype=torch.bfloat16) * 10

golden_fn = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_fn(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.hardtanh(input_tensor_a)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.9999


@pytest.mark.parametrize(
"shapes",
[[1, 1, 32, 32], [64, 64], [2, 2, 3, 256, 256]],
)
@pytest.mark.parametrize(
"min",
[0.25, 0.5, 0.66, -1],
)
@pytest.mark.parametrize(
"max",
[1, 2.5, 3, 6.6],
)
def test_hardtanh_args(device, shapes, min, max):
torch.manual_seed(0)

torch_input_tensor_a = torch.randn(shapes[0], dtype=torch.bfloat16) * 10

golden_fn = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_fn(torch_input_tensor_a, min, max)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.hardtanh(input_tensor_a, min, max)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.9999
83 changes: 80 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,83 @@ void bind_unary_composite_floats_with_default(
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_hardtanh(
py::module& module,
const unary_operation_t& operation,
const std::string& parameter_name_a,
const std::string& parameter_a_doc,
float parameter_a_value,
const std::string& parameter_name_b,
const std::string& parameter_b_doc,
float parameter_b_value,
const std::string& supported_dtype = "BFLOAT16",
const std::string& info_doc = "") {
auto doc = fmt::format(
R"doc(
Performs {0} function on :attr:`input_tensor`, :attr:`{2}`, :attr:`{5}`.
Args:
input_tensor (ttnn.Tensor): the input tensor.
{2} (float): {3}. Defaults to `{4}`.
{5} (float): {6}. Defaults to `{7}`.
Keyword args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {8}
- TILE
- 2, 3, 4
{9}
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2} = {4}, {5} = {7})
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name_a,
parameter_a_doc,
parameter_a_value,
parameter_name_b,
parameter_b_doc,
parameter_b_value,
supported_dtype,
info_doc);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const ttnn::Tensor& input_tensor,
float parameter_a,
float parameter_b,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, parameter_a, parameter_b, memory_config);
},
py::arg("input_tensor"),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

//OpHandler_two_float_with_default
template <typename unary_operation_t>
void bind_unary_composite_int(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& description) {
Expand Down Expand Up @@ -1226,10 +1303,10 @@ void bind_unary_composite_floats(
Args:
input_tensor (ttnn.Tensor): the input tensor.
Keyword args:
{2} (float): {3}.
{4} (float): {5}.
Keyword args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Returns:
Expand Down Expand Up @@ -1846,7 +1923,7 @@ void py_module(py::module& module) {
"scale", "Scale value", 1.0f/6.0f,
"shift", "Shift value", 0.5f);

detail::bind_unary_composite_floats_with_default(
detail::bind_hardtanh(
module,
ttnn::hardtanh,
"min_val", "min value", -1.0f,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs):
ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu)


def _golden_function_hardtanh(input_tensor_a, *args, min_val=-1.0, max_val=1.0, **kwargs):
def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs):
import torch

return torch.nn.functional.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val)
return torch.nn.functional.hardtanh(input_tensor_a, min_val, max_val)


ttnn.attach_golden_function(ttnn.hardtanh, golden_function=_golden_function_hardtanh)
Expand Down

0 comments on commit c8f7883

Please sign in to comment.