From 354a8346dfea9740548b757d318c9d3cb8a6976f Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Thu, 23 May 2024 21:27:46 +0000 Subject: [PATCH] #5389: ported ttnn::log_sigmoid to C++ --- ttnn/cpp/pybind11/operations/unary.hpp | 1 + ttnn/cpp/ttnn/operations/unary.hpp | 8 ++++++-- ttnn/ttnn/operations/unary.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index cd6e6a722a8..bb169e3ea6f 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -205,6 +205,7 @@ void py_module(py::module& module) { detail::bind_unary_operation(module, ttnn::square); detail::bind_unary_operation(module, ttnn::tan); detail::bind_unary_operation(module, ttnn::tanh); + detail::bind_unary_operation(module, ttnn::log_sigmoid); // Unaries with fast_and_approximate_mode detail::bind_unary_operation_with_fast_and_approximate_mode(module, ttnn::exp); diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 774a7971d2e..4c578c88742 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -52,7 +52,7 @@ inline Tensor execute( } } // namespace detail -template +template struct ExecuteUnary { static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } @@ -62,7 +62,7 @@ struct ExecuteUnary { } static Tensor execute(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { - return detail::execute(input_tensor, {UnaryWithParam{unary_op_type}}, memory_config); + return detail::execute(input_tensor, {UnaryWithParam{unary_op_types}...}, memory_config); } }; @@ -181,6 +181,10 @@ REGISTER_UNARY_OPERATION(square, SQUARE); REGISTER_UNARY_OPERATION(tan, TAN); REGISTER_UNARY_OPERATION(tanh, TANH); +constexpr auto log_sigmoid = ttnn::register_operation>("ttnn::log_sigmoid"); + // Unaries with fast_and_approximate_mode REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(exp, EXP); REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(erf, ERF); diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 32259cda06b..1ea80f6e3a7 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -47,6 +47,7 @@ def torch_prelu(x, *args, **kwargs): "log": torch.log, "log10": torch.log10, "log2": torch.log2, + "log_sigmoid": torch.nn.functional.logsigmoid, "logical_not": torch.logical_not, "ltz": lambda x: torch.lt(x, 0), "neg": torch.neg, @@ -141,6 +142,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn._ttnn.operations.unary.leaky_relu, # ttnn._ttnn.operations.unary.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) + ttnn._ttnn.operations.unary.log_sigmoid, ttnn._ttnn.operations.unary.softplus, ] for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS: @@ -185,7 +187,6 @@ def register_ttl_unary_function(name, ttl_unary_function): "hardtanh": torch.nn.functional.hardtanh, "lgamma": torch.lgamma, "log1p": torch.log1p, - "log_sigmoid": torch.nn.functional.logsigmoid, "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), "multigammaln": torch_multigammaln, "rad2deg": torch.rad2deg, @@ -275,7 +276,6 @@ def unary_function( ("hardtanh", ttl.tensor.hardtanh), # composite ("lgamma", ttl.tensor.lgamma), # composite ("log1p", ttl.tensor.log1p), # composite - ("log_sigmoid", ttl.tensor.log_sigmoid), # composite ("mish", ttl.tensor.mish), # composite ("multigammaln", ttl.tensor.multigammaln), # composite ("rad2deg", ttl.tensor.rad2deg), # composite