From e09b76b84e7041fbedd0bd2dde1155c346c6bdcd Mon Sep 17 00:00:00 2001 From: Eyon Date: Fri, 24 May 2024 15:30:56 +0000 Subject: [PATCH] #8658: Migrate composite unary ops to C++ --- .gitignore | 1 + tests/ttnn/unit_tests/operations/test_math.py | 15 +- ttnn/cpp/pybind11/operations/unary.hpp | 161 ++++++++++++++- ttnn/cpp/ttnn/decorators.hpp | 44 +++- ttnn/cpp/ttnn/operations/unary.hpp | 121 +++++++++++ ttnn/ttnn/operations/unary.py | 194 +++++------------- 6 files changed, 381 insertions(+), 155 deletions(-) diff --git a/.gitignore b/.gitignore index 39e91d9ebc8..1827de2f435 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ test_hlk_args_init_gen tt_build tt_debug build +build_debug /python_env/ /llk_out/ diff --git a/tests/ttnn/unit_tests/operations/test_math.py b/tests/ttnn/unit_tests/operations/test_math.py index 3150d5982f3..8fa3ba97ff9 100644 --- a/tests/ttnn/unit_tests/operations/test_math.py +++ b/tests/ttnn/unit_tests/operations/test_math.py @@ -165,6 +165,7 @@ def test_rad2deg(device, h, w): @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_cbrt(device, h, w): + torch_cbrt = ttnn.get_golden_function(ttnn.cbrt) run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999) @@ -270,22 +271,10 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0 assert_with_pcc(torch_output_tensor, output_tensor, pcc) -def torch_cbrt(x, *args, **kwargs): - return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) - - -def torch_multigammaln(x, *args, **kwargs): - result = torch.lgamma(x) - result += torch.lgamma(x - 0.5) - result += torch.lgamma(x - 1.0) - result += torch.lgamma(x - 1.5) - result += 3.434189657547 - return result - - @pytest.mark.parametrize("h", [5]) @pytest.mark.parametrize("w", [5]) def test_multigammaln(device, h, w): + torch_multigammaln = ttnn.get_golden_function(ttnn.multigammaln) run_math_unary_test_range(device, h, w, ttnn.multigammaln, torch_multigammaln, pcc=0.999) diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index a724a1944cc..c89aa207fe9 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation module, operation, doc, - ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt}); + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + const std::optional& memory_config) { return self(input_tensor, memory_config); }, + py::arg("input_tensor"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`scale` + * :attr:`shift` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config) { + return self(input_tensor, scale, shift, memory_config); + }, + py::arg("input_tensor"), + py::arg("scale")=1.0f/6.0f, + py::arg("shift")=0.5f, + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`low` + * :attr:`high` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + float low, + float high, + const std::optional& memory_config) { + return self(input_tensor, low, high, memory_config); + }, + py::arg("input_tensor"), + py::arg("low") = -1.0f, + py::arg("high") = 1.0f, + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`diag` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + int32_t diag, + const std::optional& memory_config) { return self(input_tensor, diag, memory_config); }, + py::arg("input_tensor"), + py::arg("diag") = 0, + py::kw_only(), + py::arg("memory_config") = std::nullopt}); } template @@ -229,6 +365,29 @@ void py_module(py::module& module) { // Other unaries (composite operations) detail::bind_softplus(module); + + detail::bind_unary_operation(module, ttnn::acosh); + detail::bind_unary_operation(module, ttnn::asinh); + detail::bind_unary_operation(module, ttnn::atanh); + detail::bind_unary_operation(module, ttnn::cbrt); + detail::bind_unary_operation(module, ttnn::cosh); + detail::bind_unary_operation(module, ttnn::deg2rad); + detail::bind_unary_operation(module, ttnn::digamma); + detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish); + detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid); + detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh); + detail::bind_unary_operation(module, ttnn::lgamma); + detail::bind_unary_operation(module, ttnn::log1p); + detail::bind_unary_operation(module, ttnn::mish); + detail::bind_unary_operation(module, ttnn::multigammaln); + detail::bind_unary_operation(module, ttnn::rad2deg); + detail::bind_unary_operation(module, ttnn::sigmoid_accurate); + detail::bind_unary_operation(module, ttnn::sinh); + detail::bind_unary_operation(module, ttnn::softsign); + detail::bind_unary_operation(module, ttnn::swish); + detail::bind_unary_operation(module, ttnn::tanhshrink); + detail::bind_unary_operation_with_diag(module, ttnn::tril); + detail::bind_unary_operation_with_diag(module, ttnn::triu); } } // namespace unary diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index 9216e6e2364..3ebccd3ad41 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -365,13 +365,53 @@ constexpr auto register_operation(const char* name) { return operation_t<__COUNTER__, concrete_operation_t>{name}; } -#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward(args)...); }) - template constexpr auto register_operation(const char* name, const lambda_t& lambda) { return lambda_operation_t<__COUNTER__, lambda_t>{name, lambda}; } +// This function is used to transform the arguments of a function before calling it +// where the lambda is applied to the type that matches T. +// Example: https://godbolt.org/z/3P9YedMdj +template +constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) { + auto transformer = [lambda](auto&& arg) -> decltype(auto) { + if constexpr (std::is_same_v>) { + return lambda(std::forward(arg)); + } else { + return std::forward(arg); + } + }; + + return func(transformer(std::forward(args))...); +} + +template +auto transform_first_matching_arg(Lambda lambda) { + static_assert(!std::is_same::value, "No matching type found"); +} + +template +auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) { + if constexpr (std::is_same_v>) { + return lambda(std::forward(first)); + } else { + return transform_first_matching_arg(lambda, std::forward(rest)...); + } +} + +#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward(args)...); }) + +#define TO_LAMBDA_WITH_RESHAPE(function) \ + ([](auto&&... args) { \ + const auto original_shape = ttnn::decorators::transform_first_matching_arg( \ + [&](auto&& tensor) { return tensor.get_shape(); }, std::forward(args)...); \ + return ttnn::reshape( \ + ttnn::decorators::transform_args_lambda( \ + function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \ + original_shape); \ + }) + } // namespace decorators using ttnn::decorators::register_operation; diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 0c28b946a23..496561d0a09 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -5,6 +5,7 @@ #pragma once #include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" +#include "tt_eager/tt_dnn/op_library/downsample/downsample_op.hpp" #include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" #include "ttnn/decorators.hpp" @@ -128,6 +129,99 @@ struct Softplus { input, {UnaryWithParam{ttnn::operations::unary::UnaryOpType::SOFTPLUS, {beta, threshold}}}, memory_config, optional_output_tensor); } }; + +// TODO: update these composite unary ops pending decision on TensorAsync implementation. + +Tensor acosh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::acosh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor asinh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::asinh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor atanh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::atanh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor cbrt(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::cbrt(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor cosh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::cosh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor deg2rad(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::deg2rad(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor digamma(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::digamma(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor hardswish( + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config = std::nullopt) { + //return tt::tt_metal::hardswish(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config())); + return tt::tt_metal::hardswish(input_tensor, scale, shift, memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG)); +} +Tensor hardsigmoid( + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config = std::nullopt) { + //return tt::tt_metal::hardsigmoid(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config())); + return tt::tt_metal::hardsigmoid(input_tensor, scale, shift, memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG)); +} +Tensor hardtanh( + const Tensor& input_tensor, + float low /* = -1.0f */, + float high /* = +1.0f */, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::hardtanh(input_tensor, low, high, memory_config.value_or(input_tensor.memory_config())); +} +Tensor lgamma(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::lgamma(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor log1p(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::log1p(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor mish(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::mish(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor multigammaln(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::multigammaln(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor rad2deg(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::rad2deg(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor sigmoid_accurate(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::sigmoid_accurate(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor sinh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::sinh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor softsign(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::softsign(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor swish(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::swish(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor tanhshrink(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::tanhshrink(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor tril( + const Tensor& input_tensor, + int32_t diag=0, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::tril(input_tensor, diag, memory_config.value_or(input_tensor.memory_config())); +} +Tensor triu( + const Tensor& input_tensor, + int32_t diag=0, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::triu(input_tensor, diag, memory_config.value_or(input_tensor.memory_config())); +} + } // namespace unary } // namespace operations @@ -203,4 +297,31 @@ auto prelu = leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU // Other unaries constexpr auto softplus = ttnn::register_operation("ttnn::softplus"); +constexpr auto acosh = ttnn::register_operation("ttnn::acosh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::acosh)); +constexpr auto asinh = ttnn::register_operation("ttnn::asinh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::asinh)); +constexpr auto atanh = ttnn::register_operation("ttnn::atanh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::atanh)); +constexpr auto cbrt = ttnn::register_operation("ttnn::cbrt", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::cbrt)); +constexpr auto cosh = ttnn::register_operation("ttnn::cosh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::cosh)); +constexpr auto deg2rad = ttnn::register_operation("ttnn::deg2rad", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::deg2rad)); +constexpr auto digamma = ttnn::register_operation("ttnn::digamma", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::digamma)); +constexpr auto hardswish = ttnn::register_operation("ttnn::hardswish", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::hardswish)); +constexpr auto hardsigmoid = + ttnn::register_operation("ttnn::hardsigmoid", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::hardsigmoid)); +constexpr auto hardtanh = ttnn::register_operation("ttnn::hardtanh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::hardtanh)); +constexpr auto lgamma = ttnn::register_operation("ttnn::lgamma", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::lgamma)); +constexpr auto log1p = ttnn::register_operation("ttnn::log1p", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::log1p)); +constexpr auto mish = ttnn::register_operation("ttnn::mish", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::mish)); +constexpr auto multigammaln = + ttnn::register_operation("ttnn::multigammaln", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::multigammaln)); +constexpr auto rad2deg = ttnn::register_operation("ttnn::rad2deg", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::rad2deg)); +constexpr auto sigmoid_accurate = + ttnn::register_operation("ttnn::sigmoid_accurate", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::sigmoid_accurate)); +constexpr auto sinh = ttnn::register_operation("ttnn::sinh", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::sinh)); +constexpr auto softsign = ttnn::register_operation("ttnn::softsign", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::softsign)); +constexpr auto swish = ttnn::register_operation("ttnn::swish", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::swish)); +constexpr auto tanhshrink = + ttnn::register_operation("ttnn::tanhshrink", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::tanhshrink)); +constexpr auto tril = ttnn::register_operation("ttnn::tril", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::tril)); +constexpr auto triu = ttnn::register_operation("ttnn::triu", TO_LAMBDA_WITH_RESHAPE(ttnn::operations::unary::triu)); + } // namespace ttnn diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 1ea80f6e3a7..2dee96359dd 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -20,6 +20,17 @@ def torch_heaviside(x, *args, **kwargs): result = torch.heaviside(x, torch.tensor(value, dtype=x.dtype)) return result + def torch_cbrt(x, *args, **kwargs): + return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) + + def torch_multigammaln(x, *args, **kwargs): + result = torch.lgamma(x) + result += torch.lgamma(x - 0.5) + result += torch.lgamma(x - 1.0) + result += torch.lgamma(x - 1.5) + result += 3.434189657547 + return result + def torch_prelu(x, *args, **kwargs): weight = kwargs.pop("scalar") result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype)) @@ -77,6 +88,28 @@ def torch_prelu(x, *args, **kwargs): # "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) "softplus": torch.nn.functional.softplus, + "acosh": torch.acosh, + "asinh": torch.asinh, + "atanh": torch.atanh, + "cbrt": torch_cbrt, + "cosh": torch.cosh, + "deg2rad": torch.deg2rad, + "digamma": torch.digamma, + "hardswish": torch.nn.functional.hardswish, + "hardsigmoid": torch.nn.functional.hardsigmoid, + "hardtanh": torch.nn.functional.hardtanh, + "lgamma": torch.lgamma, + "log1p": torch.log1p, + "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), + "multigammaln": torch_multigammaln, + "rad2deg": torch.rad2deg, + "sigmoid_accurate": torch.sigmoid, + "sinh": torch.sinh, + "softsign": torch.nn.functional.softsign, + "swish": torch.nn.functional.hardswish, + "tanhshrink": ttl.tensor.tanhshrink, + "tril": torch.tril, + "triu": torch.triu, } golden_keys = set(name_to_golden_function.keys()) @@ -144,6 +177,28 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): # Other unaries (composite operations) ttnn._ttnn.operations.unary.log_sigmoid, ttnn._ttnn.operations.unary.softplus, + ttnn._ttnn.operations.unary.acosh, + ttnn._ttnn.operations.unary.asinh, + ttnn._ttnn.operations.unary.atanh, + ttnn._ttnn.operations.unary.cbrt, + ttnn._ttnn.operations.unary.cosh, + ttnn._ttnn.operations.unary.deg2rad, + ttnn._ttnn.operations.unary.digamma, + ttnn._ttnn.operations.unary.hardswish, + ttnn._ttnn.operations.unary.hardsigmoid, + ttnn._ttnn.operations.unary.hardtanh, + ttnn._ttnn.operations.unary.lgamma, + ttnn._ttnn.operations.unary.log1p, + ttnn._ttnn.operations.unary.mish, + ttnn._ttnn.operations.unary.multigammaln, + ttnn._ttnn.operations.unary.rad2deg, + ttnn._ttnn.operations.unary.sigmoid_accurate, + ttnn._ttnn.operations.unary.sinh, + ttnn._ttnn.operations.unary.softsign, + ttnn._ttnn.operations.unary.swish, + ttnn._ttnn.operations.unary.tanhshrink, + ttnn._ttnn.operations.unary.tril, + ttnn._ttnn.operations.unary.triu, ] for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS: register_ttnn_cpp_unary_function(unary_function) @@ -154,145 +209,6 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL return leaky_relu(*args, **kwargs) -def torch_cbrt(x, *args, **kwargs): - import torch - - return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) - - -def torch_multigammaln(x, *args, **kwargs): - import torch - - result = torch.lgamma(x) - result += torch.lgamma(x - 0.5) - result += torch.lgamma(x - 1.0) - result += torch.lgamma(x - 1.5) - result += 3.434189657547 - return result - - -def register_ttl_unary_function(name, ttl_unary_function): - import torch - - name_to_golden_function = { - "acosh": torch.acosh, - "asinh": torch.asinh, - "atanh": torch.atanh, - "cbrt": torch_cbrt, - "cosh": torch.cosh, - "deg2rad": torch.deg2rad, - "digamma": torch.digamma, - "hardswish": torch.nn.functional.hardswish, - "hardsigmoid": torch.nn.functional.hardsigmoid, - "hardtanh": torch.nn.functional.hardtanh, - "lgamma": torch.lgamma, - "log1p": torch.log1p, - "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), - "multigammaln": torch_multigammaln, - "rad2deg": torch.rad2deg, - "sigmoid_accurate": torch.sigmoid, - "sinh": torch.sinh, - "softsign": torch.nn.functional.softsign, - "swish": torch.nn.functional.hardswish, - "tanhshrink": ttl.tensor.tanhshrink, - "tril": torch.tril, - "triu": torch.triu, - } - - golden_keys = set(name_to_golden_function.keys()) - function_names = {name for name, _ in TTL_UNARY_FUNCTIONS} - if golden_keys != function_names: - raise ImportError(f"Missing or extra golden functions:\n{golden_keys}\nshould be equal to\n{function_names}") - - def _golden_function(input_tensor: ttnn.Tensor, **_): - torch_function = name_to_golden_function[name] - return torch_function(input_tensor) - - def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), - can_be_on_device=True, - can_be_on_cpu=False, - ) - - @ttnn.register_operation( - name=f"ttnn.{name}", - validate_input_tensors=_unary_validate_input_tensors, - golden_function=_golden_function, - ) - def unary_function( - input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG - ) -> ttnn.Tensor: - original_shape = input_tensor.shape - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - - if not isinstance(input_tensor, ttnn.Tensor): - raise TypeError("Expected first argument to be a ttnn.Tensor") - - if not ttnn.is_tensor_storage_on_device(input_tensor): - raise RuntimeError("input_tensor must be on device!") - - output_tensor = ttl_unary_function(input_tensor, output_mem_config=memory_config) - output_tensor = ttnn.reshape(output_tensor, original_shape) - return output_tensor - - if isinstance(unary_function, ttnn.decorators.Operation): - unary_function.__name__ = f"ttnn.{name}" - unary_function.decorated_function.__doc__ = f"""{name}(input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor - - Applies {name} to :attr:`input_tensor` element-wise. - - .. math:: - {name.replace('_',' ')}(\\mathrm{{input\\_tensor}}_i) - - Args: - * :attr:`input_tensor` - - Example:: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.{name}(tensor) - - {unary_function.__doc__} - - """ - setattr(THIS_MODULE, name, unary_function) - - -TTL_UNARY_FUNCTIONS = [ - ("acosh", ttl.tensor.acosh), # composite - ("asinh", ttl.tensor.asinh), # composite - ("atanh", ttl.tensor.atanh), # composite - ("cbrt", ttl.tensor.cbrt), # composite - ("cosh", ttl.tensor.cosh), # composite - ("deg2rad", ttl.tensor.deg2rad), # composite - ("digamma", ttl.tensor.digamma), # composite - ("hardswish", ttl.tensor.hardswish), # composite - ("hardsigmoid", ttl.tensor.hardsigmoid), # composite - ("hardtanh", ttl.tensor.hardtanh), # composite - ("lgamma", ttl.tensor.lgamma), # composite - ("log1p", ttl.tensor.log1p), # composite - ("mish", ttl.tensor.mish), # composite - ("multigammaln", ttl.tensor.multigammaln), # composite - ("rad2deg", ttl.tensor.rad2deg), # composite - ("sigmoid_accurate", ttl.tensor.sigmoid_accurate), # composite - ("sinh", ttl.tensor.sinh), # composite - ("softsign", ttl.tensor.softsign), # composite - ("swish", ttl.tensor.swish), # composite - ("tanhshrink", ttl.tensor.tanhshrink), # composite - ("tril", ttl.tensor.tril), # composite - ("triu", ttl.tensor.triu), # composite -] - - -for unary_function_name, ttl_unary_function in TTL_UNARY_FUNCTIONS: - register_ttl_unary_function(unary_function_name, ttl_unary_function) - - def _is_scalar(value): return isinstance(value, (int, float))