Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8658: Migrate composite unary ops to C++ #8810

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ test_hlk_args_init_gen
tt_build
tt_debug
build
build_debug
/python_env/

/llk_out/
Expand Down
15 changes: 2 additions & 13 deletions tests/ttnn/unit_tests/operations/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_rad2deg(device, h, w):
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_cbrt(device, h, w):
torch_cbrt = ttnn.get_golden_function(ttnn.cbrt)
run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999)


Expand Down Expand Up @@ -270,22 +271,10 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


def torch_cbrt(x, *args, **kwargs):
return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3)


def torch_multigammaln(x, *args, **kwargs):
result = torch.lgamma(x)
result += torch.lgamma(x - 0.5)
result += torch.lgamma(x - 1.0)
result += torch.lgamma(x - 1.5)
result += 3.434189657547
return result


@pytest.mark.parametrize("h", [5])
@pytest.mark.parametrize("w", [5])
def test_multigammaln(device, h, w):
torch_multigammaln = ttnn.get_golden_function(ttnn.multigammaln)
run_math_unary_test_range(device, h, w, ttnn.multigammaln, torch_multigammaln, pcc=0.999)


Expand Down
161 changes: 160 additions & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation
module,
operation,
doc,
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt});
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, memory_config); },
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` element-wise.

.. math::
{0}(\\mathrm{{input\\_tensor}}_i)

Args:
* :attr:`input_tensor`
* :attr:`scale`
* :attr:`shift`

Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.

Example::

>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, scale, shift, memory_config);
},
py::arg("input_tensor"),
py::arg("scale")=1.0f/6.0f,
py::arg("shift")=0.5f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` element-wise.

.. math::
{0}(\\mathrm{{input\\_tensor}}_i)

Args:
* :attr:`input_tensor`
* :attr:`low`
* :attr:`high`

Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.

Example::

>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float low,
float high,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, low, high, memory_config);
},
py::arg("input_tensor"),
py::arg("low") = -1.0f,
py::arg("high") = 1.0f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` element-wise.

.. math::
{0}(\\mathrm{{input\\_tensor}}_i)

Args:
* :attr:`input_tensor`
* :attr:`diag`

Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.

Example::

>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
int32_t diag,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, diag, memory_config); },
py::arg("input_tensor"),
py::arg("diag") = 0,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
Expand Down Expand Up @@ -229,6 +365,29 @@ void py_module(py::module& module) {

// Other unaries (composite operations)
detail::bind_softplus(module);

detail::bind_unary_operation(module, ttnn::acosh);
detail::bind_unary_operation(module, ttnn::asinh);
detail::bind_unary_operation(module, ttnn::atanh);
detail::bind_unary_operation(module, ttnn::cbrt);
detail::bind_unary_operation(module, ttnn::cosh);
detail::bind_unary_operation(module, ttnn::deg2rad);
detail::bind_unary_operation(module, ttnn::digamma);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid);
detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh);
detail::bind_unary_operation(module, ttnn::lgamma);
detail::bind_unary_operation(module, ttnn::log1p);
detail::bind_unary_operation(module, ttnn::mish);
detail::bind_unary_operation(module, ttnn::multigammaln);
detail::bind_unary_operation(module, ttnn::rad2deg);
detail::bind_unary_operation(module, ttnn::sigmoid_accurate);
detail::bind_unary_operation(module, ttnn::sinh);
detail::bind_unary_operation(module, ttnn::softsign);
detail::bind_unary_operation(module, ttnn::swish);
detail::bind_unary_operation(module, ttnn::tanhshrink);
detail::bind_unary_operation_with_diag(module, ttnn::tril);
detail::bind_unary_operation_with_diag(module, ttnn::triu);
}

} // namespace unary
Expand Down
44 changes: 42 additions & 2 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,53 @@ constexpr auto register_operation(const char* name) {
return operation_t<__COUNTER__, concrete_operation_t>{name};
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

template <typename lambda_t>
constexpr auto register_operation(const char* name, const lambda_t& lambda) {
return lambda_operation_t<__COUNTER__, lambda_t>{name, lambda};
}

// This function is used to transform the arguments of a function before calling it
// where the lambda is applied to the type that matches T.
// Example: https://godbolt.org/z/3P9YedMdj
template <typename T, typename Func, typename Lambda, typename... Args>
constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) {
auto transformer = [lambda](auto&& arg) -> decltype(auto) {
if constexpr (std::is_same_v<T, std::decay_t<decltype(arg)>>) {
return lambda(std::forward<decltype(arg)>(arg));
} else {
return std::forward<decltype(arg)>(arg);
}
};

return func(transformer(std::forward<Args>(args))...);
}

template <typename T, typename Lambda>
auto transform_first_matching_arg(Lambda lambda) {
static_assert(!std::is_same<T, T>::value, "No matching type found");
}

template <typename T, typename Lambda, typename First, typename... Rest>
auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) {
if constexpr (std::is_same_v<T, std::decay_t<First>>) {
return lambda(std::forward<First>(first));
} else {
return transform_first_matching_arg<T>(lambda, std::forward<Rest>(rest)...);
}
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

#define TO_LAMBDA_WITH_RESHAPE(function) \
([](auto&&... args) { \
const auto original_shape = ttnn::decorators::transform_first_matching_arg<Tensor>( \
[&](auto&& tensor) { return tensor.get_shape(); }, std::forward<decltype(args)>(args)...); \
return ttnn::reshape( \
ttnn::decorators::transform_args_lambda<Tensor>( \
function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \
original_shape); \
})

} // namespace decorators

using ttnn::decorators::register_operation;
Expand Down
Loading
Loading