From 510479216f58fb05378123b1b9abd79c11933277 Mon Sep 17 00:00:00 2001 From: mcw-anasuya Date: Wed, 13 Nov 2024 09:39:56 +0000 Subject: [PATCH] Restructure supported params table for ternary bw ops --- docs/source/ttnn/ttnn/api.rst | 2 +- .../ternary_backward_pybind.hpp | 119 +++++++++--------- 2 files changed, 64 insertions(+), 57 deletions(-) diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 420d901d55f..10f5937450d 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -359,7 +359,6 @@ Pointwise Binary ttnn.rsub_bw ttnn.min_bw ttnn.max_bw - ttnn.lerp_bw Pointwise Ternary ================= @@ -376,6 +375,7 @@ Pointwise Ternary ttnn.addcmul_bw ttnn.addcdiv_bw ttnn.where_bw + ttnn.lerp_bw Losses ====== diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp index f17839c8461..dc8288d06fb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp @@ -20,7 +20,7 @@ namespace ternary_backward { namespace detail { template -void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype ="BFLOAT16") { auto doc = fmt::format( R"doc( @@ -30,7 +30,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_ grad_tensor (ttnn.Tensor): the input gradient tensor. input_tensor_a (ttnn.Tensor): the input tensor. input_tensor_b (ttnn.Tensor): the input tensor. - input_tensor_c (ttnn.Tensor or Number): the input tensor. + input_tensor_c (ttnn.Tensor): the input tensor. alpha (float): the alpha value. Keyword args: @@ -40,7 +40,17 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_ List of ttnn.Tensor: the output tensor. Note: - {3} + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 Example: @@ -81,7 +91,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_ } template -void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16") { auto doc = fmt::format( R"doc( {2} @@ -103,10 +113,33 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati Note: - {3} + Supported dtypes, layouts, and ranks: + + For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` + + .. list-table:: + :header-rows: 1 + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16 + - TILE + - 2, 3, 4 - Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`scalar` + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 + + bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT Example: @@ -162,7 +195,11 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati } template -void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { +void bind_ternary_backward_optional_output( + py::module& module, + const ternary_backward_operation_t& operation, + const std::string_view description, + const std::string& supported_dtype ="BFLOAT16") { auto doc = fmt::format( R"doc( @@ -185,7 +222,18 @@ void bind_ternary_backward_optional_output(py::module& module, const ternary_bac Note: - {3} + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - 2, 3, 4 + Example: @@ -236,66 +284,25 @@ void py_module(py::module& module) { detail::bind_ternary_backward( module, ttnn::addcmul_bw, - R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_ternary_backward( module, ttnn::addcdiv_bw, - R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc"); detail::bind_ternary_backward_optional_output( module, ttnn::where_bw, - R"doc(Performs backward operations for where of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for where of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_ternary_backward_op( module, ttnn::lerp_bw, - R"doc(Performs backward operations for lerp of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc", - R"doc(Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16 | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`scalar` - - +----------------------------+---------------------------------+-------------------+ - | Dtypes | Layouts | Ranks | - +----------------------------+---------------------------------+-------------------+ - | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | - +----------------------------+---------------------------------+-------------------+ - - )doc"); + R"doc(Performs backward operations for lerp of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc", + R"doc(BFLOAT16, BFLOAT8_B)doc"); }