Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8721: Add forward support for lei #8722

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ Pointwise Binary
ttnn/ge
ttnn/lt
ttnn/le
ttnn/le_
ttnn/eq
ttnn/ne
ttnn/isclose
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/le_.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.le_:

ttnn.le_
#########

.. autofunction:: ttnn.le_
41 changes: 41 additions & 0 deletions tests/ttnn/unit_tests/operations/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,44 @@ def test_binary_polyval_ttnn(input_shapes, coeffs, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_binary_lei_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
in_data2, input_tensor2 = data_gen_with_range(input_shapes, -150, 150, device)
ttnn.le_(input_tensor1, input_tensor2)
golden_function = ttnn.get_golden_function(ttnn.le_)
golden_tensor = golden_function(in_data1, in_data2)

comp_pass = compare_pcc([input_tensor1], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"scalar",
{random.randint(-100, 100) + 0.5 for _ in range(5)},
)
def test_lei_ttnn(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device)
ttnn.le_(input_tensor, scalar)
golden_function = ttnn.get_golden_function(ttnn.le_)
golden_tensor = golden_function(in_data, scalar)

comp_pass = compare_pcc([input_tensor], [golden_tensor])
assert comp_pass
14 changes: 14 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ struct ExecuteBinaryRemainder
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};

struct ExecuteBinaryLE
{
static Tensor operator()(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b);

static Tensor operator()(
const Tensor& input_tensor,
float scalar);
};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -254,5 +265,8 @@ constexpr auto outer = ttnn::register_operation_with_auto_launch_op<
constexpr auto polyval = ttnn::register_operation_with_auto_launch_op<
"ttnn::polyval",
operations::binary::ExecuteBinaryCompositeOpsPolyval<operations::binary::BinaryCompositeOpType::POLYVAL>>();
constexpr auto le_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::le_",
operations::binary::ExecuteBinaryLE>();

} // namespace ttnn
57 changes: 53 additions & 4 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,13 @@ void bind_polyval(py::module& module, const binary_operation_t& operation, const
template <typename binary_operation_t>
void bind_binary_overload_operation(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: ttnn.Tensor, input_tensor_b:Union[ttnn.Tensor, int], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor: ttnn.Tensor, other:Union[ttnn.Tensor, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

{2}

Args:
* :attr:`input_tensor_a`
* :attr:`input_tensor_b` (ttnn.Tensor or Number)
* :attr:`input_tensor`
* :attr:`other` (ttnn.Tensor or Number)

Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Expand Down Expand Up @@ -431,11 +431,54 @@ void bind_binary_overload_operation(py::module& module, const binary_operation_t
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, input_tensor_b, memory_config); },
py::arg("input_tensor_a"),
py::arg("inputr_tensor_b"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename binary_operation_t>
void bind_inplace_operation(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, other:Union[ttnn.Tensor, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

{2}

Args:
* :attr:`input_tensor`
* :attr:`other` (ttnn.Tensor or Number)
*
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor1, tensor2)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,

//tensor and scalar
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
float scalar) {
return self(input_tensor, scalar); },
py::arg("input_tensor"),
py::arg("scalar")},

//tensor and tensor
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b) {
return self(input_tensor_a, input_tensor_b); },
py::arg("input_tensor_a"),
py::arg("input_tensor_b")});
}

} // namespace detail

void py_module(py::module& module) {
Expand Down Expand Up @@ -685,6 +728,12 @@ void py_module(py::module& module) {
ttnn::remainder,
R"doc(Perform an eltwise-modulus operation a - a.div(b, rounding_mode=floor) * b.", "Support provided only for WH_B0.)doc");

detail::bind_inplace_operation(
module,
ttnn::le_,
R"doc(Perform Less than or equal to in-place operation on :attr:`input_tensor` and :attr:`other` and returns the tensor with the same layout as :attr:`input_tensor`
.. math:: \mathrm{{input\_tensor}}_i <= \mathrm{{other}}_i)doc");


}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,13 @@ Tensor _polyval(const Tensor& input_a, const std::vector<float>& coeffs, const s
return final_tensor;
}

Tensor ExecuteBinaryLE::operator()(const Tensor& input_a, const Tensor& input_b) {
return ttnn::le(input_a, input_b, std::nullopt, std::nullopt, input_a);
}


Tensor ExecuteBinaryLE::operator()(const Tensor& input, float scalar) {
return ttnn::le(input, scalar, std::nullopt, std::nullopt, input);
}

} // namespace ttnn::operations::binary
9 changes: 9 additions & 0 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,4 +385,13 @@ def _golden_function_polyval(input_tensor_a, coeffs, *args, **kwargs):
ttnn.attach_golden_function(ttnn._ttnn.operations.binary.polyval, golden_function=_golden_function_polyval)


def _golden_function_le_(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

return input_tensor_a.le_(input_tensor_b)


ttnn.attach_golden_function(ttnn._ttnn.operations.binary.le_, golden_function=_golden_function_le_)


__all__ = []
Loading