From 7cbf6e8e9da6a5553590def093a17499fc719aa4 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Thu, 9 May 2024 20:43:21 +0000 Subject: [PATCH] #6938: Implement softplus as a single kernel --- .../pytests/tt_dnn/test_composite.py | 14 ++---- .../pytests/tt_dnn/test_eltwise_unary.py | 36 ++++++++++++- .../unit_tests/operations/test_activation.py | 3 ++ .../op_library/composite/composite_ops.cpp | 17 ------- .../op_library/composite/composite_ops.hpp | 4 -- .../eltwise_unary/eltwise_unary_op.cpp | 18 ++++++- .../eltwise_unary/eltwise_unary_op.hpp | 14 ++++++ .../llk_api/llk_sfpu/ckernel_sfpu_softplus.h | 50 +++++++++++++++++++ .../llk_math_eltwise_unary_sfpu_softplus.h | 29 +++++++++++ .../metal/llk_api/llk_sfpu_types.h | 1 + .../eltwise_unary/sfpu_split_includes.h | 4 ++ .../eltwise_unary/softplus.h | 47 +++++++++++++++++ ttnn/cpp/ttnn/operations/unary.hpp | 17 ++----- 13 files changed, 208 insertions(+), 46 deletions(-) create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py index 5ed3ae29279..292c1ce69bc 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py @@ -19,13 +19,12 @@ from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( run_single_pytorch_test, ) -from models.utility_functions import is_wormhole_b0 +from models.utility_functions import is_wormhole_b0, is_grayskull reference_pcc = defaultdict(lambda: 0.999) reference_pcc["silu"] = 0.9714 reference_pcc["swish"] = reference_pcc["silu"] -reference_pcc["softplus"] = 0.9984 def custom_compare(*args, **kwargs): @@ -68,7 +67,6 @@ def custom_compare(*args, **kwargs): "max", "swish", "log1p", - "softplus", "mish", "silu", "polyval", @@ -157,6 +155,9 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def if is_wormhole_b0(): if fn in ["logit"]: pytest.skip("does not work for Wormhole -skipping") + if is_grayskull(): + if fn in ["mish"]: + pytest.skip("does not work for Grayskull -skipping") if fn in ["logical_xor", "logical_xori", "logical_ori", "logical_andi"]: datagen_func = [ generation_funcs.gen_func_with_cast( @@ -231,13 +232,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def "equal_nan": random.choice([False, True]), } ) - elif fn in ["softplus"]: - test_args.update( - { - "beta": random.choice([0.5, -3, 1, 4]), - "threshold": random.choice([-20, 10, 20, 5]), - } - ) run_single_pytorch_test( "eltwise-%s" % (fn), input_shapes, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py index 5a91864d38a..9f705c709f3 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py @@ -16,7 +16,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( run_single_pytorch_test, ) -from models.utility_functions import is_wormhole_b0 +from models.utility_functions import is_wormhole_b0, skip_for_grayskull shapes = [ [[1, 1, 32, 32]], # Single core @@ -1100,3 +1100,37 @@ def test_run_eltwise_unary_comp( device, test_args, ) + + @skip_for_grayskull("Softplus kernel not currently availible for GS") + @pytest.mark.parametrize("beta", [1.0, 5.0]) + @pytest.mark.parametrize("threshold", [10.0, 20.0]) + def test_run_eltwise_softplus( + self, + input_shapes, + beta, + threshold, + device, + function_level_defaults, + input_mem_config, + output_mem_config, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update({"beta": beta, "threshold": threshold}) + test_args.update( + { + "input_mem_config": [input_mem_config], + "output_mem_config": output_mem_config, + } + ) + comparison_func = comparison_funcs.comp_pcc + run_single_pytorch_test( + "eltwise-softplus", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/ttnn/unit_tests/operations/test_activation.py b/tests/ttnn/unit_tests/operations/test_activation.py index 779607c4bfa..27b15b4b1fd 100644 --- a/tests/ttnn/unit_tests/operations/test_activation.py +++ b/tests/ttnn/unit_tests/operations/test_activation.py @@ -11,6 +11,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_grayskull def run_activation_unary_test(device, h, w, ttnn_function, torch_function, pcc=0.99): @@ -52,6 +53,7 @@ def test_log_sigmoid(device, h, w): run_activation_unary_test(device, h, w, ttnn.log_sigmoid, F.logsigmoid) +@skip_for_grayskull() @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_mish(device, h, w): @@ -116,6 +118,7 @@ def run_activation_softplus_test(device, h, w, beta, threshold, ttnn_function, t assert_with_pcc(torch_output_tensor, output_tensor, pcc) +@skip_for_grayskull() @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) @pytest.mark.parametrize("beta", [-1, 1, 2, 0.5, 10]) diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 47cd6b74086..18218b8a3fe 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -110,23 +110,6 @@ Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log1p)(x, output_mem_config); } -// softplus[x] =(1/beta) * log[1 + exp[x * beta]] -// (x*beta) > threshold ==> x -// use transformation y = log[1+exp[x]] by broadcast -Tensor _softplus(const Tensor& x, float beta, float threshold, const MemoryConfig& output_mem_config) { - float oned_beta = (1 / beta); - Tensor x_beta = mul_unary(x, beta, output_mem_config); - Tensor exp_x = exp(x_beta, output_mem_config); - Tensor result_log1p = log1p(exp_x, output_mem_config); - Tensor sp_result = mul_unary(result_log1p, oned_beta, output_mem_config); - sp_result = where(gt(x_beta, full_like(x, threshold, output_mem_config), std::nullopt, output_mem_config), x, - where(eqz(full_like(x, beta, output_mem_config), output_mem_config), std::numeric_limits::infinity(), sp_result), output_mem_config); - return sp_result; -} -Tensor softplus(const Tensor& a, float beta, float threshold, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _softplus)(a, beta, threshold, output_mem_config); -} - // tanhshrink(x) = x - tanh(x) Tensor _tanhshrink(const Tensor& x, const MemoryConfig& output_mem_config) { Tensor tan_x = tanh(x, output_mem_config); diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 79520529028..03d56f7cfff 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -81,10 +81,6 @@ Tensor mac( // use transformation y = log(1.0 + x) by broadcast Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// softplus[x] = log[1 + exp[x]] -// use transformation y = log[1+exp[x]] by broadcast -Tensor softplus(const Tensor& x, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - // mish[x] = x*tanh[softplus[x]] // use transformation y = x*tanh[softplus[x]] by broadcast // Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/ diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 4be53be5b00..433e531e917 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -22,7 +22,7 @@ union Converter { float f; uint32_t u; - Converter(float f_) : f(f_) {}; + Converter(float f_) : f(f_){}; static std::string to_hex(float f_) { Converter obj(f_); @@ -67,6 +67,7 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( UnaryOpType op_type, std::vector params, string idst) { std::pair op_init_and_name; - TT_FATAL(is_parametrized_type(op_type) && "operator should support one parameter"); + TT_FATAL(is_parametrized_type(op_type) && "operator should support at least one parameter"); float param0 = params[0]; switch (op_type) { case UnaryOpType::RELU_MAX: @@ -162,6 +163,19 @@ std::pair get_op_init_and_func_parameterized( op_init_and_name = { "unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))}; break; + case UnaryOpType::SOFTPLUS: { + TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters"); + float param1 = params[1]; + op_init_and_name = { + "softplus_tile_init();", + fmt::format( + "softplus_tile({}, {}u, {}u, {}u);", + idst, + Converter::to_hex(param0), + Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device + Converter::to_hex(param1))}; + break; + } default: TT_ASSERT(false && "unexpected parameterized type"); }; return op_init_and_name; diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index cabb085ecac..f891485a5fa 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -66,6 +66,7 @@ enum class UnaryOpType { RSUB, RDIV, SILU, + SOFTPLUS, IDENTITY, NEG, ADD_UNARY_SFPU, @@ -95,6 +96,7 @@ bool is_parametrized_type(T val) { case UnaryOpType::RSUB: case UnaryOpType::RDIV: case UnaryOpType::EXP: + case UnaryOpType::SOFTPLUS: case UnaryOpType::ADD_UNARY_SFPU: case UnaryOpType::SUB_UNARY_SFPU: case UnaryOpType::MUL_UNARY_SFPU: @@ -154,6 +156,8 @@ inline UnaryWithParam string_to_unary_with_param(const std::string& name) { return UnaryWithParam(UnaryOpType::SIGN); else if (name == "square") return UnaryWithParam(UnaryOpType::SQUARE); + else if (name == "softplus") + return UnaryWithParam(UnaryOpType::SOFTPLUS); TT_THROW("Unknown unary op: " + name); } @@ -423,6 +427,16 @@ inline Tensor sigmoid_accurate( output_mem_config); } +inline Tensor softplus( + const Tensor& input_tensor, + float beta, + float threshold, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { + TT_ASSERT(input_tensor.device()->arch() != tt::ARCH::GRAYSKULL, "Softplus is not currently supported on Grayskull"); + return run_eltwise_unary( + input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config); +} + inline Tensor unary_chain( const Tensor& input_tensor, std::vector ops_chain, diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h new file mode 100644 index 00000000000..d9024fdee4b --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "ckernel_sfpu_converter.h" +#include "ckernel_sfpu_exp.h" +#include "ckernel_sfpu_log.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) { + vFloat a = dst_reg[0]; + vFloat a_beta = a * beta; + v_if(a_beta < threshold) { + exp_init(); + a = calculate_exponential_body(a_beta) + 1.0f; + + log_init(); + dst_reg[0] = a; + calculate_log_body(0); + a = beta_reciprocal * dst_reg[0]; + } + v_endif; + dst_reg[0] = a; +} + +template +inline void calculate_softplus(uint param0, uint param1, uint param2) { + vFloat beta = Converter::to_float(param0); + vFloat beta_reciprocal = Converter::to_float(param1); + vFloat threshold = Converter::to_float(param2); + for (int d = 0; d < ITERATIONS; d++) { + calculate_softplus_body(beta, beta_reciprocal, threshold); + dst_reg++; + } +} + +template +void softplus_init() {} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h new file mode 100644 index 00000000000..6d01ffc9924 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_softplus.h" +#include "llk_math_eltwise_unary_sfpu_3_param.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +template +inline void llk_math_eltwise_unary_sfpu_softplus_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_softplus( + uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_3_param( + ckernel::sfpu::calculate_softplus, + ckernel::sfpu::calculate_softplus, + dst_index, + vector_mode, + param0, param1, param2); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index a8cc39cea63..6e3051cdab6 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -73,6 +73,7 @@ enum SfpuType { unary_ne, unary_gt, unary_lt, + softplus, tiled_prod, unused, }; diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index ea67ef1480c..9c9f9ec41d7 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -68,6 +68,10 @@ #include "compute_kernel_api/eltwise_unary/binop_with_scalar.h" #endif +#if SFPU_OP_SOFTPLUS_INCLUDE +#include "compute_kernel_api/eltwise_unary/softplus.h" +#endif + #if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE #include "compute_kernel_api.h" #endif diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h b/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h new file mode 100644 index 00000000000..8a62a5ddfee --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_softplus.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise computation of softplus (`1/beta * log(1 + exp(beta * x))`) on each element + * of a tile in DST register at index tile_index. Any input value greater than the provided threshold + * with return itself. The DST register buffer must be in acquired state via *acquire_dst* call. This + * call is blocking and is only available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | tile_index | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True | + * | beta | Beta used in softplus calculation | uint32_t | Greater than 0 | True | + * | beta_reciprocal | Reciprocal of beta (1/beta) used in softplus calculation | uint32_t | Greater than 0 | True | + * | threshold | Threshold used in softplus calculation | uint32_t | Greater than 0 | True | + */ +ALWI void softplus_tile(uint32_t idst, uint32_t beta, uint32_t beta_reciprocal, uint32_t threshold) { + MATH(( llk_math_eltwise_unary_sfpu_softplus(idst, beta, beta_reciprocal, threshold) )); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void softplus_tile_init() { + MATH(( llk_math_eltwise_unary_sfpu_softplus_init() )); +} + +} // namespace ckernel diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index a5a21b4e539..bc2f561b17a 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -60,7 +60,6 @@ struct ExecuteUnary { static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); } - static Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { return detail::execute_on_worker_thread(input_tensor, {UnaryWithParam{unary_op_types}...}, memory_config); @@ -115,16 +114,10 @@ struct Softplus { const Tensor& input, const float beta, const float threshold, - const std::optional& memory_config_arg = std::nullopt) { - auto original_input_shape = input.get_shape(); - auto input_4D = ttnn::unsqueeze_to_4D(input); - - auto memory_config = memory_config_arg.value_or(input_4D.memory_config()); - auto result = tt::tt_metal::softplus(input_4D, beta, threshold, memory_config); - - result = ttnn::reshape(result, original_input_shape); - - return result; + const std::optional& memory_config = std::nullopt) { + TT_ASSERT(input.device()->arch() != tt::ARCH::GRAYSKULL, "Softplus is not currently supported on Grayskull"); + return detail::execute_on_worker_thread( + input, {UnaryWithParam{ttnn::operations::unary::UnaryOpType::SOFTPLUS, {beta, threshold}}}, memory_config); } }; } // namespace unary @@ -199,7 +192,7 @@ REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(heaviside, HEAVISIDE); REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(leaky_relu, LEAKY_RELU); auto prelu = leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU properly -// Other unaries (composite operations) +// Other unaries constexpr auto softplus = ttnn::register_operation("ttnn::softplus"); } // namespace ttnn