diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 420d901d55f..f3c4c3ef2df 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -141,7 +141,6 @@ Pointwise Unary ttnn.isneginf ttnn.isposinf ttnn.leaky_relu - ttnn.lerp ttnn.lgamma ttnn.log ttnn.log10 @@ -373,6 +372,7 @@ Pointwise Ternary ttnn.addcmul ttnn.mac ttnn.where + ttnn.lerp ttnn.addcmul_bw ttnn.addcdiv_bw ttnn.where_bw diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp index 6b2215f17d3..03efd382b98 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp @@ -22,7 +22,7 @@ namespace ternary { namespace detail { template -void bind_ternary_composite_float(py::module& module, const ternary_operation_t& operation, const std::string& description) { +void bind_ternary_composite_float(py::module& module, const ternary_operation_t& operation, const std::string& description, const std::string& supported_dtype = "BFLOAT16") { auto doc = fmt::format( R"doc( {2} @@ -34,21 +34,26 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t& Keyword Args: - value (float, optional): Float value. Defaults to `1`. + value (float, optional): Scalar value to be multiplied. memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. Returns: ttnn.Tensor: the output tensor. - Supported dtypes and layouts: + Note: + Supported dtypes, layouts, and ranks: - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ + .. list-table:: + :header-rows: 1 - Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 + + bfloat8_b/bfloat4_b supports only on TILE_LAYOUT Example: >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) @@ -58,7 +63,8 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t& )doc", operation.base_name(), operation.python_fully_qualified_name(), - description); + description, + supported_dtype); bind_registered_operation( module, @@ -77,7 +83,7 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t& py::arg("input_tensor_b"), py::arg("input_tensor_c"), py::kw_only(), - py::arg("value") = 1.0f, + py::arg("value"), py::arg("memory_config") = std::nullopt}); } @@ -89,7 +95,7 @@ void bind_ternary_where(py::module& module, const ternary_operation_t& operation Args: input_tensor_a (ttnn.Tensor): the input tensor. - input_tensor_b (ttnn.Tensor or number): the input tensor. + input_tensor_b (ttnn.Tensor or Number): the input tensor. input_tensor_c (ttnn.Tensor or Number): the input tensor. @@ -99,15 +105,20 @@ void bind_ternary_where(py::module& module, const ternary_operation_t& operation queue_id (int, optional): command queue id. Defaults to `0`. - Supported dtypes and layouts: + Note: + Supported dtypes, layouts, and ranks: - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ + .. list-table:: + :header-rows: 1 - Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16, BFLOAT8_B + - TILE + - 2, 3, 4 + + bfloat8_b/bfloat4_b supports only on TILE_LAYOUT Example: @@ -267,7 +278,7 @@ void bind_ternary_mac(py::module& module, const ternary_operation_t& operation, Args: input_tensor_a (ttnn.Tensor): the input tensor. - input_tensor_b (ttnn.Tensor or number): the input tensor. + input_tensor_b (ttnn.Tensor or Number): the input tensor. input_tensor_c (ttnn.Tensor or Number): the input tensor. Keyword Args: @@ -276,15 +287,20 @@ void bind_ternary_mac(py::module& module, const ternary_operation_t& operation, Returns: ttnn.Tensor: the output tensor. - Supported dtypes and layouts: + Note: + Supported dtypes, layouts, and ranks: - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ + .. list-table:: + :header-rows: 1 - Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16, BFLOAT8_B + - TILE + - 2, 3, 4 + + bfloat8_b/bfloat4_b supports only on TILE_LAYOUT Example: >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device) @@ -336,28 +352,29 @@ void py_module(py::module& module) { detail::bind_ternary_composite_float( module, ttnn::addcmul, - R"doc(compute Addcmul :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + R"doc(Computes Addcmul on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_ternary_composite_float( module, ttnn::addcdiv, - R"doc(compute Addcdiv :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + R"doc(Computes Addcdiv on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); detail::bind_ternary_where( module, ttnn::where, - R"doc(compute Addcdiv :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + R"doc(Computes Where on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); detail::bind_ternary_lerp( module, ttnn::lerp, - R"doc(compute Lerp :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a` + R"doc(Computes Lerp on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a` .. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc"); detail::bind_ternary_mac( module, ttnn::mac, - R"doc(compute Mac :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + R"doc(Computes Mac on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); } } // namespace ternary