diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index b540073527e..96050f92511 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -981,3 +981,26 @@ def test_binary_prelu_ttnn(input_shapes, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 3, 32, 32])), + (torch.Size([1, 6, 32, 32])), + (torch.Size([1, 7, 320, 384])), + (torch.Size([1, 4, 320, 384])), + ), +) +@pytest.mark.parametrize( + "scalar", + {-0.25, -2.7, 0.45, 6.4}, +) +def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): + in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + output_tensor = ttnn.prelu(input_tensor1, scalar) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, scalar) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index 559fd2e445f..77d2928b427 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -249,6 +249,12 @@ struct ExecutePrelu const Tensor& input_tensor_a, const Tensor& input_tensor_b, const std::optional& memory_config = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + float scalar, + const std::optional& memory_config = std::nullopt); + }; } // namespace binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 8ebaf601a71..8a42c230115 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -590,7 +590,6 @@ void bind_binary_overload_operation(py::module& module, const binary_operation_t operation, doc, - //tensor and scalar ttnn::pybind_overload_t{ [](const binary_operation_t& self, @@ -935,14 +934,6 @@ void py_module(py::module& module) { R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)} )doc"); - detail::bind_binary_composite( - module, - ttnn::prelu, - R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b . - PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", - R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc", - R"doc(BFLOAT16, BFLOAT8_B)doc"); - detail::bind_binary_composite( module, ttnn::xlogy, @@ -1053,6 +1044,11 @@ void py_module(py::module& module) { ttnn::maximum, R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + detail::bind_binary_composite_overload( + module, + ttnn::prelu, + R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc"); + detail::bind_binary_composite( module, ttnn::scatter, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 750ef0e8060..5933cc63db9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -264,10 +264,14 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result); } +Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional& output_mem_config) { + return ttnn::prelu_sfpu(input, scalar); +} + Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape(); auto volume = input_b.get_logical_volume(); - // If volume = 1 Support for a single-value tensor yet to be handled. #14933 + // If volume = 1 Support for a single-value tensor yet to be handled. TODO(#14933) TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size"); Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array{1, s_a[1], 1, 1}}); Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index ce8b02488f0..cc60e3a7a63 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -455,4 +455,16 @@ def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs): ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm) +def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): + import torch + + if not torch.is_tensor(input_tensor_b): + input_tensor_b = torch.tensor(input_tensor_b, dtype=input_tensor_a.dtype) + + return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) + + +ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) + + __all__ = [] diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 032b7d9308d..5a73f32d59f 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -237,15 +237,6 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs): ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu) -def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): - import torch - - return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) - - -ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) - - def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs): import torch