Skip to content

Commit

Permalink
Add scalar support for prelu
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 12, 2024
1 parent 97a3037 commit 5ba2908
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 19 deletions.
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,26 @@ def test_binary_prelu_ttnn(input_shapes, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 3, 32, 32])),
(torch.Size([1, 6, 32, 32])),
(torch.Size([1, 7, 320, 384])),
(torch.Size([1, 4, 320, 384])),
),
)
@pytest.mark.parametrize(
"scalar",
{-0.25, -2.7, 0.45, 6.4},
)
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
output_tensor = ttnn.prelu(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, scalar)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ struct ExecutePrelu
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
Expand Down
14 changes: 5 additions & 9 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@ void bind_binary_overload_operation(py::module& module, const binary_operation_t
operation,
doc,


//tensor and scalar
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
Expand Down Expand Up @@ -935,14 +934,6 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)}
)doc");

detail::bind_binary_composite(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b .
PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc",
R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_binary_composite(
module,
ttnn::xlogy,
Expand Down Expand Up @@ -1053,6 +1044,11 @@ void py_module(py::module& module) {
ttnn::maximum,
R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite_overload(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc");

detail::bind_binary_composite(
module,
ttnn::scatter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti
return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result);
}

Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config) {
return ttnn::prelu_sfpu(input, scalar);
}

Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape();
auto volume = input_b.get_logical_volume();
// If volume = 1 Support for a single-value tensor yet to be handled. #14933
// If volume = 1 Support for a single-value tensor yet to be handled. TODO(#14933)
TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size");
Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array<uint32_t, 4>{1, s_a[1], 1, 1}});
Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
Expand Down
12 changes: 12 additions & 0 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,16 @@ def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs):
ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm)


def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

if not torch.is_tensor(input_tensor_b):
input_tensor_b = torch.tensor(input_tensor_b, dtype=input_tensor_a.dtype)

return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b)


ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu)


__all__ = []
9 changes: 0 additions & 9 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,6 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs):
ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu)


def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b)


ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu)


def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs):
import torch

Expand Down

0 comments on commit 5ba2908

Please sign in to comment.