From 23e02f3d56ffb05d6bb9c1a62cdbe588204d6867 Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Mon, 11 Nov 2024 09:43:54 +0000 Subject: [PATCH] Add forward support for prelu --- docs/source/ttnn/ttnn/api.rst | 1 + .../eltwise/test_binary_composite.py | 22 ++++++++++ .../metal/llk_api/llk_math_unary_sfpu_api.h | 1 + .../llk_api/llk_sfpu/ckernel_sfpu_prelu.h | 38 ++++++++++++++++ .../llk_math_eltwise_unary_sfpu_prelu.h | 29 +++++++++++++ .../metal/llk_api/llk_sfpu_types.h | 1 + .../compute_kernel_api/eltwise_unary/prelu.h | 43 +++++++++++++++++++ .../eltwise_unary/sfpu_split_includes.h | 4 ++ .../eltwise/binary/binary_composite.hpp | 11 +++++ .../eltwise/binary/binary_pybind.hpp | 11 ++++- .../binary/device/binary_composite_op.cpp | 9 ++++ .../eltwise/unary/common/unary_op_types.hpp | 3 +- .../eltwise/unary/common/unary_op_utils.cpp | 10 +++++ .../eltwise/unary/common/unary_op_utils.hpp | 1 + .../ttnn/operations/eltwise/unary/unary.cpp | 27 ++++++++++++ .../ttnn/operations/eltwise/unary/unary.hpp | 17 ++++++++ ttnn/ttnn/__init__.py | 4 -- ttnn/ttnn/operations/unary.py | 11 ++++- 18 files changed, 235 insertions(+), 8 deletions(-) create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 420d901d55f..6912abc3a63 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -141,6 +141,7 @@ Pointwise Unary ttnn.isneginf ttnn.isposinf ttnn.leaky_relu + ttnn.prelu ttnn.lerp ttnn.lgamma ttnn.log diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 037a6adee65..b540073527e 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -959,3 +959,25 @@ def test_binary_lcm_ttnn(input_shapes, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 3, 32, 32])), + (torch.Size([1, 6, 32, 32])), + (torch.Size([1, 7, 320, 384])), + (torch.Size([1, 4, 320, 384])), + ), +) +def test_binary_prelu_ttnn(input_shapes, device): + in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + channels = input_shapes[1] + in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100 + input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.prelu(input_tensor1, input_tensor2) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, in_data2) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index f337c6a0c6d..78104b0ce32 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -36,3 +36,4 @@ #include "llk_math_eltwise_unary_sfpu_right_shift.h" #include "llk_math_eltwise_unary_sfpu_left_shift.h" #include "llk_math_eltwise_unary_sfpu_fill.h" +#include "llk_math_eltwise_unary_sfpu_prelu.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h new file mode 100644 index 00000000000..a50eec93d28 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_prelu.h @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" +#include "ckernel_sfpu_converter.h" + + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + + +template +inline void calculate_prelu(uint value) { + // SFPU microcode + Converter c_value; + c_value.u = value; + vFloat init = c_value.f; + + for (int d = 0; d < 8; d++) + { + vFloat a = dst_reg[0]; + v_if(a < 0.0f) { + a = a * init; + } + v_endif; + dst_reg[0] = a; + dst_reg++; + } +} +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h new file mode 100644 index 00000000000..8bedfdbca2f --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_prelu.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_prelu.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_prelu_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_prelu, + dst_index, + vector_mode, + param0); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index e20018b4102..7b3789c3743 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -14,6 +14,7 @@ enum SfpuType { reciprocal, sqrt, lrelu, + prelu, power, square, tanh_derivative, diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h b/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h new file mode 100644 index 00000000000..f691dfd8ad6 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_prelu.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise prelu operation. The value to be prelued in the tile is provided as const param0. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True | + * | param0 | The value the output is if the input is greater than 0 | uint32_t | | True | + */ +ALWI void prelu_tile(uint32_t idst, uint32_t param0) { + MATH((llk_math_eltwise_unary_sfpu_prelu(idst, param0))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void prelu_tile_init() { MATH((llk_math_eltwise_unary_sfpu_prelu_init())); } + + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index 0f087ba9bc0..204a1559546 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -116,6 +116,10 @@ #include "compute_kernel_api/eltwise_unary/softplus.h" #endif +#if SFPU_OP_PRELU_INCLUDE +#include "compute_kernel_api/eltwise_unary/prelu.h" +#endif + #if SFPU_OP_DROPOUT_INCLUDE #include "compute_kernel_api/eltwise_unary/dropout.h" #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index db1aaa16993..559fd2e445f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -243,6 +243,14 @@ struct ExecuteMinimum }; +struct ExecutePrelu +{ + static Tensor invoke( + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& memory_config = std::nullopt); +}; + } // namespace binary } // namespace operations @@ -306,5 +314,8 @@ constexpr auto gcd = ttnn::register_operation_with_auto_launch_op< constexpr auto lcm = ttnn::register_operation_with_auto_launch_op< "ttnn::lcm", operations::binary::ExecuteLCM>(); +constexpr auto prelu = ttnn::register_operation_with_auto_launch_op< + "ttnn::prelu", + operations::binary::ExecutePrelu>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 5033425fd38..8ebaf601a71 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -186,7 +186,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati } template -void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLAOT16", const std::string& note="") { +void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLOAT16", const std::string& note="") { auto doc = fmt::format( R"doc( {2} @@ -590,6 +590,7 @@ void bind_binary_overload_operation(py::module& module, const binary_operation_t operation, doc, + //tensor and scalar ttnn::pybind_overload_t{ [](const binary_operation_t& self, @@ -934,6 +935,14 @@ void py_module(py::module& module) { R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)} )doc"); + detail::bind_binary_composite( + module, + ttnn::prelu, + R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b . + PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", + R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); + detail::bind_binary_composite( module, ttnn::xlogy, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 3821e67304f..750ef0e8060 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -264,6 +264,15 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result); } +Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { + const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape(); + auto volume = input_b.get_logical_volume(); + // If volume = 1 Support for a single-value tensor yet to be handled. #14933 + TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size"); + Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array{1, s_a[1], 1, 1}}); + Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); + return result; +} // Binary remainder will be overloaded by unary remainder in another PR Tensor ExecuteBinaryRemainder::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { auto arch = input_a.device()->arch(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index c133a83c0a8..6bc52c98f92 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -84,7 +84,8 @@ enum class UnaryOpType { REMAINDER, FMOD, DROPOUT, - FILL + FILL, + PRELU_SFPU, }; struct UnaryWithParam { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index be612eac166..bed48d6a143 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -62,6 +62,7 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( Converter::to_hex(param1))}; break; } + case UnaryOpType::PRELU_SFPU: { + op_init_and_name = { + "prelu_tile_init();", + fmt::format( + "prelu_tile({}, {}u);", + idst, + Converter::to_hex(param0))}; + break; + } case UnaryOpType::TYPECAST: TT_ASSERT(params.size() == 2, "Expected eltwise_typecast to take 2 parameters"); op_init_and_name = { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp index 9f27ebefac2..6abda5c7178 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp @@ -60,6 +60,7 @@ bool is_parametrized_type(T val) { case UnaryOpType::REMAINDER: case UnaryOpType::DROPOUT: case UnaryOpType::FILL: + case UnaryOpType::PRELU_SFPU: case UnaryOpType::FMOD: return true; default: return false; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index f661b1cfedd..9ec79d3e12b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -280,6 +280,33 @@ Tensor Softplus::invoke( optional_output_tensor); } +Tensor Prelu::invoke( + uint8_t queue_id, + const Tensor& input, + float value, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + return detail::unary_impl( + queue_id, + input, + {UnaryWithParam{UnaryOpType::PRELU_SFPU, value}}, + memory_config, + optional_output_tensor); +} + +Tensor Prelu::invoke( + const Tensor& input, + float value, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + return detail::unary_impl( + DefaultQueueId, + input, + {UnaryWithParam{UnaryOpType::PRELU_SFPU, value}}, + memory_config, + optional_output_tensor); +} + Tensor Identity::invoke( uint8_t queue_id, const Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index b75e30ab478..13d40a05628 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -122,6 +122,21 @@ struct Softplus { const std::optional& optional_output_tensor = std::nullopt); }; +struct Prelu { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input, + float value, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input, + float value, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; + struct Identity { static Tensor invoke( uint8_t queue_id, @@ -340,6 +355,8 @@ REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_xor, BITWISE_XOR, int32_ constexpr auto dropout = ttnn::register_operation_with_auto_launch_op<"ttnn::dropout", ttnn::operations::unary::Dropout>(); constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>(); constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>(); +constexpr auto prelu_sfpu = ttnn::register_operation_with_auto_launch_op<"ttnn::prelu_sfpu", ttnn::operations::unary::Prelu>(); + constexpr auto sigmoid_accurate = ttnn::register_operation_with_auto_launch_op<"ttnn::sigmoid_accurate", ttnn::operations::unary::Sigmoid_accurate>(); constexpr auto unary_chain = ttnn::register_operation_with_auto_launch_op<"ttnn::unary_chain", ttnn::operations::unary::Unary_chain>(); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba363c3971..c5c6f1a63f3 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -243,10 +243,6 @@ def auto_register_ttnn_cpp_operations(module): mul_ = ttnn.multiply_ -def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReLU properly - return ttnn.leaky_relu(*args, **kwargs) - - # TODO: pybind the overloaded operators below ttnn.Tensor.__add__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs) ttnn.Tensor.__radd__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs) diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 1f34e1fec2a..032b7d9308d 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -66,7 +66,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): "gelu": torch.nn.functional.gelu, "rsqrt": torch.rsqrt, # Unaries with float parameter - # "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) "softplus": torch.nn.functional.softplus, "sigmoid_accurate": torch.sigmoid, @@ -148,7 +147,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn.gelu, ttnn.rsqrt, # Unaries with float parameter - # ttnn.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Unaries using op_chain ttnn.log_sigmoid, ttnn.softplus, @@ -239,6 +237,15 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs): ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu) +def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs): + import torch + + return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b) + + +ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu) + + def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs): import torch