From 453df35932ce2079194a1e1051dc5ca5199333c1 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Thu, 9 May 2024 20:43:21 +0000 Subject: [PATCH] #0: Add fused softplus op --- .../pytests/tt_dnn/test_eltwise_unary.py | 33 +++++++++++++ .../op_library/composite/composite_ops.cpp | 17 ------- .../op_library/composite/composite_ops.hpp | 4 -- .../eltwise_unary/eltwise_unary_op.cpp | 13 +++++ .../eltwise_unary/eltwise_unary_op.hpp | 13 +++++ .../llk_api/llk_sfpu/ckernel_sfpu_softplus.h | 49 +++++++++++++++++++ .../llk_math_eltwise_unary_sfpu_softplus.h | 29 +++++++++++ .../llk_api/llk_sfpu/ckernel_sfpu_softplus.h | 49 +++++++++++++++++++ .../llk_math_eltwise_unary_sfpu_softplus.h | 29 +++++++++++ .../metal/llk_api/llk_sfpu_types.h | 1 + .../eltwise_unary/sfpu_split_includes.h | 4 ++ .../eltwise_unary/softplus.h | 35 +++++++++++++ 12 files changed, 255 insertions(+), 21 deletions(-) create mode 100644 tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h create mode 100644 tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py index 5a91864d38a..6acb742bbcb 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py @@ -1100,3 +1100,36 @@ def test_run_eltwise_unary_comp( device, test_args, ) + + @pytest.mark.parametrize("beta", [1.0]) + @pytest.mark.parametrize("threshold", [20.0]) + def test_run_eltwise_softplus( + self, + input_shapes, + beta, + threshold, + device, + function_level_defaults, + input_mem_config, + output_mem_config, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update({"beta": beta, "threshold": threshold}) + test_args.update( + { + "input_mem_config": [input_mem_config], + "output_mem_config": output_mem_config, + } + ) + comparison_func = comparison_funcs.comp_equal + run_single_pytorch_test( + "eltwise-softplus", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index d9b8cec9818..82f49f38292 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -110,23 +110,6 @@ Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log1p)(x, output_mem_config); } -// softplus[x] =(1/beta) * log[1 + exp[x * beta]] -// (x*beta) > threshold ==> x -// use transformation y = log[1+exp[x]] by broadcast -Tensor _softplus(const Tensor& x, float beta, float threshold, const MemoryConfig& output_mem_config) { - float oned_beta = (1 / beta); - Tensor x_beta = mul_unary(x, beta, output_mem_config); - Tensor exp_x = exp(x_beta, output_mem_config); - Tensor result_log1p = log1p(exp_x, output_mem_config); - Tensor sp_result = mul_unary(result_log1p, oned_beta, output_mem_config); - sp_result = where(gt(x_beta, full_like(x, threshold, output_mem_config), std::nullopt, output_mem_config), x, - where(eqz(full_like(x, beta, output_mem_config), output_mem_config), std::numeric_limits::infinity(), sp_result), output_mem_config); - return sp_result; -} -Tensor softplus(const Tensor& a, float beta, float threshold, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _softplus)(a, beta, threshold, output_mem_config); -} - // tanhshrink(x) = x - tanh(x) Tensor _tanhshrink(const Tensor& x, const MemoryConfig& output_mem_config) { Tensor tan_x = tanh(x, output_mem_config); diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index b71219b61b6..6aa8f26e702 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -81,10 +81,6 @@ Tensor mac( // use transformation y = log(1.0 + x) by broadcast Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// softplus[x] = log[1 + exp[x]] -// use transformation y = log[1+exp[x]] by broadcast -Tensor softplus(const Tensor& x, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - // mish[x] = x*tanh[softplus[x]] // use transformation y = x*tanh[softplus[x]] by broadcast // Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/ diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 5fc6c892ae7..d23bed2d148 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -71,6 +71,9 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( case UnaryOpType::UNARY_NE: op_init_and_name = {"unary_ne_tile_init();", fmt::format("unary_ne_tile({}, {}u);", idst, Converter::to_hex(param0))}; break; case UnaryOpType::UNARY_GT: op_init_and_name = {"unary_gt_tile_init();", fmt::format("unary_gt_tile({}, {}u);", idst, Converter::to_hex(param0))}; break; case UnaryOpType::UNARY_LT: op_init_and_name = {"unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))}; break; + case UnaryOpType::SOFTPLUS: { + TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters"); + float param1 = params[1]; + op_init_and_name = {"softplus_tile_init();", fmt::format("softplus_tile({}, {}u, {}u, {}u);", idst, + Converter::to_hex(param0), + Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device + Converter::to_hex(param1) + )}; + break; + } default: TT_ASSERT( false && "unexpected parameterized type"); }; diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index f0ff9b15d30..8de45737686 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -66,6 +66,7 @@ enum class UnaryOpType { RSUB, RDIV, SILU, + SOFTPLUS, IDENTITY, NEG, ADD_UNARY, @@ -99,6 +100,7 @@ bool is_parametrized_type(T val) { case UnaryOpType::RSUB: case UnaryOpType::RDIV: case UnaryOpType::EXP: + case UnaryOpType::SOFTPLUS: case UnaryOpType::ADD_UNARY: case UnaryOpType::SUB_UNARY: case UnaryOpType::MUL_UNARY: @@ -162,6 +164,8 @@ inline UnaryWithParam string_to_unary_with_param(const std::string& name) { return UnaryWithParam(UnaryOpType::SIGN); else if (name == "square") return UnaryWithParam(UnaryOpType::SQUARE); + else if (name == "softplus") + return UnaryWithParam(UnaryOpType::SOFTPLUS); TT_THROW("Unknown unary op: " + name); } @@ -433,6 +437,15 @@ inline Tensor sigmoid_accurate( output_mem_config); } +inline Tensor softplus( + const Tensor& input_tensor, + float beta, + float threshold, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { + return run_eltwise_unary( + input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config); +} + inline Tensor unary_chain( const Tensor& input_tensor, std::vector ops_chain, diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h new file mode 100644 index 00000000000..3c544f5b684 --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "ckernel_sfpu_converter.h" +#include "ckernel_sfpu_exp.h" +#include "ckernel_sfpu_log.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) { + vFloat a = dst_reg[0]; + v_if(a < threshold) { + exp_init(); + a = calculate_exponential_body(a * beta) + 1.00f; + + log_init(); + dst_reg[0] = a; + calculate_log_body(0); + a = beta_reciprocal * dst_reg[0]; + } + v_endif; + dst_reg[0] = a; +} + +template +inline void calculate_softplus(uint param0, uint param1, uint param2) { + vFloat beta = Converter::to_float(param0); + vFloat beta_reciprocal = Converter::to_float(param1); + vFloat threshold = Converter::to_float(param2); + for (int d = 0; d < ITERATIONS; d++) { + calculate_softplus_body(beta, beta_reciprocal, threshold); + dst_reg++; + } +} + +template +void softplus_init() {} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h new file mode 100644 index 00000000000..67baa99f5cc --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_softplus.h" +#include "llk_math_eltwise_unary_sfpu_3_param.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +template +inline void llk_math_eltwise_unary_sfpu_softplus_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_softplus( + uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_3_param( + ckernel::sfpu::calculate_softplus, + ckernel::sfpu::calculate_softplus, + dst_index, + vector_mode, + param0, param1, param2); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h new file mode 100644 index 00000000000..4d524456f7f --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_softplus.h @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "ckernel_sfpu_converter.h" +#include "ckernel_sfpu_exp.h" +#include "ckernel_sfpu_log.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) { + vFloat a = dst_reg[0]; + v_if(a < threshold) { + exp_init(); + a = calculate_exponential_body(a * beta) + 1.00f; + + log_init(); + dst_reg[0] = a; + calculate_log_body(0); + a = beta_reciprocal * dst_reg[0]; + } + v_endif; + dst_reg[0] = a; +} + +template +inline void calculate_softplus(uint param0, uint param1, uint param2) { + vFloat beta = Converter::to_float(param0); + vFloat beta_reciprocal = Converter::to_float(param1); + vFloat threshold = Converter::to_float(param2); + for (int d = 0; d < ITERATIONS; d++) { + calculate_softplus_body(beta, beta_reciprocal, threshold); + dst_reg++; + } +} + +template +void softplus_init() {} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h new file mode 100644 index 00000000000..6d01ffc9924 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_softplus.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_softplus.h" +#include "llk_math_eltwise_unary_sfpu_3_param.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +template +inline void llk_math_eltwise_unary_sfpu_softplus_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_softplus( + uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_3_param( + ckernel::sfpu::calculate_softplus, + ckernel::sfpu::calculate_softplus, + dst_index, + vector_mode, + param0, param1, param2); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index a8cc39cea63..6e3051cdab6 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -73,6 +73,7 @@ enum SfpuType { unary_ne, unary_gt, unary_lt, + softplus, tiled_prod, unused, }; diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index ea67ef1480c..9c9f9ec41d7 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -68,6 +68,10 @@ #include "compute_kernel_api/eltwise_unary/binop_with_scalar.h" #endif +#if SFPU_OP_SOFTPLUS_INCLUDE +#include "compute_kernel_api/eltwise_unary/softplus.h" +#endif + #if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE #include "compute_kernel_api.h" #endif diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h b/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h new file mode 100644 index 00000000000..52c2ebecfa2 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_softplus.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * TODO! + */ +ALWI void softplus_tile(uint32_t idst, uint32_t param0, uint32_t param1, uint32_t param2) { + MATH(( llk_math_eltwise_unary_sfpu_softplus(idst, param0, param1, param2) )); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void softplus_tile_init() { + MATH(( llk_math_eltwise_unary_sfpu_softplus_init() )); +} + +} // namespace ckernel