Skip to content

Commit

Permalink
Restructure supported params table for ternary ops (#14992)
Browse files Browse the repository at this point in the history
### Ticket
#14980 

### Problem description
The table format needs to be updated in ternary pybind to avoid
duplicating it for each op that is added.

### What's changed
- Updated supported parameters table format
- Updated supported dtypes
- Reworded some descriptions
- Moved lerp doc from unary to ternary pointwise ops

### Checklist
- [ ] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/11815850262)
  • Loading branch information
mcw-anasuya authored Nov 15, 2024
1 parent d4e7f58 commit 34b36e9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 33 deletions.
2 changes: 1 addition & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ Pointwise Unary
ttnn.isneginf
ttnn.isposinf
ttnn.leaky_relu
ttnn.lerp
ttnn.lgamma
ttnn.log
ttnn.log10
Expand Down Expand Up @@ -373,6 +372,7 @@ Pointwise Ternary
ttnn.addcmul
ttnn.mac
ttnn.where
ttnn.lerp
ttnn.addcmul_bw
ttnn.addcdiv_bw
ttnn.where_bw
Expand Down
81 changes: 49 additions & 32 deletions ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace ternary {
namespace detail {

template <typename ternary_operation_t>
void bind_ternary_composite_float(py::module& module, const ternary_operation_t& operation, const std::string& description) {
void bind_ternary_composite_float(py::module& module, const ternary_operation_t& operation, const std::string& description, const std::string& supported_dtype = "BFLOAT16") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -34,21 +34,26 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t&
Keyword Args:
value (float, optional): Float value. Defaults to `1`.
value (float, optional): Scalar value to be multiplied.
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
Supported dtypes and layouts:
Note:
Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
.. list-table::
:header-rows: 1
Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- 2, 3, 4
bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
Example:
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
Expand All @@ -58,7 +63,8 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t&
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);
description,
supported_dtype);

bind_registered_operation(
module,
Expand All @@ -77,7 +83,7 @@ void bind_ternary_composite_float(py::module& module, const ternary_operation_t&
py::arg("input_tensor_b"),
py::arg("input_tensor_c"),
py::kw_only(),
py::arg("value") = 1.0f,
py::arg("value"),
py::arg("memory_config") = std::nullopt});
}

Expand All @@ -89,7 +95,7 @@ void bind_ternary_where(py::module& module, const ternary_operation_t& operation
Args:
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor or number): the input tensor.
input_tensor_b (ttnn.Tensor or Number): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
Expand All @@ -99,15 +105,20 @@ void bind_ternary_where(py::module& module, const ternary_operation_t& operation
queue_id (int, optional): command queue id. Defaults to `0`.
Supported dtypes and layouts:
Note:
Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
.. list-table::
:header-rows: 1
Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16, BFLOAT8_B
- TILE
- 2, 3, 4
bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
Example:
Expand Down Expand Up @@ -267,7 +278,7 @@ void bind_ternary_mac(py::module& module, const ternary_operation_t& operation,
Args:
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor or number): the input tensor.
input_tensor_b (ttnn.Tensor or Number): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
Keyword Args:
Expand All @@ -276,15 +287,20 @@ void bind_ternary_mac(py::module& module, const ternary_operation_t& operation,
Returns:
ttnn.Tensor: the output tensor.
Supported dtypes and layouts:
Note:
Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
.. list-table::
:header-rows: 1
Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16, BFLOAT8_B
- TILE
- 2, 3, 4
bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
Example:
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand Down Expand Up @@ -336,28 +352,29 @@ void py_module(py::module& module) {
detail::bind_ternary_composite_float(
module,
ttnn::addcmul,
R"doc(compute Addcmul :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
R"doc(Computes Addcmul on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_ternary_composite_float(
module,
ttnn::addcdiv,
R"doc(compute Addcdiv :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
R"doc(Computes Addcdiv on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_ternary_where(
module,
ttnn::where,
R"doc(compute Addcdiv :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
R"doc(Computes Where on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_ternary_lerp(
module,
ttnn::lerp,
R"doc(compute Lerp :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`
R"doc(Computes Lerp on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`
.. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc");

detail::bind_ternary_mac(
module,
ttnn::mac,
R"doc(compute Mac :attr:`input_tensor_a` and :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
R"doc(Computes Mac on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
}

} // namespace ternary
Expand Down

0 comments on commit 34b36e9

Please sign in to comment.