Skip to content

Commit

Permalink
Add forward support for prelu
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 12, 2024
1 parent becbf96 commit 23e02f3
Show file tree
Hide file tree
Showing 18 changed files with 235 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ Pointwise Unary
ttnn.isneginf
ttnn.isposinf
ttnn.leaky_relu
ttnn.prelu
ttnn.lerp
ttnn.lgamma
ttnn.log
Expand Down
22 changes: 22 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,25 @@ def test_binary_lcm_ttnn(input_shapes, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 3, 32, 32])),
(torch.Size([1, 6, 32, 32])),
(torch.Size([1, 7, 320, 384])),
(torch.Size([1, 4, 320, 384])),
),
)
def test_binary_prelu_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
channels = input_shapes[1]
in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100
input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.prelu(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {


template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_prelu(uint value) {
// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat init = c_value.f;

for (int d = 0; d < 8; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum SfpuType {
reciprocal,
sqrt,
lrelu,
prelu,
power,
square,
tanh_derivative,
Expand Down
43 changes: 43 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_prelu.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise prelu operation. The value to be prelued in the tile is provided as const param0. The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | The value the output is if the input is greater than 0 | uint32_t | | True |
*/
ALWI void prelu_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_prelu<APPROX>(idst, param0)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void prelu_tile_init() { MATH((llk_math_eltwise_unary_sfpu_prelu_init<APPROX>())); }


} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_PRELU_INCLUDE
#include "compute_kernel_api/eltwise_unary/prelu.h"
#endif

#if SFPU_OP_DROPOUT_INCLUDE
#include "compute_kernel_api/eltwise_unary/dropout.h"
#endif
Expand Down
11 changes: 11 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ struct ExecuteMinimum

};

struct ExecutePrelu
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -306,5 +314,8 @@ constexpr auto gcd = ttnn::register_operation_with_auto_launch_op<
constexpr auto lcm = ttnn::register_operation_with_auto_launch_op<
"ttnn::lcm",
operations::binary::ExecuteLCM>();
constexpr auto prelu = ttnn::register_operation_with_auto_launch_op<
"ttnn::prelu",
operations::binary::ExecutePrelu>();

} // namespace ttnn
11 changes: 10 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati
}

template <typename binary_operation_t>
void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLAOT16", const std::string& note="") {
void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math, const std::string& supported_dtype = "BFLOAT16", const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
Expand Down Expand Up @@ -590,6 +590,7 @@ void bind_binary_overload_operation(py::module& module, const binary_operation_t
operation,
doc,


//tensor and scalar
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
Expand Down Expand Up @@ -934,6 +935,14 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)}
)doc");

detail::bind_binary_composite(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b .
PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc",
R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_binary_composite(
module,
ttnn::xlogy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti
return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result);
}

Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape();
auto volume = input_b.get_logical_volume();
// If volume = 1 Support for a single-value tensor yet to be handled. #14933
TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size");
Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array<uint32_t, 4>{1, s_a[1], 1, 1}});
Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
return result;
}
// Binary remainder will be overloaded by unary remainder in another PR
Tensor ExecuteBinaryRemainder::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
auto arch = input_a.device()->arch();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ enum class UnaryOpType {
REMAINDER,
FMOD,
DROPOUT,
FILL
FILL,
PRELU_SFPU,
};

struct UnaryWithParam {
Expand Down
10 changes: 10 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::TAN: defines["SFPU_OP_TRIG_FAMILY_INCLUDE"] = "1"; break;
case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break;
case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break;
case UnaryOpType::PRELU_SFPU: defines["SFPU_OP_PRELU_INCLUDE"] = "1"; break;
case UnaryOpType::TYPECAST: defines["SFPU_OP_TYPECAST_INCLUDE"] = "1"; break;
case UnaryOpType::BITWISE_XOR: defines["SFPU_OP_BITWISE_XOR_INCLUDE"] = "1"; break;
case UnaryOpType::BITWISE_NOT: defines["SFPU_OP_BITWISE_NOT_INCLUDE"] = "1"; break;
Expand Down Expand Up @@ -223,6 +224,15 @@ std::pair<std::string, std::string> get_op_init_and_func_parameterized(
Converter::to_hex(param1))};
break;
}
case UnaryOpType::PRELU_SFPU: {
op_init_and_name = {
"prelu_tile_init();",
fmt::format(
"prelu_tile({}, {}u);",
idst,
Converter::to_hex(param0))};
break;
}
case UnaryOpType::TYPECAST:
TT_ASSERT(params.size() == 2, "Expected eltwise_typecast to take 2 parameters");
op_init_and_name = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ bool is_parametrized_type(T val) {
case UnaryOpType::REMAINDER:
case UnaryOpType::DROPOUT:
case UnaryOpType::FILL:
case UnaryOpType::PRELU_SFPU:
case UnaryOpType::FMOD: return true;
default: return false;
}
Expand Down
27 changes: 27 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ Tensor Softplus::invoke(
optional_output_tensor);
}

Tensor Prelu::invoke(
uint8_t queue_id,
const Tensor& input,
float value,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return detail::unary_impl(
queue_id,
input,
{UnaryWithParam{UnaryOpType::PRELU_SFPU, value}},
memory_config,
optional_output_tensor);
}

Tensor Prelu::invoke(
const Tensor& input,
float value,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return detail::unary_impl(
DefaultQueueId,
input,
{UnaryWithParam{UnaryOpType::PRELU_SFPU, value}},
memory_config,
optional_output_tensor);
}

Tensor Identity::invoke(
uint8_t queue_id,
const Tensor& input_tensor,
Expand Down
17 changes: 17 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ struct Softplus {
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

struct Prelu {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input,
float value,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input,
float value,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

struct Identity {
static Tensor invoke(
uint8_t queue_id,
Expand Down Expand Up @@ -340,6 +355,8 @@ REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_xor, BITWISE_XOR, int32_
constexpr auto dropout = ttnn::register_operation_with_auto_launch_op<"ttnn::dropout", ttnn::operations::unary::Dropout>();
constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>();
constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>();
constexpr auto prelu_sfpu = ttnn::register_operation_with_auto_launch_op<"ttnn::prelu_sfpu", ttnn::operations::unary::Prelu>();

constexpr auto sigmoid_accurate =
ttnn::register_operation_with_auto_launch_op<"ttnn::sigmoid_accurate", ttnn::operations::unary::Sigmoid_accurate>();
constexpr auto unary_chain = ttnn::register_operation_with_auto_launch_op<"ttnn::unary_chain", ttnn::operations::unary::Unary_chain>();
Expand Down
4 changes: 0 additions & 4 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,6 @@ def auto_register_ttnn_cpp_operations(module):
mul_ = ttnn.multiply_


def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReLU properly
return ttnn.leaky_relu(*args, **kwargs)


# TODO: pybind the overloaded operators below
ttnn.Tensor.__add__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs)
ttnn.Tensor.__radd__ = lambda self, *args, **kwargs: ttnn.add(self, *args, **kwargs)
Expand Down
11 changes: 9 additions & 2 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
"gelu": torch.nn.functional.gelu,
"rsqrt": torch.rsqrt,
# Unaries with float parameter
# "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly
# Other unaries (composite operations)
"softplus": torch.nn.functional.softplus,
"sigmoid_accurate": torch.sigmoid,
Expand Down Expand Up @@ -148,7 +147,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
ttnn.gelu,
ttnn.rsqrt,
# Unaries with float parameter
# ttnn.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly
# Unaries using op_chain
ttnn.log_sigmoid,
ttnn.softplus,
Expand Down Expand Up @@ -239,6 +237,15 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs):
ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu)


def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b)


ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu)


def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs):
import torch

Expand Down

0 comments on commit 23e02f3

Please sign in to comment.