From fcad6ac3cb94ba2b00d48776a6f28bc038d16114 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 21:52:43 +0800 Subject: [PATCH 1/5] refactor generator impl --- include/luisa/coro/coro_func.h | 116 +++++++++++++++++---------------- include/luisa/dsl/sugar.h | 8 +-- src/coro/coro_func.cpp | 23 +++---- 3 files changed, 73 insertions(+), 74 deletions(-) diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index 5adc60712..f7506cc61 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -15,9 +15,6 @@ namespace detail { LC_CORO_API void coroutine_chained_await_impl( CoroFrame &frame, uint node_count, luisa::move_only_function node) noexcept; -LC_CORO_API void coroutine_generator_step_impl( - CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept; }// namespace detail template @@ -145,94 +142,99 @@ class Generator { static_assert(luisa::always_false_v); }; -template -class Generator { +namespace detail { - static_assert(!std::is_same_v, - "Generator function must not return void."); +LC_CORO_API void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept; + +template +class GeneratorIter : public concepts::Noncopyable { private: - Coroutine _coro; + uint _n; + CoroFrame _frame; + using Resume = luisa::move_only_function; + Resume resume; -public: - template - requires std::negation_v>> && - std::negation_v>> - Generator(Def &&f) noexcept : _coro{std::forward(f)} {} +private: + template + friend class Generator; + GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept + : _n{n}, _frame{std::move(frame)}, resume{std::move(resume)} {} public: - [[nodiscard]] auto coroutine() const noexcept { return _coro; } + [[nodiscard]] auto set_id(Expr coro_id) && noexcept { + _frame.coro_id = coro_id; + return std::move(*this); + } + [[nodiscard]] Bool has_next() const noexcept { return !_frame.is_terminated(); } + [[nodiscard]] Var next() noexcept { + coroutine_generator_next_impl(_frame, _n, resume); + return _frame.get("__yielded_value"); + } private: - template - class Iterator { + class RangeForIterator { private: - luisa::unique_ptr _frame; - U _f; + GeneratorIter &_g; bool _invoked{false}; LoopStmt *_loop{nullptr}; private: - friend class Generator; - Iterator(luisa::unique_ptr frame, U f) noexcept - : _frame{std::move(frame)}, _f{std::move(f)} {} + friend class GeneratorIter; + explicit RangeForIterator(GeneratorIter &g) noexcept : _g{g} {} public: - Iterator &operator++() noexcept { + RangeForIterator &operator++() noexcept { _invoked = true; - _f(*_frame, false); compute::detail::FunctionBuilder::current()->pop_scope(_loop->body()); return *this; } [[nodiscard]] bool operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } - [[nodiscard]] Var> operator*() noexcept { - _f(*_frame, true); + [[nodiscard]] Var operator*() noexcept { auto fb = compute::detail::FunctionBuilder::current(); _loop = fb->loop_(); fb->push_scope(_loop->body()); - dsl::if_(_frame->is_terminated(), [] { dsl::break_(); }); - return _frame->get("yield_value"); + dsl::if_(!_g.has_next(), [] { dsl::break_(); }); + return _g.next(); } }; -private: - template - [[nodiscard]] auto _make_iterator(U u, luisa::optional> coro_id) const noexcept { - auto frame = luisa::make_unique(coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate()); - return Iterator{std::move(frame), std::move(u)}; - } +public: + [[nodiscard]] auto begin() noexcept { return RangeForIterator{*this}; } + [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } +}; + +}// namespace detail + +template +class Generator { + + static_assert(!std::is_same_v, + "Generator function must not return void."); private: - template - class Stepper : public concepts::Noncopyable { - private: - luisa::optional> _coro_id; - const Generator &_g; - U _f; + Coroutine _coro; - private: - friend class Generator; - Stepper(const Generator &g, U f) noexcept : _g{g}, _f{std::move(f)} {} +public: + template + requires std::negation_v>> && + std::negation_v>> + Generator(Def &&f) noexcept : _coro{std::forward(f)} {} - public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _coro_id.emplace(coro_id); - return std::move(*this); - } - [[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), std::move(_coro_id)); } - [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } - }; +public: + [[nodiscard]] auto coroutine() const noexcept { return _coro; } public: [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=, this](CoroFrame &frame, bool is_entry) noexcept { - detail::coroutine_generator_step_impl( - frame, _coro.subroutine_count(), is_entry, - [&](CoroToken token, CoroFrame &ff) noexcept { - _coro.subroutine(token)(ff, args...); - }); + return detail::GeneratorIter{ + _coro.subroutine_count(), + _coro.instantiate(), + [=, this](CoroFrame &frame, CoroToken token) noexcept { + _coro[token](frame, args...); + }, }; - return Stepper{*this, std::move(f)}; } }; diff --git a/include/luisa/dsl/sugar.h b/include/luisa/dsl/sugar.h index fb36b3a9c..c0efbf1f2 100644 --- a/include/luisa/dsl/sugar.h +++ b/include/luisa/dsl/sugar.h @@ -122,10 +122,10 @@ namespace luisa::compute::dsl_detail { #define $promise(...) ::luisa::compute::dsl::promise(__VA_ARGS__) -#define $yield(...) \ - do { \ - ::luisa::compute::dsl::promise("yield_value", __VA_ARGS__); \ - ::luisa::compute::dsl::suspend(); \ +#define $yield(...) \ + do { \ + ::luisa::compute::dsl::promise("__yielded_value", __VA_ARGS__); \ + ::luisa::compute::dsl::suspend(); \ } while (false) #define $await ::luisa::compute::coroutine::detail::CoroAwaitInvoker{} % diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index 966719c01..c8fa05541 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -23,18 +23,15 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, }; } -void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept { - if (is_entry) { - node(coro_token_entry, frame); - } else { - $switch (frame.target_token) { - for (auto i = 1u; i < node_count; i++) { - $case (i) { node(i, frame); }; - } - $default { dsl::unreachable(); }; - }; - } +inline void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept { + $switch (frame.target_token) { + for (auto i = 0u; i < node_count; i++) { + $case (i) { resume(frame, i); }; + } + $default { dsl::unreachable(); }; + }; } -}// namespace luisa::compute::coro_v2::detail +}// namespace luisa::compute::coroutine::detail From f57ab51419702755fcdaf1b7f1a8f36a3068791a Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 22:59:33 +0800 Subject: [PATCH 2/5] refactor --- include/luisa/coro/coro_func.h | 101 +++++++++++++++++++-------------- src/tests/coro/helloworld.cpp | 9 ++- 2 files changed, 66 insertions(+), 44 deletions(-) diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index f7506cc61..fbb820eb8 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -22,6 +22,26 @@ class Coroutine { static_assert(luisa::always_false_v); }; +namespace detail { + +class LC_CORO_API CoroAwaiter : public concepts::Noncopyable { + +private: + using Await = luisa::move_only_function; + Await _await; + +private: + template + friend class Coroutine; + explicit CoroAwaiter(Await await) noexcept + : _await{std::move(await)} {} + +public: + void await() && noexcept { _await(); } +}; + +}// namespace detail + template class Coroutine { @@ -95,41 +115,29 @@ class Coroutine { [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } private: - template - class Awaiter : public concepts::Noncopyable { - private: - U _f; - luisa::optional> _coro_id; - - private: - friend class Coroutine; - explicit Awaiter(U f) noexcept : _f{std::move(f)} {} - - public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _coro_id.emplace(coro_id); - return std::move(*this); - } - void await() && noexcept { return _f(std::move(_coro_id)); } - }; - -public: - [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=, this](luisa::optional> coro_id) noexcept { + [[nodiscard]] auto _await(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return detail::CoroAwaiter{[=, coro_id = std::move(coro_id), this]() noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept { subroutine(token)(f, args...); }); - }; - return Awaiter{std::move(f)}; + }}; + } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::make_optional(coro_id), args...); } }; namespace detail { struct CoroAwaitInvoker { - template - void operator%(A &&awaiter) && noexcept { - std::forward(awaiter).await(); + void operator%(CoroAwaiter &&awaiter) const && noexcept { + std::move(awaiter).await(); } }; }// namespace detail @@ -155,24 +163,21 @@ class GeneratorIter : public concepts::Noncopyable { uint _n; CoroFrame _frame; using Resume = luisa::move_only_function; - Resume resume; + Resume _resume; private: template friend class Generator; GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept - : _n{n}, _frame{std::move(frame)}, resume{std::move(resume)} {} + : _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {} public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _frame.coro_id = coro_id; - return std::move(*this); - } - [[nodiscard]] Bool has_next() const noexcept { return !_frame.is_terminated(); } - [[nodiscard]] Var next() noexcept { - coroutine_generator_next_impl(_frame, _n, resume); - return _frame.get("__yielded_value"); + auto &update() noexcept { + coroutine_generator_next_impl(_frame, _n, _resume); + return *this; } + [[nodiscard]] Bool is_terminated() const noexcept { return _frame.is_terminated(); } + [[nodiscard]] Var value() noexcept { return _frame.get("__yielded_value"); } private: class RangeForIterator { @@ -196,8 +201,9 @@ class GeneratorIter : public concepts::Noncopyable { auto fb = compute::detail::FunctionBuilder::current(); _loop = fb->loop_(); fb->push_scope(_loop->body()); - dsl::if_(!_g.has_next(), [] { dsl::break_(); }); - return _g.next(); + _g.update(); + dsl::if_(_g.is_terminated(), [] { dsl::break_(); }); + return _g.value(); } }; @@ -226,16 +232,25 @@ class Generator { public: [[nodiscard]] auto coroutine() const noexcept { return _coro; } -public: - [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { +private: + [[nodiscard]] auto _iter(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { return detail::GeneratorIter{ _coro.subroutine_count(), - _coro.instantiate(), - [=, this](CoroFrame &frame, CoroToken token) noexcept { - _coro[token](frame, args...); + coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate(), + [=, this](CoroFrame &f, CoroToken token) noexcept { + _coro[token](f, args...); }, }; } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::make_optional(coro_id), args...); + } }; }// namespace luisa::compute::coroutine diff --git a/src/tests/coro/helloworld.cpp b/src/tests/coro/helloworld.cpp index ae3c70e4d..3509a6e0d 100644 --- a/src/tests/coro/helloworld.cpp +++ b/src/tests/coro/helloworld.cpp @@ -25,10 +25,17 @@ int main(int argc, char *argv[]) { x += 1u; }; }; - for (auto x : range(100u)) { + auto iter = range(10u); + $while (!iter.update().is_terminated()) { + auto x = iter.value(); + device_log("x = {}", x); + }; + for (auto x : range(10u)) { device_log("x = {}", x); } }; + auto shader = device.compile(test); + stream << shader().dispatch(1) << synchronize(); coroutine::Coroutine nested2 = [](UInt n) { $for (i, n) { From 1161b00aff97d2540257e956336e16cd655573e2 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 01:01:07 +0800 Subject: [PATCH 3/5] support `Shared::operator[]` for non-SoA --- include/luisa/coro/coro_frame.h | 4 +- include/luisa/coro/coro_frame_desc.h | 1 + include/luisa/coro/coro_frame_smem.h | 57 +++++++++++++++++-- include/luisa/coro/coro_func.h | 13 ++--- .../coro/schedulers/persistent_threads.h | 2 +- src/coro/CMakeLists.txt | 1 + src/coro/coro_frame.cpp | 2 +- src/coro/coro_frame_desc.cpp | 6 ++ src/coro/coro_frame_smem.cpp | 7 +++ src/coro/coro_func.cpp | 2 +- src/coro/schedulers/persistent_threads.cpp | 48 ++++++++++++---- 11 files changed, 111 insertions(+), 32 deletions(-) create mode 100644 src/coro/coro_frame_smem.cpp diff --git a/include/luisa/coro/coro_frame.h b/include/luisa/coro/coro_frame.h index bac5580f8..943622ec6 100644 --- a/include/luisa/coro/coro_frame.h +++ b/include/luisa/coro/coro_frame.h @@ -14,14 +14,14 @@ class LC_CORO_API CoroFrame { private: luisa::shared_ptr _desc; - const RefExpr *_expression; + const Expression *_expression; public: Var &coro_id; Var &target_token; public: - CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept; + CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept; CoroFrame(CoroFrame &&another) noexcept; CoroFrame(const CoroFrame &another) noexcept; CoroFrame &operator=(const CoroFrame &rhs) noexcept; diff --git a/include/luisa/coro/coro_frame_desc.h b/include/luisa/coro/coro_frame_desc.h index 0c81cbcf8..82a873e93 100644 --- a/include/luisa/coro/coro_frame_desc.h +++ b/include/luisa/coro/coro_frame_desc.h @@ -26,6 +26,7 @@ class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this create(const Type *type, DesignatedFieldDict m) noexcept; [[nodiscard]] auto type() const noexcept { return _type; } + [[nodiscard]] const Type *field(uint index) const noexcept; [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; [[nodiscard]] luisa::string dump() const noexcept; diff --git a/include/luisa/coro/coro_frame_smem.h b/include/luisa/coro/coro_frame_smem.h index 42159b6a4..c23057014 100644 --- a/include/luisa/coro/coro_frame_smem.h +++ b/include/luisa/coro/coro_frame_smem.h @@ -4,17 +4,24 @@ #pragma once +#include #include #include namespace luisa::compute { +namespace detail { +[[noreturn]] LC_CORO_API void error_coro_frame_smem_subscript_on_soa() noexcept; +}// namespace detail + template<> class Shared { private: luisa::shared_ptr _desc; luisa::vector _expressions; + // we need to return a `CoroFrame &` for operator[] with proper lifetime + luisa::vector> _temp_frames; size_t _size; private: @@ -39,8 +46,9 @@ class Shared { } public: - Shared(luisa::shared_ptr desc, size_t n, - bool soa = true, luisa::span soa_excluded_fields = {}) noexcept + Shared(luisa::shared_ptr desc, + size_t n, bool soa = false, + luisa::span soa_excluded_fields = {}) noexcept : _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); } Shared(Shared &&) noexcept = default; @@ -52,10 +60,11 @@ class Shared { [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } [[nodiscard]] auto size() const noexcept { return _size; } -public: +private: /// Read index with active fields template - [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + requires is_integral_expr_v + [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); @@ -71,7 +80,7 @@ class Shared { if (m == 0u) { auto t = Type::of(); auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); - std::array elems; + std::array elems{}; elems[0] = fb->access(t, s, fb->literal(t, 0u)); elems[1] = fb->access(t, s, fb->literal(t, 1u)); elems[2] = fb->access(t, s, fb->literal(t, 2u)); @@ -88,7 +97,8 @@ class Shared { /// Write index with active fields template - void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields = luisa::nullopt) const noexcept { + requires is_integral_expr_v + void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); if (!is_soa()) { @@ -115,6 +125,41 @@ class Shared { } } } + +public: + /// Reference to the i-th element + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame &operator[](I &&index) noexcept { + if (is_soa()) { detail::error_coro_frame_smem_subscript_on_soa(); } + auto fb = detail::FunctionBuilder::current(); + auto ref = fb->access( + _desc->type(), _expressions[0], + detail::extract_expression(std::forward(index))); + auto temp_frame = luisa::make_unique(_desc, ref); + return *_temp_frames.emplace_back(std::move(temp_frame)); + } + + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index) const noexcept { + return _read(std::forward(index), luisa::nullopt); + } + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index, luisa::span active_fields) const noexcept { + return _read(std::forward(index), luisa::make_optional(active_fields)); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + _write(std::forward(index), frame, luisa::nullopt); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { + _write(std::forward(index), frame, luisa::make_optional(active_fields)); + } }; }// namespace luisa::compute diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index fbb820eb8..e3d22d4b0 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -30,13 +30,9 @@ class LC_CORO_API CoroAwaiter : public concepts::Noncopyable { using Await = luisa::move_only_function; Await _await; -private: - template - friend class Coroutine; +public: explicit CoroAwaiter(Await await) noexcept : _await{std::move(await)} {} - -public: void await() && noexcept { _await(); } }; @@ -165,11 +161,10 @@ class GeneratorIter : public concepts::Noncopyable { using Resume = luisa::move_only_function; Resume _resume; -private: - template - friend class Generator; +public: GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept - : _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {} + : _n{n}, _frame{std::move(frame)}, + _resume{std::move(resume)} {} public: auto &update() noexcept { diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h index 3972b77fc..e0c647980 100644 --- a/include/luisa/coro/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -15,7 +15,7 @@ struct PersistentThreadsCoroSchedulerConfig { uint thread_count = 64_k; uint block_size = 128; uint fetch_size = 16; - bool shared_memory_soa = true; + bool shared_memory_soa = false; bool global_memory_ext = false; }; diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index 86f1c14e6..5743d6d23 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -2,6 +2,7 @@ set(LUISA_COMPUTE_CORO_SOURCES coro_frame.cpp coro_frame_buffer.cpp coro_frame_desc.cpp + coro_frame_smem.cpp coro_func.cpp coro_graph.cpp schedulers/persistent_threads.cpp diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 3142f9d85..d23aeca5a 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -9,7 +9,7 @@ namespace luisa::compute::coroutine { -CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept +CoroFrame::CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept : _desc{std::move(desc)}, _expression{expr}, coro_id{this->get(0u)}, diff --git a/src/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp index b4b1d52f4..c9275a47f 100644 --- a/src/coro/coro_frame_desc.cpp +++ b/src/coro/coro_frame_desc.cpp @@ -28,6 +28,12 @@ luisa::shared_ptr CoroFrameDesc::create(const Type *type, Designa return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); } +const Type *CoroFrameDesc::field(uint index) const noexcept { + auto fields = _type->members(); + LUISA_ASSERT(index < fields.size(), "CoroFrame field index out of range."); + return fields[index]; +} + uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { if (name == "coro_id") { return 0u; } if (name == "target_token") { return 1u; } diff --git a/src/coro/coro_frame_smem.cpp b/src/coro/coro_frame_smem.cpp new file mode 100644 index 000000000..3a390d0e3 --- /dev/null +++ b/src/coro/coro_frame_smem.cpp @@ -0,0 +1,7 @@ +#include + +namespace luisa::compute::detail { +void error_coro_frame_smem_subscript_on_soa() noexcept { + LUISA_ERROR_WITH_LOCATION("Subscripting on SOA shared frame is not allowed."); +} +}// namespace luisa::compute::detail diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index c8fa05541..fe05e74c0 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -23,7 +23,7 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, }; } -inline void coroutine_generator_next_impl( +void coroutine_generator_next_impl( CoroFrame &frame, uint node_count, const luisa::move_only_function &resume) noexcept { $switch (frame.target_token) { diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 863287c32..9b7ac37c0 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -29,7 +29,11 @@ void persistent_threads_coro_scheduler_main_kernel_impl( for (auto index : dsl::dynamic_range(q_fac)) { auto s = index * config.block_size + thread_x(); all_token[s] = 0u; - // frames.write(s, coro.instantiate(), std::array{0u, 1u}); + if (config.shared_memory_soa) { + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u}); + } else { + frames[s].target_token = 0u; + } } for (auto index : dsl::dynamic_range(g_fac)) { auto s = index * config.block_size + thread_x(); @@ -120,23 +124,37 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto frame_token = all_token[dst]; $if (coro_token != 0u) { $if (frame_token != 0u) { - auto g_state = global_frames->read(global_id); - global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); + if (config.shared_memory_soa) { + auto g_state = global_frames->read(global_id); + global_frames->write(global_id, frames.read(dst)); + frames.write(dst, g_state); + } else { + auto g_state = global_frames->read(global_id); + global_frames->write(global_id, frames[dst]); + frames[dst] = g_state; + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; } $else { - auto g_state = global_frames->read(global_id); - frames.write(dst, g_state); + if (config.shared_memory_soa) { + auto g_state = global_frames->read(global_id); + frames.write(dst, g_state); + } else { + frames[dst] = global_frames->read(global_id); + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; }; } $else { $if (frame_token != 0u) { - auto frame = frames.read(dst); - global_frames->write(global_id, frame); + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; }; @@ -183,10 +201,16 @@ void persistent_threads_coro_scheduler_main_kernel_impl( for (auto i = 1u; i < subroutine_count; i++) { $case (i) { work_counter.atomic(i).fetch_sub(1u); - auto frame = frames.read(pid, graph->node(i).input_fields()); - call_subroutine(frame, i); - auto next = frame.target_token & coro_token_valid_mask; - frames.write(pid, frame, graph->node(i).output_fields()); + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = frames.read(pid, graph->node(i).input_fields()); + call_subroutine(frame, i); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->node(i).output_fields()); + } else { + call_subroutine(frames[pid], i); + next = frames[pid].target_token & coro_token_valid_mask; + } all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); }; From b2c36209a91e14a990532343b74b16b2df2f4a05 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 14:40:25 +0800 Subject: [PATCH 4/5] remove --- scripts/print_hlsl_builtin.lua | 2 +- src/backends/common/hlsl/builtin/coroutine | 9 --------- src/backends/common/hlsl/builtin/coroutine.c | 4 ---- src/backends/common/hlsl/hlsl_codegen_util.cpp | 4 ---- 4 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 src/backends/common/hlsl/builtin/coroutine delete mode 100644 src/backends/common/hlsl/builtin/coroutine.c diff --git a/scripts/print_hlsl_builtin.lua b/scripts/print_hlsl_builtin.lua index d6f6b0932..6688ef65e 100644 --- a/scripts/print_hlsl_builtin.lua +++ b/scripts/print_hlsl_builtin.lua @@ -5,7 +5,7 @@ local files_list = {'accel_process', 'bindless_upload', 'bc6_encode_block', 'bc6 'bc6_trymode_le10cs', 'bc7_encode_block', 'bc7_header', 'bc7_trymode_02cs', 'bc7_trymode_137cs', 'bc7_trymode_456cs', 'hlsl_header', 'raytracing_header', 'tex2d_bindless', 'tex3d_bindless', 'compute_quad', 'determinant', 'inverse', 'indirect', 'resource_size', 'accel_header', 'copy_sign', - 'bindless_common', 'auto_diff', "reduce", "coroutine"} + 'bindless_common', 'auto_diff', "reduce"} local lib = import("lib") local hlsl_builtin_path = path.join(os.projectdir(), "src/backends/common/hlsl/builtin") diff --git a/src/backends/common/hlsl/builtin/coroutine b/src/backends/common/hlsl/builtin/coroutine deleted file mode 100644 index e9cc25f62..000000000 --- a/src/backends/common/hlsl/builtin/coroutine +++ /dev/null @@ -1,9 +0,0 @@ -template -uint3 CoroId(in T frame){ -return frame.v0.v; -} - -template -uint CoroToken(in T frame){ -return frame.v1; -} \ No newline at end of file diff --git a/src/backends/common/hlsl/builtin/coroutine.c b/src/backends/common/hlsl/builtin/coroutine.c deleted file mode 100644 index 6f4494609..000000000 --- a/src/backends/common/hlsl/builtin/coroutine.c +++ /dev/null @@ -1,4 +0,0 @@ -#include "hlsl_config.h" -LC_HLSL_EXTERN char coroutine[]={116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,51,32,67,111,114,111,73,100,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,48,46,118,59,10,125,10,10,116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,32,67,111,114,111,84,111,107,101,110,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,49,59,10,125}; -LC_HLSL_EXTERN char *get_coroutine(){return coroutine;} -LC_HLSL_EXTERN int get_coroutine_size(){return 136;} diff --git a/src/backends/common/hlsl/hlsl_codegen_util.cpp b/src/backends/common/hlsl/hlsl_codegen_util.cpp index 95757f6c7..5cba86d17 100644 --- a/src/backends/common/hlsl/hlsl_codegen_util.cpp +++ b/src/backends/common/hlsl/hlsl_codegen_util.cpp @@ -117,10 +117,6 @@ static size_t AddHeader(CallOpSet const &ops, vstd::StringBuilder &builder, bool ops.test(CallOp::MATRIX_COMPONENT_WISE_MULTIPLICATION)) { builder << CodegenUtility::ReadInternalHLSLFile("reduce"); } - if (ops.test(CallOp::CORO_ID) || - ops.test(CallOp::CORO_TOKEN)) { - builder << CodegenUtility::ReadInternalHLSLFile("coroutine"); - } return immutable_size; } }// namespace detail From b0d90b8e078d6ff51fb4e4ed97b55612852bf0e4 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 16:41:47 +0800 Subject: [PATCH 5/5] trying to optimize dx --- .../coro/schedulers/persistent_threads.h | 21 ++-- src/coro/schedulers/persistent_threads.cpp | 109 ++++++++---------- 2 files changed, 62 insertions(+), 68 deletions(-) diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h index e0c647980..7d777e2b4 100644 --- a/include/luisa/coro/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -14,7 +14,7 @@ namespace luisa::compute::coroutine { struct PersistentThreadsCoroSchedulerConfig { uint thread_count = 64_k; uint block_size = 128; - uint fetch_size = 16; + uint fetch_size = 4; bool shared_memory_soa = false; bool global_memory_ext = false; }; @@ -23,7 +23,7 @@ namespace detail { LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept; }// namespace detail @@ -52,7 +52,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { + Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_size_prefix_product, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); auto global_queue_size = _config.block_size * g_fac; auto shared_queue_size = _config.block_size * q_fac; @@ -60,7 +60,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; detail::persistent_threads_coro_scheduler_main_kernel_impl( _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, - frames, dispatch_shape, global, _global_frames, call_subroutine); + frames, dispatch_size_prefix_product, global, _global_frames, call_subroutine); }; _pt_shader = device.compile(main_kernel); _clear_shader = device.compile<1>([](BufferUInt global) { @@ -78,12 +78,19 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - stream << _clear_shader(_global).dispatch(1u); + auto dispatch_size_prefix_product = make_uint3( + dispatch_size.x, + dispatch_size.x * dispatch_size.y, + dispatch_size.x * dispatch_size.y * dispatch_size.z); if (_config.global_memory_ext) { auto n = static_cast(_global_frames.size()); - stream << _initialize_shader(n).dispatch(n); + stream << _clear_shader(_global).dispatch(1u) + << _initialize_shader(n).dispatch(n) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); + } else { + stream << _clear_shader(_global).dispatch(1u) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); } - stream << _pt_shader(_global, dispatch_size, args...).dispatch(_config.thread_count); } public: diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 9b7ac37c0..81302e9b1 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -13,7 +13,7 @@ namespace luisa::compute::coroutine::detail { void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept { @@ -26,16 +26,17 @@ void persistent_threads_coro_scheduler_main_kernel_impl( shared_queue_size}; Shared workload{2}; Shared work_stat{2};// work_state[0] for max_count, [1] for max_id - for (auto index : dsl::dynamic_range(q_fac)) { + for (auto index = 0u; index < q_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[s] = 0u; if (config.shared_memory_soa) { - frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u}); + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{0u, 1u}); } else { + frames[s].coro_id = make_uint3(); frames[s].target_token = 0u; } } - for (auto index : dsl::dynamic_range(g_fac)) { + for (auto index = 0u; index < g_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[shared_queue_size + s] = 0u; } @@ -58,8 +59,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( rem_local[0] = 0u; sync_block(); auto count = def(0u); - auto count_limit = def(-1); - auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; + auto count_limit = def(std::numeric_limits::max()); $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { sync_block();//very important, synchronize for condition rem_local[0] = 0u; @@ -70,8 +70,8 @@ void persistent_threads_coro_scheduler_main_kernel_impl( $if (thread_x() == config.block_size - 1) { $if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); - workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); - $if (workload[0] >= dispatch_size) { + workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size_prefix_product.z); + $if (workload[0] >= dispatch_size_prefix_product.z) { rem_global[0] = 0u; }; }; @@ -96,7 +96,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_offset[1] = 0; sync_block(); if (!config.global_memory_ext) { - for (auto index : dsl::dynamic_range(q_fac)) {//collect indices + for (auto index = 0u; index < q_fac; index++) {//collect indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token == work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -104,7 +104,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( }; } } else { - for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices + for (auto index = 0u; index < q_fac; index++) {//collect switch out indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token != work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -113,7 +113,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( } sync_block(); $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work - for (auto index : dsl::dynamic_range(g_fac)) { //swap frames + $for (index, 0u, g_fac) { //swap frames auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); auto g_queue_id = index * config.block_size + thread_x(); auto coro_token = all_token[shared_queue_size + g_queue_id]; @@ -123,76 +123,62 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto dst = path_id[id]; auto frame_token = all_token[dst]; $if (coro_token != 0u) { + auto g_state = global_frames->read(global_id); $if (frame_token != 0u) { if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); } else { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames[dst]); - frames[dst] = g_state; } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - } - $else { - if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); - frames.write(dst, g_state); - } else { - frames[dst] = global_frames->read(global_id); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; }; + if (config.shared_memory_soa) { + frames.write(dst, g_state); + } else { + frames[dst] = g_state; + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; } - $else { - $if (frame_token != 0u) { - if (config.shared_memory_soa) { - auto frame = frames.read(dst); - global_frames->write(global_id, frame); - } else { - global_frames->write(global_id, frames[dst]); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; + $elif (frame_token != 0u) { + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; }; }; }; - } + }; }; } auto gen_st = workload[0]; sync_block(); - auto pid = def(0u); - if (config.global_memory_ext) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!config.global_memory_ext) { - launch_condition = (thread_x() < work_offset[0]); - } else { - launch_condition = (all_token[pid] == work_stat[1]); - } - $if (launch_condition) { + auto pid = config.global_memory_ext ? thread_x() : path_id[thread_x()]; + $if (config.global_memory_ext ? (all_token[pid] == work_stat[1]) : (thread_x() < work_offset[0])) { $switch (all_token[pid]) { $case (0u) { $if (gen_st + thread_x() < workload[1]) { work_counter.atomic(0u).fetch_sub(1u); auto global_index = gen_st + thread_x(); - auto image_size = dispatch_shape.x * dispatch_shape.y; - auto index_z = global_index / image_size; - auto index_xy = global_index % image_size; - auto index_x = index_xy % dispatch_shape.x; - auto index_y = index_xy / dispatch_shape.x; - auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); - call_subroutine(frame, coro_token_entry); - auto next = frame.target_token & coro_token_valid_mask; - frames.write(pid, frame, graph->entry().output_fields()); + auto index_z = global_index / dispatch_size_prefix_product.y; + auto index_xy = global_index % dispatch_size_prefix_product.y; + auto index_x = index_xy % dispatch_size_prefix_product.x; + auto index_y = index_xy / dispatch_size_prefix_product.x; + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); + call_subroutine(frame, coro_token_entry); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->entry().output_fields()); + } else { + frames[pid].coro_id = make_uint3(index_x, index_y, index_z); + frames[pid].target_token = coro_token_entry; + call_subroutine(frames[pid], coro_token_entry); + next = frames[pid].target_token & coro_token_valid_mask; + } all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); workload.atomic(0).fetch_add(1u); @@ -215,6 +201,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_counter.atomic(next).fetch_add(1u); }; } + $default { dsl::unreachable(); }; }; }; sync_block();