Skip to content

Commit

Permalink
#6938: Implement softplus as a single kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed May 15, 2024
1 parent d09fd7f commit 6c76ca4
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1100,3 +1100,36 @@ def test_run_eltwise_unary_comp(
device,
test_args,
)

@pytest.mark.parametrize("beta", [1.0, 5.0])
@pytest.mark.parametrize("threshold", [10.0, 20.0])
def test_run_eltwise_softplus(
self,
input_shapes,
beta,
threshold,
device,
function_level_defaults,
input_mem_config,
output_mem_config,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"beta": beta, "threshold": threshold})
test_args.update(
{
"input_mem_config": [input_mem_config],
"output_mem_config": output_mem_config,
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
"eltwise-softplus",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
17 changes: 0 additions & 17 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,6 @@ Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log1p)(x, output_mem_config);
}

// softplus[x] =(1/beta) * log[1 + exp[x * beta]]
// (x*beta) > threshold ==> x
// use transformation y = log[1+exp[x]] by broadcast
Tensor _softplus(const Tensor& x, float beta, float threshold, const MemoryConfig& output_mem_config) {
float oned_beta = (1 / beta);
Tensor x_beta = mul_unary(x, beta, output_mem_config);
Tensor exp_x = exp(x_beta, output_mem_config);
Tensor result_log1p = log1p(exp_x, output_mem_config);
Tensor sp_result = mul_unary(result_log1p, oned_beta, output_mem_config);
sp_result = where(gt(x_beta, full_like(x, threshold, output_mem_config), std::nullopt, output_mem_config), x,
where(eqz(full_like(x, beta, output_mem_config), output_mem_config), std::numeric_limits<float>::infinity(), sp_result), output_mem_config);
return sp_result;
}
Tensor softplus(const Tensor& a, float beta, float threshold, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _softplus)(a, beta, threshold, output_mem_config);
}

// tanhshrink(x) = x - tanh(x)
Tensor _tanhshrink(const Tensor& x, const MemoryConfig& output_mem_config) {
Tensor tan_x = tanh(x, output_mem_config);
Expand Down
4 changes: 0 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ Tensor mac(
// use transformation y = log(1.0 + x) by broadcast
Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// softplus[x] = log[1 + exp[x]]
// use transformation y = log[1+exp[x]] by broadcast
Tensor softplus(const Tensor& x, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// mish[x] = x*tanh[softplus[x]]
// use transformation y = x*tanh[softplus[x]] by broadcast
// Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/
Expand Down
13 changes: 13 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::NEG:
defines["SFPU_OP_NEG_INCLUDE"] = "1";
break;
case UnaryOpType::SOFTPLUS:
defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1";
break;
default:
defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"]="1";
break;
Expand Down Expand Up @@ -140,6 +143,16 @@ std::pair<string, string> get_op_init_and_func_parameterized(
case UnaryOpType::UNARY_NE: op_init_and_name = {"unary_ne_tile_init();", fmt::format("unary_ne_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
case UnaryOpType::UNARY_GT: op_init_and_name = {"unary_gt_tile_init();", fmt::format("unary_gt_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
case UnaryOpType::UNARY_LT: op_init_and_name = {"unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))}; break;
case UnaryOpType::SOFTPLUS: {
TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters");
float param1 = params[1];
op_init_and_name = {"softplus_tile_init();", fmt::format("softplus_tile({}, {}u, {}u, {}u);", idst,
Converter::to_hex(param0),
Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device
Converter::to_hex(param1)
)};
break;
}
default:
TT_ASSERT( false && "unexpected parameterized type");
};
Expand Down
13 changes: 13 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum class UnaryOpType {
RSUB,
RDIV,
SILU,
SOFTPLUS,
IDENTITY,
NEG,
ADD_UNARY,
Expand Down Expand Up @@ -99,6 +100,7 @@ bool is_parametrized_type(T val) {
case UnaryOpType::RSUB:
case UnaryOpType::RDIV:
case UnaryOpType::EXP:
case UnaryOpType::SOFTPLUS:
case UnaryOpType::ADD_UNARY:
case UnaryOpType::SUB_UNARY:
case UnaryOpType::MUL_UNARY:
Expand Down Expand Up @@ -162,6 +164,8 @@ inline UnaryWithParam string_to_unary_with_param(const std::string& name) {
return UnaryWithParam(UnaryOpType::SIGN);
else if (name == "square")
return UnaryWithParam(UnaryOpType::SQUARE);
else if (name == "softplus")
return UnaryWithParam(UnaryOpType::SOFTPLUS);
TT_THROW("Unknown unary op: " + name);
}

Expand Down Expand Up @@ -433,6 +437,15 @@ inline Tensor sigmoid_accurate(
output_mem_config);
}

inline Tensor softplus(
const Tensor& input_tensor,
float beta,
float threshold,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
return run_eltwise_unary(
input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config);
}

inline Tensor unary_chain(
const Tensor& input_tensor,
std::vector<UnaryWithParam> ops_chain,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "ckernel_sfpu_converter.h"
#include "ckernel_sfpu_exp.h"
#include "ckernel_sfpu_log.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE>
inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) {
vFloat a = dst_reg[0];
vFloat a_beta = a * beta;
v_if(a_beta < threshold) {
exp_init<APPROXIMATION_MODE, false>();
a = calculate_exponential_body<APPROXIMATION_MODE>(a_beta) + 1.0f;

log_init<APPROXIMATION_MODE>();
dst_reg[0] = a;
calculate_log_body<false>(0);
a = beta_reciprocal * dst_reg[0];
}
v_endif;
dst_reg[0] = a;
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 4>
inline void calculate_softplus(uint param0, uint param1, uint param2) {
vFloat beta = Converter::to_float(param0);
vFloat beta_reciprocal = Converter::to_float(param1);
vFloat threshold = Converter::to_float(param2);
for (int d = 0; d < ITERATIONS; d++) {
calculate_softplus_body<APPROXIMATION_MODE>(beta, beta_reciprocal, threshold);
dst_reg++;
}
}

template <bool APPROXIMATION_MODE>
void softplus_init() {}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_softplus.h"
#include "llk_math_eltwise_unary_sfpu_3_param.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::softplus, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus(
uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_3_param<APPROXIMATE>(
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
dst_index,
vector_mode,
param0, param1, param2);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "ckernel_sfpu_converter.h"
#include "ckernel_sfpu_exp.h"
#include "ckernel_sfpu_log.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE>
inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) {
vFloat a = dst_reg[0];
vFloat a_beta = a * beta;
v_if(a_beta < threshold) {
exp_init<APPROXIMATION_MODE, false>();
a = calculate_exponential_body<APPROXIMATION_MODE>(a_beta) + 1.0f;

log_init<APPROXIMATION_MODE>();
dst_reg[0] = a;
calculate_log_body<false>(0);
a = beta_reciprocal * dst_reg[0];
}
v_endif;
dst_reg[0] = a;
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_softplus(uint param0, uint param1, uint param2) {
vFloat beta = Converter::to_float(param0);
vFloat beta_reciprocal = Converter::to_float(param1);
vFloat threshold = Converter::to_float(param2);
for (int d = 0; d < ITERATIONS; d++) {
calculate_softplus_body<APPROXIMATION_MODE>(beta, beta_reciprocal, threshold);
dst_reg++;
}
}

template <bool APPROXIMATION_MODE>
void softplus_init() {}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_softplus.h"
#include "llk_math_eltwise_unary_sfpu_3_param.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::softplus, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus(
uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_3_param<APPROXIMATE>(
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
dst_index,
vector_mode,
param0, param1, param2);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum SfpuType {
unary_ne,
unary_gt,
unary_lt,
softplus,
tiled_prod,
unused,
};
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
#include "compute_kernel_api/eltwise_unary/binop_with_scalar.h"
#endif

#if SFPU_OP_SOFTPLUS_INCLUDE
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE
#include "compute_kernel_api.h"
#endif
Expand Down
47 changes: 47 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_softplus.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise computation of softplus (`1/beta * log(1 + exp(beta * x))`) on each element
* of a tile in DST register at index tile_index. Any input value greater than the provided threshold
* with return itself. The DST register buffer must be in acquired state via *acquire_dst* call. This
* call is blocking and is only available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | tile_index | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | beta | Beta used in softplus calculation | uint32_t | Greater than 0 | True |
* | beta_reciprocal | Reciprocal of beta (1/beta) used in softplus calculation | uint32_t | Greater than 0 | True |
* | threshold | Threshold used in softplus calculation | uint32_t | Greater than 0 | True |
*/
ALWI void softplus_tile(uint32_t idst, uint32_t beta, uint32_t beta_reciprocal, uint32_t threshold) {
MATH(( llk_math_eltwise_unary_sfpu_softplus<APPROX>(idst, beta, beta_reciprocal, threshold) ));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void softplus_tile_init() {
MATH(( llk_math_eltwise_unary_sfpu_softplus_init<APPROX>() ));
}

} // namespace ckernel

0 comments on commit 6c76ca4

Please sign in to comment.