Skip to content

Commit

Permalink
#14099: ttnn.clip, clamp interface to follow Pytorch (#14127)
Browse files Browse the repository at this point in the history
#14099: ttnn.clip interface to follow Pytorch
  • Loading branch information
VirdhatchaniKN authored Oct 25, 2024
1 parent 0194005 commit df85fd3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def clip(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.clip(t0, min=low, max=high, memory_config=output_mem_config)
t1 = ttnn.clip(t0, low, high, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ def run_activation_test_scalarBC_key(device, h, w, scalar1, scalar2, ttnn_functi
torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn_function)

torch_output_tensor = golden_function(torch_input_tensor_a, min=scalar1, max=scalar2)
torch_output_tensor = golden_function(torch_input_tensor_a, scalar1, scalar2)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn_function(input_tensor_a, min=scalar1, max=scalar2)
output_tensor = ttnn_function(input_tensor_a, scalar1, scalar2)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
Expand Down
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/eltwise/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def test_unary_composite_clamp_ttnn(input_shapes, min, max, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
if min is None and max is None:
with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"):
ttnn.clamp(input_tensor1, min=min, max=max)
ttnn.clamp(input_tensor1, min, max)
assert True
else:
output_tensor = ttnn.clamp(input_tensor1, min=min, max=max)
output_tensor = ttnn.clamp(input_tensor1, min, max)
golden_function = ttnn.get_golden_function(ttnn.clamp)
golden_tensor = golden_function(in_data1, min=min, max=max)
golden_tensor = golden_function(in_data1, min, max)
comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass

Expand Down Expand Up @@ -149,12 +149,12 @@ def test_unary_composite_clip_ttnn(input_shapes, min, max, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
if min is None and max is None:
with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"):
ttnn.clip(input_tensor1, min=min, max=max)
ttnn.clip(input_tensor1, min, max)
assert True
else:
output_tensor = ttnn.clip(input_tensor1, min=min, max=max)
output_tensor = ttnn.clip(input_tensor1, min, max)
golden_function = ttnn.get_golden_function(ttnn.clip)
golden_tensor = golden_function(in_data1, min=min, max=max)
golden_tensor = golden_function(in_data1, min, max)
comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ void bind_unary_composite_optional_floats_with_default(py::module& module, const
return self(input_tensor, parameter_a, parameter_b, memory_config);
},
py::arg("input_tensor"),
py::kw_only(),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

Expand Down
4 changes: 2 additions & 2 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _golden_function_polygamma(input_tensor_a, k, *args, **kwargs):
def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs):
import torch

return torch.clamp(input=input_tensor_a, min=min, max=max)
return torch.clamp(input_tensor_a, min, max)


ttnn.attach_golden_function(ttnn.clamp, golden_function=_golden_function_clamp)
Expand All @@ -298,7 +298,7 @@ def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs):
def _golden_function_clip(input_tensor_a, min=None, max=None, *args, **kwargs):
import torch

return torch.clip(input=input_tensor_a, min=min, max=max)
return torch.clip(input_tensor_a, min, max)


ttnn.attach_golden_function(ttnn.clip, golden_function=_golden_function_clip)
Expand Down

0 comments on commit df85fd3

Please sign in to comment.