Skip to content

Commit

Permalink
Restructure supported params table for ternary backward ops (#14994)
Browse files Browse the repository at this point in the history
### Ticket
#14981 

### Problem description
The table format needs to be updated in ternary backward pybind to avoid
duplicating it for each op that is added.

### What's changed
- Updated supported parameters table format
- Updated supported dtypes
- Moved lerp_bw doc from binary to ternary point wise ops

### Checklist
- [ ] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/11816492181)
- same as main
  • Loading branch information
mcw-anasuya authored Nov 15, 2024
1 parent 8c64dcb commit b9ef431
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 64 deletions.
2 changes: 1 addition & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ Pointwise Binary
ttnn.rsub_bw
ttnn.min_bw
ttnn.max_bw
ttnn.lerp_bw

Pointwise Ternary
=================
Expand All @@ -376,6 +375,7 @@ Pointwise Ternary
ttnn.addcmul_bw
ttnn.addcdiv_bw
ttnn.where_bw
ttnn.lerp_bw

Losses
======
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace ternary_backward {
namespace detail {

template <typename ternary_backward_operation_t>
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype ="BFLOAT16") {
auto doc = fmt::format(
R"doc(
Expand All @@ -30,17 +30,27 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
grad_tensor (ttnn.Tensor): the input gradient tensor.
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
input_tensor_c (ttnn.Tensor): the input tensor.
alpha (float): the alpha value.
Keyword args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{3}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
Example:
Expand Down Expand Up @@ -81,32 +91,55 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{2}
Args:
grad_tensor (ttnn.Tensor): the input tensor.
grad_tensor (ttnn.Tensor): the input gradient tensor.
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
Keyword args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{3}
Supported dtypes, layouts, and ranks:
For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c`
.. list-table::
:header-rows: 1
Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 2, 3, 4
For Inputs : :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`scalar`
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
Expand Down Expand Up @@ -162,30 +195,45 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
void bind_ternary_backward_optional_output(
py::module& module,
const ternary_backward_operation_t& operation,
const std::string_view description,
const std::string& supported_dtype ="BFLOAT16") {
auto doc = fmt::format(
R"doc(
{2}
Args:
grad_tensor (ttnn.Tensor): the input tensor.
grad_tensor (ttnn.Tensor): the input gradient tensor.
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor): the input tensor.
input_tensor_c (ttnn.Tensor): the input tensor.
Keyword args:
are_required_outputs (List[bool], optional): List of required outputs. Defaults to `[True, True]`.
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to `None`.
are_required_outputs (List[bool], optional): list of required outputs. Defaults to `[True, True]`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (int, optional): command queue id. Defaults to `0`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{3}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
Example:
Expand Down Expand Up @@ -236,66 +284,25 @@ void py_module(py::module& module) {
detail::bind_ternary_backward(
module,
ttnn::addcmul_bw,
R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_ternary_backward(
module,
ttnn::addcdiv_bw,
R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc");

detail::bind_ternary_backward_optional_output(
module,
ttnn::where_bw,
R"doc(Performs backward operations for where of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for where of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_ternary_backward_op(
module,
ttnn::lerp_bw,
R"doc(Performs backward operations for lerp of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`scalar`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for lerp of :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

}

Expand Down

0 comments on commit b9ef431

Please sign in to comment.