diff --git a/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py b/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py index 134825bdf04..b2940fcf817 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py @@ -14,11 +14,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) - # Parameters provided to the test vector generator are defined here. # They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. @@ -45,12 +40,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: return False, None -def torch_prelu(x, *args, **kwargs): - weight = kwargs.pop("scalar") - result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype)) - return result - - # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. @@ -65,14 +54,14 @@ def run( *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) torch_input_tensor_a = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype )(input_shape) - torch_output_tensor = torch_prelu(torch_input_tensor_a, scalar=weight) + golden_function = ttnn.get_golden_function(ttnn.prelu) + torch_output_tensor = golden_function(torch_input_tensor_a, weight) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py index ea9029a7cf6..1db66f53ced 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py @@ -871,6 +871,7 @@ def test_run_eltwise_leaky_relu_op( ) @pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5]) + @skip_for_grayskull() def test_run_eltwise_prelu( self, input_shapes, diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py index 57f8ecf4284..4407dd2a306 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py @@ -307,6 +307,7 @@ def test_scalarB_leaky_relu(device, h, w, scalar): run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu) +@skip_for_grayskull() @pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5]) @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 037a6adee65..497c2d194a6 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -9,9 +9,11 @@ from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, data_gen_with_range_int, + data_gen_with_val, compare_pcc, compare_equal, ) +from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import is_grayskull, skip_for_grayskull @@ -959,3 +961,62 @@ def test_binary_lcm_ttnn(input_shapes, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), + ), +) +def test_binary_prelu_ttnn(input_shapes, device): + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 + channels = input_shapes[1] + in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100 + + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.prelu(input_tensor1, input_tensor2) + output_tensor = ttnn.to_torch(output_tensor) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, in_data2) + + assert_with_pcc(golden_tensor, output_tensor, 0.999) + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), + ), +) +@pytest.mark.parametrize( + "scalar", + {-0.25, -2.7, 0.45, 6.4}, +) +@skip_for_grayskull() +def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.prelu(input_tensor1, scalar) + output_tensor = ttnn.to_torch(output_tensor) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, scalar) + + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h index 5324157d4a4..91a4c684384 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h @@ -25,3 +25,4 @@ #include "llk_math_eltwise_unary_sfpu_trigonometry.h" #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_fill.h" +#include "llk_math_eltwise_unary_sfpu_prelu.h" diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h new file mode 100644 index 00000000000..dc6e8e1e727 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" +#include "ckernel_sfpu_converter.h" + + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_fill(const uint value) { + + // SFPU microcode + Converter c_value; + c_value.u = value; + vFloat fill_val = c_value.f; + + #pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) + { + vFloat a = dst_reg[0]; + v_if(a < 0.0f) { + a = a * init; + } + v_endif; + dst_reg[0] = a; + dst_reg++; + } +} +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h new file mode 100644 index 00000000000..8bedfdbca2f --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_prelu.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_prelu_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_prelu, + dst_index, + vector_mode, + param0); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h index f55b01c24ab..8a7616784ed 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h @@ -87,5 +87,6 @@ enum SfpuType { ceil, unused, cumsum, - fill + fill, + prelu, }; diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index f337c6a0c6d..78104b0ce32 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -36,3 +36,4 @@ #include "llk_math_eltwise_unary_sfpu_right_shift.h" #include "llk_math_eltwise_unary_sfpu_left_shift.h" #include "llk_math_eltwise_unary_sfpu_fill.h" +#include "llk_math_eltwise_unary_sfpu_prelu.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h new file mode 100644 index 00000000000..22c28881261 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" +#include "ckernel_sfpu_converter.h" + + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + + +template +inline void calculate_prelu(uint value) { + // SFPU microcode + Converter c_value; + c_value.u = value; + vFloat init = c_value.f; + + #pragma GCC unroll 8 + for (int d = 0; d < ITERATIONS; d++) + { + vFloat a = dst_reg[0]; + v_if(a < 0.0f) { + a = a * init; + } + v_endif; + dst_reg[0] = a; + dst_reg++; + } +} +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h new file mode 100644 index 00000000000..8bedfdbca2f --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_prelu.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_prelu_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_prelu, + dst_index, + vector_mode, + param0); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index e20018b4102..7b3789c3743 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -14,6 +14,7 @@ enum SfpuType { reciprocal, sqrt, lrelu, + prelu, power, square, tanh_derivative, diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h b/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h new file mode 100644 index 00000000000..f691dfd8ad6 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_prelu.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise prelu operation. The value to be prelued in the tile is provided as const param0. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True | + * | param0 | The value the output is if the input is greater than 0 | uint32_t | | True | + */ +ALWI void prelu_tile(uint32_t idst, uint32_t param0) { + MATH((llk_math_eltwise_unary_sfpu_prelu(idst, param0))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void prelu_tile_init() { MATH((llk_math_eltwise_unary_sfpu_prelu_init())); } + + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index 0f087ba9bc0..204a1559546 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -116,6 +116,10 @@ #include "compute_kernel_api/eltwise_unary/softplus.h" #endif +#if SFPU_OP_PRELU_INCLUDE +#include "compute_kernel_api/eltwise_unary/prelu.h" +#endif + #if SFPU_OP_DROPOUT_INCLUDE #include "compute_kernel_api/eltwise_unary/dropout.h" #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index db1aaa16993..77d2928b427 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -243,6 +243,20 @@ struct ExecuteMinimum }; +struct ExecutePrelu +{ + static Tensor invoke( + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& memory_config = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + float scalar, + const std::optional& memory_config = std::nullopt); + +}; + } // namespace binary } // namespace operations @@ -306,5 +320,8 @@ constexpr auto gcd = ttnn::register_operation_with_auto_launch_op< constexpr auto lcm = ttnn::register_operation_with_auto_launch_op< "ttnn::lcm", operations::binary::ExecuteLCM>(); +constexpr auto prelu = ttnn::register_operation_with_auto_launch_op< + "ttnn::prelu", + operations::binary::ExecutePrelu>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 5033425fd38..9ca6a02f9a3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -186,7 +186,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati } template -void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLAOT16", const std::string& note="") { +void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLOAT16", const std::string& note="") { auto doc = fmt::format( R"doc( {2} @@ -358,7 +358,7 @@ void bind_binary_composite_with_rtol_atol(py::module& module, const binary_opera } template -void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description) { +void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& supported_dtype="BFLOAT16", const std::string& supported_rank= "2, 3, 4") { auto doc = fmt::format( R"doc( {2} @@ -373,6 +373,19 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t Keyword Args: memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + Note: + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - {4} + Returns: ttnn.Tensor: the output tensor. @@ -384,7 +397,9 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t )doc", operation.base_name(), operation.python_fully_qualified_name(), - description); + description, + supported_dtype, + supported_rank); bind_registered_operation( module, @@ -1044,6 +1059,13 @@ void py_module(py::module& module) { ttnn::maximum, R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + detail::bind_binary_composite_overload( + module, + ttnn::prelu, + R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc", + R"doc(2, 3, 4, 5)doc"); + detail::bind_binary_composite( module, ttnn::scatter, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 3821e67304f..e46b86fa072 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -264,6 +264,24 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result); } +Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional& output_mem_config) { + return ttnn::prelu_sfpu(input, scalar); +} + +Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { + const auto s_a = input_a.get_shape(); + const auto volume = input_b.get_logical_volume(); + + TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.", volume, s_a[1]); + Tensor b = input_b; + if(s_a.rank()>2){ + SmallVector reshape(s_a.rank(), 1); + reshape[1] = s_a[1]; + b = ttnn::reshape(input_b, ttnn::Shape(reshape)); + } + Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); + return result; +} // Binary remainder will be overloaded by unary remainder in another PR Tensor ExecuteBinaryRemainder::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { auto arch = input_a.device()->arch(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index c133a83c0a8..6bc52c98f92 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -84,7 +84,8 @@ enum class UnaryOpType { REMAINDER, FMOD, DROPOUT, - FILL + FILL, + PRELU_SFPU, }; struct UnaryWithParam { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index be612eac166..bed48d6a143 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -62,6 +62,7 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( Converter::to_hex(param1))}; break; } + case UnaryOpType::PRELU_SFPU: { + op_init_and_name = { + "prelu_tile_init();", + fmt::format( + "prelu_tile({}, {}u);", + idst, + Converter::to_hex(param0))}; + break; + } case UnaryOpType::TYPECAST: TT_ASSERT(params.size() == 2, "Expected eltwise_typecast to take 2 parameters"); op_init_and_name = { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp index 9f27ebefac2..6abda5c7178 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp @@ -60,6 +60,7 @@ bool is_parametrized_type(T val) { case UnaryOpType::REMAINDER: case UnaryOpType::DROPOUT: case UnaryOpType::FILL: + case UnaryOpType::PRELU_SFPU: case UnaryOpType::FMOD: return true; default: return false; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index e68ec9535d6..20c80ec2217 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -281,6 +281,33 @@ Tensor Softplus::invoke( optional_output_tensor); } +Tensor Prelu::invoke( + uint8_t queue_id, + const Tensor& input, + float value, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + return detail::unary_impl( + queue_id, + input, + {UnaryWithParam{UnaryOpType::PRELU_SFPU, value}}, + memory_config, + optional_output_tensor); +} + +Tensor Prelu::invoke( + const Tensor& input, + float value, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + return detail::unary_impl( + DefaultQueueId, + input, + {UnaryWithParam{UnaryOpType::PRELU_SFPU, value}}, + memory_config, + optional_output_tensor); +} + Tensor Identity::invoke( uint8_t queue_id, const Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index b75e30ab478..13d40a05628 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -122,6 +122,21 @@ struct Softplus { const std::optional& optional_output_tensor = std::nullopt); }; +struct Prelu { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input, + float value, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input, + float value, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; + struct Identity { static Tensor invoke( uint8_t queue_id, @@ -340,6 +355,8 @@ REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_xor, BITWISE_XOR, int32_ constexpr auto dropout = ttnn::register_operation_with_auto_launch_op<"ttnn::dropout", ttnn::operations::unary::Dropout>(); constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>(); constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>(); +constexpr auto prelu_sfpu = ttnn::register_operation_with_auto_launch_op<"ttnn::prelu_sfpu", ttnn::operations::unary::Prelu>(); + constexpr auto sigmoid_accurate = ttnn::register_operation_with_auto_launch_op<"ttnn::sigmoid_accurate", ttnn::operations::unary::Sigmoid_accurate>(); constexpr auto unary_chain = ttnn::register_operation_with_auto_launch_op<"ttnn::unary_chain", ttnn::operations::unary::Unary_chain>(); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba363c3971..c5c6f1a63f3 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -243,10 +243,6 @@ def auto_register_ttnn_cpp_operations(module): mul_ = ttnn.multiply_ -def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReLU properly - return ttnn.leaky_relu(*args, **kwargs) - - # TODO: pybind the overloaded operators below ttnn.Tensor.__add__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs) ttnn.Tensor.__radd__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs) diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index ce8b02488f0..cc60e3a7a63 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -455,4 +455,16 @@ def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs): ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm) +def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): + import torch + + if not torch.is_tensor(input_tensor_b): + input_tensor_b = torch.tensor(input_tensor_b, dtype=input_tensor_a.dtype) + + return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) + + +ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) + + __all__ = [] diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 1f34e1fec2a..5a73f32d59f 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -66,7 +66,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): "gelu": torch.nn.functional.gelu, "rsqrt": torch.rsqrt, # Unaries with float parameter - # "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) "softplus": torch.nn.functional.softplus, "sigmoid_accurate": torch.sigmoid, @@ -148,7 +147,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn.gelu, ttnn.rsqrt, # Unaries with float parameter - # ttnn.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Unaries using op_chain ttnn.log_sigmoid, ttnn.softplus,