Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed May 15, 2024
1 parent fcad6ac commit f57ab51
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 44 deletions.
101 changes: 58 additions & 43 deletions include/luisa/coro/coro_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ class Coroutine {
static_assert(luisa::always_false_v<T>);
};

namespace detail {

class LC_CORO_API CoroAwaiter : public concepts::Noncopyable {

private:
using Await = luisa::move_only_function<void()>;
Await _await;

private:
template<typename T>
friend class Coroutine;
explicit CoroAwaiter(Await await) noexcept
: _await{std::move(await)} {}

public:
void await() && noexcept { _await(); }
};

}// namespace detail

template<typename Ret, typename... Args>
class Coroutine<Ret(Args...)> {

Expand Down Expand Up @@ -95,41 +115,29 @@ class Coroutine<Ret(Args...)> {
[[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; }

private:
template<typename U>
class Awaiter : public concepts::Noncopyable {
private:
U _f;
luisa::optional<Expr<uint3>> _coro_id;

private:
friend class Coroutine;
explicit Awaiter(U f) noexcept : _f{std::move(f)} {}

public:
[[nodiscard]] auto set_id(Expr<uint3> coro_id) && noexcept {
_coro_id.emplace(coro_id);
return std::move(*this);
}
void await() && noexcept { return _f(std::move(_coro_id)); }
};

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
auto f = [=, this](luisa::optional<Expr<uint3>> coro_id) noexcept {
[[nodiscard]] auto _await(luisa::optional<Expr<uint3>> coro_id,
compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return detail::CoroAwaiter{[=, coro_id = std::move(coro_id), this]() noexcept {
auto frame = coro_id ? instantiate(*coro_id) : instantiate();
detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept {
subroutine(token)(f, args...);
});
};
return Awaiter<decltype(f)>{std::move(f)};
}};
}

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _await(luisa::nullopt, args...);
}
[[nodiscard]] auto operator()(Expr<uint3> coro_id, compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _await(luisa::make_optional(coro_id), args...);
}
};

namespace detail {
struct CoroAwaitInvoker {
template<typename A>
void operator%(A &&awaiter) && noexcept {
std::forward<A>(awaiter).await();
void operator%(CoroAwaiter &&awaiter) const && noexcept {
std::move(awaiter).await();
}
};
}// namespace detail
Expand All @@ -155,24 +163,21 @@ class GeneratorIter : public concepts::Noncopyable {
uint _n;
CoroFrame _frame;
using Resume = luisa::move_only_function<void(CoroFrame &, CoroToken)>;
Resume resume;
Resume _resume;

private:
template<typename U>
friend class Generator;
GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept
: _n{n}, _frame{std::move(frame)}, resume{std::move(resume)} {}
: _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {}

public:
[[nodiscard]] auto set_id(Expr<uint3> coro_id) && noexcept {
_frame.coro_id = coro_id;
return std::move(*this);
}
[[nodiscard]] Bool has_next() const noexcept { return !_frame.is_terminated(); }
[[nodiscard]] Var<T> next() noexcept {
coroutine_generator_next_impl(_frame, _n, resume);
return _frame.get<T>("__yielded_value");
auto &update() noexcept {
coroutine_generator_next_impl(_frame, _n, _resume);
return *this;
}
[[nodiscard]] Bool is_terminated() const noexcept { return _frame.is_terminated(); }
[[nodiscard]] Var<T> value() noexcept { return _frame.get<T>("__yielded_value"); }

private:
class RangeForIterator {
Expand All @@ -196,8 +201,9 @@ class GeneratorIter : public concepts::Noncopyable {
auto fb = compute::detail::FunctionBuilder::current();
_loop = fb->loop_();
fb->push_scope(_loop->body());
dsl::if_(!_g.has_next(), [] { dsl::break_(); });
return _g.next();
_g.update();
dsl::if_(_g.is_terminated(), [] { dsl::break_(); });
return _g.value();
}
};

Expand Down Expand Up @@ -226,16 +232,25 @@ class Generator<Ret(Args...)> {
public:
[[nodiscard]] auto coroutine() const noexcept { return _coro; }

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
private:
[[nodiscard]] auto _iter(luisa::optional<Expr<uint3>> coro_id,
compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return detail::GeneratorIter<Ret>{
_coro.subroutine_count(),
_coro.instantiate(),
[=, this](CoroFrame &frame, CoroToken token) noexcept {
_coro[token](frame, args...);
coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate(),
[=, this](CoroFrame &f, CoroToken token) noexcept {
_coro[token](f, args...);
},
};
}

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _iter(luisa::nullopt, args...);
}
[[nodiscard]] auto operator()(Expr<uint3> coro_id, compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _iter(luisa::make_optional(coro_id), args...);
}
};

}// namespace luisa::compute::coroutine
9 changes: 8 additions & 1 deletion src/tests/coro/helloworld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ int main(int argc, char *argv[]) {
x += 1u;
};
};
for (auto x : range(100u)) {
auto iter = range(10u);
$while (!iter.update().is_terminated()) {
auto x = iter.value();
device_log("x = {}", x);
};
for (auto x : range(10u)) {
device_log("x = {}", x);
}
};
auto shader = device.compile(test);
stream << shader().dispatch(1) << synchronize();

coroutine::Coroutine nested2 = [](UInt n) {
$for (i, n) {
Expand Down

0 comments on commit f57ab51

Please sign in to comment.