diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index f3c4c3ef2df..c1364361ff5 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -358,7 +358,6 @@ Pointwise Binary ttnn.rsub_bw ttnn.min_bw ttnn.max_bw - ttnn.lerp_bw Pointwise Ternary ================= @@ -376,6 +375,7 @@ Pointwise Ternary ttnn.addcmul_bw ttnn.addcdiv_bw ttnn.where_bw + ttnn.lerp_bw Losses ====== diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp index f17839c8461..e31b7116706 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp @@ -20,7 +20,7 @@ namespace ternary_backward { namespace detail { template -void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype ="BFLOAT16") { auto doc = fmt::format( R"doc( @@ -30,17 +30,27 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_ grad_tensor (ttnn.Tensor): the input gradient tensor. input_tensor_a (ttnn.Tensor): the input tensor. input_tensor_b (ttnn.Tensor): the input tensor. - input_tensor_c (ttnn.Tensor or Number): the input tensor. + input_tensor_c (ttnn.Tensor): the input tensor. alpha (float): the alpha value. Keyword args: - memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. Returns: List of ttnn.Tensor: the output tensor. Note: - {3} + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 Example: @@ -81,21 +91,21 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_ } template -void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16") { auto doc = fmt::format( R"doc( {2} Args: - grad_tensor (ttnn.Tensor): the input tensor. + grad_tensor (ttnn.Tensor): the input gradient tensor. input_tensor_a (ttnn.Tensor): the input tensor. input_tensor_b (ttnn.Tensor): the input tensor. input_tensor_c (ttnn.Tensor or Number): the input tensor. Keyword args: - memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. Returns: @@ -103,10 +113,33 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati Note: - {3} + Supported dtypes, layouts, and ranks: + + For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` + .. list-table:: + :header-rows: 1 - Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16 + - TILE + - 2, 3, 4 + + For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`scalar` + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 + + bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT Example: @@ -162,22 +195,26 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati } template -void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward_optional_output( + py::module& module, + const ternary_backward_operation_t& operation, + const std::string_view description, + const std::string& supported_dtype ="BFLOAT16") { auto doc = fmt::format( R"doc( {2} Args: - grad_tensor (ttnn.Tensor): the input tensor. + grad_tensor (ttnn.Tensor): the input gradient tensor. input_tensor_a (ttnn.Tensor): the input tensor. input_tensor_b (ttnn.Tensor): the input tensor. input_tensor_c (ttnn.Tensor): the input tensor. Keyword args: - are_required_outputs (List[bool], optional): List of required outputs. Defaults to `[True, True]`. - memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. - output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to `None`. + are_required_outputs (List[bool], optional): list of required outputs. Defaults to `[True, True]`. + memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. + output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. Returns: @@ -185,7 +222,18 @@ void bind_ternary_backward_optional_output(py::module& module, const ternary_bac Note: - {3} + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 + Example: @@ -236,66 +284,25 @@ void py_module(py::module& module) { detail::bind_ternary_backward( module, ttnn::addcmul_bw, - R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_ternary_backward( module, ttnn::addcdiv_bw, - R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc"); detail::bind_ternary_backward_optional_output( module, ttnn::where_bw, - R"doc(Performs backward operations for where of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for where of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_ternary_backward_op( module, ttnn::lerp_bw, - R"doc(Performs backward operations for lerp of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`scalar` - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for lerp of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); }