Skip to content

Commit

Permalink
Restructure supported params table for ternary bw ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya committed Nov 13, 2024
1 parent a5d9979 commit 5104792
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 57 deletions.
2 changes: 1 addition & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ Pointwise Binary
ttnn.rsub_bw
ttnn.min_bw
ttnn.max_bw
ttnn.lerp_bw

Pointwise Ternary
=================
Expand All @@ -376,6 +375,7 @@ Pointwise Ternary
ttnn.addcmul_bw
ttnn.addcdiv_bw
ttnn.where_bw
ttnn.lerp_bw

Losses
======
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace ternary_backward {
namespace detail {

template <typename ternary_backward_operation_t>
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype ="BFLOAT16") {
auto doc = fmt::format(
R"doc(
Expand All @@ -30,7 +30,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
grad_tensor (ttnn.Tensor): the input gradient tensor.
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
input_tensor_c (ttnn.Tensor): the input tensor.
alpha (float): the alpha value.
Keyword args:
Expand All @@ -40,7 +40,17 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
List of ttnn.Tensor: the output tensor.
Note:
{3}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
Example:
Expand Down Expand Up @@ -81,7 +91,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -103,10 +113,33 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
Note:
{3}
Supported dtypes, layouts, and ranks:
For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c`
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 2, 3, 4
Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`scalar`
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
Expand Down Expand Up @@ -162,7 +195,11 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward_optional_output(
py::module& module,
const ternary_backward_operation_t& operation,
const std::string_view description,
const std::string& supported_dtype ="BFLOAT16") {
auto doc = fmt::format(
R"doc(
Expand All @@ -185,7 +222,18 @@ void bind_ternary_backward_optional_output(py::module& module, const ternary_bac
Note:
{3}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
Example:
Expand Down Expand Up @@ -236,66 +284,25 @@ void py_module(py::module& module) {
detail::bind_ternary_backward(
module,
ttnn::addcmul_bw,
R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_ternary_backward(
module,
ttnn::addcdiv_bw,
R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc");

detail::bind_ternary_backward_optional_output(
module,
ttnn::where_bw,
R"doc(Performs backward operations for where of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for where of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_ternary_backward_op(
module,
ttnn::lerp_bw,
R"doc(Performs backward operations for lerp of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`scalar`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for lerp of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

}

Expand Down

0 comments on commit 5104792

Please sign in to comment.