Skip to content

Commit

Permalink
#6938: Implement softplus as a single kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed May 31, 2024
1 parent 74d79a8 commit 7cbf6e8
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, is_grayskull


reference_pcc = defaultdict(lambda: 0.999)
reference_pcc["silu"] = 0.9714
reference_pcc["swish"] = reference_pcc["silu"]
reference_pcc["softplus"] = 0.9984


def custom_compare(*args, **kwargs):
Expand Down Expand Up @@ -68,7 +67,6 @@ def custom_compare(*args, **kwargs):
"max",
"swish",
"log1p",
"softplus",
"mish",
"silu",
"polyval",
Expand Down Expand Up @@ -157,6 +155,9 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
if is_wormhole_b0():
if fn in ["logit"]:
pytest.skip("does not work for Wormhole -skipping")
if is_grayskull():
if fn in ["mish"]:
pytest.skip("does not work for Grayskull -skipping")
if fn in ["logical_xor", "logical_xori", "logical_ori", "logical_andi"]:
datagen_func = [
generation_funcs.gen_func_with_cast(
Expand Down Expand Up @@ -231,13 +232,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
"equal_nan": random.choice([False, True]),
}
)
elif fn in ["softplus"]:
test_args.update(
{
"beta": random.choice([0.5, -3, 1, 4]),
"threshold": random.choice([-20, 10, 20, 5]),
}
)
run_single_pytorch_test(
"eltwise-%s" % (fn),
input_shapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, skip_for_grayskull

shapes = [
[[1, 1, 32, 32]], # Single core
Expand Down Expand Up @@ -1100,3 +1100,37 @@ def test_run_eltwise_unary_comp(
device,
test_args,
)

@skip_for_grayskull("Softplus kernel not currently availible for GS")
@pytest.mark.parametrize("beta", [1.0, 5.0])
@pytest.mark.parametrize("threshold", [10.0, 20.0])
def test_run_eltwise_softplus(
self,
input_shapes,
beta,
threshold,
device,
function_level_defaults,
input_mem_config,
output_mem_config,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"beta": beta, "threshold": threshold})
test_args.update(
{
"input_mem_config": [input_mem_config],
"output_mem_config": output_mem_config,
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
"eltwise-softplus",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
3 changes: 3 additions & 0 deletions tests/ttnn/unit_tests/operations/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull


def run_activation_unary_test(device, h, w, ttnn_function, torch_function, pcc=0.99):
Expand Down Expand Up @@ -52,6 +53,7 @@ def test_log_sigmoid(device, h, w):
run_activation_unary_test(device, h, w, ttnn.log_sigmoid, F.logsigmoid)


@skip_for_grayskull()
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_mish(device, h, w):
Expand Down Expand Up @@ -116,6 +118,7 @@ def run_activation_softplus_test(device, h, w, beta, threshold, ttnn_function, t
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@skip_for_grayskull()
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
@pytest.mark.parametrize("beta", [-1, 1, 2, 0.5, 10])
Expand Down
17 changes: 0 additions & 17 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,6 @@ Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log1p)(x, output_mem_config);
}

// softplus[x] =(1/beta) * log[1 + exp[x * beta]]
// (x*beta) > threshold ==> x
// use transformation y = log[1+exp[x]] by broadcast
Tensor _softplus(const Tensor& x, float beta, float threshold, const MemoryConfig& output_mem_config) {
float oned_beta = (1 / beta);
Tensor x_beta = mul_unary(x, beta, output_mem_config);
Tensor exp_x = exp(x_beta, output_mem_config);
Tensor result_log1p = log1p(exp_x, output_mem_config);
Tensor sp_result = mul_unary(result_log1p, oned_beta, output_mem_config);
sp_result = where(gt(x_beta, full_like(x, threshold, output_mem_config), std::nullopt, output_mem_config), x,
where(eqz(full_like(x, beta, output_mem_config), output_mem_config), std::numeric_limits<float>::infinity(), sp_result), output_mem_config);
return sp_result;
}
Tensor softplus(const Tensor& a, float beta, float threshold, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _softplus)(a, beta, threshold, output_mem_config);
}

// tanhshrink(x) = x - tanh(x)
Tensor _tanhshrink(const Tensor& x, const MemoryConfig& output_mem_config) {
Tensor tan_x = tanh(x, output_mem_config);
Expand Down
4 changes: 0 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ Tensor mac(
// use transformation y = log(1.0 + x) by broadcast
Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// softplus[x] = log[1 + exp[x]]
// use transformation y = log[1+exp[x]] by broadcast
Tensor softplus(const Tensor& x, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// mish[x] = x*tanh[softplus[x]]
// use transformation y = x*tanh[softplus[x]] by broadcast
// Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/
Expand Down
18 changes: 16 additions & 2 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ union Converter {
float f;
uint32_t u;

Converter(float f_) : f(f_) {};
Converter(float f_) : f(f_){};

static std::string to_hex(float f_) {
Converter obj(f_);
Expand Down Expand Up @@ -67,14 +67,15 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::SIN:
case UnaryOpType::TAN: defines["SFPU_OP_TRIG_FAMILY_INCLUDE"] = "1"; break;
case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break;
case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break;
default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break;
};
}

std::pair<string, string> get_op_init_and_func_parameterized(
UnaryOpType op_type, std::vector<float> params, string idst) {
std::pair<string, string> op_init_and_name;
TT_FATAL(is_parametrized_type(op_type) && "operator should support one parameter");
TT_FATAL(is_parametrized_type(op_type) && "operator should support at least one parameter");
float param0 = params[0];
switch (op_type) {
case UnaryOpType::RELU_MAX:
Expand Down Expand Up @@ -162,6 +163,19 @@ std::pair<string, string> get_op_init_and_func_parameterized(
op_init_and_name = {
"unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))};
break;
case UnaryOpType::SOFTPLUS: {
TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters");
float param1 = params[1];
op_init_and_name = {
"softplus_tile_init();",
fmt::format(
"softplus_tile({}, {}u, {}u, {}u);",
idst,
Converter::to_hex(param0),
Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device
Converter::to_hex(param1))};
break;
}
default: TT_ASSERT(false && "unexpected parameterized type");
};
return op_init_and_name;
Expand Down
14 changes: 14 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum class UnaryOpType {
RSUB,
RDIV,
SILU,
SOFTPLUS,
IDENTITY,
NEG,
ADD_UNARY_SFPU,
Expand Down Expand Up @@ -95,6 +96,7 @@ bool is_parametrized_type(T val) {
case UnaryOpType::RSUB:
case UnaryOpType::RDIV:
case UnaryOpType::EXP:
case UnaryOpType::SOFTPLUS:
case UnaryOpType::ADD_UNARY_SFPU:
case UnaryOpType::SUB_UNARY_SFPU:
case UnaryOpType::MUL_UNARY_SFPU:
Expand Down Expand Up @@ -154,6 +156,8 @@ inline UnaryWithParam string_to_unary_with_param(const std::string& name) {
return UnaryWithParam(UnaryOpType::SIGN);
else if (name == "square")
return UnaryWithParam(UnaryOpType::SQUARE);
else if (name == "softplus")
return UnaryWithParam(UnaryOpType::SOFTPLUS);
TT_THROW("Unknown unary op: " + name);
}

Expand Down Expand Up @@ -423,6 +427,16 @@ inline Tensor sigmoid_accurate(
output_mem_config);
}

inline Tensor softplus(
const Tensor& input_tensor,
float beta,
float threshold,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
TT_ASSERT(input_tensor.device()->arch() != tt::ARCH::GRAYSKULL, "Softplus is not currently supported on Grayskull");
return run_eltwise_unary(
input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config);
}

inline Tensor unary_chain(
const Tensor& input_tensor,
std::vector<UnaryWithParam> ops_chain,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "ckernel_sfpu_converter.h"
#include "ckernel_sfpu_exp.h"
#include "ckernel_sfpu_log.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE>
inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) {
vFloat a = dst_reg[0];
vFloat a_beta = a * beta;
v_if(a_beta < threshold) {
exp_init<APPROXIMATION_MODE, false>();
a = calculate_exponential_body<APPROXIMATION_MODE>(a_beta) + 1.0f;

log_init<APPROXIMATION_MODE>();
dst_reg[0] = a;
calculate_log_body<false>(0);
a = beta_reciprocal * dst_reg[0];
}
v_endif;
dst_reg[0] = a;
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_softplus(uint param0, uint param1, uint param2) {
vFloat beta = Converter::to_float(param0);
vFloat beta_reciprocal = Converter::to_float(param1);
vFloat threshold = Converter::to_float(param2);
for (int d = 0; d < ITERATIONS; d++) {
calculate_softplus_body<APPROXIMATION_MODE>(beta, beta_reciprocal, threshold);
dst_reg++;
}
}

template <bool APPROXIMATION_MODE>
void softplus_init() {}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_softplus.h"
#include "llk_math_eltwise_unary_sfpu_3_param.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::softplus, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus(
uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_3_param<APPROXIMATE>(
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
dst_index,
vector_mode,
param0, param1, param2);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum SfpuType {
unary_ne,
unary_gt,
unary_lt,
softplus,
tiled_prod,
unused,
};
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
#include "compute_kernel_api/eltwise_unary/binop_with_scalar.h"
#endif

#if SFPU_OP_SOFTPLUS_INCLUDE
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE
#include "compute_kernel_api.h"
#endif
Expand Down
47 changes: 47 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_softplus.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise computation of softplus (`1/beta * log(1 + exp(beta * x))`) on each element
* of a tile in DST register at index tile_index. Any input value greater than the provided threshold
* with return itself. The DST register buffer must be in acquired state via *acquire_dst* call. This
* call is blocking and is only available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | tile_index | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | beta | Beta used in softplus calculation | uint32_t | Greater than 0 | True |
* | beta_reciprocal | Reciprocal of beta (1/beta) used in softplus calculation | uint32_t | Greater than 0 | True |
* | threshold | Threshold used in softplus calculation | uint32_t | Greater than 0 | True |
*/
ALWI void softplus_tile(uint32_t idst, uint32_t beta, uint32_t beta_reciprocal, uint32_t threshold) {
MATH(( llk_math_eltwise_unary_sfpu_softplus<APPROX>(idst, beta, beta_reciprocal, threshold) ));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void softplus_tile_init() {
MATH(( llk_math_eltwise_unary_sfpu_softplus_init<APPROX>() ));
}

} // namespace ckernel
Loading

0 comments on commit 7cbf6e8

Please sign in to comment.