diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index b540073527e..96050f92511 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -981,3 +981,26 @@ def test_binary_prelu_ttnn(input_shapes, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 3, 32, 32])), + (torch.Size([1, 6, 32, 32])), + (torch.Size([1, 7, 320, 384])), + (torch.Size([1, 4, 320, 384])), + ), +) +@pytest.mark.parametrize( + "scalar", + {-0.25, -2.7, 0.45, 6.4}, +) +def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): + in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + output_tensor = ttnn.prelu(input_tensor1, scalar) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, scalar) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index 559fd2e445f..77d2928b427 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -249,6 +249,12 @@ struct ExecutePrelu const Tensor& input_tensor_a, const Tensor& input_tensor_b, const std::optional& memory_config = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + float scalar, + const std::optional& memory_config = std::nullopt); + }; } // namespace binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 8ebaf601a71..19922a39e48 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -935,13 +935,6 @@ void py_module(py::module& module) { R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)} )doc"); - detail::bind_binary_composite( - module, - ttnn::prelu, - R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b . - PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", - R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc", - R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_binary_composite( module, @@ -1053,6 +1046,12 @@ void py_module(py::module& module) { ttnn::maximum, R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + detail::bind_binary_composite_overload( + module, + ttnn::prelu, + R"doc(Perform an eltwise-prelu operation. + PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc"); + detail::bind_binary_composite( module, ttnn::scatter, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 750ef0e8060..74f18d60fda 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -263,6 +263,9 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti Tensor div_result = ttnn::div(input_a, input_b, false, "None", output_mem_config); return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result); } +Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional& output_mem_config) { + return ttnn::prelu_sfpu(input, scalar); +} Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape(); diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index ce8b02488f0..cc60e3a7a63 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -455,4 +455,16 @@ def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs): ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm) +def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): + import torch + + if not torch.is_tensor(input_tensor_b): + input_tensor_b = torch.tensor(input_tensor_b, dtype=input_tensor_a.dtype) + + return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) + + +ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) + + __all__ = [] diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 032b7d9308d..5a73f32d59f 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -237,15 +237,6 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs): ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu) -def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): - import torch - - return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) - - -ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) - - def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs): import torch