Skip to content

Commit

Permalink
#13929: Update document and sweep test file
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Nov 14, 2024
1 parent 0b2ff25 commit 41cf35b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ on:
- eltwise.binary.ne.ne_scalar_pytorch2
- eltwise.binary.hypot.hypot
- eltwise.binary.xlogy.xlogy
- eltwise.binary_backward.ldexp_bw
- eltwise.binary_backward.ldexp_bw.ldexp_bw
- eltwise.binary_backward.logaddexp_bw
- eltwise.binary_backward.logaddexp2_bw
- eltwise.binary_backward.addalpha_bw.addalpha_bw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace binary_backward {
namespace detail {

template <typename binary_backward_operation_t>
void bind_binary_backward_ops(py::module& module, const binary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16") {
void bind_binary_backward_ops(py::module& module, const binary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "BFLOAT16", const std::string_view note = "") {
auto doc = fmt::format(
R"doc(
{2}
Expand Down Expand Up @@ -53,6 +53,8 @@ void bind_binary_backward_ops(py::module& module, const binary_backward_operatio
bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
{4}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand All @@ -65,7 +67,8 @@ void bind_binary_backward_ops(py::module& module, const binary_backward_operatio
operation.base_name(),
operation.python_fully_qualified_name(),
description,
supported_dtype);
supported_dtype,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -1110,7 +1113,7 @@ void py_module(py::module& module) {
module,
ttnn::ldexp_bw,
R"doc(Performs backward operations for ldexp of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc");
R"doc(BFLOAT16)doc", R"doc(Recommended input range : [-80, 80]. Performance of the PCC may degrade if the input falls outside this range.)doc");


detail::bind_binary_backward_ops(
Expand Down

0 comments on commit 41cf35b

Please sign in to comment.