Skip to content

Commit

Permalink
#8658: Migrate composite unary ops to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jun 7, 2024
1 parent 5e55ea4 commit 75a0609
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 155 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ test_hlk_args_init_gen
tt_build
tt_debug
build
build_debug
/python_env/

/llk_out/
Expand Down
15 changes: 2 additions & 13 deletions tests/ttnn/unit_tests/operations/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_rad2deg(device, h, w):
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_cbrt(device, h, w):
torch_cbrt = ttnn.get_golden_function(ttnn.cbrt)
run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999)


Expand Down Expand Up @@ -270,22 +271,10 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


def torch_cbrt(x, *args, **kwargs):
return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3)


def torch_multigammaln(x, *args, **kwargs):
result = torch.lgamma(x)
result += torch.lgamma(x - 0.5)
result += torch.lgamma(x - 1.0)
result += torch.lgamma(x - 1.5)
result += 3.434189657547
return result


@pytest.mark.parametrize("h", [5])
@pytest.mark.parametrize("w", [5])
def test_multigammaln(device, h, w):
torch_multigammaln = ttnn.get_golden_function(ttnn.multigammaln)
run_math_unary_test_range(device, h, w, ttnn.multigammaln, torch_multigammaln, pcc=0.999)


Expand Down
161 changes: 160 additions & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation
module,
operation,
doc,
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt});
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, memory_config); },
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`scale`
* :attr:`shift`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, scale, shift, memory_config);
},
py::arg("input_tensor"),
py::arg("scale")=1.0f/6.0f,
py::arg("shift")=0.5f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`low`
* :attr:`high`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float low,
float high,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, low, high, memory_config);
},
py::arg("input_tensor"),
py::arg("low") = -1.0f,
py::arg("high") = 1.0f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`diag`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
int32_t diag,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, diag, memory_config); },
py::arg("input_tensor"),
py::arg("diag") = 0,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
Expand Down Expand Up @@ -229,6 +365,29 @@ void py_module(py::module& module) {

// Other unaries (composite operations)
detail::bind_softplus(module);

detail::bind_unary_operation(module, ttnn::acosh);
detail::bind_unary_operation(module, ttnn::asinh);
detail::bind_unary_operation(module, ttnn::atanh);
detail::bind_unary_operation(module, ttnn::cbrt);
detail::bind_unary_operation(module, ttnn::cosh);
detail::bind_unary_operation(module, ttnn::deg2rad);
detail::bind_unary_operation(module, ttnn::digamma);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid);
detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh);
detail::bind_unary_operation(module, ttnn::lgamma);
detail::bind_unary_operation(module, ttnn::log1p);
detail::bind_unary_operation(module, ttnn::mish);
detail::bind_unary_operation(module, ttnn::multigammaln);
detail::bind_unary_operation(module, ttnn::rad2deg);
detail::bind_unary_operation(module, ttnn::sigmoid_accurate);
detail::bind_unary_operation(module, ttnn::sinh);
detail::bind_unary_operation(module, ttnn::softsign);
detail::bind_unary_operation(module, ttnn::swish);
detail::bind_unary_operation(module, ttnn::tanhshrink);
detail::bind_unary_operation_with_diag(module, ttnn::tril);
detail::bind_unary_operation_with_diag(module, ttnn::triu);
}

} // namespace unary
Expand Down
44 changes: 42 additions & 2 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,53 @@ constexpr auto register_operation(const char* name) {
return operation_t<__COUNTER__, concrete_operation_t>{name};
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

template <typename lambda_t>
constexpr auto register_operation(const char* name, const lambda_t& lambda) {
return lambda_operation_t<__COUNTER__, lambda_t>{name, lambda};
}

// This function is used to transform the arguments of a function before calling it
// where the lambda is applied to the type that matches T.
// Example: https://godbolt.org/z/3P9YedMdj
template <typename T, typename Func, typename Lambda, typename... Args>
constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) {
auto transformer = [lambda](auto&& arg) -> decltype(auto) {
if constexpr (std::is_same_v<T, std::decay_t<decltype(arg)>>) {
return lambda(std::forward<decltype(arg)>(arg));
} else {
return std::forward<decltype(arg)>(arg);
}
};

return func(transformer(std::forward<Args>(args))...);
}

template <typename T, typename Lambda>
auto transform_first_matching_arg(Lambda lambda) {
static_assert(!std::is_same<T, T>::value, "No matching type found");
}

template <typename T, typename Lambda, typename First, typename... Rest>
auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) {
if constexpr (std::is_same_v<T, std::decay_t<First>>) {
return lambda(std::forward<First>(first));
} else {
return transform_first_matching_arg<T>(lambda, std::forward<Rest>(rest)...);
}
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

#define TO_LAMBDA_WITH_RESHAPE(function) \
([](auto&&... args) { \
const auto original_shape = ttnn::decorators::transform_first_matching_arg<Tensor>( \
[&](auto&& tensor) { return tensor.get_shape(); }, std::forward<decltype(args)>(args)...); \
return ttnn::reshape( \
ttnn::decorators::transform_args_lambda<Tensor>( \
function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \
original_shape); \
})

} // namespace decorators

using ttnn::decorators::register_operation;
Expand Down
Loading

0 comments on commit 75a0609

Please sign in to comment.