Skip to content

Commit

Permalink
Update the logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 14, 2024
1 parent ddc17bd commit 8d1bc93
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
43 changes: 29 additions & 14 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_range_int,
data_gen_with_val,
compare_pcc,
compare_equal,
)
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import is_grayskull, skip_for_grayskull


Expand Down Expand Up @@ -964,32 +966,43 @@ def test_binary_lcm_ttnn(input_shapes, device):
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 3, 32, 32])),
(torch.Size([1, 6, 32, 32])),
(torch.Size([1, 7, 320, 384])),
(torch.Size([1, 4, 320, 384])),
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
def test_binary_prelu_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
channels = input_shapes[1]
in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100

input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, input_tensor2)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
assert_with_pcc(golden_tensor, output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 3, 32, 32])),
(torch.Size([1, 6, 32, 32])),
(torch.Size([1, 7, 320, 384])),
(torch.Size([1, 4, 320, 384])),
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
@pytest.mark.parametrize(
Expand All @@ -998,10 +1011,12 @@ def test_binary_prelu_ttnn(input_shapes, device):
)
@skip_for_grayskull()
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, scalar)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, scalar)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
assert_with_pcc(golden_tensor, output_tensor, 0.999)
23 changes: 20 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ void bind_binary_composite_with_rtol_atol(py::module& module, const binary_opera
}

template <typename binary_operation_t>
void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description) {
void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& supported_dtype="BFLOAT16", const std::string& supported_rank= "2, 3, 4") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -373,6 +373,19 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t
Keyword Args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- {4}
Returns:
ttnn.Tensor: the output tensor.
Expand All @@ -384,7 +397,9 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);
description,
supported_dtype,
supported_rank);

bind_registered_operation(
module,
Expand Down Expand Up @@ -1047,7 +1062,9 @@ void py_module(py::module& module) {
detail::bind_binary_composite_overload(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc");
R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc",
R"doc(2, 3, 4, 5)doc");

detail::bind_binary_composite(
module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,16 @@ Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::option
}

Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape();
auto volume = input_b.get_logical_volume();
// If volume = 1 Support for a single-value tensor yet to be handled. TODO(#14933)
TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size");
Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array<uint32_t, 4>{1, s_a[1], 1, 1}});
const auto s_a = input_a.get_shape();
const auto volume = input_b.get_logical_volume();

TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.", volume, s_a[1]);
Tensor b = input_b;
if(s_a.rank()>2){
SmallVector<uint32_t> reshape(s_a.rank(), 1);
reshape[1] = s_a[1];
b = ttnn::reshape(input_b, ttnn::Shape(reshape));
}
Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
return result;
}
Expand Down

0 comments on commit 8d1bc93

Please sign in to comment.