From db5c8e232d5c7b7f660debc4269df19effa1531b Mon Sep 17 00:00:00 2001 From: Eyon Date: Fri, 24 May 2024 15:30:56 +0000 Subject: [PATCH] #8658: Migrate composite unary ops to C++ --- tests/ttnn/unit_tests/operations/test_math.py | 14 +- ttnn/cpp/pybind11/operations/unary.hpp | 161 ++++++++- ttnn/cpp/ttnn/operations/unary.hpp | 118 +++++++ ttnn/ttnn/operations/unary.py | 312 ++++++++++-------- 4 files changed, 454 insertions(+), 151 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_math.py b/tests/ttnn/unit_tests/operations/test_math.py index c1cf8198b43..fe89298d072 100644 --- a/tests/ttnn/unit_tests/operations/test_math.py +++ b/tests/ttnn/unit_tests/operations/test_math.py @@ -108,6 +108,7 @@ def test_rad2deg(device, h, w): @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_cbrt(device, h, w): + torch_cbrt = ttnn.get_golden_function(ttnn.cbrt) run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999) @@ -213,19 +214,6 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0 assert_with_pcc(torch_output_tensor, output_tensor, pcc) -def torch_cbrt(x, *args, **kwargs): - return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) - - -def torch_multigammaln(x, *args, **kwargs): - result = torch.lgamma(x) - result += torch.lgamma(x - 0.5) - result += torch.lgamma(x - 1.0) - result += torch.lgamma(x - 1.5) - result += 3.434189657547 - return result - - @pytest.mark.parametrize("h", [5]) @pytest.mark.parametrize("w", [5]) def test_multigammaln(device, h, w): diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index bb169e3ea6f..46810de016c 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation module, operation, doc, - ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt}); + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + const std::optional& memory_config) { return self(input_tensor, memory_config); }, + py::arg("input_tensor"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`scale` + * :attr:`shift` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config) { + return self(input_tensor, scale, shift, memory_config); + }, + py::arg("input_tensor"), + py::arg("scale"), + py::arg("shift"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`low` + * :attr:`high` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + float low, + float high, + const std::optional& memory_config) { + return self(input_tensor, low, high, memory_config); + }, + py::arg("input_tensor"), + py::arg("low"), + py::arg("high"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + +template +void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + * :attr:`diag` + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + operation.name(), + operation.python_fully_qualified_name()); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + int32_t diag, + const std::optional& memory_config) { return self(input_tensor, diag, memory_config); }, + py::arg("input_tensor"), + py::arg("diag"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}); } template @@ -226,6 +362,29 @@ void py_module(py::module& module) { // Other unaries (composite operations) detail::bind_softplus(module); + + detail::bind_unary_operation(module, ttnn::acosh); + detail::bind_unary_operation(module, ttnn::asinh); + detail::bind_unary_operation(module, ttnn::atanh); + detail::bind_unary_operation(module, ttnn::cbrt); + detail::bind_unary_operation(module, ttnn::cosh); + detail::bind_unary_operation(module, ttnn::deg2rad); + detail::bind_unary_operation(module, ttnn::digamma); + detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish); + detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid); + detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh); + detail::bind_unary_operation(module, ttnn::lgamma); + detail::bind_unary_operation(module, ttnn::log1p); + detail::bind_unary_operation(module, ttnn::mish); + detail::bind_unary_operation(module, ttnn::multigammaln); + detail::bind_unary_operation(module, ttnn::rad2deg); + detail::bind_unary_operation(module, ttnn::sigmoid_accurate); + detail::bind_unary_operation(module, ttnn::sinh); + detail::bind_unary_operation(module, ttnn::softsign); + detail::bind_unary_operation(module, ttnn::swish); + detail::bind_unary_operation(module, ttnn::tanhshrink); + detail::bind_unary_operation_with_diag(module, ttnn::tril); + detail::bind_unary_operation_with_diag(module, ttnn::triu); } } // namespace unary diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index a5a21b4e539..99af645557b 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -127,6 +127,97 @@ struct Softplus { return result; } }; + +// TODO: update these composite unary ops pending decision on TensorAsync implementation. + +Tensor acosh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::acosh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor asinh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::asinh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor atanh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::atanh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} + +Tensor cbrt(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::cbrt(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor cosh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::cosh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor deg2rad(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::deg2rad(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor digamma(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::digamma(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor hardswish( + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::hardswish(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config())); +} +Tensor hardsigmoid( + const Tensor& input_tensor, + float scale, + float shift, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::hardsigmoid(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config())); +} +Tensor hardtanh( + const Tensor& input_tensor, + float low /* = -1.0f */, + float high /* = +1.0f */, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::hardtanh(input_tensor, low, high, memory_config.value_or(input_tensor.memory_config())); +} +Tensor lgamma(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::lgamma(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor log1p(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::log1p(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor mish(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::mish(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor multigammaln(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::multigammaln(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor rad2deg(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::rad2deg(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor sigmoid_accurate(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::sigmoid_accurate(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor sinh(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::sinh(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor softsign(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::softsign(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor swish(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::swish(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor tanhshrink(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::tanhshrink(input_tensor, memory_config.value_or(input_tensor.memory_config())); +} +Tensor tril( + const Tensor& input_tensor, + int32_t diag /* = -1 */, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::tril(input_tensor, diag, memory_config.value_or(input_tensor.memory_config())); +} +Tensor triu( + const Tensor& input_tensor, + int32_t diag /* = -1 */, + const std::optional& memory_config = std::nullopt) { + return tt::tt_metal::triu(input_tensor, diag, memory_config.value_or(input_tensor.memory_config())); +} + } // namespace unary } // namespace operations @@ -202,4 +293,31 @@ auto prelu = leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU // Other unaries (composite operations) constexpr auto softplus = ttnn::register_operation("ttnn::softplus"); +constexpr auto acosh = ttnn::register_operation("ttnn::acosh", TO_LAMBDA(ttnn::operations::unary::acosh)); +constexpr auto asinh = ttnn::register_operation("ttnn::asinh", TO_LAMBDA(ttnn::operations::unary::asinh)); +constexpr auto atanh = ttnn::register_operation("ttnn::atanh", TO_LAMBDA(ttnn::operations::unary::atanh)); +constexpr auto cbrt = ttnn::register_operation("ttnn::cbrt", TO_LAMBDA(ttnn::operations::unary::cbrt)); +constexpr auto cosh = ttnn::register_operation("ttnn::cosh", TO_LAMBDA(ttnn::operations::unary::cosh)); +constexpr auto deg2rad = ttnn::register_operation("ttnn::deg2rad", TO_LAMBDA(ttnn::operations::unary::deg2rad)); +constexpr auto digamma = ttnn::register_operation("ttnn::digamma", TO_LAMBDA(ttnn::operations::unary::digamma)); +constexpr auto hardswish = ttnn::register_operation("ttnn::hardswish", TO_LAMBDA(ttnn::operations::unary::hardswish)); +constexpr auto hardsigmoid = + ttnn::register_operation("ttnn::hardsigmoid", TO_LAMBDA(ttnn::operations::unary::hardsigmoid)); +constexpr auto hardtanh = ttnn::register_operation("ttnn::hardtanh", TO_LAMBDA(ttnn::operations::unary::hardtanh)); +constexpr auto lgamma = ttnn::register_operation("ttnn::lgamma", TO_LAMBDA(ttnn::operations::unary::lgamma)); +constexpr auto log1p = ttnn::register_operation("ttnn::log1p", TO_LAMBDA(ttnn::operations::unary::log1p)); +constexpr auto mish = ttnn::register_operation("ttnn::mish", TO_LAMBDA(ttnn::operations::unary::mish)); +constexpr auto multigammaln = + ttnn::register_operation("ttnn::multigammaln", TO_LAMBDA(ttnn::operations::unary::multigammaln)); +constexpr auto rad2deg = ttnn::register_operation("ttnn::rad2deg", TO_LAMBDA(ttnn::operations::unary::rad2deg)); +constexpr auto sigmoid_accurate = + ttnn::register_operation("ttnn::sigmoid_accurate", TO_LAMBDA(ttnn::operations::unary::sigmoid_accurate)); +constexpr auto sinh = ttnn::register_operation("ttnn::sinh", TO_LAMBDA(ttnn::operations::unary::sinh)); +constexpr auto softsign = ttnn::register_operation("ttnn::softsign", TO_LAMBDA(ttnn::operations::unary::softsign)); +constexpr auto swish = ttnn::register_operation("ttnn::swish", TO_LAMBDA(ttnn::operations::unary::swish)); +constexpr auto tanhshrink = + ttnn::register_operation("ttnn::tanhshrink", TO_LAMBDA(ttnn::operations::unary::tanhshrink)); +constexpr auto tril = ttnn::register_operation("ttnn::tril", TO_LAMBDA(ttnn::operations::unary::tril)); +constexpr auto triu = ttnn::register_operation("ttnn::triu", TO_LAMBDA(ttnn::operations::unary::triu)); + } // namespace ttnn diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 1ea80f6e3a7..3910d294439 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -20,6 +20,17 @@ def torch_heaviside(x, *args, **kwargs): result = torch.heaviside(x, torch.tensor(value, dtype=x.dtype)) return result + def torch_cbrt(x, *args, **kwargs): + return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) + + def torch_multigammaln(x, *args, **kwargs): + result = torch.lgamma(x) + result += torch.lgamma(x - 0.5) + result += torch.lgamma(x - 1.0) + result += torch.lgamma(x - 1.5) + result += 3.434189657547 + return result + def torch_prelu(x, *args, **kwargs): weight = kwargs.pop("scalar") result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype)) @@ -77,6 +88,28 @@ def torch_prelu(x, *args, **kwargs): # "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) "softplus": torch.nn.functional.softplus, + "acosh": torch.acosh, + "asinh": torch.asinh, + "atanh": torch.atanh, + "cbrt": torch_cbrt, + "cosh": torch.cosh, + "deg2rad": torch.deg2rad, + "digamma": torch.digamma, + "hardswish": torch.nn.functional.hardswish, + "hardsigmoid": torch.nn.functional.hardsigmoid, + "hardtanh": torch.nn.functional.hardtanh, + "lgamma": torch.lgamma, + "log1p": torch.log1p, + "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), + "multigammaln": torch_multigammaln, + "rad2deg": torch.rad2deg, + "sigmoid_accurate": torch.sigmoid, + "sinh": torch.sinh, + "softsign": torch.nn.functional.softsign, + "swish": torch.nn.functional.hardswish, + "tanhshrink": ttl.tensor.tanhshrink, + "tril": torch.tril, + "triu": torch.triu, } golden_keys = set(name_to_golden_function.keys()) @@ -144,6 +177,28 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): # Other unaries (composite operations) ttnn._ttnn.operations.unary.log_sigmoid, ttnn._ttnn.operations.unary.softplus, + ttnn._ttnn.operations.unary.acosh, + ttnn._ttnn.operations.unary.asinh, + ttnn._ttnn.operations.unary.atanh, + ttnn._ttnn.operations.unary.cbrt, + ttnn._ttnn.operations.unary.cosh, + ttnn._ttnn.operations.unary.deg2rad, + ttnn._ttnn.operations.unary.digamma, + ttnn._ttnn.operations.unary.hardswish, + ttnn._ttnn.operations.unary.hardsigmoid, + ttnn._ttnn.operations.unary.hardtanh, + ttnn._ttnn.operations.unary.lgamma, + ttnn._ttnn.operations.unary.log1p, + ttnn._ttnn.operations.unary.mish, + ttnn._ttnn.operations.unary.multigammaln, + ttnn._ttnn.operations.unary.rad2deg, + ttnn._ttnn.operations.unary.sigmoid_accurate, + ttnn._ttnn.operations.unary.sinh, + ttnn._ttnn.operations.unary.softsign, + ttnn._ttnn.operations.unary.swish, + ttnn._ttnn.operations.unary.tanhshrink, + ttnn._ttnn.operations.unary.tril, + ttnn._ttnn.operations.unary.triu, ] for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS: register_ttnn_cpp_unary_function(unary_function) @@ -154,143 +209,126 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL return leaky_relu(*args, **kwargs) -def torch_cbrt(x, *args, **kwargs): - import torch - - return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3) - - -def torch_multigammaln(x, *args, **kwargs): - import torch - - result = torch.lgamma(x) - result += torch.lgamma(x - 0.5) - result += torch.lgamma(x - 1.0) - result += torch.lgamma(x - 1.5) - result += 3.434189657547 - return result - - -def register_ttl_unary_function(name, ttl_unary_function): - import torch - - name_to_golden_function = { - "acosh": torch.acosh, - "asinh": torch.asinh, - "atanh": torch.atanh, - "cbrt": torch_cbrt, - "cosh": torch.cosh, - "deg2rad": torch.deg2rad, - "digamma": torch.digamma, - "hardswish": torch.nn.functional.hardswish, - "hardsigmoid": torch.nn.functional.hardsigmoid, - "hardtanh": torch.nn.functional.hardtanh, - "lgamma": torch.lgamma, - "log1p": torch.log1p, - "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), - "multigammaln": torch_multigammaln, - "rad2deg": torch.rad2deg, - "sigmoid_accurate": torch.sigmoid, - "sinh": torch.sinh, - "softsign": torch.nn.functional.softsign, - "swish": torch.nn.functional.hardswish, - "tanhshrink": ttl.tensor.tanhshrink, - "tril": torch.tril, - "triu": torch.triu, - } - - golden_keys = set(name_to_golden_function.keys()) - function_names = {name for name, _ in TTL_UNARY_FUNCTIONS} - if golden_keys != function_names: - raise ImportError(f"Missing or extra golden functions:\n{golden_keys}\nshould be equal to\n{function_names}") - - def _golden_function(input_tensor: ttnn.Tensor, **_): - torch_function = name_to_golden_function[name] - return torch_function(input_tensor) - - def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), - can_be_on_device=True, - can_be_on_cpu=False, - ) - - @ttnn.register_operation( - name=f"ttnn.{name}", - validate_input_tensors=_unary_validate_input_tensors, - golden_function=_golden_function, - ) - def unary_function( - input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG - ) -> ttnn.Tensor: - original_shape = input_tensor.shape - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - - if not isinstance(input_tensor, ttnn.Tensor): - raise TypeError("Expected first argument to be a ttnn.Tensor") - - if not ttnn.is_tensor_storage_on_device(input_tensor): - raise RuntimeError("input_tensor must be on device!") - - output_tensor = ttl_unary_function(input_tensor, output_mem_config=memory_config) - output_tensor = ttnn.reshape(output_tensor, original_shape) - return output_tensor - - if isinstance(unary_function, ttnn.decorators.Operation): - unary_function.__name__ = f"ttnn.{name}" - unary_function.decorated_function.__doc__ = f"""{name}(input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor - - Applies {name} to :attr:`input_tensor` element-wise. - - .. math:: - {name.replace('_',' ')}(\\mathrm{{input\\_tensor}}_i) - - Args: - * :attr:`input_tensor` - - Example:: - - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) - >>> output = ttnn.{name}(tensor) - - {unary_function.__doc__} - - """ - setattr(THIS_MODULE, name, unary_function) - - -TTL_UNARY_FUNCTIONS = [ - ("acosh", ttl.tensor.acosh), # composite - ("asinh", ttl.tensor.asinh), # composite - ("atanh", ttl.tensor.atanh), # composite - ("cbrt", ttl.tensor.cbrt), # composite - ("cosh", ttl.tensor.cosh), # composite - ("deg2rad", ttl.tensor.deg2rad), # composite - ("digamma", ttl.tensor.digamma), # composite - ("hardswish", ttl.tensor.hardswish), # composite - ("hardsigmoid", ttl.tensor.hardsigmoid), # composite - ("hardtanh", ttl.tensor.hardtanh), # composite - ("lgamma", ttl.tensor.lgamma), # composite - ("log1p", ttl.tensor.log1p), # composite - ("mish", ttl.tensor.mish), # composite - ("multigammaln", ttl.tensor.multigammaln), # composite - ("rad2deg", ttl.tensor.rad2deg), # composite - ("sigmoid_accurate", ttl.tensor.sigmoid_accurate), # composite - ("sinh", ttl.tensor.sinh), # composite - ("softsign", ttl.tensor.softsign), # composite - ("swish", ttl.tensor.swish), # composite - ("tanhshrink", ttl.tensor.tanhshrink), # composite - ("tril", ttl.tensor.tril), # composite - ("triu", ttl.tensor.triu), # composite -] - - -for unary_function_name, ttl_unary_function in TTL_UNARY_FUNCTIONS: - register_ttl_unary_function(unary_function_name, ttl_unary_function) +# def register_ttl_unary_function(name, ttl_unary_function): +# import torch + +# name_to_golden_function = { +# "acosh": torch.acosh, +# "asinh": torch.asinh, +# "atanh": torch.atanh, +# "cbrt": torch_cbrt, +# "cosh": torch.cosh, +# "deg2rad": torch.deg2rad, +# "digamma": torch.digamma, +# "hardswish": torch.nn.functional.hardswish, +# "hardsigmoid": torch.nn.functional.hardsigmoid, +# "hardtanh": torch.nn.functional.hardtanh, +# "lgamma": torch.lgamma, +# "log1p": torch.log1p, +# "mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)), +# "multigammaln": torch_multigammaln, +# "rad2deg": torch.rad2deg, +# "sigmoid_accurate": torch.sigmoid, +# "sinh": torch.sinh, +# "softsign": torch.nn.functional.softsign, +# "swish": torch.nn.functional.hardswish, +# "tanhshrink": ttl.tensor.tanhshrink, +# "tril": torch.tril, +# "triu": torch.triu, +# } + +# golden_keys = set(name_to_golden_function.keys()) +# function_names = {name for name, _ in TTL_UNARY_FUNCTIONS} +# if golden_keys != function_names: +# raise ImportError(f"Missing or extra golden functions:\n{golden_keys}\nshould be equal to\n{function_names}") + +# def _golden_function(input_tensor: ttnn.Tensor, **_): +# torch_function = name_to_golden_function[name] +# return torch_function(input_tensor) + +# def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): +# ttnn.validate_input_tensor( +# operation_name, +# input_tensor, +# ranks=(2, 3, 4), +# dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), +# layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), +# can_be_on_device=True, +# can_be_on_cpu=False, +# ) + +# @ttnn.register_operation( +# name=f"ttnn.{name}", +# validate_input_tensors=_unary_validate_input_tensors, +# golden_function=_golden_function, +# ) +# def unary_function( +# input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG +# ) -> ttnn.Tensor: +# original_shape = input_tensor.shape +# input_tensor = ttnn.unsqueeze_to_4D(input_tensor) + +# if not isinstance(input_tensor, ttnn.Tensor): +# raise TypeError("Expected first argument to be a ttnn.Tensor") + +# if not ttnn.is_tensor_storage_on_device(input_tensor): +# raise RuntimeError("input_tensor must be on device!") + +# output_tensor = ttl_unary_function(input_tensor, output_mem_config=memory_config) +# output_tensor = ttnn.reshape(output_tensor, original_shape) +# return output_tensor + +# if isinstance(unary_function, ttnn.decorators.Operation): +# unary_function.__name__ = f"ttnn.{name}" +# unary_function.decorated_function.__doc__ = f"""{name}(input_tensor: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor + +# Applies {name} to :attr:`input_tensor` element-wise. + +# .. math:: +# {name.replace('_',' ')}(\\mathrm{{input\\_tensor}}_i) + +# Args: +# * :attr:`input_tensor` + +# Example:: + +# >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) +# >>> output = ttnn.{name}(tensor) + +# {unary_function.__doc__} + +# """ +# setattr(THIS_MODULE, name, unary_function) + + +# TTL_UNARY_FUNCTIONS = [ +# ("acosh", ttl.tensor.acosh), # composite +# ("asinh", ttl.tensor.asinh), # composite +# ("atanh", ttl.tensor.atanh), # composite +# ("cbrt", ttl.tensor.cbrt), # composite +# ("cosh", ttl.tensor.cosh), # composite +# ("deg2rad", ttl.tensor.deg2rad), # composite +# ("digamma", ttl.tensor.digamma), # composite +# ("hardswish", ttl.tensor.hardswish), # composite +# ("hardsigmoid", ttl.tensor.hardsigmoid), # composite +# ("hardtanh", ttl.tensor.hardtanh), # composite +# ("lgamma", ttl.tensor.lgamma), # composite +# ("log1p", ttl.tensor.log1p), # composite +# ("mish", ttl.tensor.mish), # composite +# ("multigammaln", ttl.tensor.multigammaln), # composite +# ("rad2deg", ttl.tensor.rad2deg), # composite +# ("sigmoid_accurate", ttl.tensor.sigmoid_accurate), # composite +# ("sinh", ttl.tensor.sinh), # composite +# ("softsign", ttl.tensor.softsign), # composite +# ("swish", ttl.tensor.swish), # composite +# ("tanhshrink", ttl.tensor.tanhshrink), # composite +# ("tril", ttl.tensor.tril), # composite +# ("triu", ttl.tensor.triu), # composite +# ] + + +# for unary_function_name, ttl_unary_function in TTL_UNARY_FUNCTIONS: +# register_ttl_unary_function(unary_function_name, ttl_unary_function) def _is_scalar(value):