diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 6f5ff325837..497c2d194a6 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -9,9 +9,11 @@ from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, data_gen_with_range_int, + data_gen_with_val, compare_pcc, compare_equal, ) +from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import is_grayskull, skip_for_grayskull @@ -964,32 +966,43 @@ def test_binary_lcm_ttnn(input_shapes, device): @pytest.mark.parametrize( "input_shapes", ( - (torch.Size([1, 3, 32, 32])), - (torch.Size([1, 6, 32, 32])), - (torch.Size([1, 7, 320, 384])), - (torch.Size([1, 4, 320, 384])), + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), ), ) def test_binary_prelu_ttnn(input_shapes, device): - in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 channels = input_shapes[1] in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100 + + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.prelu(input_tensor1, input_tensor2) + output_tensor = ttnn.to_torch(output_tensor) golden_function = ttnn.get_golden_function(ttnn.prelu) golden_tensor = golden_function(in_data1, in_data2) - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + assert_with_pcc(golden_tensor, output_tensor, 0.999) @pytest.mark.parametrize( "input_shapes", ( - (torch.Size([1, 3, 32, 32])), - (torch.Size([1, 6, 32, 32])), - (torch.Size([1, 7, 320, 384])), - (torch.Size([1, 4, 320, 384])), + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), ), ) @pytest.mark.parametrize( @@ -998,10 +1011,12 @@ def test_binary_prelu_ttnn(input_shapes, device): ) @skip_for_grayskull() def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): - in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.prelu(input_tensor1, scalar) + output_tensor = ttnn.to_torch(output_tensor) golden_function = ttnn.get_golden_function(ttnn.prelu) golden_tensor = golden_function(in_data1, scalar) - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 8a42c230115..9ca6a02f9a3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -358,7 +358,7 @@ void bind_binary_composite_with_rtol_atol(py::module& module, const binary_opera } template -void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description) { +void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& supported_dtype="BFLOAT16", const std::string& supported_rank= "2, 3, 4") { auto doc = fmt::format( R"doc( {2} @@ -373,6 +373,19 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t Keyword Args: memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + Note: + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - {4} + Returns: ttnn.Tensor: the output tensor. @@ -384,7 +397,9 @@ void bind_binary_composite_overload(py::module& module, const binary_operation_t )doc", operation.base_name(), operation.python_fully_qualified_name(), - description); + description, + supported_dtype, + supported_rank); bind_registered_operation( module, @@ -1047,7 +1062,9 @@ void py_module(py::module& module) { detail::bind_binary_composite_overload( module, ttnn::prelu, - R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc"); + R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc", + R"doc(2, 3, 4, 5)doc"); detail::bind_binary_composite( module, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 5933cc63db9..e46b86fa072 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -269,11 +269,16 @@ Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::option } Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { - const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape(); - auto volume = input_b.get_logical_volume(); - // If volume = 1 Support for a single-value tensor yet to be handled. TODO(#14933) - TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size"); - Tensor b = ttnn::reshape(input_b, ttnn::SimpleShape{std::array{1, s_a[1], 1, 1}}); + const auto s_a = input_a.get_shape(); + const auto volume = input_b.get_logical_volume(); + + TT_FATAL(s_a[1] == volume, "Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.", volume, s_a[1]); + Tensor b = input_b; + if(s_a.rank()>2){ + SmallVector reshape(s_a.rank(), 1); + reshape[1] = s_a[1]; + b = ttnn::reshape(input_b, ttnn::Shape(reshape)); + } Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); return result; }