Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge branch 'refactor-api' of https://github.com/LuisaGroup/LuisaCom…
Browse files Browse the repository at this point in the history
…pute-coroutine into refactor-api
  • Loading branch information
OldNew777 committed May 17, 2024
2 parents 91b47ba + c93cc73 commit ad36f74
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 188 deletions.
4 changes: 2 additions & 2 deletions include/luisa/coro/coro_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class LC_CORO_API CoroFrame {

private:
luisa::shared_ptr<const CoroFrameDesc> _desc;
const RefExpr *_expression;
const Expression *_expression;

public:
Var<uint3> &coro_id;
Var<uint> &target_token;

public:
CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const RefExpr *expr) noexcept;
CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const Expression *expr) noexcept;
CoroFrame(CoroFrame &&another) noexcept;
CoroFrame(const CoroFrame &another) noexcept;
CoroFrame &operator=(const CoroFrame &rhs) noexcept;
Expand Down
1 change: 1 addition & 0 deletions include/luisa/coro/coro_frame_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this<CoroFram
public:
[[nodiscard]] static luisa::shared_ptr<CoroFrameDesc> create(const Type *type, DesignatedFieldDict m) noexcept;
[[nodiscard]] auto type() const noexcept { return _type; }
[[nodiscard]] const Type *field(uint index) const noexcept;
[[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; }
[[nodiscard]] uint designated_field(luisa::string_view name) const noexcept;
[[nodiscard]] luisa::string dump() const noexcept;
Expand Down
57 changes: 51 additions & 6 deletions include/luisa/coro/coro_frame_smem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

#pragma once

#include <luisa/core/stl/optional.h>
#include <luisa/dsl/shared.h>
#include <luisa/coro/coro_frame.h>

namespace luisa::compute {

namespace detail {
[[noreturn]] LC_CORO_API void error_coro_frame_smem_subscript_on_soa() noexcept;
}// namespace detail

template<>
class Shared<coroutine::CoroFrame> {

private:
luisa::shared_ptr<const coroutine::CoroFrameDesc> _desc;
luisa::vector<const RefExpr *> _expressions;
// we need to return a `CoroFrame &` for operator[] with proper lifetime
luisa::vector<luisa::unique_ptr<coroutine::CoroFrame>> _temp_frames;
size_t _size;

private:
Expand All @@ -39,8 +46,9 @@ class Shared<coroutine::CoroFrame> {
}

public:
Shared(luisa::shared_ptr<const coroutine::CoroFrameDesc> desc, size_t n,
bool soa = true, luisa::span<const uint> soa_excluded_fields = {}) noexcept
Shared(luisa::shared_ptr<const coroutine::CoroFrameDesc> desc,
size_t n, bool soa = false,
luisa::span<const uint> soa_excluded_fields = {}) noexcept
: _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); }

Shared(Shared &&) noexcept = default;
Expand All @@ -52,10 +60,11 @@ class Shared<coroutine::CoroFrame> {
[[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; }
[[nodiscard]] auto size() const noexcept { return _size; }

public:
private:
/// Read index with active fields
template<typename I>
[[nodiscard]] auto read(I &&index, luisa::optional<luisa::span<const uint>> active_fields = luisa::nullopt) const noexcept {
requires is_integral_expr_v<I>
[[nodiscard]] auto _read(I &&index, luisa::optional<luisa::span<const uint>> active_fields) const noexcept {
auto i = def(std::forward<I>(index));
auto fb = detail::FunctionBuilder::current();
auto frame = fb->local(_desc->type());
Expand All @@ -71,7 +80,7 @@ class Shared<coroutine::CoroFrame> {
if (m == 0u) {
auto t = Type::of<uint>();
auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression());
std::array<const Expression *, 3u> elems;
std::array<const Expression *, 3u> elems{};
elems[0] = fb->access(t, s, fb->literal(t, 0u));
elems[1] = fb->access(t, s, fb->literal(t, 1u));
elems[2] = fb->access(t, s, fb->literal(t, 2u));
Expand All @@ -88,7 +97,8 @@ class Shared<coroutine::CoroFrame> {

/// Write index with active fields
template<typename I>
void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional<luisa::span<const uint>> active_fields = luisa::nullopt) const noexcept {
requires is_integral_expr_v<I>
void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional<luisa::span<const uint>> active_fields) const noexcept {
auto i = def(std::forward<I>(index));
auto fb = detail::FunctionBuilder::current();
if (!is_soa()) {
Expand All @@ -115,6 +125,41 @@ class Shared<coroutine::CoroFrame> {
}
}
}

public:
/// Reference to the i-th element
template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame &operator[](I &&index) noexcept {
if (is_soa()) { detail::error_coro_frame_smem_subscript_on_soa(); }
auto fb = detail::FunctionBuilder::current();
auto ref = fb->access(
_desc->type(), _expressions[0],
detail::extract_expression(std::forward<I>(index)));
auto temp_frame = luisa::make_unique<coroutine::CoroFrame>(_desc, ref);
return *_temp_frames.emplace_back(std::move(temp_frame));
}

template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame read(I &&index) const noexcept {
return _read(std::forward<I>(index), luisa::nullopt);
}
template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame read(I &&index, luisa::span<const uint> active_fields) const noexcept {
return _read(std::forward<I>(index), luisa::make_optional(active_fields));
}
template<typename I>
requires is_integral_expr_v<I>
void write(I &&index, const coroutine::CoroFrame &frame) const noexcept {
_write(std::forward<I>(index), frame, luisa::nullopt);
}
template<typename I>
requires is_integral_expr_v<I>
void write(I &&index, const coroutine::CoroFrame &frame, luisa::span<const uint> active_fields) const noexcept {
_write(std::forward<I>(index), frame, luisa::make_optional(active_fields));
}
};

}// namespace luisa::compute
178 changes: 95 additions & 83 deletions include/luisa/coro/coro_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,29 @@ namespace detail {
LC_CORO_API void coroutine_chained_await_impl(
CoroFrame &frame, uint node_count,
luisa::move_only_function<void(CoroToken, CoroFrame &)> node) noexcept;
LC_CORO_API void coroutine_generator_step_impl(
CoroFrame &frame, uint node_count, bool is_entry,
luisa::move_only_function<void(CoroToken, CoroFrame &)> node) noexcept;
}// namespace detail

template<typename T>
class Coroutine {
static_assert(luisa::always_false_v<T>);
};

namespace detail {

class LC_CORO_API CoroAwaiter : public concepts::Noncopyable {

private:
using Await = luisa::move_only_function<void()>;
Await _await;

public:
explicit CoroAwaiter(Await await) noexcept
: _await{std::move(await)} {}
void await() && noexcept { _await(); }
};

}// namespace detail

template<typename Ret, typename... Args>
class Coroutine<Ret(Args...)> {

Expand Down Expand Up @@ -98,41 +111,29 @@ class Coroutine<Ret(Args...)> {
[[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; }

private:
template<typename U>
class Awaiter : public concepts::Noncopyable {
private:
U _f;
luisa::optional<Expr<uint3>> _coro_id;

private:
friend class Coroutine;
explicit Awaiter(U f) noexcept : _f{std::move(f)} {}

public:
[[nodiscard]] auto set_id(Expr<uint3> coro_id) && noexcept {
_coro_id.emplace(coro_id);
return std::move(*this);
}
void await() && noexcept { return _f(std::move(_coro_id)); }
};

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
auto f = [=, this](luisa::optional<Expr<uint3>> coro_id) noexcept {
[[nodiscard]] auto _await(luisa::optional<Expr<uint3>> coro_id,
compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return detail::CoroAwaiter{[=, coro_id = std::move(coro_id), this]() noexcept {
auto frame = coro_id ? instantiate(*coro_id) : instantiate();
detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept {
subroutine(token)(f, args...);
});
};
return Awaiter<decltype(f)>{std::move(f)};
}};
}

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _await(luisa::nullopt, args...);
}
[[nodiscard]] auto operator()(Expr<uint3> coro_id, compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _await(luisa::make_optional(coro_id), args...);
}
};

namespace detail {
struct CoroAwaitInvoker {
template<typename A>
void operator%(A &&awaiter) && noexcept {
std::forward<A>(awaiter).await();
void operator%(CoroAwaiter &&awaiter) const && noexcept {
std::move(awaiter).await();
}
};
}// namespace detail
Expand All @@ -145,94 +146,105 @@ class Generator {
static_assert(luisa::always_false_v<T>);
};

template<typename Ret, typename... Args>
class Generator<Ret(Args...)> {
namespace detail {

static_assert(!std::is_same_v<Ret, void>,
"Generator function must not return void.");
LC_CORO_API void coroutine_generator_next_impl(
CoroFrame &frame, uint node_count,
const luisa::move_only_function<void(CoroFrame &, CoroToken)> &resume) noexcept;

template<typename T>
class GeneratorIter : public concepts::Noncopyable {

private:
Coroutine<void(Args...)> _coro;
uint _n;
CoroFrame _frame;
using Resume = luisa::move_only_function<void(CoroFrame &, CoroToken)>;
Resume _resume;

public:
template<typename Def>
requires std::negation_v<is_callable<std::remove_cvref_t<Def>>> &&
std::negation_v<is_kernel<std::remove_cvref_t<Def>>>
Generator(Def &&f) noexcept : _coro{std::forward<Def>(f)} {}
GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept
: _n{n}, _frame{std::move(frame)},
_resume{std::move(resume)} {}

public:
[[nodiscard]] auto coroutine() const noexcept { return _coro; }
auto &update() noexcept {
coroutine_generator_next_impl(_frame, _n, _resume);
return *this;
}
[[nodiscard]] Bool is_terminated() const noexcept { return _frame.is_terminated(); }
[[nodiscard]] Var<T> value() noexcept { return _frame.get<T>("__yielded_value"); }

private:
template<typename U>
class Iterator {
class RangeForIterator {
private:
luisa::unique_ptr<CoroFrame> _frame;
U _f;
GeneratorIter &_g;
bool _invoked{false};
LoopStmt *_loop{nullptr};

private:
friend class Generator;
Iterator(luisa::unique_ptr<CoroFrame> frame, U f) noexcept
: _frame{std::move(frame)}, _f{std::move(f)} {}
friend class GeneratorIter;
explicit RangeForIterator(GeneratorIter &g) noexcept : _g{g} {}

public:
Iterator &operator++() noexcept {
RangeForIterator &operator++() noexcept {
_invoked = true;
_f(*_frame, false);
compute::detail::FunctionBuilder::current()->pop_scope(_loop->body());
return *this;
}
[[nodiscard]] bool operator==(luisa::default_sentinel_t) const noexcept { return _invoked; }
[[nodiscard]] Var<expr_value_t<Ret>> operator*() noexcept {
_f(*_frame, true);
[[nodiscard]] Var<T> operator*() noexcept {
auto fb = compute::detail::FunctionBuilder::current();
_loop = fb->loop_();
fb->push_scope(_loop->body());
dsl::if_(_frame->is_terminated(), [] { dsl::break_(); });
return _frame->get<Ret>("yield_value");
_g.update();
dsl::if_(_g.is_terminated(), [] { dsl::break_(); });
return _g.value();
}
};

private:
template<typename U>
[[nodiscard]] auto _make_iterator(U u, luisa::optional<Expr<uint3>> coro_id) const noexcept {
auto frame = luisa::make_unique<CoroFrame>(coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate());
return Iterator<U>{std::move(frame), std::move(u)};
}
public:
[[nodiscard]] auto begin() noexcept { return RangeForIterator{*this}; }
[[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; }
};

}// namespace detail

template<typename Ret, typename... Args>
class Generator<Ret(Args...)> {

static_assert(!std::is_same_v<Ret, void>,
"Generator function must not return void.");

private:
template<typename U>
class Stepper : public concepts::Noncopyable {
private:
luisa::optional<Expr<uint3>> _coro_id;
const Generator &_g;
U _f;
Coroutine<void(Args...)> _coro;

private:
friend class Generator;
Stepper(const Generator &g, U f) noexcept : _g{g}, _f{std::move(f)} {}
public:
template<typename Def>
requires std::negation_v<is_callable<std::remove_cvref_t<Def>>> &&
std::negation_v<is_kernel<std::remove_cvref_t<Def>>>
Generator(Def &&f) noexcept : _coro{std::forward<Def>(f)} {}

public:
[[nodiscard]] auto set_id(Expr<uint3> coro_id) && noexcept {
_coro_id.emplace(coro_id);
return std::move(*this);
}
[[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), std::move(_coro_id)); }
[[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; }
};
public:
[[nodiscard]] auto coroutine() const noexcept { return _coro; }

private:
[[nodiscard]] auto _iter(luisa::optional<Expr<uint3>> coro_id,
compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return detail::GeneratorIter<Ret>{
_coro.subroutine_count(),
coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate(),
[=, this](CoroFrame &f, CoroToken token) noexcept {
_coro[token](f, args...);
},
};
}

public:
[[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
auto f = [=, this](CoroFrame &frame, bool is_entry) noexcept {
detail::coroutine_generator_step_impl(
frame, _coro.subroutine_count(), is_entry,
[&](CoroToken token, CoroFrame &ff) noexcept {
_coro.subroutine(token)(ff, args...);
});
};
return Stepper<decltype(f)>{*this, std::move(f)};
return _iter(luisa::nullopt, args...);
}
[[nodiscard]] auto operator()(Expr<uint3> coro_id, compute::detail::prototype_to_callable_invocation_t<Args>... args) const noexcept {
return _iter(luisa::make_optional(coro_id), args...);
}
};

Expand Down
Loading

0 comments on commit ad36f74

Please sign in to comment.