Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward support for PReLU #14940

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)


# Parameters provided to the test vector generator are defined here.
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
Expand All @@ -45,12 +40,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
return False, None


def torch_prelu(x, *args, **kwargs):
weight = kwargs.pop("scalar")
result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype))
return result


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
Expand All @@ -65,14 +54,14 @@ def run(
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)
torch.manual_seed(0)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)

torch_output_tensor = torch_prelu(torch_input_tensor_a, scalar=weight)
golden_function = ttnn.get_golden_function(ttnn.prelu)
torch_output_tensor = golden_function(torch_input_tensor_a, weight)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def test_run_eltwise_leaky_relu_op(
)

@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@skip_for_grayskull()
def test_run_eltwise_prelu(
self,
input_shapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_scalarB_leaky_relu(device, h, w, scalar):
run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu)


@skip_for_grayskull()
@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_range_int,
data_gen_with_val,
compare_pcc,
compare_equal,
)
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import is_grayskull, skip_for_grayskull


Expand Down Expand Up @@ -959,3 +961,62 @@ def test_binary_lcm_ttnn(input_shapes, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
def test_binary_prelu_ttnn(input_shapes, device):
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
channels = input_shapes[1]
in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100

input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, input_tensor2)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, in_data2)

assert_with_pcc(golden_tensor, output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
@pytest.mark.parametrize(
"scalar",
{-0.25, -2.7, 0.45, 6.4},
)
@skip_for_grayskull()
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, scalar)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, scalar)

assert_with_pcc(golden_tensor, output_tensor, 0.999)
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_fill(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat fill_val = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ enum SfpuType {
ceil,
unused,
cumsum,
fill
fill,
prelu,
};
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {


template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_prelu(uint value) {
// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat init = c_value.f;

#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum SfpuType {
reciprocal,
sqrt,
lrelu,
prelu,
power,
square,
tanh_derivative,
Expand Down
43 changes: 43 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_prelu.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise prelu operation. The value to be prelued in the tile is provided as const param0. The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | The value the output is if the input is greater than 0 | uint32_t | | True |
*/
ALWI void prelu_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_prelu<APPROX>(idst, param0)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void prelu_tile_init() { MATH((llk_math_eltwise_unary_sfpu_prelu_init<APPROX>())); }


} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_PRELU_INCLUDE
#include "compute_kernel_api/eltwise_unary/prelu.h"
#endif

#if SFPU_OP_DROPOUT_INCLUDE
#include "compute_kernel_api/eltwise_unary/dropout.h"
#endif
Expand Down
17 changes: 17 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,20 @@ struct ExecuteMinimum

};

struct ExecutePrelu
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -306,5 +320,8 @@ constexpr auto gcd = ttnn::register_operation_with_auto_launch_op<
constexpr auto lcm = ttnn::register_operation_with_auto_launch_op<
"ttnn::lcm",
operations::binary::ExecuteLCM>();
constexpr auto prelu = ttnn::register_operation_with_auto_launch_op<
"ttnn::prelu",
operations::binary::ExecutePrelu>();

} // namespace ttnn
Loading
Loading