Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#6938: Implement softplus as a single kernel #8249

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, is_grayskull


reference_pcc = defaultdict(lambda: 0.999)
reference_pcc["silu"] = 0.9714
reference_pcc["swish"] = reference_pcc["silu"]
reference_pcc["softplus"] = 0.9984


def custom_compare(*args, **kwargs):
Expand Down Expand Up @@ -68,7 +67,6 @@ def custom_compare(*args, **kwargs):
"max",
"swish",
"log1p",
"softplus",
"mish",
"silu",
"polyval",
Expand Down Expand Up @@ -157,6 +155,9 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
if is_wormhole_b0():
if fn in ["logit"]:
pytest.skip("does not work for Wormhole -skipping")
if is_grayskull():
if fn in ["mish"]:
pytest.skip("does not work for Grayskull -skipping")
if fn in ["logical_xor", "logical_xori", "logical_ori", "logical_andi"]:
datagen_func = [
generation_funcs.gen_func_with_cast(
Expand Down Expand Up @@ -231,13 +232,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
"equal_nan": random.choice([False, True]),
}
)
elif fn in ["softplus"]:
test_args.update(
{
"beta": random.choice([0.5, -3, 1, 4]),
"threshold": random.choice([-20, 10, 20, 5]),
}
)
run_single_pytorch_test(
"eltwise-%s" % (fn),
input_shapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, skip_for_grayskull

shapes = [
[[1, 1, 32, 32]], # Single core
Expand Down Expand Up @@ -1100,3 +1100,37 @@ def test_run_eltwise_unary_comp(
device,
test_args,
)

@skip_for_grayskull("Softplus kernel not currently availible for GS")
@pytest.mark.parametrize("beta", [1.0, 5.0])
@pytest.mark.parametrize("threshold", [10.0, 20.0])
def test_run_eltwise_softplus(
self,
input_shapes,
beta,
threshold,
device,
function_level_defaults,
input_mem_config,
output_mem_config,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update({"beta": beta, "threshold": threshold})
test_args.update(
{
"input_mem_config": [input_mem_config],
"output_mem_config": output_mem_config,
}
)
comparison_func = comparison_funcs.comp_pcc
run_single_pytorch_test(
"eltwise-softplus",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
3 changes: 3 additions & 0 deletions tests/ttnn/unit_tests/operations/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull


def run_activation_unary_test(device, h, w, ttnn_function, torch_function, pcc=0.99):
Expand Down Expand Up @@ -52,6 +53,7 @@ def test_log_sigmoid(device, h, w):
run_activation_unary_test(device, h, w, ttnn.log_sigmoid, F.logsigmoid)


@skip_for_grayskull()
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_mish(device, h, w):
Expand Down Expand Up @@ -116,6 +118,7 @@ def run_activation_softplus_test(device, h, w, beta, threshold, ttnn_function, t
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@skip_for_grayskull()
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
@pytest.mark.parametrize("beta", [-1, 1, 2, 0.5, 10])
Expand Down
19 changes: 0 additions & 19 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,6 @@ Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _log1p)(x, output_mem_config);
}

// softplus[x] =(1/beta) * log[1 + exp[x * beta]]
// (x*beta) > threshold ==> x
// use transformation y = log[1+exp[x]] by broadcast
Tensor _softplus(const Tensor& x, float beta, float threshold, const MemoryConfig& output_mem_config) {
float oned_beta = (1 / beta);
Tensor x_beta = mul_unary(x, beta, output_mem_config);
Tensor exp_x = exp(x_beta, output_mem_config);
Tensor result_log1p = log1p(exp_x, output_mem_config);
exp_x.deallocate();
Tensor sp_result = mul_unary(result_log1p, oned_beta, output_mem_config);
result_log1p.deallocate();
sp_result = where(gt(x_beta, full_like(x, threshold, output_mem_config), std::nullopt, output_mem_config), x,
where(eqz(full_like(x, beta, output_mem_config), output_mem_config), std::numeric_limits<float>::infinity(), sp_result), output_mem_config);
return sp_result;
}
Tensor softplus(const Tensor& a, float beta, float threshold, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _softplus)(a, beta, threshold, output_mem_config);
}

// tanhshrink(x) = x - tanh(x)
Tensor _tanhshrink(const Tensor& x, const MemoryConfig& output_mem_config) {
Tensor tan_x = tanh(x, output_mem_config);
Expand Down
4 changes: 0 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ Tensor mac(
// use transformation y = log(1.0 + x) by broadcast
Tensor log1p(const Tensor& x, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// softplus[x] = log[1 + exp[x]]
// use transformation y = log[1+exp[x]] by broadcast
Tensor softplus(const Tensor& x, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// mish[x] = x*tanh[softplus[x]]
// use transformation y = x*tanh[softplus[x]] by broadcast
// Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/
Expand Down
18 changes: 16 additions & 2 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ union Converter {
float f;
uint32_t u;

Converter(float f_) : f(f_) {};
Converter(float f_) : f(f_){};

static std::string to_hex(float f_) {
Converter obj(f_);
Expand Down Expand Up @@ -67,14 +67,15 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::SIN:
case UnaryOpType::TAN: defines["SFPU_OP_TRIG_FAMILY_INCLUDE"] = "1"; break;
case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break;
case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break;
default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break;
};
}

std::pair<string, string> get_op_init_and_func_parameterized(
UnaryOpType op_type, std::vector<float> params, string idst) {
std::pair<string, string> op_init_and_name;
TT_FATAL(is_parametrized_type(op_type) && "operator should support one parameter");
TT_FATAL(is_parametrized_type(op_type) && "operator should support at least one parameter");
float param0 = params[0];
switch (op_type) {
case UnaryOpType::RELU_MAX:
Expand Down Expand Up @@ -162,6 +163,19 @@ std::pair<string, string> get_op_init_and_func_parameterized(
op_init_and_name = {
"unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))};
break;
case UnaryOpType::SOFTPLUS: {
TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters");
float param1 = params[1];
op_init_and_name = {
"softplus_tile_init();",
fmt::format(
"softplus_tile({}, {}u, {}u, {}u);",
idst,
Converter::to_hex(param0),
Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device
Converter::to_hex(param1))};
break;
}
default: TT_ASSERT(false && "unexpected parameterized type");
};
return op_init_and_name;
Expand Down
14 changes: 14 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum class UnaryOpType {
RSUB,
RDIV,
SILU,
SOFTPLUS,
IDENTITY,
NEG,
ADD_UNARY_SFPU,
Expand Down Expand Up @@ -95,6 +96,7 @@ bool is_parametrized_type(T val) {
case UnaryOpType::RSUB:
case UnaryOpType::RDIV:
case UnaryOpType::EXP:
case UnaryOpType::SOFTPLUS:
case UnaryOpType::ADD_UNARY_SFPU:
case UnaryOpType::SUB_UNARY_SFPU:
case UnaryOpType::MUL_UNARY_SFPU:
Expand Down Expand Up @@ -154,6 +156,8 @@ inline UnaryWithParam string_to_unary_with_param(const std::string& name) {
return UnaryWithParam(UnaryOpType::SIGN);
else if (name == "square")
return UnaryWithParam(UnaryOpType::SQUARE);
else if (name == "softplus")
return UnaryWithParam(UnaryOpType::SOFTPLUS);
TT_THROW("Unknown unary op: " + name);
}

Expand Down Expand Up @@ -423,6 +427,16 @@ inline Tensor sigmoid_accurate(
output_mem_config);
}

inline Tensor softplus(
const Tensor& input_tensor,
float beta,
float threshold,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
TT_ASSERT(input_tensor.device()->arch() != tt::ARCH::GRAYSKULL, "Softplus is not currently supported on Grayskull");
return run_eltwise_unary(
input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config);
}

inline Tensor unary_chain(
const Tensor& input_tensor,
std::vector<UnaryWithParam> ops_chain,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "ckernel_sfpu_converter.h"
#include "ckernel_sfpu_exp.h"
#include "ckernel_sfpu_log.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE>
inline void calculate_softplus_body(vFloat beta, vFloat beta_reciprocal, vFloat threshold) {
vFloat a = dst_reg[0];
vFloat a_beta = a * beta;
esmalTT marked this conversation as resolved.
Show resolved Hide resolved
v_if(a_beta < threshold) {
exp_init<APPROXIMATION_MODE, false>();
a = calculate_exponential_body<APPROXIMATION_MODE>(a_beta) + 1.0f;

log_init<APPROXIMATION_MODE>();
dst_reg[0] = a;
calculate_log_body<false>(0);
a = beta_reciprocal * dst_reg[0];
}
v_endif;
dst_reg[0] = a;
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_softplus(uint param0, uint param1, uint param2) {
vFloat beta = Converter::to_float(param0);
vFloat beta_reciprocal = Converter::to_float(param1);
vFloat threshold = Converter::to_float(param2);
for (int d = 0; d < ITERATIONS; d++) {
calculate_softplus_body<APPROXIMATION_MODE>(beta, beta_reciprocal, threshold);
dst_reg++;
}
}

template <bool APPROXIMATION_MODE>
void softplus_init() {}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_softplus.h"
#include "llk_math_eltwise_unary_sfpu_3_param.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::softplus, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_softplus(
uint dst_index, uint param0, uint param1, uint param2, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_3_param<APPROXIMATE>(
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
ckernel::sfpu::calculate_softplus<APPROXIMATE>,
dst_index,
vector_mode,
param0, param1, param2);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum SfpuType {
unary_ne,
unary_gt,
unary_lt,
softplus,
tiled_prod,
unused,
};
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
#include "compute_kernel_api/eltwise_unary/binop_with_scalar.h"
#endif

#if SFPU_OP_SOFTPLUS_INCLUDE
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE
#include "compute_kernel_api.h"
#endif
Expand Down
47 changes: 47 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/softplus.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_softplus.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise computation of softplus (`1/beta * log(1 + exp(beta * x))`) on each element
* of a tile in DST register at index tile_index. Any input value greater than the provided threshold
* with return itself. The DST register buffer must be in acquired state via *acquire_dst* call. This
* call is blocking and is only available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | tile_index | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | beta | Beta used in softplus calculation | uint32_t | Greater than 0 | True |
* | beta_reciprocal | Reciprocal of beta (1/beta) used in softplus calculation | uint32_t | Greater than 0 | True |
* | threshold | Threshold used in softplus calculation | uint32_t | Greater than 0 | True |
*/
ALWI void softplus_tile(uint32_t idst, uint32_t beta, uint32_t beta_reciprocal, uint32_t threshold) {
MATH(( llk_math_eltwise_unary_sfpu_softplus<APPROX>(idst, beta, beta_reciprocal, threshold) ));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void softplus_tile_init() {
MATH(( llk_math_eltwise_unary_sfpu_softplus_init<APPROX>() ));
}

} // namespace ckernel
Loading
Loading