Skip to content

Commit

Permalink
improv(coro): add concept awaitable_type, move `dpp::detail::promis…
Browse files Browse the repository at this point in the history
…e::promise` to `dpp::basic_promise`, document some more
  • Loading branch information
Mishura4 committed Jul 11, 2024
1 parent 6eef94a commit 6ab08aa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 33 deletions.
101 changes: 71 additions & 30 deletions include/dpp/coro/awaitable.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ struct awaitable_dummy {

#include <dpp/coro/coro.h>

// Do not include <coroutine> as coro.h includes <experimental/coroutine> or <coroutine> depending on clang version
#include <mutex>
#include <utility>
#include <type_traits>
#include <functional>
#include <exception>
#include <atomic>
#include <cstddef>

namespace dpp {

Expand Down Expand Up @@ -91,20 +91,33 @@ class promise_base;
*/
struct empty{};

/**
* @brief Variant for the 3 conceptual values of a coroutine:
*/
template <typename T>
using result_t = std::variant<std::monostate, std::conditional_t<std::is_void_v<T>, empty, T>, std::exception_ptr>;

template <typename T>
void spawn_sync_wait_job(auto* awaitable, std::condition_variable &cv, auto&& result);

} /* namespace detail::promise */

template <typename Derived>
requires (requires (Derived t) { detail::co_await_resolve(t); })
template <awaitable_type Derived>
class basic_awaitable {
protected:
/**
* @brief Implementation for sync_wait. This is code used by sync_wait, sync_wait_for, sync_wait_until.
*
* @tparam Timed Whether the wait function times out or not
* @param do_wait Function to do the actual wait on the cv
* @return If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @return If T is non-void, returns a std::optional<T> for which an absence of value means timed out.
*/
template <bool Timed>
auto sync_wait_impl(auto&& do_wait) {
using result_type = decltype(detail::co_await_resolve(std::declval<Derived>()).await_resume());
using storage_type = std::conditional_t<std::is_void_v<result_type>, detail::promise::empty, result_type>;
using variant_type = std::variant<std::monostate, storage_type, std::exception_ptr>;
using variant_type = detail::promise::result_t<result_type>;
variant_type result;
std::condition_variable cv;

Expand All @@ -131,6 +144,8 @@ class basic_awaitable {
* @brief Blocks this thread and waits for the awaitable to finish.
*
* @attention This will BLOCK THE THREAD. It is likely you want to use co_await instead.
* @return If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @return If T is non-void, returns a std::optional<T> for which an absence of value means timed out.
*/
auto sync_wait() {
return sync_wait_impl<false>([](std::condition_variable &cv, auto&& result) {
Expand All @@ -145,8 +160,8 @@ class basic_awaitable {
*
* @attention This will BLOCK THE THREAD. It is likely you want to use co_await instead.
* @param duration Maximum duration to wait for
* @retval If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @retval If T is non-void, returns a std::optional<T> for which an absense of value means timed out.
* @return If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @return If T is non-void, returns a std::optional<T> for which an absence of value means timed out.
*/
template <class Rep, class Period>
auto sync_wait_for(const std::chrono::duration<Rep, Period>& duration) {
Expand All @@ -162,8 +177,8 @@ class basic_awaitable {
*
* @attention This will BLOCK THE THREAD. It is likely you want to use co_await instead.
* @param time Maximum time point to wait for
* @retval If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @retval If T is non-void, returns a std::optional<T> for which an absense of value means timed out.
* @return If T is void, returns a boolean for which true means the awaitable completed, false means it timed out.
* @return If T is non-void, returns a std::optional<T> for which an absence of value means timed out.
*/
template <class Clock, class Duration>
auto sync_wait_until(const std::chrono::time_point<Clock, Duration> &time) {
Expand Down Expand Up @@ -339,6 +354,9 @@ class awaitable : public basic_awaitable<awaitable<T>> {

namespace detail::promise {

/**
* @brief Base class defining logic common to all promise types, aka the "write" end of an awaitable.
*/
template <typename T>
class promise_base {
protected:
Expand Down Expand Up @@ -377,6 +395,12 @@ class promise_base {
}
}

/**
* @brief Unlinks this promise from its currently linked awaiter and returns it.
*
* At the time of writing this is only used in the case of a serious internal error in dpp::task.
* Avoid using this as this will crash if the promise is used after this.
*/
std_coroutine::coroutine_handle<> release_awaiter() {
return std::exchange(awaiter, nullptr);
}
Expand All @@ -393,6 +417,8 @@ class promise_base {

/**
* @brief Move construction is disabled.
*
* awaitable hold a pointer to this object so moving is not possible.
*/
promise_base(promise_base&& rhs) = delete;

Expand All @@ -412,6 +438,7 @@ class promise_base {
*
* @tparam Notify Whether to resume any awaiter or not.
* @throws dpp::logic_exception if the promise is not empty.
* @throws ? Any exception thrown by the coroutine if resumed will propagate
*/
template <bool Notify = true>
void set_exception(std::exception_ptr ptr) {
Expand All @@ -427,9 +454,12 @@ class promise_base {

/**
* @brief Notify a currently awaiting coroutine that the result is ready.
*
* @note This may resume the coroutine on the current thread.
* @throws ? Any exception thrown by the coroutine if resumed will propagate
*/
void notify_awaiter() {
if (state.load(std::memory_order::acquire) & sf_awaited) {
if ((state.load(std::memory_order::acquire) & sf_awaited) != 0) {
awaiter.resume();
}
}
Expand All @@ -449,6 +479,8 @@ class promise_base {
}
};

}

/**
* @brief Generic promise class, represents the owning potion of an asynchronous value.
*
Expand All @@ -459,10 +491,10 @@ class promise_base {
* @see awaitable
*/
template <typename T>
class promise : public promise_base<T> {
class basic_promise : public detail::promise::promise_base<T> {
public:
using promise_base<T>::promise_base;
using promise_base<T>::operator=;
using detail::promise::promise_base<T>::promise_base;
using detail::promise::promise_base<T>::operator=;

/**
* @brief Construct the result in place by forwarding the arguments, and by default resume any awaiter.
Expand All @@ -479,9 +511,9 @@ class promise : public promise_base<T> {
} catch (...) {
this->value.template emplace<2>(std::current_exception());
}
[[maybe_unused]] auto previous_value = this->state.fetch_or(sf_ready, std::memory_order::acq_rel);
[[maybe_unused]] auto previous_value = this->state.fetch_or(detail::promise::sf_ready, std::memory_order::acq_rel);
if constexpr (Notify) {
if (previous_value & sf_awaited) {
if (previous_value & detail::promise::sf_awaited) {
this->awaiter.resume();
}
}
Expand Down Expand Up @@ -510,11 +542,20 @@ class promise : public promise_base<T> {
}
};


/**
* @brief Generic promise class, represents the owning potion of an asynchronous value.
*
* This class is roughly equivalent to std::promise, with the crucial distinction that the promise *IS* the shared state.
* As such, the promise needs to be kept alive for the entire time a value can be retrieved.
*
* @see awaitable
*/
template <>
class promise<void> : public promise_base<void> {
class basic_promise<void> : public detail::promise::promise_base<void> {
public:
using promise_base::promise_base;
using promise_base::operator=;
using detail::promise::promise_base<void>::promise_base;
using detail::promise::promise_base<void>::operator=;

/**
* @brief Set the promise to completed, and resume any awaiter.
Expand All @@ -525,27 +566,30 @@ class promise<void> : public promise_base<void> {
void set_value() {
throw_if_not_empty();
this->value.emplace<1>();
[[maybe_unused]] auto previous_value = this->state.fetch_or(sf_ready, std::memory_order::acq_rel);
[[maybe_unused]] auto previous_value = this->state.fetch_or(detail::promise::sf_ready, std::memory_order::acq_rel);
if constexpr (Notify) {
if (previous_value & sf_awaited) {
if (previous_value & detail::promise::sf_awaited) {
this->awaiter.resume();
}
}
}
};

}

template <typename T>
using basic_promise = detail::promise::promise<T>;

/**
* @brief Base class for a promise type.
* @brief Generic promise class, represents the owning potion of an asynchronous value.
*
* This class is roughly equivalent to std::promise, with the crucial distinction that the promise *IS* the shared state.
* As such, the promise needs to be kept alive for the entire time a value can be retrieved.
*
* The difference between basic_promise and this object is that this one is moveable as it wraps an underlying basic_promise in a std::unique_ptr.
*
* Contains the base logic for @ref promise, but does not contain the set_value methods.
* @see awaitable
*/
template <typename T>
class moveable_promise {
/**
* @brief Shared state, wrapped in a unique_ptr to allow move without disturbing an awaitable's promise pointer.
*/
std::unique_ptr<basic_promise<T>> shared_state = std::make_unique<basic_promise<T>>();

public:
Expand Down Expand Up @@ -712,9 +756,6 @@ namespace dpp {

namespace detail::promise {

template <typename T>
using result_t = std::variant<std::monostate, std::conditional_t<std::is_void_v<T>, empty, T>, std::exception_ptr>;

template <typename T>
void spawn_sync_wait_job(auto* awaitable, std::condition_variable &cv, auto&& result) {
[](auto* awaitable_, std::condition_variable &cv_, auto&& result_) -> dpp::job {
Expand Down
22 changes: 19 additions & 3 deletions include/dpp/coro/coro.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,36 +114,46 @@ decltype(auto) co_await_resolve(T&& expr) noexcept {
}

#else

/**
* @brief Concept to check if a type has a useable `operator co_await()` member
*
* @note This is actually a C++20 concept but Doxygen doesn't do well with them
*/
template <typename T>
bool has_co_await_member;
inline constexpr bool has_co_await_member;

/**
* @brief Concept to check if a type has a useable overload of the free function `operator co_await(expr)`
*
* @note This is actually a C++20 concept but Doxygen doesn't do well with them
*/
template <typename T>
bool has_free_co_await;
inline constexpr bool has_free_co_await;

/**
* @brief Concept to check if a type has useable `await_ready()`, `await_suspend()` and `await_resume()` member functions.
*
* @note This is actually a C++20 concept but Doxygen doesn't do well with them
*/
template <typename T>
bool has_await_members;
inline constexpr bool has_await_members;

/**
* @brief Concept to check if a type can be used with co_await
*
* @note This is actually a C++20 concept but Doxygen doesn't do well with them
*/
template <typename T>
inline constexpr bool awaitable_type;

/**
* @brief Mimics the compiler's behavior of using co_await. That is, it returns whichever works first, in order : `expr.operator co_await();` > `operator co_await(expr)` > `expr`
*
* This function is conditionally noexcept, if the returned expression also is.
*/
decltype(auto) co_await_resolve(auto&& expr) {}

#endif

/**
Expand All @@ -154,6 +164,12 @@ using awaitable_result = decltype(co_await_resolve(std::declval<T>()).await_resu

} // namespace detail

/**
* @brief Concept to check if a type can be used with co_await
*/
template <typename T>
concept awaitable_type = requires (T expr) { detail::co_await_resolve(expr); };

struct confirmation_callback_t;

template <typename R = confirmation_callback_t>
Expand Down

0 comments on commit 6ab08aa

Please sign in to comment.