Skip to content

Commit

Permalink
#8658: Migrate composite unary ops to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed May 24, 2024
1 parent 2d73835 commit db5c8e2
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 151 deletions.
14 changes: 1 addition & 13 deletions tests/ttnn/unit_tests/operations/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_rad2deg(device, h, w):
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_cbrt(device, h, w):
torch_cbrt = ttnn.get_golden_function(ttnn.cbrt)
run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999)


Expand Down Expand Up @@ -213,19 +214,6 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


def torch_cbrt(x, *args, **kwargs):
return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3)


def torch_multigammaln(x, *args, **kwargs):
result = torch.lgamma(x)
result += torch.lgamma(x - 0.5)
result += torch.lgamma(x - 1.0)
result += torch.lgamma(x - 1.5)
result += 3.434189657547
return result


@pytest.mark.parametrize("h", [5])
@pytest.mark.parametrize("w", [5])
def test_multigammaln(device, h, w):
Expand Down
161 changes: 160 additions & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation
module,
operation,
doc,
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt});
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, memory_config); },
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`scale`
* :attr:`shift`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, scale, shift, memory_config);
},
py::arg("input_tensor"),
py::arg("scale"),
py::arg("shift"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`low`
* :attr:`high`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float low,
float high,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, low, high, memory_config);
},
py::arg("input_tensor"),
py::arg("low"),
py::arg("high"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`diag`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
int32_t diag,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, diag, memory_config); },
py::arg("input_tensor"),
py::arg("diag"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
Expand Down Expand Up @@ -226,6 +362,29 @@ void py_module(py::module& module) {

// Other unaries (composite operations)
detail::bind_softplus(module);

detail::bind_unary_operation(module, ttnn::acosh);
detail::bind_unary_operation(module, ttnn::asinh);
detail::bind_unary_operation(module, ttnn::atanh);
detail::bind_unary_operation(module, ttnn::cbrt);
detail::bind_unary_operation(module, ttnn::cosh);
detail::bind_unary_operation(module, ttnn::deg2rad);
detail::bind_unary_operation(module, ttnn::digamma);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid);
detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh);
detail::bind_unary_operation(module, ttnn::lgamma);
detail::bind_unary_operation(module, ttnn::log1p);
detail::bind_unary_operation(module, ttnn::mish);
detail::bind_unary_operation(module, ttnn::multigammaln);
detail::bind_unary_operation(module, ttnn::rad2deg);
detail::bind_unary_operation(module, ttnn::sigmoid_accurate);
detail::bind_unary_operation(module, ttnn::sinh);
detail::bind_unary_operation(module, ttnn::softsign);
detail::bind_unary_operation(module, ttnn::swish);
detail::bind_unary_operation(module, ttnn::tanhshrink);
detail::bind_unary_operation_with_diag(module, ttnn::tril);
detail::bind_unary_operation_with_diag(module, ttnn::triu);
}

} // namespace unary
Expand Down
118 changes: 118 additions & 0 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,97 @@ struct Softplus {
return result;
}
};

// TODO: update these composite unary ops pending decision on TensorAsync implementation.

Tensor acosh(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::acosh(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}

Tensor asinh(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::asinh(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}

Tensor atanh(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::atanh(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}

Tensor cbrt(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::cbrt(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor cosh(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::cosh(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor deg2rad(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::deg2rad(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor digamma(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::digamma(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor hardswish(
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::hardswish(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config()));
}
Tensor hardsigmoid(
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::hardsigmoid(input_tensor, scale, shift, memory_config.value_or(input_tensor.memory_config()));
}
Tensor hardtanh(
const Tensor& input_tensor,
float low /* = -1.0f */,
float high /* = +1.0f */,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::hardtanh(input_tensor, low, high, memory_config.value_or(input_tensor.memory_config()));
}
Tensor lgamma(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::lgamma(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor log1p(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::log1p(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor mish(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::mish(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor multigammaln(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::multigammaln(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor rad2deg(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::rad2deg(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor sigmoid_accurate(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::sigmoid_accurate(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor sinh(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::sinh(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor softsign(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::softsign(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor swish(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::swish(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor tanhshrink(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::tanhshrink(input_tensor, memory_config.value_or(input_tensor.memory_config()));
}
Tensor tril(
const Tensor& input_tensor,
int32_t diag /* = -1 */,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::tril(input_tensor, diag, memory_config.value_or(input_tensor.memory_config()));
}
Tensor triu(
const Tensor& input_tensor,
int32_t diag /* = -1 */,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return tt::tt_metal::triu(input_tensor, diag, memory_config.value_or(input_tensor.memory_config()));
}

} // namespace unary
} // namespace operations

Expand Down Expand Up @@ -202,4 +293,31 @@ auto prelu = leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU
// Other unaries (composite operations)
constexpr auto softplus = ttnn::register_operation<ttnn::operations::unary::Softplus>("ttnn::softplus");

constexpr auto acosh = ttnn::register_operation("ttnn::acosh", TO_LAMBDA(ttnn::operations::unary::acosh));
constexpr auto asinh = ttnn::register_operation("ttnn::asinh", TO_LAMBDA(ttnn::operations::unary::asinh));
constexpr auto atanh = ttnn::register_operation("ttnn::atanh", TO_LAMBDA(ttnn::operations::unary::atanh));
constexpr auto cbrt = ttnn::register_operation("ttnn::cbrt", TO_LAMBDA(ttnn::operations::unary::cbrt));
constexpr auto cosh = ttnn::register_operation("ttnn::cosh", TO_LAMBDA(ttnn::operations::unary::cosh));
constexpr auto deg2rad = ttnn::register_operation("ttnn::deg2rad", TO_LAMBDA(ttnn::operations::unary::deg2rad));
constexpr auto digamma = ttnn::register_operation("ttnn::digamma", TO_LAMBDA(ttnn::operations::unary::digamma));
constexpr auto hardswish = ttnn::register_operation("ttnn::hardswish", TO_LAMBDA(ttnn::operations::unary::hardswish));
constexpr auto hardsigmoid =
ttnn::register_operation("ttnn::hardsigmoid", TO_LAMBDA(ttnn::operations::unary::hardsigmoid));
constexpr auto hardtanh = ttnn::register_operation("ttnn::hardtanh", TO_LAMBDA(ttnn::operations::unary::hardtanh));
constexpr auto lgamma = ttnn::register_operation("ttnn::lgamma", TO_LAMBDA(ttnn::operations::unary::lgamma));
constexpr auto log1p = ttnn::register_operation("ttnn::log1p", TO_LAMBDA(ttnn::operations::unary::log1p));
constexpr auto mish = ttnn::register_operation("ttnn::mish", TO_LAMBDA(ttnn::operations::unary::mish));
constexpr auto multigammaln =
ttnn::register_operation("ttnn::multigammaln", TO_LAMBDA(ttnn::operations::unary::multigammaln));
constexpr auto rad2deg = ttnn::register_operation("ttnn::rad2deg", TO_LAMBDA(ttnn::operations::unary::rad2deg));
constexpr auto sigmoid_accurate =
ttnn::register_operation("ttnn::sigmoid_accurate", TO_LAMBDA(ttnn::operations::unary::sigmoid_accurate));
constexpr auto sinh = ttnn::register_operation("ttnn::sinh", TO_LAMBDA(ttnn::operations::unary::sinh));
constexpr auto softsign = ttnn::register_operation("ttnn::softsign", TO_LAMBDA(ttnn::operations::unary::softsign));
constexpr auto swish = ttnn::register_operation("ttnn::swish", TO_LAMBDA(ttnn::operations::unary::swish));
constexpr auto tanhshrink =
ttnn::register_operation("ttnn::tanhshrink", TO_LAMBDA(ttnn::operations::unary::tanhshrink));
constexpr auto tril = ttnn::register_operation("ttnn::tril", TO_LAMBDA(ttnn::operations::unary::tril));
constexpr auto triu = ttnn::register_operation("ttnn::triu", TO_LAMBDA(ttnn::operations::unary::triu));

} // namespace ttnn
Loading

0 comments on commit db5c8e2

Please sign in to comment.