Skip to content

Commit

Permalink
#5389: ported ttnn::log_sigmoid to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 23, 2024
1 parent 9b434ad commit 354a834
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ void py_module(py::module& module) {
detail::bind_unary_operation(module, ttnn::square);
detail::bind_unary_operation(module, ttnn::tan);
detail::bind_unary_operation(module, ttnn::tanh);
detail::bind_unary_operation(module, ttnn::log_sigmoid);

// Unaries with fast_and_approximate_mode
detail::bind_unary_operation_with_fast_and_approximate_mode(module, ttnn::exp);
Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline Tensor execute(
}
} // namespace detail

template <UnaryOpType unary_op_type>
template <UnaryOpType... unary_op_types>
struct ExecuteUnary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

Expand All @@ -62,7 +62,7 @@ struct ExecuteUnary {
}

static Tensor execute(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return detail::execute(input_tensor, {UnaryWithParam{unary_op_type}}, memory_config);
return detail::execute(input_tensor, {UnaryWithParam{unary_op_types}...}, memory_config);
}
};

Expand Down Expand Up @@ -181,6 +181,10 @@ REGISTER_UNARY_OPERATION(square, SQUARE);
REGISTER_UNARY_OPERATION(tan, TAN);
REGISTER_UNARY_OPERATION(tanh, TANH);

constexpr auto log_sigmoid = ttnn::register_operation<ttnn::operations::unary::ExecuteUnary<
ttnn::operations::unary::UnaryOpType::SIGMOID,
ttnn::operations::unary::UnaryOpType::LOG>>("ttnn::log_sigmoid");

// Unaries with fast_and_approximate_mode
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(exp, EXP);
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(erf, ERF);
Expand Down
4 changes: 2 additions & 2 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def torch_prelu(x, *args, **kwargs):
"log": torch.log,
"log10": torch.log10,
"log2": torch.log2,
"log_sigmoid": torch.nn.functional.logsigmoid,
"logical_not": torch.logical_not,
"ltz": lambda x: torch.lt(x, 0),
"neg": torch.neg,
Expand Down Expand Up @@ -141,6 +142,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
ttnn._ttnn.operations.unary.leaky_relu,
# ttnn._ttnn.operations.unary.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly
# Other unaries (composite operations)
ttnn._ttnn.operations.unary.log_sigmoid,
ttnn._ttnn.operations.unary.softplus,
]
for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS:
Expand Down Expand Up @@ -185,7 +187,6 @@ def register_ttl_unary_function(name, ttl_unary_function):
"hardtanh": torch.nn.functional.hardtanh,
"lgamma": torch.lgamma,
"log1p": torch.log1p,
"log_sigmoid": torch.nn.functional.logsigmoid,
"mish": lambda _x: torch.nn.functional.mish(_x.to(torch.float)),
"multigammaln": torch_multigammaln,
"rad2deg": torch.rad2deg,
Expand Down Expand Up @@ -275,7 +276,6 @@ def unary_function(
("hardtanh", ttl.tensor.hardtanh), # composite
("lgamma", ttl.tensor.lgamma), # composite
("log1p", ttl.tensor.log1p), # composite
("log_sigmoid", ttl.tensor.log_sigmoid), # composite
("mish", ttl.tensor.mish), # composite
("multigammaln", ttl.tensor.multigammaln), # composite
("rad2deg", ttl.tensor.rad2deg), # composite
Expand Down

0 comments on commit 354a834

Please sign in to comment.