Skip to content

Commit

Permalink
Add scalar support for prelu
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 12, 2024
1 parent 23e02f3 commit dc0e900
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 16 deletions.
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,26 @@ def test_binary_prelu_ttnn(input_shapes, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 3, 32, 32])),
(torch.Size([1, 6, 32, 32])),
(torch.Size([1, 7, 320, 384])),
(torch.Size([1, 4, 320, 384])),
),
)
@pytest.mark.parametrize(
"scalar",
{-0.25, -2.7, 0.45, 6.4},
)
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
output_tensor = ttnn.prelu(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, scalar)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ struct ExecutePrelu
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
Expand Down
13 changes: 6 additions & 7 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,13 +935,6 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \sqrt{(\mathrm{input\_tensor\_a}_i^2 + \mathrm{input\_tensor\_b}_i^2)}
)doc");

detail::bind_binary_composite(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. Formula : a - a.div(b, rounding_mode=trunc) * b .
PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc",
R"doc(\mathrm{{output\_tensor}} = \verb|PReLU|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}))doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_binary_composite(
module,
Expand Down Expand Up @@ -1053,6 +1046,12 @@ void py_module(py::module& module) {
ttnn::maximum,
R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite_overload(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation.
PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc");

detail::bind_binary_composite(
module,
ttnn::scatter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti
Tensor div_result = ttnn::div(input_a, input_b, false, "None", output_mem_config);
return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result);
}
Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config) {
return ttnn::prelu_sfpu(input, scalar);
}

Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape();
Expand Down
12 changes: 12 additions & 0 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,16 @@ def _golden_function_lcm(input_tensor_a, input_tensor_b, *args, **kwargs):
ttnn.attach_golden_function(ttnn.lcm, golden_function=_golden_function_lcm)


def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

if not torch.is_tensor(input_tensor_b):
input_tensor_b = torch.tensor(input_tensor_b, dtype=input_tensor_a.dtype)

return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b)


ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu)


__all__ = []
9 changes: 0 additions & 9 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,6 @@ def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs):
ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu)


def _golden_function_prelu(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return torch.nn.functional.prelu(input_tensor_a, weight=input_tensor_b)


ttnn.attach_golden_function(ttnn.prelu, golden_function=_golden_function_prelu)


def _golden_function_hardtanh(input_tensor_a, min_val=-1.0, max_val=1.0, *args, **kwargs):
import torch

Expand Down

0 comments on commit dc0e900

Please sign in to comment.