From 0ee9a0187aaae30620bd22fda2cf8b3c593f9eff Mon Sep 17 00:00:00 2001 From: Mike Date: Tue, 7 May 2024 21:19:58 +0800 Subject: [PATCH 01/67] fixing build --- include/luisa/ast/type_registry.h | 2 +- include/luisa/dsl/struct.h | 3 ++- src/backends/common/shader_print_formatter.h | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/ast/type_registry.h b/include/luisa/ast/type_registry.h index d6f820859..3d5fa6cb3 100644 --- a/include/luisa/ast/type_registry.h +++ b/include/luisa/ast/type_registry.h @@ -372,7 +372,7 @@ constexpr auto is_valid_reflection_v = is_valid_reflection::value; #define LUISA_COROFRAME_STRUCT_REFLECT(S, name) \ template<> \ - struct canonical_layout { \ + struct luisa::compute::canonical_layout { \ using type = std::tuple; \ }; \ template<> \ diff --git a/include/luisa/dsl/struct.h b/include/luisa/dsl/struct.h index 54bb27922..e7f043b34 100644 --- a/include/luisa/dsl/struct.h +++ b/include/luisa/dsl/struct.h @@ -690,9 +690,10 @@ struct luisa_compute_extension {}; : Var{Var{detail::ArgumentCreation{}}} {} \ }; \ }// namespace luisa::compute + #define LUISA_COROFRAME_STRUCT(S) \ LUISA_COROFRAME_STRUCT_REFLECT(S, #S); \ LUISA_COROFRAME_STRUCT_EXT(S) \ LUISA_DERIVE_COROFRAME_SOA(S) \ template<> \ - struct luisa_compute_extension final : luisa::compute::detail::Ref \ No newline at end of file + struct luisa_compute_extension final : luisa::compute::detail::Ref diff --git a/src/backends/common/shader_print_formatter.h b/src/backends/common/shader_print_formatter.h index 1b41c53b2..9282ca909 100644 --- a/src/backends/common/shader_print_formatter.h +++ b/src/backends/common/shader_print_formatter.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace luisa::compute { From 2f3730707e56293ba8f8ed706fd4e3d603be9156 Mon Sep 17 00:00:00 2001 From: Mike Date: Tue, 7 May 2024 23:35:19 +0800 Subject: [PATCH 02/67] wip: coro frame --- include/luisa/coro/coro_dispatcher.h | 3 ++ include/luisa/dsl/coro/coro_frame.h | 60 +++++++++++++++++++++ include/luisa/luisa-compute.h | 10 ++++ src/dsl/CMakeLists.txt | 4 ++ src/dsl/coro/coro_frame.cpp | 81 ++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+) create mode 100644 include/luisa/dsl/coro/coro_frame.h create mode 100644 src/dsl/coro/coro_frame.cpp diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h index 89789a8bf..1e9f431b3 100644 --- a/include/luisa/coro/coro_dispatcher.h +++ b/include/luisa/coro/coro_dispatcher.h @@ -6,6 +6,9 @@ #include #include #include + +#include +#include #include #include namespace luisa::compute { diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h new file mode 100644 index 000000000..22aa8b012 --- /dev/null +++ b/include/luisa/dsl/coro/coro_frame.h @@ -0,0 +1,60 @@ +// +// Created by mike on 5/7/24. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::inline dsl::coro_v2 { + +class CoroFrame; + +class LC_DSL_API CoroFrameDesc : public luisa::enable_shared_from_this { + +public: + using DesignatedMemberDict = luisa::unordered_map; + +private: + const Type *_type{nullptr}; + DesignatedMemberDict _designated_members; + +private: + CoroFrameDesc(const Type *type, DesignatedMemberDict m) noexcept; + +public: + [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedMemberDict m) noexcept; + [[nodiscard]] auto type() const noexcept { return _type; } + [[nodiscard]] auto &designated_members() const noexcept { return _designated_members; } + [[nodiscard]] uint designated_member(luisa::string_view name) const noexcept; + [[nodiscard]] bool operator==(const CoroFrameDesc &rhs) const noexcept; + +public: + [[nodiscard]] CoroFrame instantiate(Expr coro_id = luisa::make_uint3()) const noexcept; +}; + +class LC_DSL_API CoroFrame { + +private: + luisa::shared_ptr _desc; + const RefExpr *_expression; + +private: + friend class CoroFrameDesc; + CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept; + +public: + CoroFrame(CoroFrame &&another) noexcept; + CoroFrame(const CoroFrame &another) noexcept; + CoroFrame &operator=(const CoroFrame &rhs) noexcept; + CoroFrame &operator=(CoroFrame &&rhs) noexcept; + +public: + [[nodiscard]] auto description() const noexcept { return _desc.get(); } + [[nodiscard]] auto expression() const noexcept { return _expression; } +}; + +}// namespace luisa::compute::inline dsl::coro_v2 diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 975781c26..e49b840e6 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -46,8 +46,16 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include + #ifdef LUISA_ENABLE_DSL #include #include @@ -55,6 +63,7 @@ #include #include #include +#include #include #include #include @@ -108,6 +117,7 @@ #include #include #include +#include #include #include #include diff --git a/src/dsl/CMakeLists.txt b/src/dsl/CMakeLists.txt index e5bf95b0d..f847365c8 100644 --- a/src/dsl/CMakeLists.txt +++ b/src/dsl/CMakeLists.txt @@ -10,6 +10,9 @@ if (LUISA_COMPUTE_ENABLE_DSL) set(LUISA_COMPUTE_DSL_RASTER_SOURCES raster/raster_kernel.cpp) + set(LUISA_COMPUTE_DSL_CORO_SOURCES + coro/coro_frame.cpp) + set(LUISA_COMPUTE_DSL_SOURCES builtin.cpp dispatch_indirect.cpp @@ -20,6 +23,7 @@ if (LUISA_COMPUTE_ENABLE_DSL) soa.cpp sugar.cpp ${LUISA_COMPUTE_DSL_RTX_SOURCES} + ${LUISA_COMPUTE_DSL_CORO_SOURCES} ${LUISA_COMPUTE_DSL_RASTER_SOURCES}) add_library(luisa-compute-dsl SHARED ${LUISA_COMPUTE_DSL_SOURCES}) diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp new file mode 100644 index 000000000..17681a374 --- /dev/null +++ b/src/dsl/coro/coro_frame.cpp @@ -0,0 +1,81 @@ +// +// Created by mike on 5/7/24. +// + +#include +#include + +namespace luisa::compute::inline dsl::coro_v2 { + +CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedMemberDict m) noexcept + : _type{type}, _designated_members{std::move(m)} { + LUISA_ASSERT(_type != nullptr, "CoroFrame underlying type must not be null."); + LUISA_ASSERT(_type->is_structure(), "CoroFrame underlying type must be a structure."); + LUISA_ASSERT(_type->members().size() >= 2u, "CoroFrame underlying type must have at least 2 members (coro_id and target_token)."); + LUISA_ASSERT(_type->members()[0] == Type::of(), "CoroFrame member 0 (coro_id) must be uint3."); + LUISA_ASSERT(_type->members()[1] == Type::of(), "CoroFrame member 1 (target_token) must be uint."); + auto member_count = _type->members().size(); + for (auto &&[name, index] : _designated_members) { + LUISA_ASSERT(name != "coro_id", "CoroFrame designated member name 'coro_id' is reserved."); + LUISA_ASSERT(name != "target_token", "CoroFrame designated member name 'target_token' is reserved."); + LUISA_ASSERT(index != 0, "CoroFrame designated member index 0 is reserved for coro_id."); + LUISA_ASSERT(index != 1, "CoroFrame designated member index 1 is reserved for target_token."); + LUISA_ASSERT(index < member_count, "CoroFrame designated member index out of range."); + } +} + +luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedMemberDict m) noexcept { + return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); +} + +uint CoroFrameDesc::designated_member(luisa::string_view name) const noexcept { + if (name == "coro_id") { return 0u; } + if (name == "target_token") { return 1u; } + auto iter = _designated_members.find(name); + LUISA_ASSERT(iter != _designated_members.end(), "CoroFrame designated member not found."); + return iter->second; +} + +bool CoroFrameDesc::operator==(const CoroFrameDesc &rhs) const noexcept { + return this->type() == rhs.type() && this->designated_members() == rhs.designated_members(); +} + +CoroFrame CoroFrameDesc::instantiate(Expr coro_id) const noexcept { + auto fb = detail::FunctionBuilder::current(); + // create an variable for the coro frame + auto expr = fb->local(_type); + // initialize the coro frame members + auto zero_init = fb->call(_type, CallOp::ZERO, {}); + fb->assign(expr, zero_init); + // set the coro_id and target_token fields + auto i = fb->member(_type->members()[0], expr, 0u); + fb->assign(i, coro_id.expression()); + auto zero = fb->call(Type::of(), CallOp::ZERO, {}); + auto target_token = fb->member(_type->members()[1], expr, 1u); + fb->assign(target_token, zero); + return CoroFrame{shared_from_this(), expr}; +} + +CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept + : _desc{std::move(desc)}, _expression{expr} { + LUISA_ASSERT(expr != nullptr, "CoroFrame expression must not be null."); + LUISA_ASSERT(expr->type() == _desc->type(), "CoroFrame expression type mismatch."); +} + +CoroFrame::CoroFrame(CoroFrame &&another) noexcept { + LUISA_NOT_IMPLEMENTED(); +} + +CoroFrame::CoroFrame(const CoroFrame &another) noexcept { + LUISA_NOT_IMPLEMENTED(); +} + +CoroFrame &CoroFrame::operator=(const CoroFrame &rhs) noexcept { + LUISA_NOT_IMPLEMENTED(); +} + +CoroFrame &CoroFrame::operator=(CoroFrame &&rhs) noexcept { + LUISA_NOT_IMPLEMENTED(); +} + +}// namespace luisa::compute::inline dsl::coro_v2 From 4af3d64174afa32c042701626f7845a10bc5536c Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 01:06:23 +0800 Subject: [PATCH 03/67] added coro frame --- include/luisa/dsl/coro/coro_frame.h | 35 ++++++++++++-- src/backends/common/shader_print_formatter.h | 1 + src/dsl/coro/coro_frame.cpp | 51 ++++++++++++-------- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index 22aa8b012..c1ff823f0 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -30,10 +30,10 @@ class LC_DSL_API CoroFrameDesc : public luisa::enable_shared_from_this coro_id = luisa::make_uint3()) const noexcept; + [[nodiscard]] CoroFrame instantiate() const noexcept; + [[nodiscard]] CoroFrame instantiate(Expr coro_id) const noexcept; }; class LC_DSL_API CoroFrame { @@ -42,19 +42,44 @@ class LC_DSL_API CoroFrame { luisa::shared_ptr _desc; const RefExpr *_expression; -private: - friend class CoroFrameDesc; - CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept; +public: + Var &coro_id; + Var &target_token; public: + CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept; CoroFrame(CoroFrame &&another) noexcept; CoroFrame(const CoroFrame &another) noexcept; CoroFrame &operator=(const CoroFrame &rhs) noexcept; CoroFrame &operator=(CoroFrame &&rhs) noexcept; +private: + void _check_member_index(uint index) const noexcept; + public: [[nodiscard]] auto description() const noexcept { return _desc.get(); } [[nodiscard]] auto expression() const noexcept { return _expression; } + +public: + template + [[nodiscard]] Var &m(uint index) noexcept { + _check_member_index(index); + auto fb = detail::FunctionBuilder::current(); + auto member = fb->member(_desc->type()->members()[index], _expression, index); + return *fb->create_temporary>(member); + } + template + [[nodiscard]] const Var &m(uint index) const noexcept { + return const_cast(this)->m(index); + } + template + [[nodiscard]] Var &m(luisa::string_view name) noexcept { + return m(_desc->designated_member(name)); + } + template + [[nodiscard]] const Var &m(luisa::string_view name) const noexcept { + return const_cast(this)->m(name); + } }; }// namespace luisa::compute::inline dsl::coro_v2 diff --git a/src/backends/common/shader_print_formatter.h b/src/backends/common/shader_print_formatter.h index 9282ca909..1b41c53b2 100644 --- a/src/backends/common/shader_print_formatter.h +++ b/src/backends/common/shader_print_formatter.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace luisa::compute { diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp index 17681a374..ac30e741d 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/dsl/coro/coro_frame.cpp @@ -36,46 +36,59 @@ uint CoroFrameDesc::designated_member(luisa::string_view name) const noexcept { return iter->second; } -bool CoroFrameDesc::operator==(const CoroFrameDesc &rhs) const noexcept { - return this->type() == rhs.type() && this->designated_members() == rhs.designated_members(); -} - -CoroFrame CoroFrameDesc::instantiate(Expr coro_id) const noexcept { +CoroFrame CoroFrameDesc::instantiate() const noexcept { auto fb = detail::FunctionBuilder::current(); // create an variable for the coro frame auto expr = fb->local(_type); // initialize the coro frame members auto zero_init = fb->call(_type, CallOp::ZERO, {}); fb->assign(expr, zero_init); - // set the coro_id and target_token fields - auto i = fb->member(_type->members()[0], expr, 0u); - fb->assign(i, coro_id.expression()); - auto zero = fb->call(Type::of(), CallOp::ZERO, {}); - auto target_token = fb->member(_type->members()[1], expr, 1u); - fb->assign(target_token, zero); return CoroFrame{shared_from_this(), expr}; } +CoroFrame CoroFrameDesc::instantiate(Expr coro_id) const noexcept { + auto frame = instantiate(); + frame.coro_id = coro_id; + return frame; +} + CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept - : _desc{std::move(desc)}, _expression{expr} { + : _desc{std::move(desc)}, + _expression{expr}, + coro_id{this->m(0u)}, + target_token{this->m(1u)} { LUISA_ASSERT(expr != nullptr, "CoroFrame expression must not be null."); LUISA_ASSERT(expr->type() == _desc->type(), "CoroFrame expression type mismatch."); } -CoroFrame::CoroFrame(CoroFrame &&another) noexcept { - LUISA_NOT_IMPLEMENTED(); +CoroFrame::CoroFrame(CoroFrame &&another) noexcept + : CoroFrame{std::move(another._desc), another._expression} { + another._expression = nullptr; } -CoroFrame::CoroFrame(const CoroFrame &another) noexcept { - LUISA_NOT_IMPLEMENTED(); -} +CoroFrame::CoroFrame(const CoroFrame &another) noexcept + : CoroFrame{another._desc, [e = another.expression()]() noexcept { + auto fb = detail::FunctionBuilder::current(); + auto copy = fb->local(e->type()); + fb->assign(copy, e); + return copy; + }()} {} CoroFrame &CoroFrame::operator=(const CoroFrame &rhs) noexcept { - LUISA_NOT_IMPLEMENTED(); + if (this == std::addressof(rhs)) { return *this; } + LUISA_ASSERT(this->description() == rhs.description(), "CoroFrame description mismatch."); + auto fb = detail::FunctionBuilder::current(); + fb->assign(_expression, rhs.expression()); + return *this; } CoroFrame &CoroFrame::operator=(CoroFrame &&rhs) noexcept { - LUISA_NOT_IMPLEMENTED(); + return *this = static_cast(rhs); +} + +void CoroFrame::_check_member_index(uint index) const noexcept { + LUISA_ASSERT(index < _desc->type()->members().size(), + "CoroFrame member index out of range."); } }// namespace luisa::compute::inline dsl::coro_v2 From ed74e980c649b6deafe1d18babf164d16882dc6e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 01:07:14 +0800 Subject: [PATCH 04/67] added coro frame --- include/luisa/dsl/coro/coro_frame.h | 14 +++++++------- src/dsl/coro/coro_frame.cpp | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index c1ff823f0..57beb9049 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -62,23 +62,23 @@ class LC_DSL_API CoroFrame { public: template - [[nodiscard]] Var &m(uint index) noexcept { + [[nodiscard]] Var &get(uint index) noexcept { _check_member_index(index); auto fb = detail::FunctionBuilder::current(); auto member = fb->member(_desc->type()->members()[index], _expression, index); return *fb->create_temporary>(member); } template - [[nodiscard]] const Var &m(uint index) const noexcept { - return const_cast(this)->m(index); + [[nodiscard]] const Var &get(uint index) const noexcept { + return const_cast(this)->get(index); } template - [[nodiscard]] Var &m(luisa::string_view name) noexcept { - return m(_desc->designated_member(name)); + [[nodiscard]] Var &get(luisa::string_view name) noexcept { + return get(_desc->designated_member(name)); } template - [[nodiscard]] const Var &m(luisa::string_view name) const noexcept { - return const_cast(this)->m(name); + [[nodiscard]] const Var &get(luisa::string_view name) const noexcept { + return const_cast(this)->get(name); } }; diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp index ac30e741d..73a55bc21 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/dsl/coro/coro_frame.cpp @@ -55,8 +55,8 @@ CoroFrame CoroFrameDesc::instantiate(Expr coro_id) const noexcept { CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept : _desc{std::move(desc)}, _expression{expr}, - coro_id{this->m(0u)}, - target_token{this->m(1u)} { + coro_id{this->get(0u)}, + target_token{this->get(1u)} { LUISA_ASSERT(expr != nullptr, "CoroFrame expression must not be null."); LUISA_ASSERT(expr->type() == _desc->type(), "CoroFrame expression type mismatch."); } From 0b1e2e85f4f18c57a790958f4cf47a768a1c1ee0 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 18:44:25 +0800 Subject: [PATCH 05/67] wip: coro func --- include/luisa/dsl/coro/coro_frame.h | 15 +- include/luisa/dsl/coro/coro_func.h | 66 + include/luisa/dsl/coro/coro_graph.h | 71 ++ include/luisa/dsl/func.h | 6 +- include/luisa/luisa-compute.h | 2 + include/luisa/rust/ir.hpp | 1 + src/dsl/CMakeLists.txt | 4 +- src/dsl/coro/coro_frame.cpp | 28 +- src/dsl/coro/coro_func.cpp | 8 + src/dsl/coro/coro_graph.cpp | 254 ++++ src/dsl/func.cpp | 1 + src/rust/luisa_compute_ir/src/ast2ir.rs | 1 + src/rust/luisa_compute_ir/src/ir.rs | 2 + .../src/transform/materialize_coro.rs | 15 +- .../src/transform/materialize_coro_v2.rs | 1097 +++++++++++++++++ .../luisa_compute_ir/src/transform/mod.rs | 5 + src/tests/CMakeLists.txt | 1 + src/tests/coro/helloworld_v2.cpp | 33 + 18 files changed, 1590 insertions(+), 20 deletions(-) create mode 100644 include/luisa/dsl/coro/coro_func.h create mode 100644 include/luisa/dsl/coro/coro_graph.h create mode 100644 src/dsl/coro/coro_func.cpp create mode 100644 src/dsl/coro/coro_graph.cpp create mode 100644 src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs create mode 100644 src/tests/coro/helloworld_v2.cpp diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index 57beb9049..662996678 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -16,20 +16,21 @@ class CoroFrame; class LC_DSL_API CoroFrameDesc : public luisa::enable_shared_from_this { public: - using DesignatedMemberDict = luisa::unordered_map; + using DesignatedFieldDict = luisa::unordered_map; private: const Type *_type{nullptr}; - DesignatedMemberDict _designated_members; + DesignatedFieldDict _designated_fields; private: - CoroFrameDesc(const Type *type, DesignatedMemberDict m) noexcept; + CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept; public: - [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedMemberDict m) noexcept; + [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedFieldDict m) noexcept; [[nodiscard]] auto type() const noexcept { return _type; } - [[nodiscard]] auto &designated_members() const noexcept { return _designated_members; } - [[nodiscard]] uint designated_member(luisa::string_view name) const noexcept; + [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } + [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; public: [[nodiscard]] CoroFrame instantiate() const noexcept; @@ -74,7 +75,7 @@ class LC_DSL_API CoroFrame { } template [[nodiscard]] Var &get(luisa::string_view name) noexcept { - return get(_desc->designated_member(name)); + return get(_desc->designated_field(name)); } template [[nodiscard]] const Var &get(luisa::string_view name) const noexcept { diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h new file mode 100644 index 000000000..00df4c44b --- /dev/null +++ b/include/luisa/dsl/coro/coro_func.h @@ -0,0 +1,66 @@ +// +// Created by Mike on 2024/5/8. +// + +#pragma once + +#include +#include +#include + +namespace luisa::compute::inline dsl::coro_v2 { + +template +class Coroutine { + static_assert(luisa::always_false_v); +}; + +template +class Coroutine { + +public: + static_assert(std::is_same_v, + "Coroutine function must return void."); + + using Token = CoroGraph::Token; + +private: + luisa::shared_ptr _graph; + +public: + explicit Coroutine(luisa::shared_ptr graph) noexcept + : _graph{std::move(graph)} { + // TODO: check arguments + } + + template + requires std::negation_v>> && + std::negation_v>> + Coroutine(Def &&f) noexcept { + auto coro = detail::FunctionBuilder::define_coroutine([&f] { + static_assert(std::is_invocable_r_v...>); + auto create = [](auto &&def, std::index_sequence) noexcept { + using arg_tuple = std::tuple; + using var_tuple = std::tuple>...>; + using tag_tuple = std::tuple...>; + auto args = detail::create_argument_definitions(std::tuple<>{}); + static_assert(std::tuple_size_v == sizeof...(Args)); + return luisa::invoke(std::forward(def), + static_cast> &&>(std::get(args))...); + }; + create(std::forward(f), std::index_sequence_for{}); + detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types + }); + _graph = CoroGraph::create(coro->function()); + } + +public: + [[nodiscard]] auto graph() const noexcept { return _graph.get(); } + [[nodiscard]] auto &shared_graph() const noexcept { return _graph; } +}; + +template +Coroutine(T &&) -> Coroutine>>; + +}// namespace luisa::compute::inline dsl::coro_v2 diff --git a/include/luisa/dsl/coro/coro_graph.h b/include/luisa/dsl/coro/coro_graph.h new file mode 100644 index 000000000..81180c440 --- /dev/null +++ b/include/luisa/dsl/coro/coro_graph.h @@ -0,0 +1,71 @@ +// +// Created by Mike on 2024/5/8. +// + +#pragma once + +#include +#include + +namespace luisa::compute::inline dsl::coro_v2 { + +class CoroFrameDesc; + +class LC_DSL_API CoroGraph { + +public: + using Token = uint; + static constexpr Token entry_token = 0u; + static constexpr Token terminal_token = 0x8000'0000u; + using CC = luisa::shared_ptr;// current continuation function + +public: + class Node { + + private: + luisa::vector _input_fields; + luisa::vector _output_fields; + luisa::vector _targets; + CC _cc; + + public: + Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept; + ~Node() noexcept; + + public: + [[nodiscard]] auto input_fields() const noexcept { return luisa::span{_input_fields}; } + [[nodiscard]] auto output_fields() const noexcept { return luisa::span{_output_fields}; } + [[nodiscard]] auto targets() const noexcept { return luisa::span{_targets}; } + [[nodiscard]] Function cc() const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; + }; + +private: + luisa::shared_ptr _frame; + luisa::unordered_map _nodes; + luisa::unordered_map _named_tokens; + +public: + CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept; + ~CoroGraph() noexcept; + +public: + // create a coroutine graph from a coroutine function definition + [[nodiscard]] static luisa::shared_ptr create(Function coroutine) noexcept; + +public: + [[nodiscard]] auto frame() const noexcept { return _frame.get(); } + [[nodiscard]] auto &nodes() const noexcept { return _nodes; } + [[nodiscard]] auto &named_tokens() const noexcept { return _named_tokens; } + [[nodiscard]] const Node &entry() const noexcept; + [[nodiscard]] const Node &node(Token index) const noexcept; + [[nodiscard]] const Node &node(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; +}; + +}// namespace luisa::compute::inline dsl::coro_v2 diff --git a/include/luisa/dsl/func.h b/include/luisa/dsl/func.h index bd51738be..ebad45f5b 100644 --- a/include/luisa/dsl/func.h +++ b/include/luisa/dsl/func.h @@ -372,7 +372,9 @@ class Callable { public: explicit Callable(luisa::shared_ptr builder) noexcept - : _builder{std::move(builder)} {} + : _builder{std::move(builder)} { + // TODO: check arguments + } /** * @brief Construct a Callable object. * @@ -572,7 +574,7 @@ class Coroutine { [[nodiscard]] auto const &function_builder() const & noexcept { return _builder; } [[nodiscard]] auto &&function_builder() && noexcept { return std::move(_builder); } [[nodiscard]] auto const suspend_count() noexcept { return _coro_tokens.size(); } - [[nodiscard]] auto const & coro_tokens() const & noexcept { return _coro_tokens; } + [[nodiscard]] auto const &coro_tokens() const & noexcept { return _coro_tokens; } [[nodiscard]] auto const &graph() noexcept { return _coro_graph; } //Call from start of coroutine auto operator()(detail::prototype_to_callable_invocation_t type, diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index e49b840e6..4d8ec72a4 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -64,6 +64,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/include/luisa/rust/ir.hpp b/include/luisa/rust/ir.hpp index 82fb07636..557654e7d 100644 --- a/include/luisa/rust/ir.hpp +++ b/include/luisa/rust/ir.hpp @@ -314,6 +314,7 @@ struct CallableModule { CBoxedSlice captures; CBoxedSlice subroutines; CBoxedSlice subroutine_ids; + CBoxedSlice coro_target_tokens; CBoxedSlice coro_frame_input_fields; CBoxedSlice coro_frame_output_fields; CBoxedSlice coro_frame_designated_fields; diff --git a/src/dsl/CMakeLists.txt b/src/dsl/CMakeLists.txt index f847365c8..ea27abc61 100644 --- a/src/dsl/CMakeLists.txt +++ b/src/dsl/CMakeLists.txt @@ -11,7 +11,9 @@ if (LUISA_COMPUTE_ENABLE_DSL) raster/raster_kernel.cpp) set(LUISA_COMPUTE_DSL_CORO_SOURCES - coro/coro_frame.cpp) + coro/coro_frame.cpp + coro/coro_func.cpp + coro/coro_graph.cpp) set(LUISA_COMPUTE_DSL_SOURCES builtin.cpp diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp index 73a55bc21..d8cbc6750 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/dsl/coro/coro_frame.cpp @@ -7,15 +7,15 @@ namespace luisa::compute::inline dsl::coro_v2 { -CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedMemberDict m) noexcept - : _type{type}, _designated_members{std::move(m)} { +CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept + : _type{type}, _designated_fields{std::move(m)} { LUISA_ASSERT(_type != nullptr, "CoroFrame underlying type must not be null."); LUISA_ASSERT(_type->is_structure(), "CoroFrame underlying type must be a structure."); LUISA_ASSERT(_type->members().size() >= 2u, "CoroFrame underlying type must have at least 2 members (coro_id and target_token)."); LUISA_ASSERT(_type->members()[0] == Type::of(), "CoroFrame member 0 (coro_id) must be uint3."); LUISA_ASSERT(_type->members()[1] == Type::of(), "CoroFrame member 1 (target_token) must be uint."); auto member_count = _type->members().size(); - for (auto &&[name, index] : _designated_members) { + for (auto &&[name, index] : _designated_fields) { LUISA_ASSERT(name != "coro_id", "CoroFrame designated member name 'coro_id' is reserved."); LUISA_ASSERT(name != "target_token", "CoroFrame designated member name 'target_token' is reserved."); LUISA_ASSERT(index != 0, "CoroFrame designated member index 0 is reserved for coro_id."); @@ -24,18 +24,32 @@ CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedMemberDict m) noexcept } } -luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedMemberDict m) noexcept { +luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedFieldDict m) noexcept { return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); } -uint CoroFrameDesc::designated_member(luisa::string_view name) const noexcept { +uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { if (name == "coro_id") { return 0u; } if (name == "target_token") { return 1u; } - auto iter = _designated_members.find(name); - LUISA_ASSERT(iter != _designated_members.end(), "CoroFrame designated member not found."); + auto iter = _designated_fields.find(name); + LUISA_ASSERT(iter != _designated_fields.end(), "CoroFrame designated member not found."); return iter->second; } +luisa::string CoroFrameDesc::dump() const noexcept { + luisa::string s; + for (auto i = 0u; i < _type->members().size(); i++) { + s.append(luisa::format(" Field {}: {}\n", i, _type->members()[i]->description())); + } + if (!_designated_fields.empty()) { + s.append("Designated Fields:\n"); + for (auto &&[name, index] : _designated_fields) { + s.append(luisa::format(" {} -> \"{}\"\n", index, name)); + } + } + return s; +} + CoroFrame CoroFrameDesc::instantiate() const noexcept { auto fb = detail::FunctionBuilder::current(); // create an variable for the coro frame diff --git a/src/dsl/coro/coro_func.cpp b/src/dsl/coro/coro_func.cpp new file mode 100644 index 000000000..c66e22f33 --- /dev/null +++ b/src/dsl/coro/coro_func.cpp @@ -0,0 +1,8 @@ +// +// Created by Mike on 2024/5/8. +// + +#include + +namespace luisa::compute::inline dsl::coro_v2 { +} diff --git a/src/dsl/coro/coro_graph.cpp b/src/dsl/coro/coro_graph.cpp new file mode 100644 index 000000000..23c609c33 --- /dev/null +++ b/src/dsl/coro/coro_graph.cpp @@ -0,0 +1,254 @@ +// +// Created by Mike on 2024/5/8. +// + +#include +#include +#include +#include + +#ifdef LUISA_ENABLE_IR +#include +#endif + +namespace luisa::compute::inline dsl::coro_v2 { + +CoroGraph::Node::Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept + : _input_fields{std::move(input_fields)}, + _output_fields{std::move(output_fields)}, + _targets{std::move(targets)}, + _cc{std::move(current_continuation)} {} + +CoroGraph::Node::~Node() noexcept = default; + +Function CoroGraph::Node::cc() const noexcept { return _cc->function(); } + +luisa::string CoroGraph::Node::dump() const noexcept { + luisa::string s; + s.append(" Input Fields: ["); + for (auto i : _input_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_input_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Output Fields: ["); + for (auto i : _output_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_output_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Transition Targets: ["); + for (auto i : _targets) { + s.append(luisa::format("{}, ", i)); + } + if (!_targets.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + return s; +} + +CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept + : _frame{std::move(frame_desc)}, + _nodes{std::move(nodes)}, + _named_tokens{std::move(named_tokens)} {} + +CoroGraph::~CoroGraph() noexcept = default; + +const CoroGraph::Node &CoroGraph::entry() const noexcept { + return node(entry_token); +} + +const CoroGraph::Node &CoroGraph::node(Token token) const noexcept { + auto iter = _nodes.find(token); + LUISA_ASSERT(iter != _nodes.end(), + "Coroutine node with token {} not found.", + token); + return iter->second; +} + +const CoroGraph::Node &CoroGraph::node(luisa::string_view name) const noexcept { + auto iter = _named_tokens.find(name); + LUISA_ASSERT(iter != _named_tokens.end(), + "Coroutine node with name '{}' not found.", + name); + return node(iter->second); +} + +luisa::string CoroGraph::dump() const noexcept { + luisa::string s; + s.append("Arguments:\n"); + auto args = entry().cc().arguments(); + for (auto i = 0u; i < args.size(); i++) { + s.append(luisa::format(" Argument {}: ", i)); + s.append(args[i].type()->description()); + if (args[i].is_reference()) { s.append(" &"); } + s.append("\n"); + } + s.append("Frame:\n").append(_frame->dump()); + for (auto &&[token, node] : _nodes) { + if (token == entry_token) { + s.append("Entry:\n"); + } else { + s.append(luisa::format("Node {}:\n", token)); + } + s.append(node.dump()); + } + if (!_named_tokens.empty()) { + s.append("Named Tokens:\n"); + for (auto &&[name, token] : _named_tokens) { + s.append(luisa::format(" {} -> \"{}\"\n", token, name)); + } + } + return s; +} + +#ifndef LUISA_ENABLE_IR +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_ERROR_WITH_LOCATION( + "Coroutine requires IR support but " + "LuisaCompute is built without the IR module. " + "This might be caused by missing Rust. " + "Please install the Rust toolchain and " + "recompile LuisaCompute to get the IR module."); +} +#else + +namespace detail { + +static void perform_coroutine_transform(ir::CallableModule *m) noexcept { + auto coroutine_pipeline = ir::luisa_compute_ir_transform_pipeline_new(); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "canonicalize_control_flow"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "demote_locals"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "defer_load"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "extract_loop_cond"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "split_coro"); + ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro_v2"); + auto converted_module = ir::luisa_compute_ir_transform_pipeline_transform_callable(coroutine_pipeline, *m); + ir::luisa_compute_ir_transform_pipeline_destroy(coroutine_pipeline); + *m = converted_module; +} + +[[nodiscard]] static auto make_subroutine_wrapper(Function coroutine, Function cc) noexcept { + using FB = luisa::compute::detail::FunctionBuilder; + return FB::define_callable([&] { + luisa::vector args; + args.reserve(1u /* frame */ + coroutine.arguments().size()); + LUISA_ASSERT(coroutine.arguments().size() == coroutine.bound_arguments().size(), + "Invalid capture list size (expected {}, got {}).", + coroutine.arguments().size(), coroutine.bound_arguments().size()); + auto fb = FB::current(); + args.emplace_back(fb->reference(cc.arguments().front().type())); + for (auto arg_i = 0u; arg_i < coroutine.arguments().size(); arg_i++) { + auto def_arg = coroutine.arguments()[arg_i]; + auto internal_arg = luisa::visit( + [&](T b) noexcept -> const Expression * { + if constexpr (std::is_same_v) { + return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); + } else if constexpr (std::is_same_v) { + return fb->texture_binding(def_arg.type(), b.handle, b.level); + } else if constexpr (std::is_same_v) { + return fb->bindless_array_binding(b.handle); + } else if constexpr (std::is_same_v) { + return fb->accel_binding(b.handle); + } else { + static_assert(std::is_same_v); + switch (def_arg.tag()) { + case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); + case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); + case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); + case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); + case Variable::Tag::ACCEL: return fb->accel(); + default: /* value argument */ return fb->argument(def_arg.type()); + } + } + }, + coroutine.bound_arguments()[arg_i]); + args.emplace_back(internal_arg); + } + LUISA_ASSERT(cc.return_type() == nullptr, + "Coroutine subroutines should not have return type."); + fb->call(cc, args); + }); +} + +}// namespace detail + +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " + "on function with hash {:016x}.", + coroutine.hash()); + + // convert the coroutine function to IR, transform it, and then convert back + auto m = AST2IR::build_coroutine(coroutine); + detail::perform_coroutine_transform(m->get()); + auto entry = IR2AST::build(m->get()); + + // create the coroutine frame descriptor + auto frame = [m, entry] { + auto underlying = entry->arguments().front().type(); + CoroFrameDesc::DesignatedFieldDict members; + for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, + m->get()->coro_frame_designated_fields.len}) { + auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; + if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } + auto [_, success] = members.try_emplace(name, field.index); + LUISA_ASSERT(success, "Duplicated designated field name '{}' at field {}.", name, field.index); + } + return CoroFrameDesc::create(underlying, std::move(members)); + }(); + + // extract the subroutines + auto subroutines = m->get()->subroutines; + auto subroutine_ids = m->get()->subroutine_ids; + LUISA_ASSERT(subroutines.len == subroutine_ids.len, + "Subroutine count mismatch: {} vs {}.", + subroutines.len, subroutine_ids.len); + luisa::unordered_map nodes; + nodes.reserve(subroutines.len + 1u); + auto convert_fields = [](ir::CBoxedSlice slice) noexcept { + luisa::vector fields; + fields.reserve(slice.len); + for (auto i = 0u; i < slice.len; i++) { fields.emplace_back(slice.ptr[i]); } + return fields; + }; + // add the entry node + nodes.emplace( + entry_token, + Node{convert_fields(m->get()->coro_frame_input_fields), + convert_fields(m->get()->coro_frame_output_fields), + convert_fields(m->get()->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, entry->function())}); + // add subroutine nodes + for (auto i = 0u; i < subroutines.len; i++) { + auto s = subroutines.ptr[i]._0.get(); + auto subroutine = IR2AST::build(s); + auto [_, success] = nodes.try_emplace( + subroutine_ids.ptr[i], + Node{convert_fields(s->coro_frame_input_fields), + convert_fields(s->coro_frame_output_fields), + convert_fields(s->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, subroutine->function())}); + LUISA_ASSERT(success, "Duplicated subroutine token {}.", subroutine_ids.ptr[i]); + } + // create the graph + return luisa::make_shared( + std::move(frame), std::move(nodes), + coroutine.builder()->coro_tokens()); +} + +#endif + +}// namespace luisa::compute::inline dsl::coro_v2 diff --git a/src/dsl/func.cpp b/src/dsl/func.cpp index 1ac3a99ec..cf6f78c8a 100644 --- a/src/dsl/func.cpp +++ b/src/dsl/func.cpp @@ -181,4 +181,5 @@ luisa::shared_ptr transform_coroutine( } return function.shared_builder(); } + }// namespace luisa::compute::detail diff --git a/src/rust/luisa_compute_ir/src/ast2ir.rs b/src/rust/luisa_compute_ir/src/ast2ir.rs index bc486f294..937fa26ce 100644 --- a/src/rust/luisa_compute_ir/src/ast2ir.rs +++ b/src/rust/luisa_compute_ir/src/ast2ir.rs @@ -2322,6 +2322,7 @@ impl<'a: 'b, 'b> AST2IR<'a, 'b> { captures: CBoxedSlice::new(Vec::new()), subroutine_ids: CBoxedSlice::new(Vec::new()), subroutines: CBoxedSlice::new(Vec::new()), + coro_target_tokens: CBoxedSlice::new(Vec::new()), coro_frame_input_fields: CBoxedSlice::new(Vec::new()), coro_frame_output_fields: CBoxedSlice::new(Vec::new()), coro_frame_designated_fields: CBoxedSlice::new(Vec::new()), diff --git a/src/rust/luisa_compute_ir/src/ir.rs b/src/rust/luisa_compute_ir/src/ir.rs index 38cf9e788..f18f0ed1f 100644 --- a/src/rust/luisa_compute_ir/src/ir.rs +++ b/src/rust/luisa_compute_ir/src/ir.rs @@ -2012,6 +2012,7 @@ pub struct CallableModule { pub captures: CBoxedSlice, pub subroutines: CBoxedSlice, pub subroutine_ids: CBoxedSlice, + pub coro_target_tokens: CBoxedSlice, pub coro_frame_input_fields: CBoxedSlice, pub coro_frame_output_fields: CBoxedSlice, pub coro_frame_designated_fields: CBoxedSlice, @@ -2266,6 +2267,7 @@ impl ModuleDuplicator { captures: dup_captures, subroutines: callable.subroutines.clone(), subroutine_ids: callable.subroutine_ids.clone(), + coro_target_tokens: callable.coro_target_tokens.clone(), coro_frame_input_fields: callable.coro_frame_input_fields.clone(), coro_frame_output_fields: callable.coro_frame_output_fields.clone(), coro_frame_designated_fields: callable.coro_frame_designated_fields.clone(), diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs index b5c758804..743ece66a 100644 --- a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs +++ b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs @@ -455,9 +455,11 @@ impl<'a> CoroScopeMaterializer<'a> { .iter() .map(|&a| self.value_or_load(a, ctx, state)) .collect(); - let call = state - .builder - .call(Func::External(c.clone()), args.as_slice(), ret.type_().clone()); + let call = state.builder.call( + Func::External(c.clone()), + args.as_slice(), + ret.type_().clone(), + ); process_return!(call) } // replayable but need special handling @@ -990,6 +992,11 @@ impl<'a> CoroScopeMaterializer<'a> { } } + fn collect_target_tokens(&self, scope: CoroScopeRef) -> Vec { + let node = self.frame.transition_graph.nodes.get(&scope).unwrap(); + node.outlets.keys().cloned().collect::<_>() + } + fn materialize(&self) -> CallableModule { let mappings: HashMap<_, _> = self .coro @@ -1044,6 +1051,7 @@ impl<'a> CoroScopeMaterializer<'a> { index: i as u32 + designated_filed_offset, }) .collect(); + let target_tokens = self.collect_target_tokens(self.scope); // create the callable module CallableModule { module, @@ -1052,6 +1060,7 @@ impl<'a> CoroScopeMaterializer<'a> { captures: CBoxedSlice::new(Vec::new()), subroutines: CBoxedSlice::new(Vec::new()), subroutine_ids: CBoxedSlice::new(Vec::new()), + coro_target_tokens: CBoxedSlice::new(target_tokens), coro_frame_input_fields: CBoxedSlice::new(in_fields), coro_frame_output_fields: CBoxedSlice::new(out_fields), coro_frame_designated_fields: CBoxedSlice::new(designated_fields), diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs new file mode 100644 index 000000000..461077375 --- /dev/null +++ b/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs @@ -0,0 +1,1097 @@ +// This file implements the materialization of subroutines in a coroutine. It analyzes the +// input coroutine module, generates the coroutine graph and transition graph, computes the +// coroutine frame layout, and finally materializes the subroutines into callable modules. +// Some corner cases to consider: +// - Some values might be promoted to values in the coroutine frame and should be loaded +// before use. +// - Some values might not dominate their uses any more as they might have been moved into +// a `SkipIfFirst` block. We need to promote them to locals. +// - Some "replayable" values are not included in the coroutine frame, nor defined in the +// subroutine body. We need to replay them. + +use crate::analysis::coro_frame::CoroFrame; +use crate::analysis::coro_graph::{ + CoroGraph, CoroInstrRef, CoroInstruction, CoroScope, CoroScopeRef, +}; +use crate::analysis::coro_transition_graph::CoroTransitionGraph; +use crate::analysis::coro_use_def::CoroUseDefAnalysis; +use crate::analysis::replayable_values::ReplayableValueAnalysis; +use crate::analysis::utility::{AccessChainIndex, AccessTree}; +use crate::ir::{ + collect_nodes, new_node, BasicBlock, CallableModule, CallableModuleRef, Const, + CoroFrameDesignatedField, CurveBasisSet, Func, Instruction, IrBuilder, Module, ModuleFlags, + ModuleKind, Node, NodeRef, Primitive, SwitchCase, Type, +}; +use crate::transform::canonicalize_control_flow::CanonicalizeControlFlow; +use crate::transform::defer_load::DeferLoad; +use crate::transform::demote_locals::DemoteLocals; +use crate::transform::mem2reg::Mem2Reg; +use crate::transform::reg2mem::Reg2Mem; +use crate::transform::Transform; +use crate::{CArc, CBox, CBoxedSlice, Pooled}; +use bitflags::Flags; +use std::collections::{HashMap, HashSet}; + +struct DuplicateNodeCollector<'a> { + frame: &'a CoroFrame<'a>, + scope: CoroScopeRef, +} + +// collect the nodes that violate the SSA form, which should be promoted to locals +impl<'a> DuplicateNodeCollector<'a> { + fn new(frame: &'a CoroFrame<'a>, scope: CoroScopeRef) -> Self { + Self { frame, scope } + } + + fn collect(&self) -> HashSet { + todo!() + } +} + +pub(crate) struct MaterializeCoro; + +struct CoroScopeMaterializer<'a> { + frame: &'a CoroFrame<'a>, + coro: &'a CallableModule, + token: Option, // None for entry + scope: CoroScopeRef, + args: Vec, +} + +impl<'a> CoroScopeMaterializer<'a> { + fn get_frame_node(&self) -> NodeRef { + self.args[0].clone() + } + + fn get_scope(&self) -> &CoroScope { + &self.frame.graph.get_scope(self.scope) + } + + fn get_instr(&self, instr: CoroInstrRef) -> &CoroInstruction { + self.frame.graph.get_instr(instr) + } + + fn create_args(&self) -> Vec { + let mut args = Vec::new(); + args.reserve(1 /* frame */+ self.coro.args.len()); + // the coro frame + let node = new_node( + &self.coro.pools, + Node::new( + CArc::new(Instruction::Argument { by_value: false }), + self.frame.interface_type.clone(), + ), + ); + args.push(node); + for arg in self.coro.args.iter() { + // normal args + let instr = &arg.get().instruction; + match instr.as_ref() { + Instruction::Buffer + | Instruction::Bindless + | Instruction::Texture2D + | Instruction::Texture3D + | Instruction::Accel + | Instruction::Argument { .. } => { + let node = new_node( + &self.coro.pools, + Node::new(instr.clone(), arg.type_().clone()), + ); + args.push(node); + } + _ => unreachable!("Invalid argument type"), + } + } + args + } + + fn new(frame: &'a CoroFrame<'a>, coro: &'a CallableModule, token: Option) -> Self { + let scope = if let Some(token) = token { + frame.graph.tokens[&token] + } else { + frame.graph.entry + }; + let mut m = Self { + frame, + coro, + token, + scope, + args: Vec::new(), + }; + m.args = m.create_args(); + m + } +} + +struct CoroScopeMaterializerCtx { + mappings: HashMap, // mapping from old nodes to new nodes + entry_builder: IrBuilder, // suitable for declaring locals + first_flag: Option, + uses_ray_tracing: bool, + uses_coro_id: bool, + replayable: ReplayableValueAnalysis, +} + +struct CoroScopeMaterializerState { + builder: IrBuilder, // current running build +} + +impl CoroScopeMaterializerState { + fn clone_for_branch_block(&self) -> Self { + Self { + builder: IrBuilder::new(self.builder.pools.clone()), + } + } +} + +impl<'a> CoroScopeMaterializer<'a> { + fn resume(&self, ctx: &mut CoroScopeMaterializerCtx) { + let mappings = self + .frame + .resume(self.scope, self.get_frame_node(), &mut ctx.entry_builder); + for (old_node, new_node) in mappings { + ctx.mappings.insert(old_node, new_node); + } + } + + fn suspend(&self, target: u32, ctx: &mut CoroScopeMaterializerCtx, b: &mut IrBuilder) { + self.frame.suspend( + self.scope, + target, + self.get_frame_node(), + b, + &mut ctx.mappings, + ); + } + + fn terminate(&self, b: &mut IrBuilder) { + self.frame.terminate(self.scope, self.get_frame_node(), b); + } + + fn ref_or_local( + &self, + old_node: NodeRef, + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) -> NodeRef { + if let Some(defined) = ctx.mappings.get(&old_node) { + defined.clone() + } else if old_node.is_gep() { + let (root, chain) = AccessTree::access_chain_from_gep_chain(old_node); + let chain: Vec<_> = chain + .iter() + .map(|node| self.value_or_load(node.clone(), ctx, state)) + .collect(); + let root = self.ref_or_local(root, ctx, state); + let gep = state + .builder + .gep(root, chain.as_slice(), old_node.type_().clone()); + ctx.mappings.insert(old_node, gep.clone()); + gep + } else { + // not defined yet, we'll define it now + let local = ctx.entry_builder.local_zero_init(old_node.type_().clone()); + ctx.mappings.insert(old_node.clone(), local.clone()); + local + } + } + + fn replay_value(&self, old_node: NodeRef, ctx: &mut CoroScopeMaterializerCtx) -> NodeRef { + if let Some(replayed) = ctx.mappings.get(&old_node) { + return replayed.clone(); + } + match old_node.get().instruction.as_ref() { + Instruction::Const(c) => ctx.entry_builder.const_(c.clone()), + Instruction::Call(func, args) => match func { + Func::Unreachable(_) | Func::ZeroInitializer | Func::WarpSize => ctx + .entry_builder + .call(func.clone(), &[], old_node.type_().clone()), + Func::ThreadId | Func::BlockId | Func::WarpLaneId | Func::DispatchSize => { + panic!("{:?} is not available in coroutines", func) + } + Func::CoroId | Func::DispatchId => { + ctx.uses_coro_id = true; + self.frame + .read_coro_id(self.get_frame_node(), &mut ctx.entry_builder) + } + Func::CoroToken => ctx + .entry_builder + .const_(Const::Uint32(self.token.unwrap_or(0))), + Func::Cast + | Func::Bitcast + | Func::Pack + | Func::Unpack + | Func::Add + | Func::Sub + | Func::Mul + | Func::Div + | Func::Rem + | Func::BitAnd + | Func::BitOr + | Func::BitXor + | Func::Shl + | Func::Shr + | Func::RotRight + | Func::RotLeft + | Func::Eq + | Func::Ne + | Func::Lt + | Func::Le + | Func::Gt + | Func::Ge + | Func::MatCompMul + | Func::Neg + | Func::Not + | Func::BitNot + | Func::All + | Func::Any + | Func::Select + | Func::Clamp + | Func::Lerp + | Func::Step + | Func::SmoothStep + | Func::Saturate + | Func::Abs + | Func::Min + | Func::Max + | Func::ReduceSum + | Func::ReduceProd + | Func::ReduceMin + | Func::ReduceMax + | Func::Clz + | Func::Ctz + | Func::PopCount + | Func::Reverse + | Func::IsInf + | Func::IsNan + | Func::Acos + | Func::Acosh + | Func::Asin + | Func::Asinh + | Func::Atan + | Func::Atan2 + | Func::Atanh + | Func::Cos + | Func::Cosh + | Func::Sin + | Func::Sinh + | Func::Tan + | Func::Tanh + | Func::Exp + | Func::Exp2 + | Func::Exp10 + | Func::Log + | Func::Log2 + | Func::Log10 + | Func::Powi + | Func::Powf + | Func::Sqrt + | Func::Rsqrt + | Func::Ceil + | Func::Floor + | Func::Fract + | Func::Trunc + | Func::Round + | Func::Fma + | Func::Copysign + | Func::Cross + | Func::Dot + | Func::OuterProduct + | Func::Length + | Func::LengthSquared + | Func::Normalize + | Func::Faceforward + | Func::Distance + | Func::Reflect + | Func::Determinant + | Func::Transpose + | Func::Inverse + | Func::Vec + | Func::Vec2 + | Func::Vec3 + | Func::Vec4 + | Func::Permute + | Func::InsertElement + | Func::ExtractElement + | Func::GetElementPtr + | Func::Struct + | Func::Array + | Func::Mat + | Func::Mat2 + | Func::Mat3 + | Func::Mat4 => { + let replayed_args: Vec<_> = args + .iter() + .map(|arg| self.replay_value(arg.clone(), ctx)) + .collect(); + ctx.entry_builder.call( + func.clone(), + replayed_args.as_slice(), + old_node.type_().clone(), + ) + } + _ => unreachable!("non-replayable value"), + }, + _ => unreachable!("non-replayable value"), + } + } + + fn try_replay(&self, old_node: NodeRef, ctx: &mut CoroScopeMaterializerCtx) -> Option { + if old_node.is_unreachable() { + ctx.mappings.get(&old_node).cloned() + } else { + if !ctx.replayable.detect(old_node) { + None + } else if let Some(defined) = ctx.mappings.get(&old_node) { + // if already defined, simply return it + Some(defined.clone()) + } else { + let replayed = self.replay_value(old_node, ctx); + ctx.mappings.insert(old_node, replayed.clone()); + Some(replayed) + } + } + } + + fn value_or_load( + &self, + old_node: NodeRef, + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) -> NodeRef { + if let Some(node) = self.try_replay(old_node, ctx) { + node + } else { + let node = self.ref_or_local(old_node, ctx, state); + if node.is_local() || node.is_gep() || node.is_reference_argument() { + state.builder.load(node) + } else { + node + } + } + } + + fn def_or_assign( + &self, + old_node: NodeRef, + new_value: NodeRef, + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) { + let var = self.ref_or_local(old_node, ctx, state); + assert!(var.is_gep() || var.is_local() || var.is_reference_argument()); + state.builder.update(var, new_value); + } + + fn materialize_branch_block( + &self, + block: &Vec, + ctx: &mut CoroScopeMaterializerCtx, + state: &CoroScopeMaterializerState, + ) -> Pooled { + let mut branch_state = state.clone_for_branch_block(); + self.materialize_instructions(block.as_slice(), ctx, &mut branch_state); + branch_state.builder.finish() + } + + fn make_first_flag(&self, ctx: &mut CoroScopeMaterializerCtx) { + assert_eq!(ctx.first_flag, None, "First flag already defined"); + let flag = { + let b = &mut ctx.entry_builder; + b.comment(CBoxedSlice::from("make first flag".as_bytes())); + let v = b.const_(Const::Bool(false)); + b.local(v) + }; + ctx.first_flag = Some(flag); + } + + fn materialize_call( + &self, + ret: NodeRef, + func: Func, + args: &[NodeRef], + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) { + macro_rules! process_return { + ($call: expr) => { + match $call.type_().as_ref() { + Type::Void => { /* nothing */ } + Type::UserData => todo!(), + Type::Opaque(_) => { + // as non-copyable reference + ctx.mappings.insert(ret.clone(), $call); + } + _ => self.def_or_assign(ret.clone(), $call, ctx, state), + } + }; + } + match func { + // callable + Func::Callable(c) => { + let args: Vec<_> = + c.0.args + .iter() + .zip(args.iter()) + .map(|(formal, &given)| { + if formal.is_reference_argument() || formal.type_().is_opaque("") { + self.ref_or_local(given, ctx, state) + } else { + self.value_or_load(given, ctx, state) + } + }) + .collect(); + let call = state.builder.call( + Func::Callable(c.clone()), + args.as_slice(), + ret.type_().clone(), + ); + process_return!(call) + } + Func::External(c) => { + let args: Vec<_> = args + .iter() + .map(|&a| self.value_or_load(a, ctx, state)) + .collect(); + let call = state.builder.call( + Func::External(c.clone()), + args.as_slice(), + ret.type_().clone(), + ); + process_return!(call) + } + // replayable but need special handling + Func::Unreachable(_) => { + let call = state.builder.call(func.clone(), &[], ret.type_().clone()); + match call.type_().as_ref() { + Type::Void => {} + _ => { + ctx.mappings.insert(ret, call); + } + } + } + // always replayable functions, should not appear here + Func::CoroId + | Func::CoroToken + | Func::ZeroInitializer + | Func::ThreadId + | Func::BlockId + | Func::WarpSize + | Func::WarpLaneId + | Func::DispatchId + | Func::DispatchSize => unreachable!(), + // local variable operations + Func::Load => { + let loaded = self.value_or_load(args[0].clone(), ctx, state); + self.def_or_assign(ret, loaded, ctx, state); + } + Func::AddressOf => { + // the first argument should be reference + let var = self.ref_or_local(args[0].clone(), ctx, state); + let addr = state + .builder + .call(func.clone(), &[var], ret.type_().clone()); + self.def_or_assign(ret, addr, ctx, state); + } + Func::GetElementPtr => { + let (root, chain) = AccessTree::access_chain_from_gep_chain(ret); + let root = self.ref_or_local(root, ctx, state); + let chain: Vec<_> = chain + .iter() + .map(|&i| self.value_or_load(i, ctx, state)) + .collect(); + let gep = state + .builder + .gep(root, chain.as_slice(), ret.type_().clone()); + ctx.mappings.insert(ret, gep); + } + // AD functions + Func::PropagateGrad => todo!(), + Func::OutputGrad => todo!(), + Func::RequiresGradient => todo!(), + Func::Backward => todo!(), + Func::Gradient => todo!(), + Func::GradientMarker => todo!(), + Func::AccGrad => todo!(), + Func::Detach => todo!(), + // resource functions, the first argument should always be a reference + Func::RayTracingQueryAll + | Func::RayTracingQueryAny + | Func::RayTracingInstanceTransform + | Func::RayTracingInstanceVisibilityMask + | Func::RayTracingInstanceUserId + | Func::RayTracingSetInstanceTransform + | Func::RayTracingSetInstanceOpacity + | Func::RayTracingSetInstanceVisibility + | Func::RayTracingSetInstanceUserId + | Func::RayTracingTraceClosest + | Func::RayTracingTraceAny + | Func::RayQueryWorldSpaceRay + | Func::RayQueryProceduralCandidateHit + | Func::RayQueryTriangleCandidateHit + | Func::RayQueryCommittedHit + | Func::RayQueryCommitTriangle + | Func::RayQueryCommitProcedural + | Func::RayQueryTerminate + | Func::IndirectDispatchSetCount + | Func::IndirectDispatchSetKernel + | Func::AtomicRef + | Func::AtomicExchange + | Func::AtomicCompareExchange + | Func::AtomicFetchAdd + | Func::AtomicFetchSub + | Func::AtomicFetchAnd + | Func::AtomicFetchOr + | Func::AtomicFetchXor + | Func::AtomicFetchMin + | Func::AtomicFetchMax + | Func::BufferRead + | Func::BufferWrite + | Func::BufferSize + | Func::BufferAddress + | Func::ByteBufferRead + | Func::ByteBufferWrite + | Func::ByteBufferSize + | Func::Texture2dRead + | Func::Texture2dWrite + | Func::Texture2dSize + | Func::Texture3dRead + | Func::Texture3dWrite + | Func::Texture3dSize + | Func::BindlessTexture2dSample + | Func::BindlessTexture2dSampleLevel + | Func::BindlessTexture2dSampleGrad + | Func::BindlessTexture2dSampleGradLevel + | Func::BindlessTexture3dSample + | Func::BindlessTexture3dSampleLevel + | Func::BindlessTexture3dSampleGrad + | Func::BindlessTexture3dSampleGradLevel + | Func::BindlessTexture2dRead + | Func::BindlessTexture3dRead + | Func::BindlessTexture2dReadLevel + | Func::BindlessTexture3dReadLevel + | Func::BindlessTexture2dSize + | Func::BindlessTexture3dSize + | Func::BindlessTexture2dSizeLevel + | Func::BindlessTexture3dSizeLevel + | Func::BindlessBufferRead + | Func::BindlessBufferWrite + | Func::BindlessBufferSize + | Func::BindlessBufferAddress + | Func::BindlessBufferType + | Func::BindlessByteBufferRead => { + let args: Vec<_> = args + .iter() + .enumerate() + .map(|(i, &a)| { + if i == 0 { + // resource + self.ref_or_local(a, ctx, state) + } else { + // value + self.value_or_load(a, ctx, state) + } + }) + .collect(); + let call = state + .builder + .call(func.clone(), args.as_slice(), ret.type_().clone()); + process_return!(call) + } + // functions with all value arguments + Func::Assume + | Func::Assert(_) + | Func::RasterDiscard + | Func::Cast + | Func::Bitcast + | Func::Pack + | Func::Unpack + | Func::Add + | Func::Sub + | Func::Mul + | Func::Div + | Func::Rem + | Func::BitAnd + | Func::BitOr + | Func::BitXor + | Func::Shl + | Func::Shr + | Func::RotRight + | Func::RotLeft + | Func::Eq + | Func::Ne + | Func::Lt + | Func::Le + | Func::Gt + | Func::Ge + | Func::MatCompMul + | Func::Neg + | Func::Not + | Func::BitNot + | Func::All + | Func::Any + | Func::Select + | Func::Clamp + | Func::Lerp + | Func::Step + | Func::SmoothStep + | Func::Saturate + | Func::Abs + | Func::Min + | Func::Max + | Func::ReduceSum + | Func::ReduceProd + | Func::ReduceMin + | Func::ReduceMax + | Func::Clz + | Func::Ctz + | Func::PopCount + | Func::Reverse + | Func::IsInf + | Func::IsNan + | Func::Acos + | Func::Acosh + | Func::Asin + | Func::Asinh + | Func::Atan + | Func::Atan2 + | Func::Atanh + | Func::Cos + | Func::Cosh + | Func::Sin + | Func::Sinh + | Func::Tan + | Func::Tanh + | Func::Exp + | Func::Exp2 + | Func::Exp10 + | Func::Log + | Func::Log2 + | Func::Log10 + | Func::Powi + | Func::Powf + | Func::Sqrt + | Func::Rsqrt + | Func::Ceil + | Func::Floor + | Func::Fract + | Func::Trunc + | Func::Round + | Func::Fma + | Func::Copysign + | Func::Cross + | Func::Dot + | Func::OuterProduct + | Func::Length + | Func::LengthSquared + | Func::Normalize + | Func::Faceforward + | Func::Distance + | Func::Reflect + | Func::Determinant + | Func::Transpose + | Func::Inverse + | Func::WarpIsFirstActiveLane + | Func::WarpFirstActiveLane + | Func::WarpActiveAllEqual + | Func::WarpActiveBitAnd + | Func::WarpActiveBitOr + | Func::WarpActiveBitXor + | Func::WarpActiveCountBits + | Func::WarpActiveMax + | Func::WarpActiveMin + | Func::WarpActiveProduct + | Func::WarpActiveSum + | Func::WarpActiveAll + | Func::WarpActiveAny + | Func::WarpActiveBitMask + | Func::WarpPrefixCountBits + | Func::WarpPrefixSum + | Func::WarpPrefixProduct + | Func::WarpReadLaneAt + | Func::WarpReadFirstLane + | Func::SynchronizeBlock + | Func::Vec + | Func::Vec2 + | Func::Vec3 + | Func::Vec4 + | Func::Permute + | Func::InsertElement + | Func::ExtractElement + | Func::Struct + | Func::Array + | Func::Mat + | Func::Mat2 + | Func::Mat3 + | Func::Mat4 + | Func::ShaderExecutionReorder + | Func::CpuCustomOp(_) => { + let args: Vec<_> = args + .iter() + .map(|&a| self.value_or_load(a, ctx, state)) + .collect(); + let call = state + .builder + .call(func.clone(), args.as_slice(), ret.type_().clone()); + process_return!(call) + } + // other, unused + Func::Unknown0 => todo!(), + Func::Unknown1 => todo!(), + } + } + + fn materialize_simple( + &self, + node: NodeRef, + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) { + if ctx.replayable.detect(node) && !node.is_unreachable() { + self.replay_value(node.clone(), ctx); + return; + } + match node.get().instruction.as_ref() { + Instruction::Local { init } => { + let init = self.value_or_load(init.clone(), ctx, state); + let this = self.ref_or_local(node, ctx, state); + state.builder.update(this, init); + } + Instruction::Update { var, value } => { + let value = self.value_or_load(value.clone(), ctx, state); + self.def_or_assign(var.clone(), value, ctx, state); + } + Instruction::Call(func, args) => { + self.materialize_call(node, func.clone(), args.iter().as_slice(), ctx, state); + } + Instruction::Loop { body, cond } => { + let mut body_state = state.clone_for_branch_block(); + for node in body.iter() { + self.materialize_simple(node, ctx, &mut body_state); + } + let cond = self.value_or_load(cond.clone(), ctx, &mut body_state); + let body = body_state.builder.finish(); + state.builder.loop_(body, cond); + } + Instruction::If { + cond, + true_branch, + false_branch, + } => { + let cond = self.value_or_load(cond.clone(), ctx, state); + let true_branch = self.materialize_branch_in_simple(true_branch, ctx, state); + let false_branch = self.materialize_branch_in_simple(false_branch, ctx, state); + state.builder.if_(cond, true_branch, false_branch); + } + Instruction::Switch { + value, + cases, + default, + } => { + let value = self.value_or_load(value.clone(), ctx, state); + let cases: Vec<_> = cases + .iter() + .map(|case| SwitchCase { + value: case.value, + block: self.materialize_branch_in_simple(&case.block, ctx, state), + }) + .collect(); + let default = self.materialize_branch_in_simple(default, ctx, state); + state.builder.switch(value, cases.as_slice(), default); + } + Instruction::AdScope { + body, + n_forward_grads, + forward, + } => { + let body = self.materialize_branch_in_simple(body, ctx, state); + if *forward { + state.builder.fwd_ad_scope(body, *n_forward_grads); + } else { + state.builder.ad_scope(body); + } + } + Instruction::RayQuery { + ray_query, + on_triangle_hit, + on_procedural_hit, + } => { + let ray_query = self.ref_or_local(ray_query.clone(), ctx, state); + let on_triangle_hit = + self.materialize_branch_in_simple(on_triangle_hit, ctx, state); + let on_procedural_hit = + self.materialize_branch_in_simple(on_procedural_hit, ctx, state); + ctx.uses_ray_tracing = true; + state.builder.ray_query( + ray_query, + on_triangle_hit, + on_procedural_hit, + node.type_().clone(), + ); + } + Instruction::Print { fmt, args } => { + let args: Vec<_> = args + .iter() + .map(|arg| self.value_or_load(arg.clone(), ctx, state)) + .collect(); + state.builder.print(fmt.clone(), args.as_slice()); + } + Instruction::AdDetach(body) => { + let body = self.materialize_branch_in_simple(body.as_ref(), ctx, state); + state.builder.ad_detach(body); + } + Instruction::Comment(msg) => { + state.builder.comment(msg.clone()); + } + Instruction::CoroRegister { .. } => { + // nothing to do + } + _ => unreachable!(), + } + } + + fn materialize_branch_in_simple( + &self, + block: &BasicBlock, + ctx: &mut CoroScopeMaterializerCtx, + state: &CoroScopeMaterializerState, + ) -> Pooled { + let mut branch_state = state.clone_for_branch_block(); + for node in block.iter() { + self.materialize_simple(node, ctx, &mut branch_state); + } + branch_state.builder.finish() + } + + fn materialize_instr( + &self, + instr: &CoroInstruction, + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) { + match instr { + CoroInstruction::Simple(node) => { + self.materialize_simple(node.clone(), ctx, state); + } + CoroInstruction::ConditionStackReplay { items } => { + state.builder.comment(CBoxedSlice::from( + "condition stack replay begin".to_string(), + )); + for item in items.iter() { + macro_rules! decode_value { + ($t:tt, $value: expr) => { + state.builder.const_(Const::$t($value)) + }; + } + let value = match item.node.type_().as_ref() { + Type::Primitive(p) => match p { + Primitive::Bool => decode_value!(Bool, item.value != 0), + Primitive::Int8 => decode_value!(Int8, item.value as i8), + Primitive::Uint8 => decode_value!(Uint8, item.value as u8), + Primitive::Int16 => decode_value!(Int16, item.value as i16), + Primitive::Uint16 => decode_value!(Uint16, item.value as u16), + Primitive::Int32 => decode_value!(Int32, item.value), + Primitive::Uint32 => decode_value!(Uint32, item.value as u32), + Primitive::Int64 => decode_value!(Int64, item.value as i64), + Primitive::Uint64 => decode_value!(Uint64, item.value as u64), + _ => unreachable!(), + }, + _ => unreachable!(), + }; + self.def_or_assign(item.node.clone(), value, ctx, state); + } + state + .builder + .comment(CBoxedSlice::from("condition stack replay end".to_string())); + } + CoroInstruction::MakeFirstFlag => { + self.make_first_flag(ctx); + } + CoroInstruction::SkipIfFirstFlag { body, .. } => { + state + .builder + .comment(CBoxedSlice::from("skip if first flag".to_string())); + let flag = state.builder.load(ctx.first_flag.unwrap().clone()); + let true_branch = self.materialize_branch_block(body, ctx, state); + let false_branch = IrBuilder::new(state.builder.pools.clone()).finish(); + state.builder.if_(flag, true_branch, false_branch); + state + .builder + .comment(CBoxedSlice::from("after skip if first flag".to_string())); + } + CoroInstruction::ClearFirstFlag(_) => { + state + .builder + .comment(CBoxedSlice::from("clear first flag".to_string())); + let v = state.builder.const_(Const::Bool(true)); + state.builder.update(ctx.first_flag.unwrap(), v); + } + CoroInstruction::Loop { body, cond } => { + // note: cond is inside the scope of body, so we have to convert it before pop + let mut body_state = state.clone_for_branch_block(); + self.materialize_instructions(body.as_slice(), ctx, &mut body_state); + let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { + self.value_or_load(*cond, ctx, &mut body_state) + } else { + unreachable!() + }; + // now we can pop the body and build the instruction + let body = body_state.builder.finish(); + state.builder.loop_(body, cond); + } + CoroInstruction::If { + cond, + true_branch, + false_branch, + } => { + let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { + self.value_or_load(cond.clone(), ctx, state) + } else { + unreachable!() + }; + let true_branch = self.materialize_branch_block(true_branch, ctx, state); + let false_branch = self.materialize_branch_block(false_branch, ctx, state); + state.builder.if_(cond, true_branch, false_branch); + } + CoroInstruction::Switch { + cond, + cases, + default, + } => { + let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { + self.value_or_load(cond.clone(), ctx, state) + } else { + unreachable!() + }; + let cases: Vec<_> = cases + .iter() + .map(|case| SwitchCase { + value: case.value, + block: self.materialize_branch_block(&case.body, ctx, state), + }) + .collect(); + let default = self.materialize_branch_block(default, ctx, state); + state.builder.switch(cond, cases.as_slice(), default); + } + CoroInstruction::Suspend { token } => self.suspend(*token, ctx, &mut state.builder), + CoroInstruction::Terminate => self.terminate(&mut state.builder), + _ => unreachable!(), + } + } + + fn materialize_instructions( + &self, + instructions: &[CoroInstrRef], + ctx: &mut CoroScopeMaterializerCtx, + state: &mut CoroScopeMaterializerState, + ) { + for &instr in instructions { + self.materialize_instr(self.get_instr(instr), ctx, state); + } + } + + fn collect_target_tokens(&self, scope: CoroScopeRef) -> Vec { + let node = self.frame.transition_graph.nodes.get(&scope).unwrap(); + node.outlets.keys().cloned().collect::<_>() + } + + fn materialize(&self) -> CallableModule { + let mappings: HashMap<_, _> = self + .coro + .args + .iter() + .cloned() + .zip(self.args.iter().cloned()) + .collect(); + let mut entry_builder = IrBuilder::new(self.coro.pools.clone()); + let mut ctx = CoroScopeMaterializerCtx { + mappings, + entry_builder, + first_flag: None, + uses_ray_tracing: false, + uses_coro_id: false, + replayable: ReplayableValueAnalysis::new(false), + }; + // resume states and generate first flag if not entry + if let Some(_) = self.token { + self.resume(&mut ctx); + } + // materialize the body + let mut b = IrBuilder::new_without_bb(self.coro.pools.clone()); + b.set_insert_point(ctx.entry_builder.get_insert_point()); + b.comment(CBoxedSlice::from(format!( + "coro body (token = {})", + self.token.unwrap_or(0) + ))); + let mut state = CoroScopeMaterializerState { builder: b }; + self.materialize_instructions(&self.get_scope().instructions, &mut ctx, &mut state); + let module = Module { + kind: ModuleKind::Function, + entry: ctx.entry_builder.finish(), + flags: ModuleFlags::empty(), + curve_basis_set: if ctx.uses_ray_tracing { + self.coro.module.curve_basis_set + } else { + CurveBasisSet::empty() + }, + pools: self.coro.pools.clone(), + }; + // compute the input/output coro frame fields so that the frontend scheduler can optimize the I/O + let (in_fields, out_fields) = self.frame.collect_io_fields(self.scope, ctx.uses_coro_id); + let designated_filed_offset = self.frame.get_designated_field_offset(); + let designated_fields: Vec<_> = self + .frame + .designated_field_names + .iter() + .enumerate() + .map(|(i, name)| CoroFrameDesignatedField { + name: CBoxedSlice::from(name.as_bytes()), + index: i as u32 + designated_filed_offset, + }) + .collect(); + let target_tokens = self.collect_target_tokens(self.scope); + // create the callable module + CallableModule { + module, + ret_type: Type::void(), + args: CBoxedSlice::new(self.args.clone()), + captures: CBoxedSlice::new(Vec::new()), + subroutines: CBoxedSlice::new(Vec::new()), + subroutine_ids: CBoxedSlice::new(Vec::new()), + coro_target_tokens: CBoxedSlice::new(target_tokens), + coro_frame_input_fields: CBoxedSlice::new(in_fields), + coro_frame_output_fields: CBoxedSlice::new(out_fields), + coro_frame_designated_fields: CBoxedSlice::new(designated_fields), + cpu_custom_ops: CBoxedSlice::new(Vec::new()), + pools: self.coro.pools.clone(), + } + } +} + +impl Transform for MaterializeCoro { + fn transform_callable(&self, callable: CallableModule) -> CallableModule { + let callable = CanonicalizeControlFlow.transform_callable(callable); + // let callable = Mem2Reg.transform_callable(callable); + let callable = DemoteLocals.transform_callable(callable); + let callable = DeferLoad.transform_callable(callable); + let coro_graph = CoroGraph::from(&callable.module); + let coro_use_def = CoroUseDefAnalysis::analyze(&coro_graph); + let coro_transition_graph = CoroTransitionGraph::build(&coro_graph, &coro_use_def); + let coro_frame = CoroFrame::build(&coro_graph, &coro_transition_graph); + coro_frame.dump(); + let mut entry = CoroScopeMaterializer::new(&coro_frame, &callable, None).materialize(); + let subroutines: Vec<_> = coro_graph + .tokens + .keys() + .map(|token| { + let r = + CoroScopeMaterializer::new(&coro_frame, &callable, Some(*token)).materialize(); + CallableModuleRef(CArc::new(r)) + }) + .collect(); + let subroutine_token: Vec<_> = coro_graph.tokens.keys().copied().collect(); + entry.subroutines = CBoxedSlice::new(subroutines); + entry.subroutine_ids = CBoxedSlice::new(subroutine_token); + entry + } +} diff --git a/src/rust/luisa_compute_ir/src/transform/mod.rs b/src/rust/luisa_compute_ir/src/transform/mod.rs index e6ed7f19a..c29ec31e2 100644 --- a/src/rust/luisa_compute_ir/src/transform/mod.rs +++ b/src/rust/luisa_compute_ir/src/transform/mod.rs @@ -20,6 +20,7 @@ pub mod copy_propagation; pub mod defer_load; pub mod inliner; pub mod materialize_coro; +pub mod materialize_coro_v2; pub mod remove_phi; use crate::ir::{self, CallableModule, KernelModule, Module, ModuleFlags}; @@ -136,6 +137,10 @@ pub extern "C" fn luisa_compute_ir_transform_pipeline_add_transform( let transform = materialize_coro::MaterializeCoro; unsafe { (*pipeline).add_transform(Box::new(transform)) }; } + "materialize_coro_v2" => { + let transform = materialize_coro_v2::MaterializeCoro; + unsafe { (*pipeline).add_transform(Box::new(transform)) }; + } "mem2reg" => { let transform = mem2reg::Mem2Reg; unsafe { (*pipeline).add_transform(Box::new(transform)) }; diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 9ece0d5b7..ae5fe50cd 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -201,4 +201,5 @@ luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) +luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp new file mode 100644 index 000000000..907870839 --- /dev/null +++ b/src/tests/coro/helloworld_v2.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace luisa; +using namespace luisa::compute; + +int main(int argc, char *argv[]) { + + Context context{argv[0]}; + if (argc <= 1) { exit(1); } + Device device = context.create_device(argv[1]); + Stream stream = device.create_stream(); + constexpr uint2 resolution = make_uint2(1024, 1024); + Image image{device.create_image(PixelStorage::BYTE4, resolution)}; + luisa::vector host_image(image.view().size_bytes()); + + coro_v2::Coroutine coro = [&]() noexcept { + $suspend("1"); + Var coord = coro_id().xy(); + $suspend("2"); + Var uv = (make_float2(coord) + 0.5f) / make_float2(resolution); + $suspend("3", std::make_pair(uv, "uv")); + image->write(coord, make_float4(uv, 0.5f, 1.0f)); + }; + + LUISA_INFO("CoroGraph:\n{}", coro.graph()->dump()); +} From fbbf1860e6f722483a4c415578a2df7dd4f82610 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 20:02:46 +0800 Subject: [PATCH 06/67] good --- include/luisa/dsl/coro/coro_frame.h | 1 + include/luisa/dsl/coro/coro_func.h | 31 +- include/luisa/dsl/coro/coro_graph.h | 2 +- include/luisa/dsl/func.h | 13 +- src/ir/ast2ir.cpp | 2 +- .../src/transform/materialize_coro_v2.rs | 2 +- src/tests/CMakeLists.txt | 1 + src/tests/coro/path_tracing_v2.cpp | 375 ++++++++++++++++++ 8 files changed, 419 insertions(+), 8 deletions(-) create mode 100644 src/tests/coro/path_tracing_v2.cpp diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index 662996678..0cc2b4698 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace luisa::compute::inline dsl::coro_v2 { diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 00df4c44b..8d46c376c 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -4,9 +4,9 @@ #pragma once -#include #include #include +#include namespace luisa::compute::inline dsl::coro_v2 { @@ -23,6 +23,25 @@ class Coroutine { "Coroutine function must return void."); using Token = CoroGraph::Token; + static constexpr auto entry_token = CoroGraph::entry_token; + + class Subroutine { + + private: + Function f; + + private: + friend class Coroutine; + explicit Subroutine(Function function) noexcept : f{function} {} + + public: + void operator()(CoroFrame &frame, detail::prototype_to_callable_invocation_t... args) const noexcept { + detail::CallableInvoke invoke; + invoke << frame.expression(); + static_cast((invoke << ... << args)); + detail::FunctionBuilder::current()->call(f, invoke.args()); + } + }; private: luisa::shared_ptr _graph; @@ -58,6 +77,16 @@ class Coroutine { public: [[nodiscard]] auto graph() const noexcept { return _graph.get(); } [[nodiscard]] auto &shared_graph() const noexcept { return _graph; } + +public: + [[nodiscard]] auto instantiate() const noexcept { return _graph->frame()->instantiate(); } + [[nodiscard]] auto instantiate(Expr coro_id) const noexcept { return _graph->frame()->instantiate(coro_id); } + [[nodiscard]] auto subroutine_count() const noexcept { return _graph->nodes().size(); } + [[nodiscard]] auto operator[](Token token) const noexcept { return Subroutine{_graph->node(token).cc()}; } + [[nodiscard]] auto operator[](luisa::string_view name) const noexcept { return Subroutine{_graph->node(name).cc()}; } + [[nodiscard]] auto entry() const noexcept { return (*this)[entry_token]; } + [[nodiscard]] auto subroutine(Token token) const noexcept { return (*this)[token]; } + [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } }; template diff --git a/include/luisa/dsl/coro/coro_graph.h b/include/luisa/dsl/coro/coro_graph.h index 81180c440..6602425c8 100644 --- a/include/luisa/dsl/coro/coro_graph.h +++ b/include/luisa/dsl/coro/coro_graph.h @@ -20,7 +20,7 @@ class LC_DSL_API CoroGraph { using CC = luisa::shared_ptr;// current continuation function public: - class Node { + class LC_DSL_API Node { private: luisa::vector _input_fields; diff --git a/include/luisa/dsl/func.h b/include/luisa/dsl/func.h index ebad45f5b..693c78980 100644 --- a/include/luisa/dsl/func.h +++ b/include/luisa/dsl/func.h @@ -306,6 +306,14 @@ class CallableInvoke { public: CallableInvoke() noexcept = default; + /// Add an argument + CallableInvoke &operator<<(const Expression *expr) noexcept { + if (_arg_count == max_argument_count) [[unlikely]] { + _error_too_many_arguments(); + } + _args[_arg_count++] = expr; + return *this; + } /// Add an argument. template CallableInvoke &operator<<(Expr arg) noexcept { @@ -314,10 +322,7 @@ class CallableInvoke { } else if constexpr (is_soa_expr_v) { callable_encode_soa(*this, arg); } else { - if (_arg_count == max_argument_count) [[unlikely]] { - _error_too_many_arguments(); - } - _args[_arg_count++] = arg.expression(); + *this << arg.expression(); } return *this; } diff --git a/src/ir/ast2ir.cpp b/src/ir/ast2ir.cpp index 77651fb59..e76d8c573 100644 --- a/src/ir/ast2ir.cpp +++ b/src/ir/ast2ir.cpp @@ -33,7 +33,7 @@ namespace luisa::compute { [](ir::CArc *p) noexcept { p->release(); luisa::delete_with_allocator(p); - }}; + }}; } [[nodiscard]] luisa::shared_ptr> AST2IR::build_coroutine(Function function) noexcept { diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs index 461077375..9c9d41eb8 100644 --- a/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs +++ b/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs @@ -1001,7 +1001,7 @@ impl<'a> CoroScopeMaterializer<'a> { .args .iter() .cloned() - .zip(self.args.iter().cloned()) + .zip(self.args.iter().skip(1).cloned()) .collect(); let mut entry_builder = IrBuilder::new(self.coro.pools.clone()); let mut ctx = CoroScopeMaterializerCtx { diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index ae5fe50cd..e9b53812f 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -200,6 +200,7 @@ endif () luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) +luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp new file mode 100644 index 000000000..9c79279b5 --- /dev/null +++ b/src/tests/coro/path_tracing_v2.cpp @@ -0,0 +1,375 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common/cornell_box.h" + +#define TINYOBJLOADER_IMPLEMENTATION +#include "../common/tiny_obj_loader.h" + +using namespace luisa; +using namespace luisa::compute; + +struct Onb { + float3 tangent; + float3 binormal; + float3 normal; +}; + +LUISA_STRUCT(Onb, tangent, binormal, normal) { + [[nodiscard]] Float3 to_world(Expr v) const noexcept { + return v.x * tangent + v.y * binormal + v.z * normal; + } +}; + +int main(int argc, char *argv[]) { + + log_level_verbose(); + + Context context{argv[0]}; + if (argc <= 1) { + LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); + exit(1); + } + Device device = context.create_device(argv[1]); + + // load the Cornell Box scene + tinyobj::ObjReaderConfig obj_reader_config; + obj_reader_config.triangulate = true; + obj_reader_config.vertex_color = false; + tinyobj::ObjReader obj_reader; + if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { + luisa::string_view error_message = "unknown error."; + if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } + LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); + } + if (auto &&e = obj_reader.Warning(); !e.empty()) { + LUISA_WARNING_WITH_LOCATION("{}", e); + } + + auto &&p = obj_reader.GetAttrib().vertices; + luisa::vector vertices; + vertices.reserve(p.size() / 3u); + for (uint i = 0u; i < p.size(); i += 3u) { + vertices.emplace_back(make_float3( + p[i + 0u], p[i + 1u], p[i + 2u])); + } + LUISA_INFO( + "Loaded mesh with {} shape(s) and {} vertices.", + obj_reader.GetShapes().size(), vertices.size()); + + BindlessArray heap = device.create_bindless_array(); + Stream stream = device.create_stream(StreamTag::GRAPHICS); + Buffer vertex_buffer = device.create_buffer(vertices.size()); + stream << vertex_buffer.copy_from(vertices.data()); + luisa::vector meshes; + luisa::vector> triangle_buffers; + for (auto &&shape : obj_reader.GetShapes()) { + uint index = static_cast(meshes.size()); + std::vector const &t = shape.mesh.indices; + uint triangle_count = t.size() / 3u; + LUISA_INFO( + "Processing shape '{}' at index {} with {} triangle(s).", + shape.name, index, triangle_count); + luisa::vector indices; + indices.reserve(t.size()); + for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } + Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); + Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); + heap.emplace_on_update(index, triangle_buffer); + stream << triangle_buffer.copy_from(indices.data()) + << mesh.build(); + } + + Accel accel = device.create_accel({}); + for (Mesh &m : meshes) { + accel.emplace_back(m, make_float4x4(1.0f)); + } + stream << heap.update() + << accel.build() + << synchronize(); + + Constant materials{ + make_float3(0.725f, 0.710f, 0.680f),// floor + make_float3(0.725f, 0.710f, 0.680f),// ceiling + make_float3(0.725f, 0.710f, 0.680f),// back wall + make_float3(0.140f, 0.450f, 0.091f),// right wall + make_float3(0.630f, 0.065f, 0.050f),// left wall + make_float3(0.725f, 0.710f, 0.680f),// short box + make_float3(0.725f, 0.710f, 0.680f),// tall box + make_float3(0.000f, 0.000f, 0.000f),// light + }; + + Callable linear_to_srgb = [&](Var x) noexcept { + return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, + 12.92f * x, + x <= 0.00031308f)); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + UInt s0 = def(0u); + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { + UInt2 p = dispatch_id().xy(); + UInt state = tea(p.x, p.y); + seed_image.write(p, make_uint4(state)); + }; + + Callable lcg = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state & 0x00ffffffu) * + (1.0f / static_cast(0x01000000u)); + }; + + Callable make_onb = [](const Float3 &normal) noexcept { + Float3 binormal = normalize(ite( + abs(normal.x) > abs(normal.z), + make_float3(-normal.y, normal.x, 0.0f), + make_float3(0.0f, -normal.z, normal.y))); + Float3 tangent = normalize(cross(binormal, normal)); + return def(tangent, binormal, normal); + }; + + Callable generate_ray = [](Float2 p) noexcept { + static constexpr float fov = radians(27.8f); + static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); + Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); + Float3 direction = normalize(pixel - origin); + return make_ray(origin, direction); + }; + + Callable cosine_sample_hemisphere = [](Float2 u) noexcept { + Float r = sqrt(u.x); + Float phi = 2.0f * constants::pi * u.y; + return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); + }; + + Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { + return pdf_a / max(pdf_a + pdf_b, 1e-4f); + }; + + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + + coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + UInt2 coord = dispatch_id().xy(); + Float frame_size = min(resolution.x, resolution.y).cast(); + UInt state = seed_image.read(coord).x; + Float rx = lcg(state); + Float ry = lcg(state); + Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; + Float3 radiance = def(make_float3(0.0f)); + $suspend("per_spp"); + $for (i, spp_per_dispatch) { + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + $suspend("per_depth"); + $for (depth, 10u) { + // trace + $suspend("before_tracing"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + $suspend("after_tracing"); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; + }; + $break; + }; + + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + }; + + // sample BSDF + $suspend("sample_bsdf"); + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; + }; + }; + $suspend("write_film"); + radiance /= static_cast(spp_per_dispatch); + seed_image.write(coord, make_uint4(state)); + $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; + image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); + }; + + Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { + auto frame = raytracing_coro.instantiate(dispatch_id()); + raytracing_coro.entry()(frame, image, seed_image, accel, resolution); + $loop { + $switch (frame.target_token) { + for (auto i = 1u; i < raytracing_coro.subroutine_count(); i++) { + $case (i) { + raytracing_coro[i](frame, image, seed_image, accel, resolution); + }; + } + $default { + $return(); + }; + }; + }; + }; + + Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { + UInt2 p = dispatch_id().xy(); + Float4 accum = accum_image.read(p); + Float3 curr = curr_image.read(p).xyz(); + accum_image.write(p, accum + make_float4(curr, 1.f)); + }; + + Callable aces_tonemapping = [](Float3 x) noexcept { + static constexpr float a = 2.51f; + static constexpr float b = 0.03f; + static constexpr float c = 2.43f; + static constexpr float d = 0.59f; + static constexpr float e = 0.14f; + return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); + }; + + Kernel2D clear_kernel = [](ImageFloat image) noexcept { + image.write(dispatch_id().xy(), make_float4(0.0f)); + }; + + Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { + UInt2 coord = dispatch_id().xy(); + Float4 hdr = hdr_image.read(coord); + Float3 ldr = hdr.xyz() / hdr.w * scale; + $if (!is_hdr) { + ldr = linear_to_srgb(ldr); + }; + ldr_image.write(coord, make_float4(ldr, 1.0f)); + }; + + ShaderOption o{.enable_debug_info = false}; + auto clear_shader = device.compile(clear_kernel, o); + auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); + auto accumulate_shader = device.compile(accumulate_kernel, o); + auto raytracing_shader = device.compile(mega_kernel, o); + auto make_sampler_shader = device.compile(make_sampler_kernel, o); + + static constexpr uint2 resolution = make_uint2(1024u); + Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); + Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); + luisa::vector> host_image(resolution.x * resolution.y); + CommandList cmd_list; + Image seed_image = device.create_image(PixelStorage::INT1, resolution); + cmd_list << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); + + Window window{"path tracing", resolution}; + Swapchain swap_chain = device.create_swapchain( + stream, + SwapchainOption{ + .display = window.native_display(), + .window = window.native_handle(), + .size = make_uint2(resolution), + .wants_hdr = false, + .wants_vsync = false, + .back_buffer_count = 3, + }); + Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); + double last_time = 0.0; + uint frame_count = 0u; + Clock clock; + + while (!window.should_close()) { + cmd_list << raytracing_shader(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution); + cmd_list << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution); + stream << cmd_list.commit() + << swap_chain.present(ldr_image) << synchronize(); + window.poll_events(); + double dt = clock.toc() - last_time; + last_time = clock.toc(); + frame_count += spp_per_dispatch; + LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", + frame_count, dt, spp_per_dispatch / dt * 1000); + } + stream + << ldr_image.copy_to(host_image.data()) + << synchronize(); + + LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); + stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); +} From 57a45acccf5bcbcc63fa9f2af09f277b1d47b585 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 20:38:20 +0800 Subject: [PATCH 07/67] add coro-v2 examples --- include/luisa/dsl/coro/coro_func.h | 34 ++++++++++++++++++++++++++++++ src/tests/coro/path_tracing_v2.cpp | 10 ++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 8d46c376c..998c3f0f3 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -87,6 +87,40 @@ class Coroutine { [[nodiscard]] auto entry() const noexcept { return (*this)[entry_token]; } [[nodiscard]] auto subroutine(Token token) const noexcept { return (*this)[token]; } [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } + +private: + template + class Awaiter : public concepts::Noncopyable { + private: + U _f; + + private: + friend class Coroutine; + explicit Awaiter(U f) noexcept : _f{std::move(f)} {} + + public: + void await() && noexcept { return _f(luisa::nullopt); } + void await(Expr coro_id) && noexcept { return _f(luisa::make_optional(coro_id)); } + }; + +public: + [[nodiscard]] auto operator()(detail::prototype_to_callable_invocation_t... args) const noexcept { + return Awaiter{[=](luisa::optional> coro_id) noexcept { + auto frame = coro_id ? instantiate(*coro_id) : instantiate(); + entry()(frame, args...); + dsl::loop([&] { + dsl::if_(frame.target_token == CoroGraph::terminal_token, + [&] { dsl::break_(); }); + dsl::suspend(); + auto s = dsl::switch_(frame.target_token); + for (auto i = 1u; i < subroutine_count(); i++) { + std::move(s).case_(i, [&] { + subroutine(i)(frame, args...); + }); + } + }); + }}; + } }; template diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 9c79279b5..2c9d3bb8e 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -171,7 +171,7 @@ int main(int argc, char *argv[]) { auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; - coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + coro_v2::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { UInt2 coord = dispatch_id().xy(); Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; @@ -272,6 +272,14 @@ int main(int argc, char *argv[]) { image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); }; + coro_v2::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + coro(image, seed_image, accel, resolution).await(dispatch_id()); + }; + + coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + raytrace_coro(image, seed_image, accel, resolution).await(dispatch_id()); + }; + Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { auto frame = raytracing_coro.instantiate(dispatch_id()); raytracing_coro.entry()(frame, image, seed_image, accel, resolution); From 70390d9d40bb12f434161d129e980f2a0a8d8793 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 21:03:51 +0800 Subject: [PATCH 08/67] minor api revise --- include/luisa/dsl/coro/coro_func.h | 41 ++++++++++++++---------------- src/dsl/coro/coro_func.cpp | 20 ++++++++++++++- src/tests/coro/path_tracing_v2.cpp | 7 ++--- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 998c3f0f3..e403d0bf2 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -10,6 +10,12 @@ namespace luisa::compute::inline dsl::coro_v2 { +namespace detail { +LC_DSL_API void coroutine_chained_await_impl( + CoroFrame &frame, uint node_count, + luisa::move_only_function node) noexcept; +}// namespace detail + template class Coroutine { static_assert(luisa::always_false_v); @@ -35,11 +41,11 @@ class Coroutine { explicit Subroutine(Function function) noexcept : f{function} {} public: - void operator()(CoroFrame &frame, detail::prototype_to_callable_invocation_t... args) const noexcept { - detail::CallableInvoke invoke; + void operator()(CoroFrame &frame, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + compute::detail::CallableInvoke invoke; invoke << frame.expression(); static_cast((invoke << ... << args)); - detail::FunctionBuilder::current()->call(f, invoke.args()); + compute::detail::FunctionBuilder::current()->call(f, invoke.args()); } }; @@ -56,20 +62,20 @@ class Coroutine { requires std::negation_v>> && std::negation_v>> Coroutine(Def &&f) noexcept { - auto coro = detail::FunctionBuilder::define_coroutine([&f] { - static_assert(std::is_invocable_r_v...>); + auto coro = compute::detail::FunctionBuilder::define_coroutine([&f] { + static_assert(std::is_invocable_r_v...>); auto create = [](auto &&def, std::index_sequence) noexcept { using arg_tuple = std::tuple; using var_tuple = std::tuple>...>; - using tag_tuple = std::tuple...>; - auto args = detail::create_argument_definitions(std::tuple<>{}); + using tag_tuple = std::tuple...>; + auto args = compute::detail::create_argument_definitions(std::tuple<>{}); static_assert(std::tuple_size_v == sizeof...(Args)); return luisa::invoke(std::forward(def), - static_cast> &&>(std::get(args))...); }; create(std::forward(f), std::index_sequence_for{}); - detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types + compute::detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types }); _graph = CoroGraph::create(coro->function()); } @@ -104,26 +110,17 @@ class Coroutine { }; public: - [[nodiscard]] auto operator()(detail::prototype_to_callable_invocation_t... args) const noexcept { + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { return Awaiter{[=](luisa::optional> coro_id) noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); - entry()(frame, args...); - dsl::loop([&] { - dsl::if_(frame.target_token == CoroGraph::terminal_token, - [&] { dsl::break_(); }); - dsl::suspend(); - auto s = dsl::switch_(frame.target_token); - for (auto i = 1u; i < subroutine_count(); i++) { - std::move(s).case_(i, [&] { - subroutine(i)(frame, args...); - }); - } + detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](Token token, CoroFrame &f) noexcept { + subroutine(token)(f, args...); }); }}; } }; template -Coroutine(T &&) -> Coroutine>>; +Coroutine(T &&) -> Coroutine>>; }// namespace luisa::compute::inline dsl::coro_v2 diff --git a/src/dsl/coro/coro_func.cpp b/src/dsl/coro/coro_func.cpp index c66e22f33..acd3ae2c6 100644 --- a/src/dsl/coro/coro_func.cpp +++ b/src/dsl/coro/coro_func.cpp @@ -2,7 +2,25 @@ // Created by Mike on 2024/5/8. // +#include #include -namespace luisa::compute::inline dsl::coro_v2 { +namespace luisa::compute::inline dsl::coro_v2::detail { + +void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, + luisa::move_only_function node) noexcept { + node(CoroGraph::entry_token, frame); + $while (frame.target_token != CoroGraph::terminal_token) { + $suspend(); + $switch (frame.target_token) { + for (auto i = 1u; i < node_count; i++) { + $case (i) { + node(i, frame); + }; + } + $default { dsl::unreachable(); }; + }; + }; } + +}// namespace luisa::compute::inline dsl::coro_v2::detail diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 2c9d3bb8e..f95b02385 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -272,12 +272,13 @@ int main(int argc, char *argv[]) { image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); }; - coro_v2::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - coro(image, seed_image, accel, resolution).await(dispatch_id()); + coro_v2::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution, UInt2 pixel_id) noexcept { + auto coro_id = make_uint3(pixel_id, 0u); + coro(image, seed_image, accel, resolution).await(coro_id); }; coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - raytrace_coro(image, seed_image, accel, resolution).await(dispatch_id()); + raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { From e3be885e6d35d192e6a2fc01de9782b316be4418 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 23:16:19 +0800 Subject: [PATCH 09/67] generator good --- include/luisa/dsl/coro/coro_frame.h | 1 + include/luisa/dsl/coro/coro_func.h | 96 ++++++++++++++++++++++++++++- include/luisa/dsl/stmt.h | 7 +++ include/luisa/dsl/sugar.h | 10 ++- src/dsl/coro/coro_frame.cpp | 5 ++ src/dsl/coro/coro_func.cpp | 16 ++++- src/tests/coro/helloworld_v2.cpp | 21 ++++--- 7 files changed, 144 insertions(+), 12 deletions(-) diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index 0cc2b4698..4601c813a 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -82,6 +82,7 @@ class LC_DSL_API CoroFrame { [[nodiscard]] const Var &get(luisa::string_view name) const noexcept { return const_cast(this)->get(name); } + [[nodiscard]] Var is_terminated() const noexcept; }; }// namespace luisa::compute::inline dsl::coro_v2 diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index e403d0bf2..b4947197d 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -14,6 +14,9 @@ namespace detail { LC_DSL_API void coroutine_chained_await_impl( CoroFrame &frame, uint node_count, luisa::move_only_function node) noexcept; +LC_DSL_API void coroutine_generator_step_impl( + CoroFrame &frame, uint node_count, bool is_entry, + luisa::move_only_function node) noexcept; }// namespace detail template @@ -111,16 +114,105 @@ class Coroutine { public: [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - return Awaiter{[=](luisa::optional> coro_id) noexcept { + auto f = [=](luisa::optional> coro_id) noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](Token token, CoroFrame &f) noexcept { subroutine(token)(f, args...); }); - }}; + }; + return Awaiter{std::move(f)}; } }; template Coroutine(T &&) -> Coroutine>>; +template +class Generator { + static_assert(luisa::always_false_v); +}; + +template +class Generator { + + static_assert(!std::is_same_v, + "Generator function must not return void."); + +private: + Coroutine _coro; + +public: + template + requires std::negation_v>> && + std::negation_v>> + Generator(Def &&f) noexcept : _coro{std::forward(f)} {} + +private: + template + class Iterator { + private: + luisa::unique_ptr _frame; + U _f; + bool _invoked{false}; + LoopStmt *_loop{nullptr}; + + private: + friend class Generator; + Iterator(luisa::unique_ptr frame, U f) noexcept + : _frame{std::move(frame)}, _f{std::move(f)} {} + + public: + Iterator &operator++() noexcept { + _invoked = true; + _f(*_frame, false); + compute::detail::FunctionBuilder::current()->pop_scope(_loop->body()); + return *this; + } + [[nodiscard]] auto operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } + [[nodiscard]] Var> operator*() noexcept { + _f(*_frame, true); + auto fb = compute::detail::FunctionBuilder::current(); + _loop = fb->loop_(); + fb->push_scope(_loop->body()); + dsl::if_(_frame->is_terminated(), [] { dsl::break_(); }); + return _frame->get("yield_value"); + } + }; + +private: + template + [[nodiscard]] auto _make_iterator(U u, luisa::optional> coro_id) const noexcept { + auto frame = luisa::make_unique(coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate()); + return Iterator{std::move(frame), std::move(u)}; + } + +private: + template + class Stepper : public concepts::Noncopyable { + private: + const Generator &_g; + U _f; + + private: + friend class Generator; + Stepper(const Generator &g, U f) noexcept : _g{g}, _f{std::move(f)} {} + + public: + [[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), luisa::nullopt); } + [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } + }; + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + auto f = [=](CoroFrame &frame, bool is_entry) noexcept { + detail::coroutine_generator_step_impl( + frame, _coro.subroutine_count(), is_entry, + [&](CoroGraph::Token token, CoroFrame &f) noexcept { + _coro.subroutine(token)(f, args...); + }); + }; + return Stepper{*this, std::move(f)}; + } +}; + }// namespace luisa::compute::inline dsl::coro_v2 diff --git a/include/luisa/dsl/stmt.h b/include/luisa/dsl/stmt.h index d24e6dd5b..0b166ad9d 100644 --- a/include/luisa/dsl/stmt.h +++ b/include/luisa/dsl/stmt.h @@ -494,6 +494,13 @@ inline auto suspend(S &&first, Args &&...args) noexcept { return compute::dsl::suspend(""sv, std::forward(first), std::forward(args)...); } +template +inline void promise(luisa::string name, T &&value) noexcept { + detail::FunctionBuilder::current()->bind_promise_( + detail::extract_expression(std::forward(value)), + std::move(name)); +} + inline void return_() noexcept { detail::FunctionBuilder::current()->return_(); } diff --git a/include/luisa/dsl/sugar.h b/include/luisa/dsl/sugar.h index 45bce5d26..18676c53d 100644 --- a/include/luisa/dsl/sugar.h +++ b/include/luisa/dsl/sugar.h @@ -117,9 +117,17 @@ namespace luisa::compute::dsl_detail { ::luisa::compute::detail::comment( \ ::luisa::compute::dsl_detail::format_source_location( \ __FILE__, __LINE__)); \ - return ::luisa::compute::suspend(__VA_ARGS__); \ + return ::luisa::compute::dsl::suspend(__VA_ARGS__); \ }()) +#define $promise(...) ::luisa::compute::dsl::promise(__VA_ARGS__) + +#define $yield(...) \ + do { \ + ::luisa::compute::dsl::promise("yield_value", __VA_ARGS__); \ + ::luisa::compute::dsl::suspend(); \ + } while (false) + #define $loop \ ::luisa::compute::detail::LoopStmtBuilder::create_with_comment( \ ::luisa::compute::dsl_detail::format_source_location(__FILE__, __LINE__)) % \ diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp index d8cbc6750..543d4d60b 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/dsl/coro/coro_frame.cpp @@ -3,6 +3,7 @@ // #include +#include #include namespace luisa::compute::inline dsl::coro_v2 { @@ -105,4 +106,8 @@ void CoroFrame::_check_member_index(uint index) const noexcept { "CoroFrame member index out of range."); } +Var CoroFrame::is_terminated() const noexcept { + return target_token == CoroGraph::terminal_token; +} + }// namespace luisa::compute::inline dsl::coro_v2 diff --git a/src/dsl/coro/coro_func.cpp b/src/dsl/coro/coro_func.cpp index acd3ae2c6..7e667a389 100644 --- a/src/dsl/coro/coro_func.cpp +++ b/src/dsl/coro/coro_func.cpp @@ -10,7 +10,7 @@ namespace luisa::compute::inline dsl::coro_v2::detail { void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, luisa::move_only_function node) noexcept { node(CoroGraph::entry_token, frame); - $while (frame.target_token != CoroGraph::terminal_token) { + $while (!frame.is_terminated()) { $suspend(); $switch (frame.target_token) { for (auto i = 1u; i < node_count; i++) { @@ -23,4 +23,18 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, }; } +void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_entry, + luisa::move_only_function node) noexcept { + if (is_entry) { + node(CoroGraph::entry_token, frame); + } else { + $switch (frame.target_token) { + for (auto i = 1u; i < node_count; i++) { + $case (i) { node(i, frame); }; + } + $default { dsl::unreachable(); }; + }; + } +} + }// namespace luisa::compute::inline dsl::coro_v2::detail diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 907870839..4452a757e 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -20,14 +20,19 @@ int main(int argc, char *argv[]) { Image image{device.create_image(PixelStorage::BYTE4, resolution)}; luisa::vector host_image(image.view().size_bytes()); - coro_v2::Coroutine coro = [&]() noexcept { - $suspend("1"); - Var coord = coro_id().xy(); - $suspend("2"); - Var uv = (make_float2(coord) + 0.5f) / make_float2(resolution); - $suspend("3", std::make_pair(uv, "uv")); - image->write(coord, make_float4(uv, 0.5f, 1.0f)); + Kernel1D test = [] { + coro_v2::Generator g = [] { + auto x = def(0u); + $while (x < 10u) { + $yield(x); + x += 1u; + }; + }; + for (auto x : g()) { + device_log("x = {}", x); + } }; - LUISA_INFO("CoroGraph:\n{}", coro.graph()->dump()); + auto shader = device.compile(test); + stream << shader().dispatch(1u) << synchronize(); } From d34d6c40c989645031f5ca81b59765265c54ca97 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 23:26:09 +0800 Subject: [PATCH 10/67] fix --- include/luisa/dsl/coro/coro_func.h | 15 ++++++++++++--- src/tests/coro/helloworld_v2.cpp | 2 +- src/tests/coro/path_tracing_v2.cpp | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index b4947197d..11f8eef04 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -102,14 +102,18 @@ class Coroutine { class Awaiter : public concepts::Noncopyable { private: U _f; + luisa::optional> _coro_id; private: friend class Coroutine; explicit Awaiter(U f) noexcept : _f{std::move(f)} {} public: - void await() && noexcept { return _f(luisa::nullopt); } - void await(Expr coro_id) && noexcept { return _f(luisa::make_optional(coro_id)); } + [[nodiscard]] auto set_id(Expr coro_id) && noexcept { + _coro_id.emplace(coro_id); + return std::move(*this); + } + void await() && noexcept { return _f(std::move(_coro_id)); } }; public: @@ -190,6 +194,7 @@ class Generator { template class Stepper : public concepts::Noncopyable { private: + luisa::optional> _coro_id; const Generator &_g; U _f; @@ -198,7 +203,11 @@ class Generator { Stepper(const Generator &g, U f) noexcept : _g{g}, _f{std::move(f)} {} public: - [[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), luisa::nullopt); } + [[nodiscard]] auto set_id(Expr coro_id) && noexcept { + _coro_id.emplace(coro_id); + return std::move(*this); + } + [[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), std::move(_coro_id)); } [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } }; diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 4452a757e..9a120f7cc 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -28,7 +28,7 @@ int main(int argc, char *argv[]) { x += 1u; }; }; - for (auto x : g()) { + for (auto x : g().set_id(dispatch_id())) { device_log("x = {}", x); } }; diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index f95b02385..4e41877a3 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -274,7 +274,7 @@ int main(int argc, char *argv[]) { coro_v2::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution, UInt2 pixel_id) noexcept { auto coro_id = make_uint3(pixel_id, 0u); - coro(image, seed_image, accel, resolution).await(coro_id); + coro(image, seed_image, accel, resolution).set_id(coro_id).await(); }; coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { From b69dbd4f9c90723f94f867c1a7fa29498456c1a0 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 23:31:33 +0800 Subject: [PATCH 11/67] fix --- src/tests/coro/helloworld_v2.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 9a120f7cc..89ceb0c83 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -21,14 +21,14 @@ int main(int argc, char *argv[]) { luisa::vector host_image(image.view().size_bytes()); Kernel1D test = [] { - coro_v2::Generator g = [] { + coro_v2::Generator g = [](UInt n) { auto x = def(0u); - $while (x < 10u) { + $while (x < n) { $yield(x); x += 1u; }; }; - for (auto x : g().set_id(dispatch_id())) { + for (auto x : g(100u)) { device_log("x = {}", x); } }; From 479f33c4bf10c2ffb75de2d875ae07f54b3d505c Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 8 May 2024 23:41:41 +0800 Subject: [PATCH 12/67] minor --- include/luisa/dsl/coro/coro_func.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 11f8eef04..91d63c3ef 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -151,6 +151,9 @@ class Generator { std::negation_v>> Generator(Def &&f) noexcept : _coro{std::forward(f)} {} +public: + [[nodiscard]] auto coroutine() const noexcept { return _coro; } + private: template class Iterator { From a65538dec9bf1d918d82528cc22aa6630221a33d Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 9 May 2024 00:33:57 +0800 Subject: [PATCH 13/67] fix build --- include/luisa/dsl/coro/coro_func.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 91d63c3ef..391016ef8 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -118,7 +118,7 @@ class Coroutine { public: [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=](luisa::optional> coro_id) noexcept { + auto f = [=, this](luisa::optional> coro_id) noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](Token token, CoroFrame &f) noexcept { subroutine(token)(f, args...); @@ -175,7 +175,7 @@ class Generator { compute::detail::FunctionBuilder::current()->pop_scope(_loop->body()); return *this; } - [[nodiscard]] auto operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } + [[nodiscard]] bool operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } [[nodiscard]] Var> operator*() noexcept { _f(*_frame, true); auto fb = compute::detail::FunctionBuilder::current(); @@ -216,7 +216,7 @@ class Generator { public: [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=](CoroFrame &frame, bool is_entry) noexcept { + auto f = [=, this](CoroFrame &frame, bool is_entry) noexcept { detail::coroutine_generator_step_impl( frame, _coro.subroutine_count(), is_entry, [&](CoroGraph::Token token, CoroFrame &f) noexcept { From 41e1061d6e900f1c029972ab1b738c2a940a3f1a Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 12:57:50 +0800 Subject: [PATCH 14/67] refactor; prepare for schedulers --- include/luisa/dsl/coro/coro_frame.h | 38 ++------- include/luisa/dsl/coro/coro_func.h | 27 +++--- include/luisa/luisa-compute.h | 4 +- include/luisa/runtime/coro/coro_frame_desc.h | 34 ++++++++ .../luisa/{dsl => runtime}/coro/coro_graph.h | 29 +++---- include/luisa/runtime/coro/coro_token.h | 11 +++ src/dsl/CMakeLists.txt | 3 +- src/dsl/coro/coro_frame.cpp | 83 +++++-------------- src/dsl/coro/coro_func.cpp | 12 +-- src/runtime/CMakeLists.txt | 5 ++ src/runtime/coro/coro_frame_desc.cpp | 53 ++++++++++++ src/{dsl => runtime}/coro/coro_graph.cpp | 25 +++--- src/tests/coro/helloworld_v2.cpp | 2 +- src/tests/coro/path_tracing_v2.cpp | 6 +- 14 files changed, 184 insertions(+), 148 deletions(-) create mode 100644 include/luisa/runtime/coro/coro_frame_desc.h rename include/luisa/{dsl => runtime}/coro/coro_graph.h (69%) create mode 100644 include/luisa/runtime/coro/coro_token.h create mode 100644 src/runtime/coro/coro_frame_desc.cpp rename src/{dsl => runtime}/coro/coro_graph.cpp (94%) diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/dsl/coro/coro_frame.h index 4601c813a..3fef791a1 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/dsl/coro/coro_frame.h @@ -5,38 +5,10 @@ #pragma once #include -#include -#include -#include +#include #include -namespace luisa::compute::inline dsl::coro_v2 { - -class CoroFrame; - -class LC_DSL_API CoroFrameDesc : public luisa::enable_shared_from_this { - -public: - using DesignatedFieldDict = luisa::unordered_map; - -private: - const Type *_type{nullptr}; - DesignatedFieldDict _designated_fields; - -private: - CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept; - -public: - [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedFieldDict m) noexcept; - [[nodiscard]] auto type() const noexcept { return _type; } - [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } - [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; - [[nodiscard]] luisa::string dump() const noexcept; - -public: - [[nodiscard]] CoroFrame instantiate() const noexcept; - [[nodiscard]] CoroFrame instantiate(Expr coro_id) const noexcept; -}; +namespace luisa::compute::coroutine { class LC_DSL_API CoroFrame { @@ -55,6 +27,10 @@ class LC_DSL_API CoroFrame { CoroFrame &operator=(const CoroFrame &rhs) noexcept; CoroFrame &operator=(CoroFrame &&rhs) noexcept; +public: + [[nodiscard]] static CoroFrame create(luisa::shared_ptr desc) noexcept; + [[nodiscard]] static CoroFrame create(luisa::shared_ptr desc, Expr coro_id) noexcept; + private: void _check_member_index(uint index) const noexcept; @@ -85,4 +61,4 @@ class LC_DSL_API CoroFrame { [[nodiscard]] Var is_terminated() const noexcept; }; -}// namespace luisa::compute::inline dsl::coro_v2 +}// namespace luisa::compute::coro_v2 diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 91d63c3ef..1187f53f6 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -5,18 +5,18 @@ #pragma once #include -#include +#include #include -namespace luisa::compute::inline dsl::coro_v2 { +namespace luisa::compute::coroutine { namespace detail { LC_DSL_API void coroutine_chained_await_impl( CoroFrame &frame, uint node_count, - luisa::move_only_function node) noexcept; + luisa::move_only_function node) noexcept; LC_DSL_API void coroutine_generator_step_impl( CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept; + luisa::move_only_function node) noexcept; }// namespace detail template @@ -31,9 +31,6 @@ class Coroutine { static_assert(std::is_same_v, "Coroutine function must return void."); - using Token = CoroGraph::Token; - static constexpr auto entry_token = CoroGraph::entry_token; - class Subroutine { private: @@ -88,13 +85,13 @@ class Coroutine { [[nodiscard]] auto &shared_graph() const noexcept { return _graph; } public: - [[nodiscard]] auto instantiate() const noexcept { return _graph->frame()->instantiate(); } - [[nodiscard]] auto instantiate(Expr coro_id) const noexcept { return _graph->frame()->instantiate(coro_id); } + [[nodiscard]] auto instantiate() const noexcept { return CoroFrame::create(_graph->shared_frame()); } + [[nodiscard]] auto instantiate(Expr coro_id) const noexcept { return CoroFrame::create(_graph->shared_frame(), coro_id); } [[nodiscard]] auto subroutine_count() const noexcept { return _graph->nodes().size(); } - [[nodiscard]] auto operator[](Token token) const noexcept { return Subroutine{_graph->node(token).cc()}; } + [[nodiscard]] auto operator[](CoroToken token) const noexcept { return Subroutine{_graph->node(token).cc()}; } [[nodiscard]] auto operator[](luisa::string_view name) const noexcept { return Subroutine{_graph->node(name).cc()}; } - [[nodiscard]] auto entry() const noexcept { return (*this)[entry_token]; } - [[nodiscard]] auto subroutine(Token token) const noexcept { return (*this)[token]; } + [[nodiscard]] auto entry() const noexcept { return (*this)[coro_token_entry]; } + [[nodiscard]] auto subroutine(CoroToken token) const noexcept { return (*this)[token]; } [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } private: @@ -120,7 +117,7 @@ class Coroutine { [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { auto f = [=](luisa::optional> coro_id) noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); - detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](Token token, CoroFrame &f) noexcept { + detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept { subroutine(token)(f, args...); }); }; @@ -219,7 +216,7 @@ class Generator { auto f = [=](CoroFrame &frame, bool is_entry) noexcept { detail::coroutine_generator_step_impl( frame, _coro.subroutine_count(), is_entry, - [&](CoroGraph::Token token, CoroFrame &f) noexcept { + [&](CoroToken token, CoroFrame &f) noexcept { _coro.subroutine(token)(f, args...); }); }; @@ -227,4 +224,4 @@ class Generator { } }; -}// namespace luisa::compute::inline dsl::coro_v2 +}// namespace luisa::compute::coro_v2 diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 4d8ec72a4..ea2f7726e 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -65,7 +65,6 @@ #include #include #include -#include #include #include #include @@ -122,6 +121,9 @@ #include #include #include +#include +#include +#include #include #include #include diff --git a/include/luisa/runtime/coro/coro_frame_desc.h b/include/luisa/runtime/coro/coro_frame_desc.h new file mode 100644 index 000000000..5893c7b38 --- /dev/null +++ b/include/luisa/runtime/coro/coro_frame_desc.h @@ -0,0 +1,34 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +class LC_RUNTIME_API CoroFrameDesc : public luisa::enable_shared_from_this { + +public: + using DesignatedFieldDict = luisa::unordered_map; + +private: + const Type *_type{nullptr}; + DesignatedFieldDict _designated_fields; + +private: + CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept; + +public: + [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedFieldDict m) noexcept; + [[nodiscard]] auto type() const noexcept { return _type; } + [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } + [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; +}; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/dsl/coro/coro_graph.h b/include/luisa/runtime/coro/coro_graph.h similarity index 69% rename from include/luisa/dsl/coro/coro_graph.h rename to include/luisa/runtime/coro/coro_graph.h index 6602425c8..8ef56956c 100644 --- a/include/luisa/dsl/coro/coro_graph.h +++ b/include/luisa/runtime/coro/coro_graph.h @@ -5,33 +5,33 @@ #pragma once #include +#include +#include #include +#include -namespace luisa::compute::inline dsl::coro_v2 { +namespace luisa::compute::coroutine { class CoroFrameDesc; -class LC_DSL_API CoroGraph { +class LC_RUNTIME_API CoroGraph { public: - using Token = uint; - static constexpr Token entry_token = 0u; - static constexpr Token terminal_token = 0x8000'0000u; using CC = luisa::shared_ptr;// current continuation function public: - class LC_DSL_API Node { + class LC_RUNTIME_API Node { private: luisa::vector _input_fields; luisa::vector _output_fields; - luisa::vector _targets; + luisa::vector _targets; CC _cc; public: Node(luisa::vector input_fields, luisa::vector output_fields, - luisa::vector targets, + luisa::vector targets, CC current_continuation) noexcept; ~Node() noexcept; @@ -45,13 +45,13 @@ class LC_DSL_API CoroGraph { private: luisa::shared_ptr _frame; - luisa::unordered_map _nodes; - luisa::unordered_map _named_tokens; + luisa::unordered_map _nodes; + luisa::unordered_map _named_tokens; public: CoroGraph(luisa::shared_ptr frame_desc, - luisa::unordered_map nodes, - luisa::unordered_map named_tokens) noexcept; + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept; ~CoroGraph() noexcept; public: @@ -60,12 +60,13 @@ class LC_DSL_API CoroGraph { public: [[nodiscard]] auto frame() const noexcept { return _frame.get(); } + [[nodiscard]] auto &shared_frame() const noexcept { return _frame; } [[nodiscard]] auto &nodes() const noexcept { return _nodes; } [[nodiscard]] auto &named_tokens() const noexcept { return _named_tokens; } [[nodiscard]] const Node &entry() const noexcept; - [[nodiscard]] const Node &node(Token index) const noexcept; + [[nodiscard]] const Node &node(CoroToken index) const noexcept; [[nodiscard]] const Node &node(luisa::string_view name) const noexcept; [[nodiscard]] luisa::string dump() const noexcept; }; -}// namespace luisa::compute::inline dsl::coro_v2 +}// namespace luisa::compute::co diff --git a/include/luisa/runtime/coro/coro_token.h b/include/luisa/runtime/coro/coro_token.h new file mode 100644 index 000000000..258aaa7cd --- /dev/null +++ b/include/luisa/runtime/coro/coro_token.h @@ -0,0 +1,11 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +namespace luisa::compute::coroutine { +using CoroToken = unsigned int; +constexpr CoroToken coro_token_entry = 0u; +constexpr CoroToken coro_token_terminal = 0x8000'0000u; +}// namespace luisa::compute::coro_v2 diff --git a/src/dsl/CMakeLists.txt b/src/dsl/CMakeLists.txt index ea27abc61..2895759fb 100644 --- a/src/dsl/CMakeLists.txt +++ b/src/dsl/CMakeLists.txt @@ -12,8 +12,7 @@ if (LUISA_COMPUTE_ENABLE_DSL) set(LUISA_COMPUTE_DSL_CORO_SOURCES coro/coro_frame.cpp - coro/coro_func.cpp - coro/coro_graph.cpp) + coro/coro_func.cpp) set(LUISA_COMPUTE_DSL_SOURCES builtin.cpp diff --git a/src/dsl/coro/coro_frame.cpp b/src/dsl/coro/coro_frame.cpp index 543d4d60b..d7d9f7aa6 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/dsl/coro/coro_frame.cpp @@ -3,69 +3,10 @@ // #include -#include +#include #include -namespace luisa::compute::inline dsl::coro_v2 { - -CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept - : _type{type}, _designated_fields{std::move(m)} { - LUISA_ASSERT(_type != nullptr, "CoroFrame underlying type must not be null."); - LUISA_ASSERT(_type->is_structure(), "CoroFrame underlying type must be a structure."); - LUISA_ASSERT(_type->members().size() >= 2u, "CoroFrame underlying type must have at least 2 members (coro_id and target_token)."); - LUISA_ASSERT(_type->members()[0] == Type::of(), "CoroFrame member 0 (coro_id) must be uint3."); - LUISA_ASSERT(_type->members()[1] == Type::of(), "CoroFrame member 1 (target_token) must be uint."); - auto member_count = _type->members().size(); - for (auto &&[name, index] : _designated_fields) { - LUISA_ASSERT(name != "coro_id", "CoroFrame designated member name 'coro_id' is reserved."); - LUISA_ASSERT(name != "target_token", "CoroFrame designated member name 'target_token' is reserved."); - LUISA_ASSERT(index != 0, "CoroFrame designated member index 0 is reserved for coro_id."); - LUISA_ASSERT(index != 1, "CoroFrame designated member index 1 is reserved for target_token."); - LUISA_ASSERT(index < member_count, "CoroFrame designated member index out of range."); - } -} - -luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedFieldDict m) noexcept { - return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); -} - -uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { - if (name == "coro_id") { return 0u; } - if (name == "target_token") { return 1u; } - auto iter = _designated_fields.find(name); - LUISA_ASSERT(iter != _designated_fields.end(), "CoroFrame designated member not found."); - return iter->second; -} - -luisa::string CoroFrameDesc::dump() const noexcept { - luisa::string s; - for (auto i = 0u; i < _type->members().size(); i++) { - s.append(luisa::format(" Field {}: {}\n", i, _type->members()[i]->description())); - } - if (!_designated_fields.empty()) { - s.append("Designated Fields:\n"); - for (auto &&[name, index] : _designated_fields) { - s.append(luisa::format(" {} -> \"{}\"\n", index, name)); - } - } - return s; -} - -CoroFrame CoroFrameDesc::instantiate() const noexcept { - auto fb = detail::FunctionBuilder::current(); - // create an variable for the coro frame - auto expr = fb->local(_type); - // initialize the coro frame members - auto zero_init = fb->call(_type, CallOp::ZERO, {}); - fb->assign(expr, zero_init); - return CoroFrame{shared_from_this(), expr}; -} - -CoroFrame CoroFrameDesc::instantiate(Expr coro_id) const noexcept { - auto frame = instantiate(); - frame.coro_id = coro_id; - return frame; -} +namespace luisa::compute::coroutine { CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept : _desc{std::move(desc)}, @@ -101,13 +42,29 @@ CoroFrame &CoroFrame::operator=(CoroFrame &&rhs) noexcept { return *this = static_cast(rhs); } +CoroFrame CoroFrame::create(luisa::shared_ptr desc) noexcept { + auto fb = detail::FunctionBuilder::current(); + // create an variable for the coro frame + auto expr = fb->local(desc->type()); + // initialize the coro frame members + auto zero_init = fb->call(desc->type(), CallOp::ZERO, {}); + fb->assign(expr, zero_init); + return CoroFrame{std::move(desc), expr}; +} + +CoroFrame CoroFrame::create(luisa::shared_ptr desc, Expr coro_id) noexcept { + auto frame = create(std::move(desc)); + frame.coro_id = coro_id; + return frame; +} + void CoroFrame::_check_member_index(uint index) const noexcept { LUISA_ASSERT(index < _desc->type()->members().size(), "CoroFrame member index out of range."); } Var CoroFrame::is_terminated() const noexcept { - return target_token == CoroGraph::terminal_token; + return target_token == coro_token_terminal; } -}// namespace luisa::compute::inline dsl::coro_v2 +}// namespace luisa::compute::coro_v2 diff --git a/src/dsl/coro/coro_func.cpp b/src/dsl/coro/coro_func.cpp index 7e667a389..2fb583bd1 100644 --- a/src/dsl/coro/coro_func.cpp +++ b/src/dsl/coro/coro_func.cpp @@ -5,11 +5,11 @@ #include #include -namespace luisa::compute::inline dsl::coro_v2::detail { +namespace luisa::compute::coroutine::detail { void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, - luisa::move_only_function node) noexcept { - node(CoroGraph::entry_token, frame); + luisa::move_only_function node) noexcept { + node(coro_token_entry, frame); $while (!frame.is_terminated()) { $suspend(); $switch (frame.target_token) { @@ -24,9 +24,9 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, } void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept { + luisa::move_only_function node) noexcept { if (is_entry) { - node(CoroGraph::entry_token, frame); + node(coro_token_entry, frame); } else { $switch (frame.target_token) { for (auto i = 1u; i < node_count; i++) { @@ -37,4 +37,4 @@ void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_en } } -}// namespace luisa::compute::inline dsl::coro_v2::detail +}// namespace luisa::compute::coro_v2::detail diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 8401261d6..ccbedd707 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -18,6 +18,10 @@ set(LUISA_COMPUTE_RUNTIME_REMOTE_SOURCES remote/client_interface.cpp remote/server_interface.cpp) +set(LUISA_COMPUTE_RUNTIME_CORO_SOURCES + coro/coro_frame_desc.cpp + coro/coro_graph.cpp) + set(LUISA_COMPUTE_RUNTIME_SOURCES bindless_array.cpp buffer.cpp @@ -38,6 +42,7 @@ set(LUISA_COMPUTE_RUNTIME_SOURCES volume.cpp ${LUISA_COMPUTE_RUNTIME_RHI_SOURCES} ${LUISA_COMPUTE_RUNTIME_RTX_SOURCES} + ${LUISA_COMPUTE_RUNTIME_CORO_SOURCES} ${LUISA_COMPUTE_RUNTIME_RASTER_SOURCES} ${LUISA_COMPUTE_RUNTIME_REMOTE_SOURCES}) diff --git a/src/runtime/coro/coro_frame_desc.cpp b/src/runtime/coro/coro_frame_desc.cpp new file mode 100644 index 000000000..96ffd4916 --- /dev/null +++ b/src/runtime/coro/coro_frame_desc.cpp @@ -0,0 +1,53 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::coroutine { + +CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept + : _type{type}, _designated_fields{std::move(m)} { + LUISA_ASSERT(_type != nullptr, "CoroFrame underlying type must not be null."); + LUISA_ASSERT(_type->is_structure(), "CoroFrame underlying type must be a structure."); + LUISA_ASSERT(_type->members().size() >= 2u, "CoroFrame underlying type must have at least 2 members (coro_id and target_token)."); + LUISA_ASSERT(_type->members()[0]->is_uint32_vector() && _type->members()[0]->dimension() == 3u, "CoroFrame member 0 (coro_id) must be uint3."); + LUISA_ASSERT(_type->members()[1]->is_uint32(), "CoroFrame member 1 (target_token) must be uint."); + auto member_count = _type->members().size(); + for (auto &&[name, index] : _designated_fields) { + LUISA_ASSERT(name != "coro_id", "CoroFrame designated member name 'coro_id' is reserved."); + LUISA_ASSERT(name != "target_token", "CoroFrame designated member name 'target_token' is reserved."); + LUISA_ASSERT(index != 0, "CoroFrame designated member index 0 is reserved for coro_id."); + LUISA_ASSERT(index != 1, "CoroFrame designated member index 1 is reserved for target_token."); + LUISA_ASSERT(index < member_count, "CoroFrame designated member index out of range."); + } +} + +luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedFieldDict m) noexcept { + return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); +} + +uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { + if (name == "coro_id") { return 0u; } + if (name == "target_token") { return 1u; } + auto iter = _designated_fields.find(name); + LUISA_ASSERT(iter != _designated_fields.end(), "CoroFrame designated member not found."); + return iter->second; +} + +luisa::string CoroFrameDesc::dump() const noexcept { + luisa::string s; + for (auto i = 0u; i < _type->members().size(); i++) { + s.append(luisa::format(" Field {}: {}\n", i, _type->members()[i]->description())); + } + if (!_designated_fields.empty()) { + s.append("Designated Fields:\n"); + for (auto &&[name, index] : _designated_fields) { + s.append(luisa::format(" {} -> \"{}\"\n", index, name)); + } + } + return s; +} + +}// namespace luisa::compute::co diff --git a/src/dsl/coro/coro_graph.cpp b/src/runtime/coro/coro_graph.cpp similarity index 94% rename from src/dsl/coro/coro_graph.cpp rename to src/runtime/coro/coro_graph.cpp index 23c609c33..ed98433b9 100644 --- a/src/dsl/coro/coro_graph.cpp +++ b/src/runtime/coro/coro_graph.cpp @@ -4,18 +4,19 @@ #include #include -#include -#include +#include +#include +#include #ifdef LUISA_ENABLE_IR #include #endif -namespace luisa::compute::inline dsl::coro_v2 { +namespace luisa::compute::coroutine { CoroGraph::Node::Node(luisa::vector input_fields, luisa::vector output_fields, - luisa::vector targets, + luisa::vector targets, CC current_continuation) noexcept : _input_fields{std::move(input_fields)}, _output_fields{std::move(output_fields)}, @@ -59,8 +60,8 @@ luisa::string CoroGraph::Node::dump() const noexcept { } CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, - luisa::unordered_map nodes, - luisa::unordered_map named_tokens) noexcept + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept : _frame{std::move(frame_desc)}, _nodes{std::move(nodes)}, _named_tokens{std::move(named_tokens)} {} @@ -68,10 +69,10 @@ CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, CoroGraph::~CoroGraph() noexcept = default; const CoroGraph::Node &CoroGraph::entry() const noexcept { - return node(entry_token); + return node(coro_token_entry); } -const CoroGraph::Node &CoroGraph::node(Token token) const noexcept { +const CoroGraph::Node &CoroGraph::node(CoroToken token) const noexcept { auto iter = _nodes.find(token); LUISA_ASSERT(iter != _nodes.end(), "Coroutine node with token {} not found.", @@ -99,7 +100,7 @@ luisa::string CoroGraph::dump() const noexcept { } s.append("Frame:\n").append(_frame->dump()); for (auto &&[token, node] : _nodes) { - if (token == entry_token) { + if (token == coro_token_entry) { s.append("Entry:\n"); } else { s.append(luisa::format("Node {}:\n", token)); @@ -216,7 +217,7 @@ luisa::shared_ptr CoroGraph::create(Function coroutine) noexcep LUISA_ASSERT(subroutines.len == subroutine_ids.len, "Subroutine count mismatch: {} vs {}.", subroutines.len, subroutine_ids.len); - luisa::unordered_map nodes; + luisa::unordered_map nodes; nodes.reserve(subroutines.len + 1u); auto convert_fields = [](ir::CBoxedSlice slice) noexcept { luisa::vector fields; @@ -226,7 +227,7 @@ luisa::shared_ptr CoroGraph::create(Function coroutine) noexcep }; // add the entry node nodes.emplace( - entry_token, + coro_token_entry, Node{convert_fields(m->get()->coro_frame_input_fields), convert_fields(m->get()->coro_frame_output_fields), convert_fields(m->get()->coro_target_tokens), @@ -251,4 +252,4 @@ luisa::shared_ptr CoroGraph::create(Function coroutine) noexcep #endif -}// namespace luisa::compute::inline dsl::coro_v2 +}// namespace luisa::compute::co diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 89ceb0c83..2a0d76f05 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -21,7 +21,7 @@ int main(int argc, char *argv[]) { luisa::vector host_image(image.view().size_bytes()); Kernel1D test = [] { - coro_v2::Generator g = [](UInt n) { + coroutine::Generator g = [](UInt n) { auto x = def(0u); $while (x < n) { $yield(x); diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 4e41877a3..8827fe5cf 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -171,7 +171,7 @@ int main(int argc, char *argv[]) { auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; - coro_v2::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { UInt2 coord = dispatch_id().xy(); Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; @@ -272,12 +272,12 @@ int main(int argc, char *argv[]) { image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); }; - coro_v2::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution, UInt2 pixel_id) noexcept { + coroutine::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution, UInt2 pixel_id) noexcept { auto coro_id = make_uint3(pixel_id, 0u); coro(image, seed_image, accel, resolution).set_id(coro_id).await(); }; - coro_v2::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + coroutine::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; From fa779bec31b92c735fac8173a649907ebea4497e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 12:59:17 +0800 Subject: [PATCH 15/67] minor --- include/luisa/dsl/coro/coro_func.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/dsl/coro/coro_func.h index 1ff8785d7..c72ea61b5 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/dsl/coro/coro_func.h @@ -216,8 +216,8 @@ class Generator { auto f = [=, this](CoroFrame &frame, bool is_entry) noexcept { detail::coroutine_generator_step_impl( frame, _coro.subroutine_count(), is_entry, - [&](CoroToken token, CoroFrame &f) noexcept { - _coro.subroutine(token)(f, args...); + [&](CoroToken token, CoroFrame &ff) noexcept { + _coro.subroutine(token)(ff, args...); }); }; return Stepper{*this, std::move(f)}; From ceeab9e89c9d8b21d47bdcf4dff10a05b2cc3734 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 15:37:18 +0800 Subject: [PATCH 16/67] refactor --- include/luisa/coro/coro_graph.h | 35 ++- include/luisa/coro/coro_node.h | 5 +- .../luisa/{dsl/coro => coro/v2}/coro_frame.h | 4 +- .../coro => coro/v2}/coro_frame_desc.h | 2 +- .../luisa/{dsl/coro => coro/v2}/coro_func.h | 9 +- .../{runtime/coro => coro/v2}/coro_graph.h | 6 +- include/luisa/coro/v2/coro_scheduler.h | 87 ++++++ .../{runtime/coro => coro/v2}/coro_token.h | 0 .../luisa/coro/v2/schedulers/state_machine.h | 51 ++++ include/luisa/luisa-compute.h | 12 +- src/coro/CMakeLists.txt | 13 +- src/coro/coro_dispatcher.cpp | 0 src/{dsl => }/coro/coro_frame.cpp | 4 +- src/{runtime => }/coro/coro_frame_desc.cpp | 2 +- src/{dsl => }/coro/coro_func.cpp | 2 +- src/coro/coro_graph.cpp | 264 ++++++++++++++++-- src/coro/coro_node.cpp | 8 - src/coro/schedulers/state_machine.cpp | 21 ++ src/coro/shader_scheduler.cpp | 0 src/dsl/CMakeLists.txt | 7 +- src/runtime/CMakeLists.txt | 5 - src/runtime/coro/coro_graph.cpp | 255 ----------------- src/tests/coro/helloworld_v2.cpp | 28 +- src/tests/coro/path_tracing_v2.cpp | 14 +- 24 files changed, 479 insertions(+), 355 deletions(-) rename include/luisa/{dsl/coro => coro/v2}/coro_frame.h (96%) rename include/luisa/{runtime/coro => coro/v2}/coro_frame_desc.h (91%) rename include/luisa/{dsl/coro => coro/v2}/coro_func.h (97%) rename include/luisa/{runtime/coro => coro/v2}/coro_graph.h (95%) create mode 100644 include/luisa/coro/v2/coro_scheduler.h rename include/luisa/{runtime/coro => coro/v2}/coro_token.h (100%) create mode 100644 include/luisa/coro/v2/schedulers/state_machine.h delete mode 100644 src/coro/coro_dispatcher.cpp rename src/{dsl => }/coro/coro_frame.cpp (96%) rename src/{runtime => }/coro/coro_frame_desc.cpp (98%) rename src/{dsl => }/coro/coro_func.cpp (96%) delete mode 100644 src/coro/coro_node.cpp create mode 100644 src/coro/schedulers/state_machine.cpp delete mode 100644 src/coro/shader_scheduler.cpp delete mode 100644 src/runtime/coro/coro_graph.cpp diff --git a/include/luisa/coro/coro_graph.h b/include/luisa/coro/coro_graph.h index 07b2ed697..0bf726033 100644 --- a/include/luisa/coro/coro_graph.h +++ b/include/luisa/coro/coro_graph.h @@ -1,12 +1,11 @@ #pragma once -#include #include #include namespace luisa::compute::inline coro { -class LC_CORO_API CoroGraph { +class CoroGraph { private: luisa::unordered_map _nodes; @@ -16,17 +15,37 @@ class LC_CORO_API CoroGraph { public: // for construction only - CoroGraph(uint entry, const Type *state_type) noexcept; - [[nodiscard]] CoroNode *add_node(uint token, CoroNode::Func f) noexcept; - void designate_state_member(luisa::string name, uint index) noexcept; + CoroGraph(uint entry, const Type *state_type) noexcept : _entry{entry}, _state_type{state_type} {} + [[nodiscard]] CoroNode *add_node(uint token, CoroNode::Func f) noexcept { + auto node = CoroNode{this, std::move(f)}; + auto [iter, success] = _nodes.emplace(token, std::move(node)); + LUISA_ASSERT(success, "Coroutine node (token = {}) already exists.", token); + return &(iter->second); + } + void designate_state_member(luisa::string name, uint index) noexcept { + auto [iter, success] = _designated_state_members.emplace(name, index); + LUISA_ASSERT(success, "State member '{}' already designated.", name); + } public: - [[nodiscard]] const CoroNode *entry() const noexcept; - [[nodiscard]] const CoroNode *node(uint token) const noexcept; + [[nodiscard]] const CoroNode *entry() const noexcept { + return this->node(_entry); + } + [[nodiscard]] const CoroNode *node(uint token) const noexcept { + auto iter = _nodes.find(token); + LUISA_ASSERT(iter != _nodes.cend(), + "Coroutine node (token = {}) not found.", token); + return &(iter->second); + } [[nodiscard]] auto &nodes() const noexcept { return _nodes; } [[nodiscard]] auto state_type() const noexcept { return _state_type; } [[nodiscard]] auto &designated_state_members() const noexcept { return _designated_state_members; } - [[nodiscard]] uint designated_state_member(luisa::string_view name) const noexcept; + [[nodiscard]] uint designated_state_member(luisa::string_view name) const noexcept { + auto iter = _designated_state_members.find(name); + LUISA_ASSERT(iter != _designated_state_members.cend(), + "State member '{}' not designated.", name); + return iter->second; + } }; }// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/coro_node.h b/include/luisa/coro/coro_node.h index cb76522c9..b91d05201 100644 --- a/include/luisa/coro/coro_node.h +++ b/include/luisa/coro/coro_node.h @@ -8,7 +8,7 @@ namespace luisa::compute::inline coro { class CoroGraph; -class LC_CORO_API CoroNode { +class CoroNode { friend class CoroGraph; @@ -20,7 +20,8 @@ class LC_CORO_API CoroNode { Func _function; protected: - CoroNode(const CoroGraph *graph, Func function) noexcept; + CoroNode(const CoroGraph *graph, Func function) noexcept + : _graph{graph}, _function{std::move(function)} {} public: luisa::vector input_state_members; diff --git a/include/luisa/dsl/coro/coro_frame.h b/include/luisa/coro/v2/coro_frame.h similarity index 96% rename from include/luisa/dsl/coro/coro_frame.h rename to include/luisa/coro/v2/coro_frame.h index 3fef791a1..c258b11ff 100644 --- a/include/luisa/dsl/coro/coro_frame.h +++ b/include/luisa/coro/v2/coro_frame.h @@ -5,12 +5,12 @@ #pragma once #include -#include +#include #include namespace luisa::compute::coroutine { -class LC_DSL_API CoroFrame { +class LC_CORO_API CoroFrame { private: luisa::shared_ptr _desc; diff --git a/include/luisa/runtime/coro/coro_frame_desc.h b/include/luisa/coro/v2/coro_frame_desc.h similarity index 91% rename from include/luisa/runtime/coro/coro_frame_desc.h rename to include/luisa/coro/v2/coro_frame_desc.h index 5893c7b38..0c81cbcf8 100644 --- a/include/luisa/runtime/coro/coro_frame_desc.h +++ b/include/luisa/coro/v2/coro_frame_desc.h @@ -11,7 +11,7 @@ namespace luisa::compute::coroutine { -class LC_RUNTIME_API CoroFrameDesc : public luisa::enable_shared_from_this { +class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this { public: using DesignatedFieldDict = luisa::unordered_map; diff --git a/include/luisa/dsl/coro/coro_func.h b/include/luisa/coro/v2/coro_func.h similarity index 97% rename from include/luisa/dsl/coro/coro_func.h rename to include/luisa/coro/v2/coro_func.h index c72ea61b5..5534f0fad 100644 --- a/include/luisa/dsl/coro/coro_func.h +++ b/include/luisa/coro/v2/coro_func.h @@ -4,17 +4,18 @@ #pragma once -#include -#include +#include +#include +#include #include namespace luisa::compute::coroutine { namespace detail { -LC_DSL_API void coroutine_chained_await_impl( +LC_CORO_API void coroutine_chained_await_impl( CoroFrame &frame, uint node_count, luisa::move_only_function node) noexcept; -LC_DSL_API void coroutine_generator_step_impl( +LC_CORO_API void coroutine_generator_step_impl( CoroFrame &frame, uint node_count, bool is_entry, luisa::move_only_function node) noexcept; }// namespace detail diff --git a/include/luisa/runtime/coro/coro_graph.h b/include/luisa/coro/v2/coro_graph.h similarity index 95% rename from include/luisa/runtime/coro/coro_graph.h rename to include/luisa/coro/v2/coro_graph.h index 8ef56956c..3f2b765f9 100644 --- a/include/luisa/runtime/coro/coro_graph.h +++ b/include/luisa/coro/v2/coro_graph.h @@ -8,19 +8,19 @@ #include #include #include -#include +#include namespace luisa::compute::coroutine { class CoroFrameDesc; -class LC_RUNTIME_API CoroGraph { +class LC_CORO_API CoroGraph { public: using CC = luisa::shared_ptr;// current continuation function public: - class LC_RUNTIME_API Node { + class LC_CORO_API Node { private: luisa::vector _input_fields; diff --git a/include/luisa/coro/v2/coro_scheduler.h b/include/luisa/coro/v2/coro_scheduler.h new file mode 100644 index 000000000..336dc4f2e --- /dev/null +++ b/include/luisa/coro/v2/coro_scheduler.h @@ -0,0 +1,87 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include + +namespace luisa::compute::coroutine { + +template +class CoroScheduler; + +namespace detail { + +template +class CoroSchedulerInvoke; + +class CoroSchedulerDispatch : public concepts::Noncopyable { + +private: + luisa::move_only_function _impl; + +private: + template + friend class CoroSchedulerInvoke; + explicit CoroSchedulerDispatch(luisa::move_only_function impl) noexcept + : _impl{std::move(impl)} {} + +public: + void operator()(Stream &stream) && noexcept { _impl(stream); } +}; + +template +class CoroSchedulerInvoke : public concepts::Noncopyable { + +private: + using Scheduler = CoroScheduler; + Scheduler *_scheduler; + std::tuple...> _args; + +private: + friend class Scheduler; + CoroSchedulerInvoke(Scheduler *scheduler, compute::detail::prototype_to_shader_invocation_t... args) noexcept + : _scheduler{scheduler}, _args{args...} {} + +public: + [[nodiscard]] auto dispatch(uint3 size) && noexcept { + return CoroSchedulerDispatch{[s = _scheduler, args = std::move(_args), size](Stream &stream) noexcept { + std::apply( + [s, size, &stream](A &&...a) noexcept { + s->_dispatch(stream, size, std::forward(a)...); + }, + args); + }}; + } + [[nodiscard]] auto dispatch(uint nx, uint ny, uint nz) && noexcept { + return std::move(*this).dispatch(make_uint3(nx, ny, nz)); + } + [[nodiscard]] auto dispatch(uint nx, uint ny) && noexcept { + return std::move(*this).dispatch(make_uint3(nx, ny, 1u)); + } + [[nodiscard]] auto dispatch(uint nx) && noexcept { + return std::move(*this).dispatch(make_uint3(nx, 1u, 1u)); + } +}; + +}// namespace detail + +template +class CoroScheduler { + +private: + friend class detail::CoroSchedulerInvoke; + virtual void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept = 0; + +public: + virtual ~CoroScheduler() noexcept = default; + [[nodiscard]] auto operator()(compute::detail::prototype_to_shader_invocation_t... args) noexcept { + return detail::CoroSchedulerInvoke{this, args...}; + } +}; + +}// namespace luisa::compute::coroutine + +LUISA_MARK_STREAM_EVENT_TYPE(luisa::compute::coroutine::detail::CoroSchedulerDispatch) diff --git a/include/luisa/runtime/coro/coro_token.h b/include/luisa/coro/v2/coro_token.h similarity index 100% rename from include/luisa/runtime/coro/coro_token.h rename to include/luisa/coro/v2/coro_token.h diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h new file mode 100644 index 000000000..16d81e3b6 --- /dev/null +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -0,0 +1,51 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include + +namespace luisa::compute::coroutine { + +namespace detail { +LC_CORO_API void coro_scheduler_state_machine_impl( + CoroFrame &frame, uint state_count, + luisa::move_only_function node) noexcept; +}// namespace detail + +template +class StateMachineCoroScheduler : public CoroScheduler { + +private: + Shader3D _shader; + +private: + [[nodiscard]] static auto _create_shader(Device &device, const Coroutine &coro) noexcept { + Kernel3D kernel = [&coro](Var... args) noexcept { + set_block_size(128u, 1u, 1u); + auto frame = coro.instantiate(dispatch_id()); + detail::coro_scheduler_state_machine_impl( + frame, coro.subroutine_count(), + [&](CoroToken token) noexcept { + coro.subroutine(token)(frame, args...); + }); + }; + return device.compile(kernel); + } + + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + stream << _shader(args...).dispatch(dispatch_size); + } + +public: + StateMachineCoroScheduler(Device &device, const Coroutine &coro) noexcept + : _shader{_create_shader(device, coro)} {} +}; + +template +StateMachineCoroScheduler(Device &, const Coroutine &) -> StateMachineCoroScheduler; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index ea2f7726e..81016f709 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -55,6 +55,13 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #ifdef LUISA_ENABLE_DSL #include @@ -63,8 +70,6 @@ #include #include #include -#include -#include #include #include #include @@ -121,9 +126,6 @@ #include #include #include -#include -#include -#include #include #include #include diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index 5c6645170..b24854a8c 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -1,11 +1,16 @@ set(LUISA_COMPUTE_CORO_SOURCES - shader_scheduler.cpp - coro_dispatcher.cpp + coro_frame.cpp + coro_frame_desc.cpp + coro_func.cpp coro_graph.cpp - coro_node.cpp) + schedulers/state_machine.cpp) add_library(luisa-compute-coro SHARED ${LUISA_COMPUTE_CORO_SOURCES}) -target_link_libraries(luisa-compute-coro PUBLIC luisa-compute-ast luisa-compute-ir) +target_link_libraries(luisa-compute-coro PUBLIC + luisa-compute-ast + luisa-compute-ir + luisa-compute-dsl + luisa-compute-runtime) target_precompile_headers(luisa-compute-coro PRIVATE pch.h) target_compile_definitions(luisa-compute-coro PRIVATE LC_CORO_EXPORT_DLL=1) diff --git a/src/coro/coro_dispatcher.cpp b/src/coro/coro_dispatcher.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dsl/coro/coro_frame.cpp b/src/coro/coro_frame.cpp similarity index 96% rename from src/dsl/coro/coro_frame.cpp rename to src/coro/coro_frame.cpp index d7d9f7aa6..704fed6d0 100644 --- a/src/dsl/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -3,8 +3,8 @@ // #include -#include -#include +#include +#include namespace luisa::compute::coroutine { diff --git a/src/runtime/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp similarity index 98% rename from src/runtime/coro/coro_frame_desc.cpp rename to src/coro/coro_frame_desc.cpp index 96ffd4916..302f08911 100644 --- a/src/runtime/coro/coro_frame_desc.cpp +++ b/src/coro/coro_frame_desc.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::coroutine { diff --git a/src/dsl/coro/coro_func.cpp b/src/coro/coro_func.cpp similarity index 96% rename from src/dsl/coro/coro_func.cpp rename to src/coro/coro_func.cpp index 2fb583bd1..1386da21e 100644 --- a/src/dsl/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::coroutine::detail { diff --git a/src/coro/coro_graph.cpp b/src/coro/coro_graph.cpp index e57371f17..9214e654e 100644 --- a/src/coro/coro_graph.cpp +++ b/src/coro/coro_graph.cpp @@ -1,39 +1,255 @@ +// +// Created by Mike on 2024/5/8. +// + #include -#include +#include +#include +#include +#include + +#ifdef LUISA_ENABLE_IR +#include +#endif + +namespace luisa::compute::coroutine { -namespace luisa::compute::inline coro { +CoroGraph::Node::Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept + : _input_fields{std::move(input_fields)}, + _output_fields{std::move(output_fields)}, + _targets{std::move(targets)}, + _cc{std::move(current_continuation)} {} -const CoroNode *CoroGraph::entry() const noexcept { - return this->node(_entry); +CoroGraph::Node::~Node() noexcept = default; + +Function CoroGraph::Node::cc() const noexcept { return _cc->function(); } + +luisa::string CoroGraph::Node::dump() const noexcept { + luisa::string s; + s.append(" Input Fields: ["); + for (auto i : _input_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_input_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Output Fields: ["); + for (auto i : _output_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_output_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Transition Targets: ["); + for (auto i : _targets) { + s.append(luisa::format("{}, ", i)); + } + if (!_targets.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + return s; } -const CoroNode *CoroGraph::node(uint token) const noexcept { +CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept + : _frame{std::move(frame_desc)}, + _nodes{std::move(nodes)}, + _named_tokens{std::move(named_tokens)} {} + +CoroGraph::~CoroGraph() noexcept = default; + +const CoroGraph::Node &CoroGraph::entry() const noexcept { + return node(coro_token_entry); +} + +const CoroGraph::Node &CoroGraph::node(CoroToken token) const noexcept { auto iter = _nodes.find(token); - LUISA_ASSERT(iter != _nodes.cend(), - "Coroutine node (token = {}) not found.", token); - return &(iter->second); + LUISA_ASSERT(iter != _nodes.end(), + "Coroutine node with token {} not found.", + token); + return iter->second; } -CoroNode *CoroGraph::add_node(uint token, CoroNode::Func f) noexcept { - auto node = CoroNode{this, std::move(f)}; - auto [iter, success] = _nodes.emplace(token, std::move(node)); - LUISA_ASSERT(success, "Coroutine node (token = {}) already exists.", token); - return &(iter->second); +const CoroGraph::Node &CoroGraph::node(luisa::string_view name) const noexcept { + auto iter = _named_tokens.find(name); + LUISA_ASSERT(iter != _named_tokens.end(), + "Coroutine node with name '{}' not found.", + name); + return node(iter->second); } -uint CoroGraph::designated_state_member(luisa::string_view name) const noexcept { - auto iter = _designated_state_members.find(name); - LUISA_ASSERT(iter != _designated_state_members.cend(), - "State member '{}' not designated.", name); - return iter->second; +luisa::string CoroGraph::dump() const noexcept { + luisa::string s; + s.append("Arguments:\n"); + auto args = entry().cc().arguments(); + for (auto i = 0u; i < args.size(); i++) { + s.append(luisa::format(" Argument {}: ", i)); + s.append(args[i].type()->description()); + if (args[i].is_reference()) { s.append(" &"); } + s.append("\n"); + } + s.append("Frame:\n").append(_frame->dump()); + for (auto &&[token, node] : _nodes) { + if (token == coro_token_entry) { + s.append("Entry:\n"); + } else { + s.append(luisa::format("Node {}:\n", token)); + } + s.append(node.dump()); + } + if (!_named_tokens.empty()) { + s.append("Named Tokens:\n"); + for (auto &&[name, token] : _named_tokens) { + s.append(luisa::format(" {} -> \"{}\"\n", token, name)); + } + } + return s; } -CoroGraph::CoroGraph(uint entry, const Type *state_type) noexcept - : _entry{entry}, _state_type{state_type} {} +#ifndef LUISA_ENABLE_IR +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_ERROR_WITH_LOCATION( + "Coroutine requires IR support but " + "LuisaCompute is built without the IR module. " + "This might be caused by missing Rust. " + "Please install the Rust toolchain and " + "recompile LuisaCompute to get the IR module."); +} +#else + +namespace detail { + +static void perform_coroutine_transform(ir::CallableModule *m) noexcept { + auto coroutine_pipeline = ir::luisa_compute_ir_transform_pipeline_new(); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "canonicalize_control_flow"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "demote_locals"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "defer_load"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "extract_loop_cond"); + // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "split_coro"); + ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro_v2"); + auto converted_module = ir::luisa_compute_ir_transform_pipeline_transform_callable(coroutine_pipeline, *m); + ir::luisa_compute_ir_transform_pipeline_destroy(coroutine_pipeline); + *m = converted_module; +} -void CoroGraph::designate_state_member(luisa::string name, uint index) noexcept { - auto [iter, success] = _designated_state_members.emplace(name, index); - LUISA_ASSERT(success, "State member '{}' already designated.", name); +[[nodiscard]] static auto make_subroutine_wrapper(Function coroutine, Function cc) noexcept { + using FB = luisa::compute::detail::FunctionBuilder; + return FB::define_callable([&] { + luisa::vector args; + args.reserve(1u /* frame */ + coroutine.arguments().size()); + LUISA_ASSERT(coroutine.arguments().size() == coroutine.bound_arguments().size(), + "Invalid capture list size (expected {}, got {}).", + coroutine.arguments().size(), coroutine.bound_arguments().size()); + auto fb = FB::current(); + args.emplace_back(fb->reference(cc.arguments().front().type())); + for (auto arg_i = 0u; arg_i < coroutine.arguments().size(); arg_i++) { + auto def_arg = coroutine.arguments()[arg_i]; + auto internal_arg = luisa::visit( + [&](T b) noexcept -> const Expression * { + if constexpr (std::is_same_v) { + return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); + } else if constexpr (std::is_same_v) { + return fb->texture_binding(def_arg.type(), b.handle, b.level); + } else if constexpr (std::is_same_v) { + return fb->bindless_array_binding(b.handle); + } else if constexpr (std::is_same_v) { + return fb->accel_binding(b.handle); + } else { + static_assert(std::is_same_v); + switch (def_arg.tag()) { + case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); + case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); + case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); + case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); + case Variable::Tag::ACCEL: return fb->accel(); + default: /* value argument */ return fb->argument(def_arg.type()); + } + } + }, + coroutine.bound_arguments()[arg_i]); + args.emplace_back(internal_arg); + } + LUISA_ASSERT(cc.return_type() == nullptr, + "Coroutine subroutines should not have return type."); + fb->call(cc, args); + }); } -}// namespace luisa::compute::inline coro +}// namespace detail + +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " + "on function with hash {:016x}.", + coroutine.hash()); + + // convert the coroutine function to IR, transform it, and then convert back + auto m = AST2IR::build_coroutine(coroutine); + detail::perform_coroutine_transform(m->get()); + auto entry = IR2AST::build(m->get()); + + // create the coroutine frame descriptor + auto frame = [m, entry] { + auto underlying = entry->arguments().front().type(); + CoroFrameDesc::DesignatedFieldDict members; + for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, + m->get()->coro_frame_designated_fields.len}) { + auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; + if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } + auto [_, success] = members.try_emplace(name, field.index); + LUISA_ASSERT(success, "Duplicated designated field name '{}' at field {}.", name, field.index); + } + return CoroFrameDesc::create(underlying, std::move(members)); + }(); + + // extract the subroutines + auto subroutines = m->get()->subroutines; + auto subroutine_ids = m->get()->subroutine_ids; + LUISA_ASSERT(subroutines.len == subroutine_ids.len, + "Subroutine count mismatch: {} vs {}.", + subroutines.len, subroutine_ids.len); + luisa::unordered_map nodes; + nodes.reserve(subroutines.len + 1u); + auto convert_fields = [](ir::CBoxedSlice slice) noexcept { + luisa::vector fields; + fields.reserve(slice.len); + for (auto i = 0u; i < slice.len; i++) { fields.emplace_back(slice.ptr[i]); } + return fields; + }; + // add the entry node + nodes.emplace( + coro_token_entry, + Node{convert_fields(m->get()->coro_frame_input_fields), + convert_fields(m->get()->coro_frame_output_fields), + convert_fields(m->get()->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, entry->function())}); + // add subroutine nodes + for (auto i = 0u; i < subroutines.len; i++) { + auto s = subroutines.ptr[i]._0.get(); + auto subroutine = IR2AST::build(s); + auto [_, success] = nodes.try_emplace( + subroutine_ids.ptr[i], + Node{convert_fields(s->coro_frame_input_fields), + convert_fields(s->coro_frame_output_fields), + convert_fields(s->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, subroutine->function())}); + LUISA_ASSERT(success, "Duplicated subroutine token {}.", subroutine_ids.ptr[i]); + } + // create the graph + return luisa::make_shared( + std::move(frame), std::move(nodes), + coroutine.builder()->coro_tokens()); +} + +#endif + +}// namespace luisa::compute::coroutine diff --git a/src/coro/coro_node.cpp b/src/coro/coro_node.cpp deleted file mode 100644 index 499c55d0a..000000000 --- a/src/coro/coro_node.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -namespace luisa::compute::inline coro { - -CoroNode::CoroNode(const CoroGraph *graph, - CoroNode::Func function) noexcept - : _graph{graph}, _function{std::move(function)} {} -}// namespace luisa::compute::inline coro diff --git a/src/coro/schedulers/state_machine.cpp b/src/coro/schedulers/state_machine.cpp new file mode 100644 index 000000000..f46ef5ce0 --- /dev/null +++ b/src/coro/schedulers/state_machine.cpp @@ -0,0 +1,21 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::coroutine::detail { +inline void coro_scheduler_state_machine_impl(CoroFrame &frame, uint state_count, + luisa::move_only_function node) noexcept { + node(coro_token_entry); + $loop { + $switch (frame.target_token) { + for (auto i = 1u; i < state_count; i++) { + $case (i) { node(i); }; + } + $default { $return(); }; + }; + }; +} +}// namespace luisa::compute::coroutine::detail diff --git a/src/coro/shader_scheduler.cpp b/src/coro/shader_scheduler.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dsl/CMakeLists.txt b/src/dsl/CMakeLists.txt index 2895759fb..cffae679a 100644 --- a/src/dsl/CMakeLists.txt +++ b/src/dsl/CMakeLists.txt @@ -10,10 +10,6 @@ if (LUISA_COMPUTE_ENABLE_DSL) set(LUISA_COMPUTE_DSL_RASTER_SOURCES raster/raster_kernel.cpp) - set(LUISA_COMPUTE_DSL_CORO_SOURCES - coro/coro_frame.cpp - coro/coro_func.cpp) - set(LUISA_COMPUTE_DSL_SOURCES builtin.cpp dispatch_indirect.cpp @@ -24,11 +20,10 @@ if (LUISA_COMPUTE_ENABLE_DSL) soa.cpp sugar.cpp ${LUISA_COMPUTE_DSL_RTX_SOURCES} - ${LUISA_COMPUTE_DSL_CORO_SOURCES} ${LUISA_COMPUTE_DSL_RASTER_SOURCES}) add_library(luisa-compute-dsl SHARED ${LUISA_COMPUTE_DSL_SOURCES}) - target_link_libraries(luisa-compute-dsl PUBLIC luisa-compute-ast luisa-compute-runtime luisa-compute-coro) + target_link_libraries(luisa-compute-dsl PUBLIC luisa-compute-ast luisa-compute-runtime) target_compile_definitions(luisa-compute-dsl PRIVATE LC_DSL_EXPORT_DLL=1 PUBLIC LUISA_ENABLE_DSL=1) diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index ccbedd707..8401261d6 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -18,10 +18,6 @@ set(LUISA_COMPUTE_RUNTIME_REMOTE_SOURCES remote/client_interface.cpp remote/server_interface.cpp) -set(LUISA_COMPUTE_RUNTIME_CORO_SOURCES - coro/coro_frame_desc.cpp - coro/coro_graph.cpp) - set(LUISA_COMPUTE_RUNTIME_SOURCES bindless_array.cpp buffer.cpp @@ -42,7 +38,6 @@ set(LUISA_COMPUTE_RUNTIME_SOURCES volume.cpp ${LUISA_COMPUTE_RUNTIME_RHI_SOURCES} ${LUISA_COMPUTE_RUNTIME_RTX_SOURCES} - ${LUISA_COMPUTE_RUNTIME_CORO_SOURCES} ${LUISA_COMPUTE_RUNTIME_RASTER_SOURCES} ${LUISA_COMPUTE_RUNTIME_REMOTE_SOURCES}) diff --git a/src/runtime/coro/coro_graph.cpp b/src/runtime/coro/coro_graph.cpp deleted file mode 100644 index ed98433b9..000000000 --- a/src/runtime/coro/coro_graph.cpp +++ /dev/null @@ -1,255 +0,0 @@ -// -// Created by Mike on 2024/5/8. -// - -#include -#include -#include -#include -#include - -#ifdef LUISA_ENABLE_IR -#include -#endif - -namespace luisa::compute::coroutine { - -CoroGraph::Node::Node(luisa::vector input_fields, - luisa::vector output_fields, - luisa::vector targets, - CC current_continuation) noexcept - : _input_fields{std::move(input_fields)}, - _output_fields{std::move(output_fields)}, - _targets{std::move(targets)}, - _cc{std::move(current_continuation)} {} - -CoroGraph::Node::~Node() noexcept = default; - -Function CoroGraph::Node::cc() const noexcept { return _cc->function(); } - -luisa::string CoroGraph::Node::dump() const noexcept { - luisa::string s; - s.append(" Input Fields: ["); - for (auto i : _input_fields) { - s.append(luisa::format("{}, ", i)); - } - if (!_input_fields.empty()) { - s.pop_back(); - s.pop_back(); - } - s.append("]\n"); - s.append(" Output Fields: ["); - for (auto i : _output_fields) { - s.append(luisa::format("{}, ", i)); - } - if (!_output_fields.empty()) { - s.pop_back(); - s.pop_back(); - } - s.append("]\n"); - s.append(" Transition Targets: ["); - for (auto i : _targets) { - s.append(luisa::format("{}, ", i)); - } - if (!_targets.empty()) { - s.pop_back(); - s.pop_back(); - } - s.append("]\n"); - return s; -} - -CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, - luisa::unordered_map nodes, - luisa::unordered_map named_tokens) noexcept - : _frame{std::move(frame_desc)}, - _nodes{std::move(nodes)}, - _named_tokens{std::move(named_tokens)} {} - -CoroGraph::~CoroGraph() noexcept = default; - -const CoroGraph::Node &CoroGraph::entry() const noexcept { - return node(coro_token_entry); -} - -const CoroGraph::Node &CoroGraph::node(CoroToken token) const noexcept { - auto iter = _nodes.find(token); - LUISA_ASSERT(iter != _nodes.end(), - "Coroutine node with token {} not found.", - token); - return iter->second; -} - -const CoroGraph::Node &CoroGraph::node(luisa::string_view name) const noexcept { - auto iter = _named_tokens.find(name); - LUISA_ASSERT(iter != _named_tokens.end(), - "Coroutine node with name '{}' not found.", - name); - return node(iter->second); -} - -luisa::string CoroGraph::dump() const noexcept { - luisa::string s; - s.append("Arguments:\n"); - auto args = entry().cc().arguments(); - for (auto i = 0u; i < args.size(); i++) { - s.append(luisa::format(" Argument {}: ", i)); - s.append(args[i].type()->description()); - if (args[i].is_reference()) { s.append(" &"); } - s.append("\n"); - } - s.append("Frame:\n").append(_frame->dump()); - for (auto &&[token, node] : _nodes) { - if (token == coro_token_entry) { - s.append("Entry:\n"); - } else { - s.append(luisa::format("Node {}:\n", token)); - } - s.append(node.dump()); - } - if (!_named_tokens.empty()) { - s.append("Named Tokens:\n"); - for (auto &&[name, token] : _named_tokens) { - s.append(luisa::format(" {} -> \"{}\"\n", token, name)); - } - } - return s; -} - -#ifndef LUISA_ENABLE_IR -luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { - LUISA_ERROR_WITH_LOCATION( - "Coroutine requires IR support but " - "LuisaCompute is built without the IR module. " - "This might be caused by missing Rust. " - "Please install the Rust toolchain and " - "recompile LuisaCompute to get the IR module."); -} -#else - -namespace detail { - -static void perform_coroutine_transform(ir::CallableModule *m) noexcept { - auto coroutine_pipeline = ir::luisa_compute_ir_transform_pipeline_new(); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "canonicalize_control_flow"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "demote_locals"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "defer_load"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "extract_loop_cond"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "split_coro"); - ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro_v2"); - auto converted_module = ir::luisa_compute_ir_transform_pipeline_transform_callable(coroutine_pipeline, *m); - ir::luisa_compute_ir_transform_pipeline_destroy(coroutine_pipeline); - *m = converted_module; -} - -[[nodiscard]] static auto make_subroutine_wrapper(Function coroutine, Function cc) noexcept { - using FB = luisa::compute::detail::FunctionBuilder; - return FB::define_callable([&] { - luisa::vector args; - args.reserve(1u /* frame */ + coroutine.arguments().size()); - LUISA_ASSERT(coroutine.arguments().size() == coroutine.bound_arguments().size(), - "Invalid capture list size (expected {}, got {}).", - coroutine.arguments().size(), coroutine.bound_arguments().size()); - auto fb = FB::current(); - args.emplace_back(fb->reference(cc.arguments().front().type())); - for (auto arg_i = 0u; arg_i < coroutine.arguments().size(); arg_i++) { - auto def_arg = coroutine.arguments()[arg_i]; - auto internal_arg = luisa::visit( - [&](T b) noexcept -> const Expression * { - if constexpr (std::is_same_v) { - return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); - } else if constexpr (std::is_same_v) { - return fb->texture_binding(def_arg.type(), b.handle, b.level); - } else if constexpr (std::is_same_v) { - return fb->bindless_array_binding(b.handle); - } else if constexpr (std::is_same_v) { - return fb->accel_binding(b.handle); - } else { - static_assert(std::is_same_v); - switch (def_arg.tag()) { - case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); - case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); - case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); - case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); - case Variable::Tag::ACCEL: return fb->accel(); - default: /* value argument */ return fb->argument(def_arg.type()); - } - } - }, - coroutine.bound_arguments()[arg_i]); - args.emplace_back(internal_arg); - } - LUISA_ASSERT(cc.return_type() == nullptr, - "Coroutine subroutines should not have return type."); - fb->call(cc, args); - }); -} - -}// namespace detail - -luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { - LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " - "on function with hash {:016x}.", - coroutine.hash()); - - // convert the coroutine function to IR, transform it, and then convert back - auto m = AST2IR::build_coroutine(coroutine); - detail::perform_coroutine_transform(m->get()); - auto entry = IR2AST::build(m->get()); - - // create the coroutine frame descriptor - auto frame = [m, entry] { - auto underlying = entry->arguments().front().type(); - CoroFrameDesc::DesignatedFieldDict members; - for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, - m->get()->coro_frame_designated_fields.len}) { - auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; - if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } - auto [_, success] = members.try_emplace(name, field.index); - LUISA_ASSERT(success, "Duplicated designated field name '{}' at field {}.", name, field.index); - } - return CoroFrameDesc::create(underlying, std::move(members)); - }(); - - // extract the subroutines - auto subroutines = m->get()->subroutines; - auto subroutine_ids = m->get()->subroutine_ids; - LUISA_ASSERT(subroutines.len == subroutine_ids.len, - "Subroutine count mismatch: {} vs {}.", - subroutines.len, subroutine_ids.len); - luisa::unordered_map nodes; - nodes.reserve(subroutines.len + 1u); - auto convert_fields = [](ir::CBoxedSlice slice) noexcept { - luisa::vector fields; - fields.reserve(slice.len); - for (auto i = 0u; i < slice.len; i++) { fields.emplace_back(slice.ptr[i]); } - return fields; - }; - // add the entry node - nodes.emplace( - coro_token_entry, - Node{convert_fields(m->get()->coro_frame_input_fields), - convert_fields(m->get()->coro_frame_output_fields), - convert_fields(m->get()->coro_target_tokens), - detail::make_subroutine_wrapper(coroutine, entry->function())}); - // add subroutine nodes - for (auto i = 0u; i < subroutines.len; i++) { - auto s = subroutines.ptr[i]._0.get(); - auto subroutine = IR2AST::build(s); - auto [_, success] = nodes.try_emplace( - subroutine_ids.ptr[i], - Node{convert_fields(s->coro_frame_input_fields), - convert_fields(s->coro_frame_output_fields), - convert_fields(s->coro_target_tokens), - detail::make_subroutine_wrapper(coroutine, subroutine->function())}); - LUISA_ASSERT(success, "Duplicated subroutine token {}.", subroutine_ids.ptr[i]); - } - // create the graph - return luisa::make_shared( - std::move(frame), std::move(nodes), - coroutine.builder()->coro_tokens()); -} - -#endif - -}// namespace luisa::compute::co diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 2a0d76f05..8d9f662ea 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -1,15 +1,12 @@ -#include -#include -#include -#include -#include -#include -#include -#include +#include using namespace luisa; using namespace luisa::compute; +namespace luisa::compute::coroutine { + +}// namespace luisa::compute::coroutine + int main(int argc, char *argv[]) { Context context{argv[0]}; @@ -21,18 +18,25 @@ int main(int argc, char *argv[]) { luisa::vector host_image(image.view().size_bytes()); Kernel1D test = [] { - coroutine::Generator g = [](UInt n) { + coroutine::Generator range = [](UInt n) { auto x = def(0u); $while (x < n) { $yield(x); x += 1u; }; }; - for (auto x : g(100u)) { + for (auto x : range(100u)) { device_log("x = {}", x); } }; - auto shader = device.compile(test); - stream << shader().dispatch(1u) << synchronize(); + coroutine::Coroutine coro = [] { + $for (i, 10u) { + device_log("i = {}", i); + $suspend(); + }; + }; + + coroutine::StateMachineCoroScheduler sched{device, coro}; + stream << sched().dispatch(1u, 1u, 1u) << synchronize(); } diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 8827fe5cf..2b630144a 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -1,19 +1,9 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include +#include #include "../common/cornell_box.h" #define TINYOBJLOADER_IMPLEMENTATION From 394c9a378f77bd7ffb95a128e34802a46d7ade36 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 15:44:04 +0800 Subject: [PATCH 17/67] re-design scheduler interface --- include/luisa/coro/v2/coro_scheduler.h | 9 ++++-- src/tests/coro/path_tracing_v2.cpp | 38 ++++++++------------------ 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/include/luisa/coro/v2/coro_scheduler.h b/include/luisa/coro/v2/coro_scheduler.h index 336dc4f2e..a6565a9ec 100644 --- a/include/luisa/coro/v2/coro_scheduler.h +++ b/include/luisa/coro/v2/coro_scheduler.h @@ -55,13 +55,16 @@ class CoroSchedulerInvoke : public concepts::Noncopyable { }}; } [[nodiscard]] auto dispatch(uint nx, uint ny, uint nz) && noexcept { - return std::move(*this).dispatch(make_uint3(nx, ny, nz)); + return std::move(*this).dispatch(luisa::make_uint3(nx, ny, nz)); } [[nodiscard]] auto dispatch(uint nx, uint ny) && noexcept { - return std::move(*this).dispatch(make_uint3(nx, ny, 1u)); + return std::move(*this).dispatch(luisa::make_uint3(nx, ny, 1u)); + } + [[nodiscard]] auto dispatch(uint2 size) && noexcept { + return std::move(*this).dispatch(luisa::make_uint3(size, 1u)); } [[nodiscard]] auto dispatch(uint nx) && noexcept { - return std::move(*this).dispatch(make_uint3(nx, 1u, 1u)); + return std::move(*this).dispatch(luisa::make_uint3(nx, 1u, 1u)); } }; diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 2b630144a..1a233e9f1 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -271,22 +271,7 @@ int main(int argc, char *argv[]) { raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; - Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { - auto frame = raytracing_coro.instantiate(dispatch_id()); - raytracing_coro.entry()(frame, image, seed_image, accel, resolution); - $loop { - $switch (frame.target_token) { - for (auto i = 1u; i < raytracing_coro.subroutine_count(); i++) { - $case (i) { - raytracing_coro[i](frame, image, seed_image, accel, resolution); - }; - } - $default { - $return(); - }; - }; - }; - }; + coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); @@ -322,17 +307,16 @@ int main(int argc, char *argv[]) { auto clear_shader = device.compile(clear_kernel, o); auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); auto accumulate_shader = device.compile(accumulate_kernel, o); - auto raytracing_shader = device.compile(mega_kernel, o); auto make_sampler_shader = device.compile(make_sampler_kernel, o); static constexpr uint2 resolution = make_uint2(1024u); Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); luisa::vector> host_image(resolution.x * resolution.y); - CommandList cmd_list; + Image seed_image = device.create_image(PixelStorage::INT1, resolution); - cmd_list << clear_shader(accum_image).dispatch(resolution) - << make_sampler_shader(seed_image).dispatch(resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); Window window{"path tracing", resolution}; Swapchain swap_chain = device.create_swapchain( @@ -351,13 +335,13 @@ int main(int argc, char *argv[]) { Clock clock; while (!window.should_close()) { - cmd_list << raytracing_shader(framebuffer, seed_image, accel, resolution) - .dispatch(resolution) - << accumulate_shader(accum_image, framebuffer) - .dispatch(resolution); - cmd_list << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution); - stream << cmd_list.commit() - << swap_chain.present(ldr_image) << synchronize(); + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); window.poll_events(); double dt = clock.toc() - last_time; last_time = clock.toc(); From 1028b67adeaf4d4486b1b69736e48be33faa58b7 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 19:19:04 +0800 Subject: [PATCH 18/67] fix build --- include/luisa/coro/v2/coro_scheduler.h | 2 +- src/ext/imgui | 2 +- src/ext/magic_enum | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/coro/v2/coro_scheduler.h b/include/luisa/coro/v2/coro_scheduler.h index a6565a9ec..e6e9ccde5 100644 --- a/include/luisa/coro/v2/coro_scheduler.h +++ b/include/luisa/coro/v2/coro_scheduler.h @@ -40,7 +40,7 @@ class CoroSchedulerInvoke : public concepts::Noncopyable { std::tuple...> _args; private: - friend class Scheduler; + friend Scheduler; CoroSchedulerInvoke(Scheduler *scheduler, compute::detail::prototype_to_shader_invocation_t... args) noexcept : _scheduler{scheduler}, _args{args...} {} diff --git a/src/ext/imgui b/src/ext/imgui index e391fe2e6..8b2c6dd42 160000 --- a/src/ext/imgui +++ b/src/ext/imgui @@ -1 +1 @@ -Subproject commit e391fe2e66eb1c96b1624ae8444dc64c23146ef4 +Subproject commit 8b2c6dd42fb02dc95a581bb8c7db0199064cfced diff --git a/src/ext/magic_enum b/src/ext/magic_enum index 7afc57b19..f34f967c4 160000 --- a/src/ext/magic_enum +++ b/src/ext/magic_enum @@ -1 +1 @@ -Subproject commit 7afc57b194dd08631d5e96e42b217bb52933828f +Subproject commit f34f967c4e70ec60d1c561fbd2fcabaeedce5957 From 7fb826146a6c2b869be170c163e259f8b8b20674 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 20:53:43 +0800 Subject: [PATCH 19/67] prepare --- include/luisa/coro/v2/coro_frame_buffer.h | 300 ++++++++++++++ include/luisa/coro/v2/coro_func.h | 2 + .../coro/v2/schedulers/persistent_threads.h | 8 + .../luisa/coro/v2/schedulers/state_machine.h | 24 +- include/luisa/coro/v2/schedulers/wavefront.h | 8 + include/luisa/luisa-compute.h | 3 + include/luisa/runtime/device.h | 21 +- src/coro/CMakeLists.txt | 1 + src/coro/coro_frame_buffer.cpp | 14 + src/tests/CMakeLists.txt | 1 + src/tests/coro/path_tracing_wavefront_v2.cpp | 388 ++++++++++++++++++ 11 files changed, 763 insertions(+), 7 deletions(-) create mode 100644 include/luisa/coro/v2/coro_frame_buffer.h create mode 100644 include/luisa/coro/v2/schedulers/persistent_threads.h create mode 100644 include/luisa/coro/v2/schedulers/wavefront.h create mode 100644 src/coro/coro_frame_buffer.cpp create mode 100644 src/tests/coro/path_tracing_wavefront_v2.cpp diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/v2/coro_frame_buffer.h new file mode 100644 index 000000000..d771425eb --- /dev/null +++ b/include/luisa/coro/v2/coro_frame_buffer.h @@ -0,0 +1,300 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include + +namespace luisa::compute { + +namespace detail { + +[[noreturn]] LC_CORO_API void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept; + +template typename B> +class BufferExprProxy> { + +private: + using CoroFrameBuffer = B; + CoroFrameBuffer _buffer; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) + +public: + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_buffer}.read(std::forward(index)); + } + template + requires is_integral_expr_v + void write(I &&index, V &&value) const noexcept { + Expr{_buffer}.write(std::forward(index), std::forward(value)); + } + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_buffer}.device_address(); + } +}; + +}// namespace detail + +template<> +class BufferView { + + friend class lc::validation::Stream; + +private: + luisa::shared_ptr _desc; + void *_native_handle; + uint64_t _handle; + size_t _offset_bytes; + size_t _size; + size_t _total_size; + +private: + friend class Buffer; + friend class SparseBuffer; + + template + friend class BufferView; + +public: + BufferView(luisa::shared_ptr desc, + void *native_handle, uint64_t handle, + size_t offset_bytes, size_t size, size_t total_size) noexcept + : _desc{std::move(desc)}, _native_handle{native_handle}, _handle{handle}, + _offset_bytes{offset_bytes}, _size{size}, _total_size{total_size} { + if (auto a = _desc->type()->alignment(); _offset_bytes % a != 0u) [[unlikely]] { + detail::error_buffer_invalid_alignment(_offset_bytes, a); + } + } + + template typename B> + requires(is_buffer_v>) + BufferView(const B &buffer) noexcept : BufferView{buffer.view()} {} + + BufferView() noexcept : BufferView{nullptr, nullptr, invalid_resource_handle, 0, 0, 0} {} + [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } + + // properties + [[nodiscard]] auto coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _handle; } + [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } + [[nodiscard]] auto stride() const noexcept { return _desc->type()->size(); } + [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto offset() const noexcept { return _offset_bytes / stride(); } + [[nodiscard]] auto offset_bytes() const noexcept { return _offset_bytes; } + [[nodiscard]] auto size_bytes() const noexcept { return _size * stride(); } + + [[nodiscard]] auto original() const noexcept { + return BufferView{_desc, _native_handle, _handle, 0u, _total_size, _total_size}; + } + [[nodiscard]] auto subview(size_t offset_elements, size_t size_elements) const noexcept { + if (size_elements + offset_elements > _size) [[unlikely]] { + detail::error_buffer_subview_overflow(offset_elements, size_elements, _size); + } + return BufferView{_desc, _native_handle, _handle, + _offset_bytes + offset_elements * stride(), + size_elements, _total_size}; + } + // reinterpret cast buffer to another type U + template + requires(!is_custom_struct_v) + [[nodiscard]] auto as() const noexcept { + if (this->size_bytes() < sizeof(U)) [[unlikely]] { + detail::error_buffer_reinterpret_size_too_small(sizeof(U), this->size_bytes()); + } + auto total_size_bytes = _total_size * stride(); + return BufferView{_native_handle, _handle, sizeof(U), _offset_bytes, + this->size_bytes() / sizeof(U), total_size_bytes / sizeof(U)}; + } + // commands + // copy buffer's data to pointer + [[nodiscard]] auto copy_to(void *data) const noexcept { + return luisa::make_unique(_handle, offset_bytes(), size_bytes(), data); + } + // copy pointer's data to buffer + [[nodiscard]] auto copy_from(const void *data) noexcept { + return luisa::make_unique(this->handle(), this->offset_bytes(), this->size_bytes(), data); + } + // copy source buffer's data to buffer + [[nodiscard]] auto copy_from(BufferView source) noexcept { + if (source.size() != this->size()) [[unlikely]] { + detail::error_buffer_copy_sizes_mismatch(source.size(), this->size()); + } + return luisa::make_unique( + source.handle(), this->handle(), + source.offset_bytes(), this->offset_bytes(), + this->size_bytes()); + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +template<> +class Buffer final : public Resource { + +private: + luisa::shared_ptr _desc; + size_t _size{}; + +private: + friend class Device; + friend class ResourceGenerator; + friend class DxCudaInterop; + friend class PinnedMemoryExt; + + Buffer(DeviceInterface *device, + const BufferCreationInfo &info, + luisa::shared_ptr desc) noexcept + : Resource{device, Tag::BUFFER, info}, + _desc{std::move(desc)}, + _size{info.total_size_bytes / info.element_stride} { + if (info.element_stride != _desc->type()->size()) [[unlikely]] { + detail::error_coro_frame_buffer_invalid_element_size( + info.element_stride, _desc->type()->size()); + } + } + + Buffer(DeviceInterface *device, + const luisa::shared_ptr &desc, + size_t size) noexcept + : Buffer{device, + [&] { + if (size == 0) [[unlikely]] { + detail::error_buffer_size_is_zero(); + } + return device->create_buffer(desc->type(), size, nullptr); + }(), + desc} {} + + Buffer(DeviceInterface *device, + const coroutine::CoroFrameDesc *desc, + size_t size) noexcept + : Buffer{device, desc->shared_from_this(), size} {} + +public: + Buffer() noexcept = default; + ~Buffer() noexcept override { + if (*this) { device()->destroy_buffer(handle()); } + } + Buffer(Buffer &&) noexcept = default; + Buffer(Buffer const &) noexcept = delete; + Buffer &operator=(Buffer &&rhs) noexcept { + _move_from(std::move(rhs)); + return *this; + } + Buffer &operator=(Buffer const &) noexcept = delete; + using Resource::operator bool; + // properties + [[nodiscard]] auto coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] auto size() const noexcept { + _check_is_valid(); + return _size; + } + [[nodiscard]] auto stride() const noexcept { + _check_is_valid(); + return _desc->type()->size(); + } + [[nodiscard]] auto size_bytes() const noexcept { + _check_is_valid(); + return _size * stride(); + } + [[nodiscard]] auto view() const noexcept { + _check_is_valid(); + return BufferView{_desc, this->native_handle(), this->handle(), 0u, _size, _size}; + } + [[nodiscard]] auto view(size_t offset, size_t count) const noexcept { + return view().subview(offset, count); + } + // commands + // copy buffer's data to pointer + [[nodiscard]] auto copy_to(void *data) const noexcept { + return this->view().copy_to(data); + } + // copy pointer's data to buffer + [[nodiscard]] auto copy_from(const void *data) noexcept { + return this->view().copy_from(data); + } + // copy source buffer's data to buffer + [[nodiscard]] auto copy_from(BufferView source) noexcept { + return this->view().copy_from(source); + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + _check_is_valid(); + return reinterpret_cast> *>(this); + } +}; + +template<> +struct Expr> { + +private: + luisa::shared_ptr _desc; + const RefExpr *_expression{nullptr}; + +public: + explicit Expr(luisa::shared_ptr desc, + const RefExpr *expr) noexcept + : _desc{std::move(desc)}, _expression{expr} {} + + Expr(BufferView buffer) noexcept + : _desc{buffer.coro_frame()->shared_from_this()}, + _expression{detail::FunctionBuilder::current()->buffer_binding( + Type::buffer(buffer.coro_frame()->type()), buffer.handle(), + buffer.offset_bytes(), buffer.size_bytes())} {} + + /// Contruct from Buffer. Will call buffer_binding() to bind buffer + Expr(const Buffer &buffer) noexcept + : Expr{BufferView{buffer}} {} + + /// Construct from Var>. + Expr(const Var> &buffer) noexcept + : Expr{buffer.coro_frame()->shared_from_this(), buffer.expression()} {} + + /// Construct from Var>. + Expr(const Var> &buffer) noexcept + : Expr{buffer.coro_frame()->shared_from_this(), buffer.expression()} {} + + [[nodiscard]] const coroutine::CoroFrameDesc *coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } + + /// Read buffer at index + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + auto f = detail::FunctionBuilder::current(); + auto expr = f->call(_desc->type(), CallOp::BUFFER_READ, + {_expression, detail::extract_expression(std::forward(index))}); + auto frame = f->local(_desc->type()); + f->assign(frame, expr); + return coroutine::CoroFrame{_desc, frame}; + } + + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &value) const noexcept { + detail::FunctionBuilder::current()->call( + CallOp::BUFFER_WRITE, + {_expression, + detail::extract_expression(std::forward(index)), + value.expression()}); + } + + [[nodiscard]] Expr device_address() const noexcept { + return def(detail::FunctionBuilder::current()->call( + Type::of(), CallOp::BUFFER_ADDRESS, {_expression})); + } + +public: + [[nodiscard]] auto operator->() const noexcept { return this; } +}; + +}// namespace luisa::compute diff --git a/include/luisa/coro/v2/coro_func.h b/include/luisa/coro/v2/coro_func.h index 5534f0fad..c8d969369 100644 --- a/include/luisa/coro/v2/coro_func.h +++ b/include/luisa/coro/v2/coro_func.h @@ -84,6 +84,8 @@ class Coroutine { public: [[nodiscard]] auto graph() const noexcept { return _graph.get(); } [[nodiscard]] auto &shared_graph() const noexcept { return _graph; } + [[nodiscard]] auto frame() const noexcept { return _graph->frame(); } + [[nodiscard]] auto &shared_frame() const noexcept { return _graph->shared_frame(); } public: [[nodiscard]] auto instantiate() const noexcept { return CoroFrame::create(_graph->shared_frame()); } diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h new file mode 100644 index 000000000..a2856c99c --- /dev/null +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -0,0 +1,8 @@ +// +// Created by Mike on 2024/5/10. +// + +#ifndef PERSISTENT_THREADS_H +#define PERSISTENT_THREADS_H + +#endif //PERSISTENT_THREADS_H diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index 16d81e3b6..ed6c5b5c9 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -15,6 +15,10 @@ LC_CORO_API void coro_scheduler_state_machine_impl( luisa::move_only_function node) noexcept; }// namespace detail +struct StateMachineCoroSchedulerConfig { + uint3 block_size = luisa::make_uint3(128, 1, 1); +}; + template class StateMachineCoroScheduler : public CoroScheduler { @@ -22,9 +26,10 @@ class StateMachineCoroScheduler : public CoroScheduler { Shader3D _shader; private: - [[nodiscard]] static auto _create_shader(Device &device, const Coroutine &coro) noexcept { - Kernel3D kernel = [&coro](Var... args) noexcept { - set_block_size(128u, 1u, 1u); + [[nodiscard]] static auto _create_shader(Device &device, const Coroutine &coro, + const StateMachineCoroSchedulerConfig &config) noexcept { + Kernel3D kernel = [&coro, &config](Var... args) noexcept { + set_block_size(config.block_size); auto frame = coro.instantiate(dispatch_id()); detail::coro_scheduler_state_machine_impl( frame, coro.subroutine_count(), @@ -41,11 +46,18 @@ class StateMachineCoroScheduler : public CoroScheduler { } public: - StateMachineCoroScheduler(Device &device, const Coroutine &coro) noexcept - : _shader{_create_shader(device, coro)} {} + StateMachineCoroScheduler(Device &device, const Coroutine &coro, + const StateMachineCoroSchedulerConfig &config = {}) noexcept + : _shader{_create_shader(device, coro, config)} {} }; template -StateMachineCoroScheduler(Device &, const Coroutine &) -> StateMachineCoroScheduler; +StateMachineCoroScheduler(Device &, const Coroutine &) + -> StateMachineCoroScheduler; + +template +StateMachineCoroScheduler(Device &, const Coroutine &, + const StateMachineCoroSchedulerConfig &) + -> StateMachineCoroScheduler; }// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h new file mode 100644 index 000000000..ccc752c27 --- /dev/null +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -0,0 +1,8 @@ +// +// Created by Mike on 2024/5/10. +// + +#ifndef WAVEFRONT_H +#define WAVEFRONT_H + +#endif //WAVEFRONT_H diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 81016f709..673672185 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -56,12 +56,15 @@ #include #include #include +#include #include #include #include #include #include +#include #include +#include #ifdef LUISA_ENABLE_DSL #include diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 4c966f5d9..72aae35dc 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -29,6 +29,11 @@ class SparseBufferHeap; class SparseTextureHeap; class ByteBuffer; +namespace coroutine { +class CoroFrameDesc; +class CoroFrame; +}// namespace coroutine + template class SOA; @@ -222,11 +227,25 @@ class LC_RUNTIME_API Device { [[nodiscard]] ByteBuffer import_external_byte_buffer(void *external_memory, size_t byte_size) noexcept; template - requires(!is_custom_struct_v)//backend-specific type not allowed + requires(!is_custom_struct_v) /* backend-specific types not allowed */ && + (!std::same_as, coroutine::CoroFrame>) [[nodiscard]] auto create_buffer(size_t size) noexcept { return _create>(size); } + template + requires std::same_as, luisa::shared_ptr> || + std::same_as, const coroutine::CoroFrameDesc *> + [[nodiscard]] auto create_coro_frame_buffer(Desc &&desc, size_t size) noexcept { + return _create>(std::forward(desc), size); + } + + template + requires std::same_as, coroutine::CoroFrame> + [[nodiscard]] auto create_buffer(Desc &&desc, size_t size) noexcept { + return create_coro_frame_buffer(std::forward(desc), size); + } + template requires(!is_custom_struct_v)//backend-specific type not allowed [[nodiscard]] auto import_external_buffer(void *external_memory, size_t elem_count) noexcept { diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index b24854a8c..545482e66 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -1,5 +1,6 @@ set(LUISA_COMPUTE_CORO_SOURCES coro_frame.cpp + coro_frame_buffer.cpp coro_frame_desc.cpp coro_func.cpp coro_graph.cpp diff --git a/src/coro/coro_frame_buffer.cpp b/src/coro/coro_frame_buffer.cpp new file mode 100644 index 000000000..83ea5d0f3 --- /dev/null +++ b/src/coro/coro_frame_buffer.cpp @@ -0,0 +1,14 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::detail { +void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept { + LUISA_ERROR( + "Invalid coroutine frame buffer view element size {} (expected {}).", + stride, expected); +} +}// namespace luisa::compute::detail diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index e9b53812f..132e52a9e 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -201,6 +201,7 @@ luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) +luisa_compute_add_executable(test_coro_path_tracing_wavefront_v2 coro/path_tracing_wavefront_v2.cpp) luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp new file mode 100644 index 000000000..9badcba39 --- /dev/null +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -0,0 +1,388 @@ +#include + +#include +#include + +#include +#include "../common/cornell_box.h" + +#define TINYOBJLOADER_IMPLEMENTATION +#include "../common/tiny_obj_loader.h" + +// namespace luisa::compute::coroutine { +// +// template +// class WavefrontCoroScheduler final : public CoroScheduler { +// +// private: +// Device &_device; +// Coroutine _coro; +// Buffer _frame_buffer; +// luisa>> _shaders; +// +// private: +// void _prepare(uint n) noexcept { +// _frame_buffer = _device.create_buffer(_coro.frame(), n); +// } +// +// private: +// void _dispatch(Stream &stream, uint3 dispatch_size, +// compute::detail::prototype_to_shader_invocation_t... args) noexcept override { +// auto s = luisa::make_ulong3(dispatch_size); +// auto n = s.x * s.y * s.z; +// LUISA_ASSERT(n < std::numeric_limits::max(), "Dispatch size is too large."); +// LUISA_ASSERT(n > 0u, "Dispatch size must be greater than zero."); +// if (!_frame_buffer || _frame_buffer.size() < n) { _prepare(n); } +// // generate +// stream << _shaders[0](args...).dispatch(dispatch_size); +// // loop over the subroutines until we found that all of them are done +// +// } +// +// public: +// explicit WavefrontCoroScheduler(Device &device, Coroutine coro) noexcept +// : _device{device}, _coro{std::move(coro)} {} +// }; +// +// }// namespace luisa::compute::coroutine + +using namespace luisa; +using namespace luisa::compute; + +struct Onb { + float3 tangent; + float3 binormal; + float3 normal; +}; + +LUISA_STRUCT(Onb, tangent, binormal, normal) { + [[nodiscard]] Float3 to_world(Expr v) const noexcept { + return v.x * tangent + v.y * binormal + v.z * normal; + } +}; + +int main(int argc, char *argv[]) { + + log_level_verbose(); + + Context context{argv[0]}; + if (argc <= 1) { + LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); + exit(1); + } + Device device = context.create_device(argv[1]); + + // load the Cornell Box scene + tinyobj::ObjReaderConfig obj_reader_config; + obj_reader_config.triangulate = true; + obj_reader_config.vertex_color = false; + tinyobj::ObjReader obj_reader; + if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { + luisa::string_view error_message = "unknown error."; + if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } + LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); + } + if (auto &&e = obj_reader.Warning(); !e.empty()) { + LUISA_WARNING_WITH_LOCATION("{}", e); + } + + auto &&p = obj_reader.GetAttrib().vertices; + luisa::vector vertices; + vertices.reserve(p.size() / 3u); + for (uint i = 0u; i < p.size(); i += 3u) { + vertices.emplace_back(make_float3( + p[i + 0u], p[i + 1u], p[i + 2u])); + } + LUISA_INFO( + "Loaded mesh with {} shape(s) and {} vertices.", + obj_reader.GetShapes().size(), vertices.size()); + + BindlessArray heap = device.create_bindless_array(); + Stream stream = device.create_stream(StreamTag::GRAPHICS); + Buffer vertex_buffer = device.create_buffer(vertices.size()); + stream << vertex_buffer.copy_from(vertices.data()); + luisa::vector meshes; + luisa::vector> triangle_buffers; + for (auto &&shape : obj_reader.GetShapes()) { + uint index = static_cast(meshes.size()); + std::vector const &t = shape.mesh.indices; + uint triangle_count = t.size() / 3u; + LUISA_INFO( + "Processing shape '{}' at index {} with {} triangle(s).", + shape.name, index, triangle_count); + luisa::vector indices; + indices.reserve(t.size()); + for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } + Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); + Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); + heap.emplace_on_update(index, triangle_buffer); + stream << triangle_buffer.copy_from(indices.data()) + << mesh.build(); + } + + Accel accel = device.create_accel({}); + for (Mesh &m : meshes) { + accel.emplace_back(m, make_float4x4(1.0f)); + } + stream << heap.update() + << accel.build() + << synchronize(); + + Constant materials{ + make_float3(0.725f, 0.710f, 0.680f),// floor + make_float3(0.725f, 0.710f, 0.680f),// ceiling + make_float3(0.725f, 0.710f, 0.680f),// back wall + make_float3(0.140f, 0.450f, 0.091f),// right wall + make_float3(0.630f, 0.065f, 0.050f),// left wall + make_float3(0.725f, 0.710f, 0.680f),// short box + make_float3(0.725f, 0.710f, 0.680f),// tall box + make_float3(0.000f, 0.000f, 0.000f),// light + }; + + Callable linear_to_srgb = [&](Var x) noexcept { + return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, + 12.92f * x, + x <= 0.00031308f)); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + UInt s0 = def(0u); + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { + UInt2 p = dispatch_id().xy(); + UInt state = tea(p.x, p.y); + seed_image.write(p, make_uint4(state)); + }; + + Callable lcg = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state & 0x00ffffffu) * + (1.0f / static_cast(0x01000000u)); + }; + + Callable make_onb = [](const Float3 &normal) noexcept { + Float3 binormal = normalize(ite( + abs(normal.x) > abs(normal.z), + make_float3(-normal.y, normal.x, 0.0f), + make_float3(0.0f, -normal.z, normal.y))); + Float3 tangent = normalize(cross(binormal, normal)); + return def(tangent, binormal, normal); + }; + + Callable generate_ray = [](Float2 p) noexcept { + static constexpr float fov = radians(27.8f); + static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); + Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); + Float3 direction = normalize(pixel - origin); + return make_ray(origin, direction); + }; + + Callable cosine_sample_hemisphere = [](Float2 u) noexcept { + Float r = sqrt(u.x); + Float phi = 2.0f * constants::pi * u.y; + return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); + }; + + Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { + return pdf_a / max(pdf_a + pdf_b, 1e-4f); + }; + + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + UInt2 coord = dispatch_id().xy(); + Float frame_size = min(resolution.x, resolution.y).cast(); + UInt state = seed_image.read(coord).x; + Float rx = lcg(state); + Float ry = lcg(state); + Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; + Float3 radiance = def(make_float3(0.0f)); + $suspend("per_spp"); + $for (i, spp_per_dispatch) { + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + $suspend("per_depth"); + $for (depth, 10u) { + // trace + $suspend("before_tracing"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + $suspend("after_tracing"); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; + }; + $break; + }; + + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + }; + + // sample BSDF + $suspend("sample_bsdf"); + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; + }; + }; + $suspend("write_film"); + radiance /= static_cast(spp_per_dispatch); + seed_image.write(coord, make_uint4(state)); + $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; + image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); + }; + + auto coro_buffer = device.create_coro_frame_buffer(coro.frame(), 1024u); + + coroutine::StateMachineCoroScheduler scheduler{device, coro}; + + Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { + UInt2 p = dispatch_id().xy(); + Float4 accum = accum_image.read(p); + Float3 curr = curr_image.read(p).xyz(); + accum_image.write(p, accum + make_float4(curr, 1.f)); + }; + + Callable aces_tonemapping = [](Float3 x) noexcept { + static constexpr float a = 2.51f; + static constexpr float b = 0.03f; + static constexpr float c = 2.43f; + static constexpr float d = 0.59f; + static constexpr float e = 0.14f; + return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); + }; + + Kernel2D clear_kernel = [](ImageFloat image) noexcept { + image.write(dispatch_id().xy(), make_float4(0.0f)); + }; + + Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { + UInt2 coord = dispatch_id().xy(); + Float4 hdr = hdr_image.read(coord); + Float3 ldr = hdr.xyz() / hdr.w * scale; + $if (!is_hdr) { + ldr = linear_to_srgb(ldr); + }; + ldr_image.write(coord, make_float4(ldr, 1.0f)); + }; + + ShaderOption o{.enable_debug_info = false}; + auto clear_shader = device.compile(clear_kernel, o); + auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); + auto accumulate_shader = device.compile(accumulate_kernel, o); + auto make_sampler_shader = device.compile(make_sampler_kernel, o); + + static constexpr uint2 resolution = make_uint2(1024u); + Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); + Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); + luisa::vector> host_image(resolution.x * resolution.y); + + Image seed_image = device.create_image(PixelStorage::INT1, resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); + + Window window{"path tracing", resolution}; + Swapchain swap_chain = device.create_swapchain( + stream, + SwapchainOption{ + .display = window.native_display(), + .window = window.native_handle(), + .size = make_uint2(resolution), + .wants_hdr = false, + .wants_vsync = false, + .back_buffer_count = 3, + }); + Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); + double last_time = 0.0; + uint frame_count = 0u; + Clock clock; + + while (!window.should_close()) { + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); + window.poll_events(); + double dt = clock.toc() - last_time; + last_time = clock.toc(); + frame_count += spp_per_dispatch; + LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", + frame_count, dt, spp_per_dispatch / dt * 1000); + } + stream + << ldr_image.copy_to(host_image.data()) + << synchronize(); + + LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); + stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); +} From 8c5b72eec5ecade3bc581cb3f74b9823186f852f Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 21:46:04 +0800 Subject: [PATCH 20/67] prepare for pthreads schedulers --- include/luisa/coro/v2/coro_frame_buffer.h | 14 +-- include/luisa/coro/v2/coro_frame_smem.h | 109 ++++++++++++++++++ .../luisa/coro/v2/schedulers/state_machine.h | 34 ++++-- include/luisa/luisa-compute.h | 1 + src/tests/coro/path_tracing_v2.cpp | 3 +- 5 files changed, 146 insertions(+), 15 deletions(-) create mode 100644 include/luisa/coro/v2/coro_frame_smem.h diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/v2/coro_frame_buffer.h index d771425eb..a76f51b41 100644 --- a/include/luisa/coro/v2/coro_frame_buffer.h +++ b/include/luisa/coro/v2/coro_frame_buffer.h @@ -81,7 +81,7 @@ class BufferView { [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } // properties - [[nodiscard]] auto coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } [[nodiscard]] auto handle() const noexcept { return _handle; } [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } [[nodiscard]] auto stride() const noexcept { return _desc->type()->size(); } @@ -193,7 +193,7 @@ class Buffer final : public Resource { Buffer &operator=(Buffer const &) noexcept = delete; using Resource::operator bool; // properties - [[nodiscard]] auto coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } [[nodiscard]] auto size() const noexcept { _check_is_valid(); return _size; @@ -246,9 +246,9 @@ struct Expr> { : _desc{std::move(desc)}, _expression{expr} {} Expr(BufferView buffer) noexcept - : _desc{buffer.coro_frame()->shared_from_this()}, + : _desc{buffer.desc()->shared_from_this()}, _expression{detail::FunctionBuilder::current()->buffer_binding( - Type::buffer(buffer.coro_frame()->type()), buffer.handle(), + Type::buffer(buffer.desc()->type()), buffer.handle(), buffer.offset_bytes(), buffer.size_bytes())} {} /// Contruct from Buffer. Will call buffer_binding() to bind buffer @@ -257,13 +257,13 @@ struct Expr> { /// Construct from Var>. Expr(const Var> &buffer) noexcept - : Expr{buffer.coro_frame()->shared_from_this(), buffer.expression()} {} + : Expr{buffer.desc()->shared_from_this(), buffer.expression()} {} /// Construct from Var>. Expr(const Var> &buffer) noexcept - : Expr{buffer.coro_frame()->shared_from_this(), buffer.expression()} {} + : Expr{buffer.desc()->shared_from_this(), buffer.expression()} {} - [[nodiscard]] const coroutine::CoroFrameDesc *coro_frame() const noexcept { return _desc.get(); } + [[nodiscard]] const coroutine::CoroFrameDesc *desc() const noexcept { return _desc.get(); } [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } /// Read buffer at index diff --git a/include/luisa/coro/v2/coro_frame_smem.h b/include/luisa/coro/v2/coro_frame_smem.h new file mode 100644 index 000000000..65569c68e --- /dev/null +++ b/include/luisa/coro/v2/coro_frame_smem.h @@ -0,0 +1,109 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include + +namespace luisa::compute { + +template<> +class Shared { + +private: + luisa::shared_ptr _desc; + luisa::vector _expressions; + size_t _size; + +private: + void _create(bool soa) noexcept { + auto fb = detail::FunctionBuilder::current(); + if (!soa) { + auto s = fb->shared(Type::array(_desc->type(), _size)); + _expressions.emplace_back(s); + } else { + auto fields = _desc->type()->members(); + _expressions.reserve(fields.size()); + for (auto i = 0u; i < fields.size(); i++) { + auto type = i == 0u ? Type::array(Type::of(), 3u) : fields[i]; + auto s = fb->shared(Type::array(type, _size)); + _expressions.emplace_back(s); + } + } + } + +public: + Shared(luisa::shared_ptr desc, + size_t n, bool soa = true) noexcept + : _desc{std::move(desc)}, _size{n} { _create(soa); } + + Shared(Shared &&) noexcept = default; + Shared(const Shared &) noexcept = delete; + Shared &operator=(Shared &&) noexcept = delete; + Shared &operator=(const Shared &) noexcept = delete; + + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } + [[nodiscard]] auto size() const noexcept { return _size; } + + /// Read index + template + [[nodiscard]] auto read(I &&index) const noexcept { + auto i = def(std::forward(index)); + auto fb = detail::FunctionBuilder::current(); + auto frame = fb->local(_desc->type()); + if (!is_soa()) { + auto expr = fb->access(_desc->type(), _expressions[0], i.expression()); + fb->assign(frame, expr); + } else { + auto fields = _desc->type()->members(); + for (auto m = 0u; m < fields.size(); m++) { + auto s = fb->access(fields[m], _expressions[m], i.expression()); + auto f = fb->member(fields[m], frame, m); + if (m == 0u) { + auto t = Type::of(); + std::array elems; + elems[0] = fb->access(t, s, fb->literal(t, 0u)); + elems[1] = fb->access(t, s, fb->literal(t, 1u)); + elems[2] = fb->access(t, s, fb->literal(t, 2u)); + auto v = fb->call(Type::of(), CallOp::MAKE_UINT3, elems); + fb->assign(f, v); + } else { + fb->assign(f, s); + } + } + } + return coroutine::CoroFrame{_desc, frame}; + } + + /// Write index + template + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + auto i = def(std::forward(index)); + auto fb = detail::FunctionBuilder::current(); + if (!is_soa()) { + auto ref = fb->access(_desc->type(), _expressions[0], i.expression()); + fb->assign(ref, frame.expression()); + } else { + auto fields = _desc->type()->members(); + for (auto m = 0u; m < fields.size(); m++) { + auto s = fb->access(fields[m], _expressions[m], i.expression()); + auto f = fb->member(fields[m], frame.expression(), m); + if (m == 0u) { + auto t = Type::of(); + for (auto j = 0u; j < 3u; j++) { + auto fj = fb->swizzle(t, f, 1u, j); + auto sj = fb->access(t, s, fb->literal(t, j)); + fb->assign(sj, fj); + } + } else { + fb->assign(s, f); + } + } + } + } +}; + +}// namespace luisa::compute diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index ed6c5b5c9..ae386437f 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include namespace luisa::compute::coroutine { @@ -17,6 +18,8 @@ LC_CORO_API void coro_scheduler_state_machine_impl( struct StateMachineCoroSchedulerConfig { uint3 block_size = luisa::make_uint3(128, 1, 1); + bool shared_memory = false; + bool shared_memory_soa = true; }; template @@ -30,12 +33,27 @@ class StateMachineCoroScheduler : public CoroScheduler { const StateMachineCoroSchedulerConfig &config) noexcept { Kernel3D kernel = [&coro, &config](Var... args) noexcept { set_block_size(config.block_size); - auto frame = coro.instantiate(dispatch_id()); - detail::coro_scheduler_state_machine_impl( - frame, coro.subroutine_count(), - [&](CoroToken token) noexcept { - coro.subroutine(token)(frame, args...); - }); + if (config.shared_memory) { + auto n = config.block_size.x * config.block_size.y * config.block_size.z; + Shared sm{coro.shared_frame(), n, config.shared_memory_soa}; + auto tid = thread_z() * block_size().x * block_size().y + thread_y() * block_size().x + thread_x(); + auto frame = coro.instantiate(dispatch_id()); + sm.write(tid, frame); + detail::coro_scheduler_state_machine_impl( + frame, coro.subroutine_count(), + [&](CoroToken token) noexcept { + frame = sm.read(tid); + coro.subroutine(token)(frame, args...); + sm.write(tid, frame); + }); + } else { + auto frame = coro.instantiate(dispatch_id()); + detail::coro_scheduler_state_machine_impl( + frame, coro.subroutine_count(), + [&](CoroToken token) noexcept { + coro.subroutine(token)(frame, args...); + }); + } }; return device.compile(kernel); } @@ -47,8 +65,10 @@ class StateMachineCoroScheduler : public CoroScheduler { public: StateMachineCoroScheduler(Device &device, const Coroutine &coro, - const StateMachineCoroSchedulerConfig &config = {}) noexcept + const StateMachineCoroSchedulerConfig &config) noexcept : _shader{_create_shader(device, coro, config)} {} + StateMachineCoroScheduler(Device &device, const Coroutine &coro) noexcept + : StateMachineCoroScheduler{device, coro, StateMachineCoroSchedulerConfig{}} {} }; template diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 673672185..d1eb7371f 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 1a233e9f1..4c0af315b 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -271,7 +271,8 @@ int main(int argc, char *argv[]) { raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; - coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro}; + coroutine::StateMachineCoroSchedulerConfig config{.shared_memory = false, .shared_memory_soa = true}; + coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro, config}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); From a1be882055aea4700e67dfe1d1d9f743fa4958c9 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 22:31:17 +0800 Subject: [PATCH 21/67] fix smem for coro frame --- include/luisa/coro/v2/coro_frame_smem.h | 55 +++++++++++++++---- .../luisa/coro/v2/schedulers/state_machine.h | 17 +++--- src/coro/schedulers/state_machine.cpp | 30 +++++++++- src/tests/coro/path_tracing_v2.cpp | 4 +- 4 files changed, 82 insertions(+), 24 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_smem.h b/include/luisa/coro/v2/coro_frame_smem.h index 65569c68e..0de2d1f92 100644 --- a/include/luisa/coro/v2/coro_frame_smem.h +++ b/include/luisa/coro/v2/coro_frame_smem.h @@ -4,6 +4,8 @@ #pragma once +#include "coro_token.h" + #include #include @@ -18,7 +20,7 @@ class Shared { size_t _size; private: - void _create(bool soa) noexcept { + void _create(bool soa, luisa::span soa_excluded) noexcept { auto fb = detail::FunctionBuilder::current(); if (!soa) { auto s = fb->shared(Type::array(_desc->type(), _size)); @@ -27,17 +29,21 @@ class Shared { auto fields = _desc->type()->members(); _expressions.reserve(fields.size()); for (auto i = 0u; i < fields.size(); i++) { - auto type = i == 0u ? Type::array(Type::of(), 3u) : fields[i]; - auto s = fb->shared(Type::array(type, _size)); - _expressions.emplace_back(s); + if (std::find(soa_excluded.begin(), soa_excluded.end(), i) != soa_excluded.end()) { + _expressions.emplace_back(nullptr); + } else { + auto type = i == 0u ? Type::array(Type::of(), 3u) : fields[i]; + auto s = fb->shared(Type::array(type, _size)); + _expressions.emplace_back(s); + } } } } public: - Shared(luisa::shared_ptr desc, - size_t n, bool soa = true) noexcept - : _desc{std::move(desc)}, _size{n} { _create(soa); } + Shared(luisa::shared_ptr desc, size_t n, + bool soa = true, luisa::span soa_excluded_fields = {}) noexcept + : _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); } Shared(Shared &&) noexcept = default; Shared(const Shared &) noexcept = delete; @@ -48,9 +54,10 @@ class Shared { [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } [[nodiscard]] auto size() const noexcept { return _size; } - /// Read index + +private: template - [[nodiscard]] auto read(I &&index) const noexcept { + [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); @@ -60,10 +67,12 @@ class Shared { } else { auto fields = _desc->type()->members(); for (auto m = 0u; m < fields.size(); m++) { - auto s = fb->access(fields[m], _expressions[m], i.expression()); + if (_expressions[m] == nullptr) { continue; } + if (active_fields && std::find(active_fields->begin(), active_fields->end(), m) == active_fields->end()) { continue; } auto f = fb->member(fields[m], frame, m); if (m == 0u) { auto t = Type::of(); + auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); std::array elems; elems[0] = fb->access(t, s, fb->literal(t, 0u)); elems[1] = fb->access(t, s, fb->literal(t, 1u)); @@ -71,6 +80,7 @@ class Shared { auto v = fb->call(Type::of(), CallOp::MAKE_UINT3, elems); fb->assign(f, v); } else { + auto s = fb->access(fields[m], _expressions[m], i.expression()); fb->assign(f, s); } } @@ -80,7 +90,7 @@ class Shared { /// Write index template - void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); if (!is_soa()) { @@ -89,21 +99,42 @@ class Shared { } else { auto fields = _desc->type()->members(); for (auto m = 0u; m < fields.size(); m++) { - auto s = fb->access(fields[m], _expressions[m], i.expression()); + if (_expressions[m] == nullptr) { continue; } + if (active_fields && std::find(active_fields->begin(), active_fields->end(), m) == active_fields->end()) { continue; } auto f = fb->member(fields[m], frame.expression(), m); if (m == 0u) { auto t = Type::of(); + auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); for (auto j = 0u; j < 3u; j++) { auto fj = fb->swizzle(t, f, 1u, j); auto sj = fb->access(t, s, fb->literal(t, j)); fb->assign(sj, fj); } } else { + auto s = fb->access(fields[m], _expressions[m], i.expression()); fb->assign(s, f); } } } } + +public: + template + [[nodiscard]] auto read(I &&index) const noexcept { + return _read(std::forward(index), luisa::nullopt); + } + template + [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { + return _read(std::forward(index), luisa::make_optional(active_fields)); + } + template + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + _write(std::forward(index), frame, luisa::nullopt); + } + template + void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { + _write(std::forward(index), frame, luisa::make_optional(active_fields)); + } }; }// namespace luisa::compute diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index ae386437f..56324b8f1 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -14,6 +15,9 @@ namespace detail { LC_CORO_API void coro_scheduler_state_machine_impl( CoroFrame &frame, uint state_count, luisa::move_only_function node) noexcept; +LC_CORO_API void coro_scheduler_state_machine_smem_impl( + Shared &smem, const CoroGraph *graph, + luisa::move_only_function node) noexcept; }// namespace detail struct StateMachineCoroSchedulerConfig { @@ -35,16 +39,11 @@ class StateMachineCoroScheduler : public CoroScheduler { set_block_size(config.block_size); if (config.shared_memory) { auto n = config.block_size.x * config.block_size.y * config.block_size.z; - Shared sm{coro.shared_frame(), n, config.shared_memory_soa}; - auto tid = thread_z() * block_size().x * block_size().y + thread_y() * block_size().x + thread_x(); - auto frame = coro.instantiate(dispatch_id()); - sm.write(tid, frame); - detail::coro_scheduler_state_machine_impl( - frame, coro.subroutine_count(), - [&](CoroToken token) noexcept { - frame = sm.read(tid); + Shared sm{coro.shared_frame(), n, config.shared_memory_soa, std::array{0u, 1u}}; + detail::coro_scheduler_state_machine_smem_impl( + sm, coro.graph(), + [&](CoroToken token, CoroFrame &frame) noexcept { coro.subroutine(token)(frame, args...); - sm.write(tid, frame); }); } else { auto frame = coro.instantiate(dispatch_id()); diff --git a/src/coro/schedulers/state_machine.cpp b/src/coro/schedulers/state_machine.cpp index f46ef5ce0..2c639f807 100644 --- a/src/coro/schedulers/state_machine.cpp +++ b/src/coro/schedulers/state_machine.cpp @@ -6,8 +6,9 @@ #include namespace luisa::compute::coroutine::detail { -inline void coro_scheduler_state_machine_impl(CoroFrame &frame, uint state_count, - luisa::move_only_function node) noexcept { + +void coro_scheduler_state_machine_impl(CoroFrame &frame, uint state_count, + luisa::move_only_function node) noexcept { node(coro_token_entry); $loop { $switch (frame.target_token) { @@ -18,4 +19,29 @@ inline void coro_scheduler_state_machine_impl(CoroFrame &frame, uint state_count }; }; } + +void coro_scheduler_state_machine_smem_impl(Shared &smem, const CoroGraph *graph, + luisa::move_only_function node) noexcept { + auto target_token = def(0u); + auto frame = CoroFrame::create(graph->shared_frame(), dispatch_id()); + node(coro_token_entry, frame); + target_token = frame.target_token; + auto tid = thread_z() * block_size().x * block_size().y + thread_y() * block_size().x + thread_x(); + smem.write(tid, frame, graph->entry().output_fields()); + $loop { + $switch (target_token) { + for (auto i = 1u; i < graph->nodes().size(); i++) { + $case (i) { + auto f = smem.read(tid, graph->node(i).input_fields()); + f.coro_id = dispatch_id(); + node(i, f); + target_token = f.target_token; + smem.write(tid, f, graph->node(i).output_fields()); + }; + } + $default { $return(); }; + }; + }; +} + }// namespace luisa::compute::coroutine::detail diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 4c0af315b..a27867a7e 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -271,7 +271,9 @@ int main(int argc, char *argv[]) { raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; - coroutine::StateMachineCoroSchedulerConfig config{.shared_memory = false, .shared_memory_soa = true}; + coroutine::StateMachineCoroSchedulerConfig config{.block_size = make_uint3(64u, 1u, 1u), + .shared_memory = true, + .shared_memory_soa = true}; coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro, config}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { From 3fbc580fa8c7226116a59bf80d61412f7955fefb Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 May 2024 22:36:40 +0800 Subject: [PATCH 22/67] minor --- src/tests/coro/path_tracing_v2.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index a27867a7e..8d09cb1fe 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -271,8 +271,8 @@ int main(int argc, char *argv[]) { raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); }; - coroutine::StateMachineCoroSchedulerConfig config{.block_size = make_uint3(64u, 1u, 1u), - .shared_memory = true, + coroutine::StateMachineCoroSchedulerConfig config{.block_size = make_uint3(8u, 8u, 1u), + .shared_memory = false, .shared_memory_soa = true}; coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro, config}; From e0069ac3b4d05cf0274f3dc3eab5460db8965f92 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sat, 11 May 2024 02:20:25 +0800 Subject: [PATCH 23/67] add `$await` --- include/luisa/coro/v2/coro_func.h | 11 ++++++++++- include/luisa/dsl/sugar.h | 2 ++ src/tests/coro/helloworld_v2.cpp | 14 +++++++++----- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/include/luisa/coro/v2/coro_func.h b/include/luisa/coro/v2/coro_func.h index c8d969369..776484db8 100644 --- a/include/luisa/coro/v2/coro_func.h +++ b/include/luisa/coro/v2/coro_func.h @@ -128,6 +128,15 @@ class Coroutine { } }; +namespace detail { +struct CoroAwaitInvoker { + template + void operator%(A &&awaiter) && noexcept { + std::forward(awaiter).await(); + } +}; +}// namespace detail + template Coroutine(T &&) -> Coroutine>>; @@ -227,4 +236,4 @@ class Generator { } }; -}// namespace luisa::compute::coro_v2 +}// namespace luisa::compute::coroutine diff --git a/include/luisa/dsl/sugar.h b/include/luisa/dsl/sugar.h index 18676c53d..fb36b3a9c 100644 --- a/include/luisa/dsl/sugar.h +++ b/include/luisa/dsl/sugar.h @@ -128,6 +128,8 @@ namespace luisa::compute::dsl_detail { ::luisa::compute::dsl::suspend(); \ } while (false) +#define $await ::luisa::compute::coroutine::detail::CoroAwaitInvoker{} % + #define $loop \ ::luisa::compute::detail::LoopStmtBuilder::create_with_comment( \ ::luisa::compute::dsl_detail::format_source_location(__FILE__, __LINE__)) % \ diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 8d9f662ea..586eb10f7 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -30,13 +30,17 @@ int main(int argc, char *argv[]) { } }; - coroutine::Coroutine coro = [] { - $for (i, 10u) { - device_log("i = {}", i); + coroutine::Coroutine coro = [](UInt n) { + $for (i, n) { + device_log("{} / {}", i, n); $suspend(); }; }; - coroutine::StateMachineCoroScheduler sched{device, coro}; - stream << sched().dispatch(1u, 1u, 1u) << synchronize(); + coroutine::Coroutine awaiter = [&coro] { + $await coro(dispatch_x()); + }; + + coroutine::StateMachineCoroScheduler sched{device, awaiter}; + stream << sched().dispatch(10u) << synchronize(); } From 7755aa6773c6c91774d2954bc9352a43df66d8e9 Mon Sep 17 00:00:00 2001 From: chenxin Date: Sat, 11 May 2024 20:50:18 +0800 Subject: [PATCH 24/67] fix coro_frame is_terminated --- include/luisa/coro/v2/schedulers/wavefront.h | 13 ++++++++++--- src/coro/coro_frame.cpp | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index ccc752c27..2da44e090 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -2,7 +2,14 @@ // Created by Mike on 2024/5/10. // -#ifndef WAVEFRONT_H -#define WAVEFRONT_H +#pragma once -#endif //WAVEFRONT_H +#include +#include +#include + +namespace luisa::compute::coroutine { + + + +} // namespace luisa::compute::coroutine \ No newline at end of file diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 704fed6d0..20b61f485 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -64,7 +64,7 @@ void CoroFrame::_check_member_index(uint index) const noexcept { } Var CoroFrame::is_terminated() const noexcept { - return target_token == coro_token_terminal; + return target_token & coro_token_terminal; } }// namespace luisa::compute::coro_v2 From 40da4d2cafbc4d559b6e1543db47c02660b36c66 Mon Sep 17 00:00:00 2001 From: chenxin Date: Sun, 12 May 2024 00:57:23 +0800 Subject: [PATCH 25/67] minor changes; add wavefrontcoroscheduler (undone) --- .../luisa/coro/v2/schedulers/state_machine.h | 9 ++-- include/luisa/coro/v2/schedulers/wavefront.h | 45 ++++++++++++++++++- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index 56324b8f1..1e5cecc56 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -33,7 +33,7 @@ class StateMachineCoroScheduler : public CoroScheduler { Shader3D _shader; private: - [[nodiscard]] static auto _create_shader(Device &device, const Coroutine &coro, + void _create_shader(Device &device, const Coroutine &coro, const StateMachineCoroSchedulerConfig &config) noexcept { Kernel3D kernel = [&coro, &config](Var... args) noexcept { set_block_size(config.block_size); @@ -54,7 +54,7 @@ class StateMachineCoroScheduler : public CoroScheduler { }); } }; - return device.compile(kernel); + _shader = device.compile(kernel); } void _dispatch(Stream &stream, uint3 dispatch_size, @@ -64,8 +64,9 @@ class StateMachineCoroScheduler : public CoroScheduler { public: StateMachineCoroScheduler(Device &device, const Coroutine &coro, - const StateMachineCoroSchedulerConfig &config) noexcept - : _shader{_create_shader(device, coro, config)} {} + const StateMachineCoroSchedulerConfig &config) noexcept { + _create_shader(device, coro, config); + } StateMachineCoroScheduler(Device &device, const Coroutine &coro) noexcept : StateMachineCoroScheduler{device, coro, StateMachineCoroSchedulerConfig{}} {} }; diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 2da44e090..0f826ee87 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -10,6 +10,49 @@ namespace luisa::compute::coroutine { +struct WavefrontCoroSchedulerConfig { + uint3 block_size = luisa::make_uint3(128, 1, 1); + uint max_instance_count = 2_M; + bool soa = true; + bool sort = true;// use sort for coro token gathering + bool compact = true; + bool debug = false; + uint hint_range = 0xffff'ffff; + luisa::vector hint_fields; +}; +template +class WavefrontCoroScheduler : public CoroScheduler { -} // namespace luisa::compute::coroutine \ No newline at end of file +private: + Shader3D _shader; + +private: + void _create_shader(Device &device, const Coroutine &coro, + const WavefrontCoroSchedulerConfig &config) noexcept { + } + + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + LUISA_ERROR_WITH_LOCATION("Unimplemented"); + } + +public: + WavefrontCoroScheduler(Device &device, const Coroutine &coro, + const WavefrontCoroSchedulerConfig &config) noexcept { + _create_shader(device, coro, config); + } + WavefrontCoroScheduler(Device &device, const Coroutine &coro) noexcept + : WavefrontCoroScheduler{device, coro, WavefrontCoroSchedulerConfig{}} {} +}; + +template +WavefrontCoroScheduler(Device &, const Coroutine &) + -> WavefrontCoroScheduler; + +template +WavefrontCoroScheduler(Device &, const Coroutine &, + const WavefrontCoroSchedulerConfig &) + -> WavefrontCoroScheduler; + +}// namespace luisa::compute::coroutine \ No newline at end of file From a96d5ce5fe563072b82e9381b92d9150a4c6efb3 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sun, 12 May 2024 12:50:15 +0800 Subject: [PATCH 26/67] wip --- src/coro/CMakeLists.txt | 4 +++- src/coro/schedulers/persistent_threads.cpp | 3 +++ src/coro/schedulers/wavefront.cpp | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 src/coro/schedulers/persistent_threads.cpp create mode 100644 src/coro/schedulers/wavefront.cpp diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index 545482e66..86f1c14e6 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -4,7 +4,9 @@ set(LUISA_COMPUTE_CORO_SOURCES coro_frame_desc.cpp coro_func.cpp coro_graph.cpp - schedulers/state_machine.cpp) + schedulers/persistent_threads.cpp + schedulers/state_machine.cpp + schedulers/wavefront.cpp) add_library(luisa-compute-coro SHARED ${LUISA_COMPUTE_CORO_SOURCES}) target_link_libraries(luisa-compute-coro PUBLIC diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp new file mode 100644 index 000000000..cc81a9278 --- /dev/null +++ b/src/coro/schedulers/persistent_threads.cpp @@ -0,0 +1,3 @@ +// +// Created by Mike on 2024/5/10. +// diff --git a/src/coro/schedulers/wavefront.cpp b/src/coro/schedulers/wavefront.cpp new file mode 100644 index 000000000..cc81a9278 --- /dev/null +++ b/src/coro/schedulers/wavefront.cpp @@ -0,0 +1,3 @@ +// +// Created by Mike on 2024/5/10. +// From 0b32699b114369e6385e906028271e26bd3cc839 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sun, 12 May 2024 12:56:48 +0800 Subject: [PATCH 27/67] minor type fix --- src/coro/coro_frame.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 20b61f485..479ea0dbc 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -64,7 +64,7 @@ void CoroFrame::_check_member_index(uint index) const noexcept { } Var CoroFrame::is_terminated() const noexcept { - return target_token & coro_token_terminal; + return (target_token & coro_token_terminal) != 0u; } }// namespace luisa::compute::coro_v2 From e87a1acbcb347d15286d09081e78b20168ae0884 Mon Sep 17 00:00:00 2001 From: chenxin Date: Sun, 12 May 2024 18:58:01 +0800 Subject: [PATCH 28/67] add SOA --- include/luisa/coro/v2/coro_frame_smem.h | 3 +- include/luisa/coro/v2/coro_frame_soa.h | 103 +++++++++++++++++++ include/luisa/coro/v2/schedulers/wavefront.h | 26 ++++- include/luisa/runtime/byte_buffer.h | 7 ++ include/luisa/runtime/device.h | 6 ++ 5 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 include/luisa/coro/v2/coro_frame_soa.h diff --git a/include/luisa/coro/v2/coro_frame_smem.h b/include/luisa/coro/v2/coro_frame_smem.h index 0de2d1f92..b7e70889e 100644 --- a/include/luisa/coro/v2/coro_frame_smem.h +++ b/include/luisa/coro/v2/coro_frame_smem.h @@ -56,6 +56,7 @@ class Shared { private: + /// Read index with active fields template [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); @@ -88,7 +89,7 @@ class Shared { return coroutine::CoroFrame{_desc, frame}; } - /// Write index + /// Write index with active fields template void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h new file mode 100644 index 000000000..98ad3c32d --- /dev/null +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -0,0 +1,103 @@ +// +// Created by ChenXin on 2024/5/12. +// + +#pragma once + +#include "luisa/runtime/device.h" +#include "spdlog/fmt/bundled/compile.h" + +#include +#include +#include + +namespace luisa::compute { + +template<> +class SOA { + +private: + luisa::shared_ptr _desc; + size_t _count{0u}; + ByteBuffer _buffer; + luisa::vector _field_offsets; + +public: + SOA(Device &device, luisa::shared_ptr desc, size_t n) noexcept + : _desc{std::move(desc)}, _count{n} { + auto size_bytes = 0u; + const auto type = _desc->type(); + auto fields = type->members(); + for (auto field : fields) { + auto aligned_offset = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); + _field_offsets.emplace_back(aligned_offset); + size_bytes = aligned_offset; + size_bytes += field->size() * _count; + } + _buffer = device.create_byte_buffer(size_bytes); + } + SOA() noexcept = default; + SOA(const SOA &) = delete; + SOA(SOA &&) noexcept = default; + SOA &operator=(const SOA &) = delete; + SOA &operator=(SOA &&) noexcept = default; + ~SOA() noexcept = default; + +private: + /// Read index with active fields + template + [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto frame = fb->local(_desc->type()); + auto fields = _desc->type()->members(); + for (auto i = 0u; i < fields.size(); i++) { + if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } + auto field = fields[i]; + auto offset = _field_offsets[i]; + auto field_buffer = fb->buffer_binding(field, _buffer.handle(), offset, field->size() * _count); + auto f = fb->member(field, frame, i); + auto s = fb->call( + field, CallOp::BUFFER_READ, + {field_buffer, detail::extract_expression(std::forward(index))}); + fb->assign(f, s); + } + return coroutine::CoroFrame{_desc, frame}; + } + + /// Write index with active fields + template + void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto fields = _desc->type()->members(); + for (auto i = 0u; i < fields.size(); i++) { + if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } + auto field = fields[i]; + auto offset = _field_offsets[i]; + auto field_buffer = fb->buffer_binding(field, _buffer.handle(), offset, field->size() * _count); + auto f = fb->member(field, frame.expression(), i); + auto s = fb->call( + field, CallOp::BUFFER_WRITE, + {field_buffer, detail::extract_expression(std::forward(index)), f}); + } + } + +public: + template + [[nodiscard]] auto read(I &&index) const noexcept { + return _read(std::forward(index), luisa::nullopt); + } + template + [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { + return _read(std::forward(index), luisa::make_optional(active_fields)); + } + template + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + _write(std::forward(index), frame, luisa::nullopt); + } + template + void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { + _write(std::forward(index), frame, luisa::make_optional(active_fields)); + } +}; + +}// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 0f826ee87..f8e21630a 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -25,7 +25,31 @@ template class WavefrontCoroScheduler : public CoroScheduler { private: - Shader3D _shader; + // Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; + // luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; + // Shader1D, Buffer, uint> _count_prefix_shader; + // Shader1D, Buffer, Container, uint> _gather_shader; + // Shader1D, Container, uint> _initialize_shader; + // Shader1D, Container, uint, uint> _compact_shader; + // Shader1D, uint> _clear_shader; + // compute::Buffer _resume_index; + // compute::Buffer _resume_count; + // ///offset calculate from count, will be end after gathering + // compute::Buffer _resume_offset; + // compute::Buffer _global_buffer; + // compute::Buffer _debug_buffer; + // luisa::vector _host_count; + // luisa::vector _host_offset; + // bool _host_empty; + // uint _dispatch_counter; + // uint _max_sub_coro; + // uint _max_frame_count; + // radix_sort::temp_storage _sort_temp_storage; + // radix_sort::instance<> _sort_token; + // radix_sort::instance> _sort_hint; + // luisa::vector _have_hint; + // compute::Buffer _temp_key[2]; + // compute::Buffer _temp_index; private: void _create_shader(Device &device, const Coroutine &coro, diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 146f70009..c0a3b1c1f 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -30,6 +30,13 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { return *this; } ByteBuffer &operator=(ByteBuffer const &) noexcept = delete; + [[nodiscard]] auto view() const noexcept { + _check_is_valid(); + return BufferView{this->native_handle(), this->handle(), 1u, 0u, _size_bytes, _size_bytes}; + } + [[nodiscard]] auto view(size_t offset, size_t count) const noexcept { + return view().subview(offset, count); + } using Resource::operator bool; [[nodiscard]] auto copy_to(void *data) const noexcept { _check_is_valid(); diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 72aae35dc..298e86bbd 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -37,6 +37,8 @@ class CoroFrame; template class SOA; +class SOA; + template class Buffer; @@ -258,6 +260,10 @@ class LC_RUNTIME_API Device { return SOA{*this, size}; } + [[nodiscard]] auto create_soa(luisa::shared_ptr desc, size_t size) noexcept { + return SOA{*this, desc, size}; + } + template requires(!is_custom_struct_v)//backend-specific type not allowed [[nodiscard]] auto create_sparse_buffer(size_t size) noexcept { From 9c296ff759b3988f0c397edddef8afcc0a0889fd Mon Sep 17 00:00:00 2001 From: chenxin Date: Sun, 12 May 2024 23:46:47 +0800 Subject: [PATCH 29/67] checkpoint --- include/luisa/coro/v2/coro_frame_smem.h | 25 +-- include/luisa/coro/v2/coro_frame_soa.h | 196 ++++++++++++++++--- include/luisa/coro/v2/schedulers/wavefront.h | 60 +++--- include/luisa/runtime/byte_buffer.h | 12 +- include/luisa/runtime/device.h | 11 +- 5 files changed, 213 insertions(+), 91 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_smem.h b/include/luisa/coro/v2/coro_frame_smem.h index b7e70889e..08a9b849b 100644 --- a/include/luisa/coro/v2/coro_frame_smem.h +++ b/include/luisa/coro/v2/coro_frame_smem.h @@ -54,11 +54,10 @@ class Shared { [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } [[nodiscard]] auto size() const noexcept { return _size; } - -private: +public: /// Read index with active fields template - [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { + [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); @@ -91,7 +90,7 @@ class Shared { /// Write index with active fields template - void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { + void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields = luisa::nullopt) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); if (!is_soa()) { @@ -118,24 +117,6 @@ class Shared { } } } - -public: - template - [[nodiscard]] auto read(I &&index) const noexcept { - return _read(std::forward(index), luisa::nullopt); - } - template - [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { - return _read(std::forward(index), luisa::make_optional(active_fields)); - } - template - void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { - _write(std::forward(index), frame, luisa::nullopt); - } - template - void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { - _write(std::forward(index), frame, luisa::make_optional(active_fields)); - } }; }// namespace luisa::compute diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 98ad3c32d..55e61c4df 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -11,42 +11,175 @@ #include #include +#include + namespace luisa::compute { +namespace detail { + +template +class CoroFrameSOAExprProxy { + +private: + T _soa; + using CoroFrameSOA = T; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(CoroFrameSOAExprProxy) + +public: + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + return Expr{_soa}.read(std::forward(index), std::move(active_fields)); + } + template + requires is_integral_expr_v + void write(I &&index, V &&value, luisa::optional> active_fields = luisa::nullopt) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + std::move(active_fields)); + } + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_soa}.device_address(); + } +}; + +}// namespace detail + +template<> +class SOAView { + +private: + luisa::shared_ptr _desc; + ByteBuffer *_bufferview{nullptr}; + +private: + friend class SOA; + +public: + SOAView() noexcept = default; + SOAView(luisa::shared_ptr desc, ByteBuffer *_buffer) noexcept + : _desc{std::move(desc)}, + _bufferview{_buffer} {} + ~SOAView() noexcept = default; + + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _bufferview->handle(); } + [[nodiscard]] auto size_bytes() const noexcept { return _bufferview->size_bytes(); } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + template<> class SOA { private: luisa::shared_ptr _desc; - size_t _count{0u}; - ByteBuffer _buffer; + size_t _size{0u}; luisa::vector _field_offsets; + size_t _buffer_element_count{0u}; + ByteBuffer _buffer; -public: - SOA(Device &device, luisa::shared_ptr desc, size_t n) noexcept - : _desc{std::move(desc)}, _count{n} { +private: + [[nodiscard]] auto _init(DeviceInterface *device) noexcept { auto size_bytes = 0u; + size_bytes = 0u; const auto type = _desc->type(); auto fields = type->members(); - for (auto field : fields) { - auto aligned_offset = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); - _field_offsets.emplace_back(aligned_offset); - size_bytes = aligned_offset; - size_bytes += field->size() * _count; + _field_offsets.reserve(fields.size()); + for (const auto field : fields) { + size_bytes = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); + _field_offsets.emplace_back(size_bytes); + size_bytes += field->size() * _size; } - _buffer = device.create_byte_buffer(size_bytes); + _buffer_element_count = (size_bytes + element_type()->size() - 1u) / element_type()->size(); + return device->create_buffer( + element_type(), + _buffer_element_count, + nullptr); + } + +public: + [[nodiscard]] static const Type *element_type() noexcept { + return Type::of(); + } + +public: + SOA(DeviceInterface *device, luisa::shared_ptr desc, size_t n, bool soa) noexcept + : _size{n}, _desc{std::move(desc)} { + auto info = _init(device); + _buffer = std::move(ByteBuffer{device, info}); + LUISA_ASSERT(_buffer.size_bytes() == _buffer_element_count * element_type()->size(), + "CoroFrame SOA Buffer size mismatch."); } SOA() noexcept = default; SOA(const SOA &) = delete; SOA(SOA &&) noexcept = default; SOA &operator=(const SOA &) = delete; - SOA &operator=(SOA &&) noexcept = default; + SOA &operator=(SOA &&x) noexcept { + _desc = std::move(x._desc); + _size = x._size; + _field_offsets = std::move(x._field_offsets); + _buffer_element_count = x._buffer_element_count; + _buffer = std::move(x._buffer); + return *this; + } ~SOA() noexcept = default; + // properties + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto handle() const noexcept { return _buffer.handle(); } + [[nodiscard]] auto size_bytes() const noexcept { return _buffer.size_bytes(); } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +/// Class of Expr> +template<> +struct Expr> { + private: + luisa::shared_ptr _desc; + const RefExpr *_expression{nullptr}; + +public: + /// Construct from CoroFrameDesc and RefExpr + explicit Expr(luisa::shared_ptr desc, + const RefExpr *expr) noexcept + : _desc{std::move(desc)}, _expression{expr} {} + + Expr(SOAView soa_view) noexcept + : _desc{soa_view.desc()->shared_from_this()}, + _expression{detail::FunctionBuilder::current()->buffer_binding( + Type::buffer(SOA::element_type()), soa_view.handle(), + 0u, soa_view.size_bytes())} {} + + /// Construct from SOA. Will call buffer_binding() to bind buffer + Expr(const SOA &soa) noexcept + : _expression{detail::FunctionBuilder::current()->buffer_binding( + SOA::element_type(), soa.handle(), + 0u, soa.size_bytes())} {} + + /// Construct from Var>. + Expr(const Var> &soa) noexcept + : Expr{soa.desc()->shread, soa.expression()} {} + + /// Construct from Var>. + Expr(const Var> &buffer) noexcept + : Expr{buffer.expression()} {} + + /// Return RefExpr + [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } + /// Read index with active fields template - [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { + [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); auto fields = _desc->type()->members(); @@ -54,7 +187,7 @@ class SOA { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field = fields[i]; auto offset = _field_offsets[i]; - auto field_buffer = fb->buffer_binding(field, _buffer.handle(), offset, field->size() * _count); + auto field_buffer = fb->buffer_binding(field, handle(), offset, field->size() * _size);// FIXME auto f = fb->member(field, frame, i); auto s = fb->call( field, CallOp::BUFFER_READ, @@ -66,14 +199,14 @@ class SOA { /// Write index with active fields template - void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { + void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field = fields[i]; auto offset = _field_offsets[i]; - auto field_buffer = fb->buffer_binding(field, _buffer.handle(), offset, field->size() * _count); + auto field_buffer = fb->buffer_binding(field, handle(), offset, field->size() * _size);// FIXME auto f = fb->member(field, frame.expression(), i); auto s = fb->call( field, CallOp::BUFFER_WRITE, @@ -81,23 +214,22 @@ class SOA { } } -public: - template - [[nodiscard]] auto read(I &&index) const noexcept { - return _read(std::forward(index), luisa::nullopt); - } - template - [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { - return _read(std::forward(index), luisa::make_optional(active_fields)); - } - template - void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { - _write(std::forward(index), frame, luisa::nullopt); - } - template - void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { - _write(std::forward(index), frame, luisa::make_optional(active_fields)); + [[nodiscard]] Expr device_address() const noexcept { + return def(detail::FunctionBuilder::current()->call( + Type::of(), CallOp::BUFFER_ADDRESS, {_expression})); } + + /// Self-pointer to unify the interfaces of the captured Buffer and Expr> + [[nodiscard]] auto operator->() const noexcept { return this; } +}; + +/// Class of Var> +template<> +struct Var> : public Expr> { + explicit Var(detail::ArgumentCreation) noexcept + : Expr>{detail::FunctionBuilder::current()->buffer(SOA::element_type())} {} + Var(Var &&) noexcept = default; + Var(const Var &) noexcept = delete; }; }// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index f8e21630a..bc4f186cd 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -4,9 +4,10 @@ #pragma once +#include #include #include -#include +#include namespace luisa::compute::coroutine { @@ -25,35 +26,42 @@ template class WavefrontCoroScheduler : public CoroScheduler { private: - // Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; - // luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; - // Shader1D, Buffer, uint> _count_prefix_shader; - // Shader1D, Buffer, Container, uint> _gather_shader; - // Shader1D, Container, uint> _initialize_shader; - // Shader1D, Container, uint, uint> _compact_shader; - // Shader1D, uint> _clear_shader; - // compute::Buffer _resume_index; - // compute::Buffer _resume_count; - // ///offset calculate from count, will be end after gathering - // compute::Buffer _resume_offset; - // compute::Buffer _global_buffer; - // compute::Buffer _debug_buffer; - // luisa::vector _host_count; - // luisa::vector _host_offset; - // bool _host_empty; - // uint _dispatch_counter; - // uint _max_sub_coro; - // uint _max_frame_count; - // radix_sort::temp_storage _sort_temp_storage; - // radix_sort::instance<> _sort_token; - // radix_sort::instance> _sort_hint; - // luisa::vector _have_hint; - // compute::Buffer _temp_key[2]; - // compute::Buffer _temp_index; + using Container = SOA; + Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; + luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; + Shader1D, Buffer, uint> _count_prefix_shader; + Shader1D, Buffer, Container, uint> _gather_shader; + Shader1D, Container, uint> _initialize_shader; + Shader1D, Container, uint, uint> _compact_shader; + Shader1D, uint> _clear_shader; + Container _frame; + Buffer _resume_index; + Buffer _resume_count; + ///offset calculate from count, will be end after gathering + Buffer _resume_offset; + Buffer _global_buffer; + Buffer _debug_buffer; + luisa::vector _host_count; + luisa::vector _host_offset; + bool _host_empty; + uint _dispatch_counter; + uint _max_sub_coro; + uint _max_frame_count; + radix_sort::temp_storage _sort_temp_storage; + radix_sort::instance<> _sort_token; + radix_sort::instance> _sort_hint; + luisa::vector _have_hint; + Buffer _temp_key[2]; + Buffer _temp_index; private: void _create_shader(Device &device, const Coroutine &coro, const WavefrontCoroSchedulerConfig &config) noexcept { + if (config.soa) { + _frame = device.create_coro_frame_soa(coro.shared_frame(), config.max_instance_count, config.soa); + } else { + + } } void _dispatch(Stream &stream, uint3 dispatch_size, diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index c0a3b1c1f..bf4161b07 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -1,5 +1,6 @@ #pragma once +#include "luisa/coro/v2/coro_frame.h" #include namespace luisa::compute { @@ -8,6 +9,9 @@ namespace detail { class ByteBufferExprProxy; }// namespace detail +template<> +class SOA; + class LC_RUNTIME_API ByteBuffer final : public Resource { private: @@ -16,6 +20,7 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { private: friend class Device; friend class ResourceGenerator; + friend class SOA; ByteBuffer(DeviceInterface *device, const BufferCreationInfo &info) noexcept; ByteBuffer(DeviceInterface *device, size_t size_bytes) noexcept; @@ -30,13 +35,6 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { return *this; } ByteBuffer &operator=(ByteBuffer const &) noexcept = delete; - [[nodiscard]] auto view() const noexcept { - _check_is_valid(); - return BufferView{this->native_handle(), this->handle(), 1u, 0u, _size_bytes, _size_bytes}; - } - [[nodiscard]] auto view(size_t offset, size_t count) const noexcept { - return view().subview(offset, count); - } using Resource::operator bool; [[nodiscard]] auto copy_to(void *data) const noexcept { _check_is_valid(); diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 298e86bbd..49325765b 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -7,6 +7,8 @@ #include #include +#include + namespace luisa { class BinaryIO; }// namespace luisa @@ -37,8 +39,6 @@ class CoroFrame; template class SOA; -class SOA; - template class Buffer; @@ -260,8 +260,11 @@ class LC_RUNTIME_API Device { return SOA{*this, size}; } - [[nodiscard]] auto create_soa(luisa::shared_ptr desc, size_t size) noexcept { - return SOA{*this, desc, size}; + template + requires std::same_as, luisa::shared_ptr> || + std::same_as, const coroutine::CoroFrameDesc *> + [[nodiscard]] auto create_coro_frame_soa(Desc &&desc, size_t size, bool soa) noexcept { + return SOA{impl(), std::forward(desc), size, soa}; } template From 4832794f0d2d0fda1d89ed7eb36a745d154d6d97 Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 02:14:51 +0800 Subject: [PATCH 30/67] add ByteBufferView (bug); SOA TBD --- include/luisa/coro/v2/coro_frame_buffer.h | 55 ++++++++++--------- include/luisa/coro/v2/coro_frame_soa.h | 5 +- include/luisa/dsl/resource.h | 58 ++++++++++++++------ include/luisa/runtime/byte_buffer.h | 64 ++++++++++++++++++++++- 4 files changed, 137 insertions(+), 45 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/v2/coro_frame_buffer.h index a76f51b41..d5619b9b0 100644 --- a/include/luisa/coro/v2/coro_frame_buffer.h +++ b/include/luisa/coro/v2/coro_frame_buffer.h @@ -15,30 +15,7 @@ namespace detail { [[noreturn]] LC_CORO_API void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept; template typename B> -class BufferExprProxy> { - -private: - using CoroFrameBuffer = B; - CoroFrameBuffer _buffer; - -public: - LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) - -public: - template - requires is_integral_expr_v - [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_buffer}.read(std::forward(index)); - } - template - requires is_integral_expr_v - void write(I &&index, V &&value) const noexcept { - Expr{_buffer}.write(std::forward(index), std::forward(value)); - } - [[nodiscard]] Expr device_address() const noexcept { - return Expr{_buffer}.device_address(); - } -}; +class BufferExprProxy>; }// namespace detail @@ -297,4 +274,34 @@ struct Expr> { [[nodiscard]] auto operator->() const noexcept { return this; } }; +namespace detail { + +template typename B> +class BufferExprProxy> { + +private: + using CoroFrameBuffer = B; + CoroFrameBuffer _buffer; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) + +public: + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_buffer}.read(std::forward(index)); + } + template + requires is_integral_expr_v + void write(I &&index, V &&value) const noexcept { + Expr{_buffer}.write(std::forward(index), std::forward(value)); + } + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_buffer}.device_address(); + } +}; + +}// namespace detail + }// namespace luisa::compute diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 55e61c4df..68c4b5f90 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -134,6 +134,7 @@ class SOA { [[nodiscard]] auto size() const noexcept { return _size; } [[nodiscard]] auto handle() const noexcept { return _buffer.handle(); } [[nodiscard]] auto size_bytes() const noexcept { return _buffer.size_bytes(); } + [[nodiscard]] auto view() const noexcept { return SOAView{_desc, &_buffer}; } // DSL interface [[nodiscard]] auto operator->() const noexcept { return reinterpret_cast> *>(this); @@ -162,9 +163,7 @@ struct Expr> { /// Construct from SOA. Will call buffer_binding() to bind buffer Expr(const SOA &soa) noexcept - : _expression{detail::FunctionBuilder::current()->buffer_binding( - SOA::element_type(), soa.handle(), - 0u, soa.size_bytes())} {} + : Expr{SOAView{soa}} {} /// Construct from Var>. Expr(const Var> &soa) noexcept diff --git a/include/luisa/dsl/resource.h b/include/luisa/dsl/resource.h index 7ede08fe1..47924cbaf 100644 --- a/include/luisa/dsl/resource.h +++ b/include/luisa/dsl/resource.h @@ -96,20 +96,39 @@ struct Expr> [[nodiscard]] auto operator->() const noexcept { return this; } }; +/// Same as Expr> +template +struct Expr> : public Expr> { + using Expr>::Expr; +}; + template<> struct Expr { private: const RefExpr *_expression{nullptr}; + public: /// Construct from RefExpr explicit Expr(const RefExpr *expr) noexcept : _expression{expr} {} - /// Construct from BufferView. Will call buffer_binding() to bind buffer - Expr(const ByteBuffer &buffer) noexcept + /// Construct from ByteBufferView. Will call buffer_binding() to bind buffer + Expr(const ByteBufferView &buffer) noexcept : _expression{detail::FunctionBuilder::current()->buffer_binding( Type::of(), buffer.handle(), - 0u, buffer.size_bytes())} {} + buffer.offset(), buffer.size_bytes())} {} + + /// Contruct from ByteBuffer. Will call buffer_binding() to bind buffer + Expr(const ByteBuffer &buffer) noexcept + : Expr{ByteBufferView{buffer}} {} + + /// Construct from Var. + Expr(const Var &buffer) noexcept + : Expr{buffer.expression()} {} + + /// Construct from Var. + Expr(const Var &buffer) noexcept + : Expr{buffer.expression()} {} /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } @@ -139,12 +158,6 @@ struct Expr { } }; -/// Same as Expr> -template -struct Expr> : public Expr> { - using Expr>::Expr; -}; - /// Class of Expr> template struct Expr> { @@ -531,10 +544,11 @@ class BufferExprProxy { } }; +template class ByteBufferExprProxy { private: - ByteBuffer _buffer; + BufferOrExpr _buffer; public: LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(ByteBufferExprProxy) @@ -543,16 +557,16 @@ class ByteBufferExprProxy { template requires is_integral_expr_v [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_buffer}.read(std::forward(index)); + return Expr{_buffer}.read(std::forward(index)); } template requires is_integral_expr_v void write(I &&index, V &&value) const noexcept { - Expr{_buffer}.write(std::forward(index), - std::forward(value)); + Expr{_buffer}.write(std::forward(index), + std::forward(value)); } [[nodiscard]] Expr device_address() const noexcept { - return Expr{_buffer}.device_address(); + return Expr{_buffer}.device_address(); } }; @@ -644,6 +658,14 @@ struct Var> : public Expr> { Var(const Var &) noexcept = delete; }; +template +struct Var> : public Expr> { + explicit Var(detail::ArgumentCreation) noexcept + : Expr>{detail::FunctionBuilder::current()->buffer(Type::of>())} {} + Var(Var &&) noexcept = default; + Var(const Var &) noexcept = delete; +}; + template<> struct Var : public Expr { explicit Var(detail::ArgumentCreation) noexcept @@ -652,10 +674,10 @@ struct Var : public Expr { Var(const Var &) noexcept = delete; }; -template -struct Var> : public Expr> { +template<> +struct Var : public Expr { explicit Var(detail::ArgumentCreation) noexcept - : Expr>{detail::FunctionBuilder::current()->buffer(Type::of>())} {} + : Expr{detail::FunctionBuilder::current()->buffer(Type::of())} {} Var(Var &&) noexcept = default; Var(const Var &) noexcept = delete; }; @@ -704,6 +726,8 @@ struct Var : public Expr { template using BufferVar = Var>; +using ByteBufferVar = Var; + template using ImageVar = Var>; diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index bf4161b07..489687bda 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -6,12 +6,15 @@ namespace luisa::compute { namespace detail { +LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; class ByteBufferExprProxy; }// namespace detail template<> class SOA; +class ByteBufferView; + class LC_RUNTIME_API ByteBuffer final : public Resource { private: @@ -35,6 +38,10 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { return *this; } ByteBuffer &operator=(ByteBuffer const &) noexcept = delete; + [[nodiscard]] auto view() const noexcept { + _check_is_valid(); + return ByteBufferView{this->native_handle(), this->handle(), 0u, size_bytes()}; + } using Resource::operator bool; [[nodiscard]] auto copy_to(void *data) const noexcept { _check_is_valid(); @@ -79,10 +86,65 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { } }; +class ByteBufferView { + +private: + void *_native_handle; + uint64_t _handle; + size_t _offset_bytes; + size_t _size; + size_t _total_size; + +private: + friend class ByteBuffer; + +public: + ByteBufferView(void *native_handle, uint64_t handle, + size_t offset_bytes, + size_t size, size_t total_size) noexcept + : _native_handle{native_handle}, _handle{handle}, _offset_bytes{offset_bytes}, + _size{size}, _total_size{total_size} {} + + ByteBufferView(const ByteBuffer &buffer) noexcept : ByteBufferView{buffer.view()} {} + ByteBufferView() noexcept : ByteBufferView{nullptr, invalid_resource_handle, 0u, 0u, 0u} {} + [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } + + [[nodiscard]] auto handle() const noexcept { return _handle; } + [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } + [[nodiscard]] auto offset() const noexcept { return _offset_bytes; } + [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto size_bytes() const noexcept { return _size; } + [[nodiscard]] auto original() const noexcept { + return ByteBufferView{_native_handle, _handle, 0u, _total_size, _total_size}; + } + [[nodiscard]] auto subview(size_t offset_elements, size_t size_elements) const noexcept { + if (size_elements + offset_elements > _size) [[unlikely]] { + detail::error_buffer_subview_overflow(offset_elements, size_elements, _size); + } + return ByteBufferView{_native_handle, _handle, _offset_bytes + offset_elements, size_elements, _total_size}; + } + // reinterpret cast buffer to another type U + template + requires(!is_custom_struct_v) + [[nodiscard]] auto as() const noexcept { + if (this->size_bytes() < sizeof(U)) [[unlikely]] { + detail::error_buffer_reinterpret_size_too_small(sizeof(U), this->size_bytes()); + } + auto total_size_bytes = _total_size; + return BufferView{_native_handle, _handle, sizeof(U), _offset_bytes, + this->size_bytes() / sizeof(U), total_size_bytes / sizeof(U)}; + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast(this); + } +}; + namespace detail { -LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; + template<> struct is_buffer_impl : std::true_type {}; + }// namespace detail }// namespace luisa::compute From 4ebdb3294e4c1acec95e0b770eb2e365960371cf Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 14:20:18 +0800 Subject: [PATCH 31/67] finish ByteBufferView; SOA TBD --- include/luisa/coro/v2/coro_frame_soa.h | 136 +++++++++++++++---------- include/luisa/dsl/resource.h | 10 +- include/luisa/runtime/byte_buffer.h | 11 +- src/runtime/byte_buffer.cpp | 5 + 4 files changed, 97 insertions(+), 65 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 68c4b5f90..da1c9ef5b 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -30,12 +30,16 @@ class CoroFrameSOAExprProxy { public: template requires is_integral_expr_v - [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + [[nodiscard]] auto read( + I &&index, + luisa::optional &> active_fields = luisa::nullopt) const noexcept { return Expr{_soa}.read(std::forward(index), std::move(active_fields)); } template requires is_integral_expr_v - void write(I &&index, V &&value, luisa::optional> active_fields = luisa::nullopt) const noexcept { + void write( + I &&index, V &&value, + luisa::optional &> active_fields = luisa::nullopt) const noexcept { Expr{_soa}.write(std::forward(index), std::forward(value), std::move(active_fields)); @@ -47,39 +51,13 @@ class CoroFrameSOAExprProxy { }// namespace detail -template<> -class SOAView { - -private: - luisa::shared_ptr _desc; - ByteBuffer *_bufferview{nullptr}; - -private: - friend class SOA; - -public: - SOAView() noexcept = default; - SOAView(luisa::shared_ptr desc, ByteBuffer *_buffer) noexcept - : _desc{std::move(desc)}, - _bufferview{_buffer} {} - ~SOAView() noexcept = default; - - [[nodiscard]] auto desc() const noexcept { return _desc.get(); } - [[nodiscard]] auto handle() const noexcept { return _bufferview->handle(); } - [[nodiscard]] auto size_bytes() const noexcept { return _bufferview->size_bytes(); } - // DSL interface - [[nodiscard]] auto operator->() const noexcept { - return reinterpret_cast> *>(this); - } -}; - template<> class SOA { private: luisa::shared_ptr _desc; size_t _size{0u}; - luisa::vector _field_offsets; + luisa::shared_ptr> _field_offsets; size_t _buffer_element_count{0u}; ByteBuffer _buffer; @@ -89,10 +67,12 @@ class SOA { size_bytes = 0u; const auto type = _desc->type(); auto fields = type->members(); - _field_offsets.reserve(fields.size()); + LUISA_ASSERT(_field_offsets == nullptr, "CoroFrame SOA Buffer already initialized."); + _field_offsets = luisa::make_shared>(); + _field_offsets->reserve(fields.size()); for (const auto field : fields) { size_bytes = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); - _field_offsets.emplace_back(size_bytes); + _field_offsets->emplace_back(size_bytes); size_bytes += field->size() * _size; } _buffer_element_count = (size_bytes + element_type()->size() - 1u) / element_type()->size(); @@ -134,7 +114,40 @@ class SOA { [[nodiscard]] auto size() const noexcept { return _size; } [[nodiscard]] auto handle() const noexcept { return _buffer.handle(); } [[nodiscard]] auto size_bytes() const noexcept { return _buffer.size_bytes(); } - [[nodiscard]] auto view() const noexcept { return SOAView{_desc, &_buffer}; } + [[nodiscard]] auto view() const noexcept { return SOAView{_desc, _buffer.view()}; } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +template<> +class SOAView { +private: + luisa::shared_ptr _desc; + ByteBufferView _buffer_view; + luisa::shared_ptr> _field_offsets; + +private: + friend class SOA; + +public: + SOAView() noexcept = default; + SOAView(luisa::shared_ptr desc, ByteBufferView buffer_view, + luisa::shared_ptr> field_offsets) noexcept + : _desc{std::move(desc)}, + _buffer_view{buffer_view}, + _field_offsets{field_offsets} {} + SOAView(const SOA &soa) noexcept + : SOAView{soa.view()} {} + ~SOAView() noexcept = default; + + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } + [[nodiscard]] auto offset() const noexcept { return _buffer_view.offset(); } + [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } + [[nodiscard]] auto total_size() const noexcept { return _buffer_view.total_size(); } + [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } // DSL interface [[nodiscard]] auto operator->() const noexcept { return reinterpret_cast> *>(this); @@ -147,50 +160,61 @@ struct Expr> { private: luisa::shared_ptr _desc; + luisa::shared_ptr> _field_offsets; const RefExpr *_expression{nullptr}; public: - /// Construct from CoroFrameDesc and RefExpr - explicit Expr(luisa::shared_ptr desc, - const RefExpr *expr) noexcept - : _desc{std::move(desc)}, _expression{expr} {} - - Expr(SOAView soa_view) noexcept - : _desc{soa_view.desc()->shared_from_this()}, - _expression{detail::FunctionBuilder::current()->buffer_binding( + /// Construct from RefExpr + explicit Expr(const RefExpr *expr) noexcept + : _expression{expr} {} + + /// Construct from SOAView. Will call buffer_binding() to bind buffer + Expr(const SOAView &soa_view) noexcept + : _desc{soa_view.desc()->shared_from_this()}, + _expression{detail::FunctionBuilder::current()->buffer_binding( Type::buffer(SOA::element_type()), soa_view.handle(), - 0u, soa_view.size_bytes())} {} + soa_view.offset(), soa_view.size_bytes())}, + _field_offsets{soa_view.field_offsets()} { + LUISA_ASSERT((soa_view.offset() == 0u) && (soa_view.size_bytes() == soa_view.total_size()), + "The ByteBufferView in SOAView must be an original view."); + } /// Construct from SOA. Will call buffer_binding() to bind buffer Expr(const SOA &soa) noexcept : Expr{SOAView{soa}} {} - /// Construct from Var>. - Expr(const Var> &soa) noexcept + /// Construct from Var>. + template + requires std::same_as> + Expr(const T &soa) noexcept : Expr{soa.desc()->shread, soa.expression()} {} - /// Construct from Var>. - Expr(const Var> &buffer) noexcept - : Expr{buffer.expression()} {} + /// Construct from Var>. + template + requires std::same_as>> + Expr(const T &soa_view) noexcept + : Expr{soa_view.expression()} {} /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } /// Read index with active fields template - [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + [[nodiscard]] auto read( + I &&index, + luisa::optional &> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field = fields[i]; - auto offset = _field_offsets[i]; - auto field_buffer = fb->buffer_binding(field, handle(), offset, field->size() * _size);// FIXME + auto offset = _field_offsets->at(i); auto f = fb->member(field, frame, i); + auto offset_var = offset + index * field->size(); auto s = fb->call( field, CallOp::BUFFER_READ, - {field_buffer, detail::extract_expression(std::forward(index))}); + {_expression, detail::extract_expression(offset_var)}); fb->assign(f, s); } return coroutine::CoroFrame{_desc, frame}; @@ -198,18 +222,20 @@ struct Expr> { /// Write index with active fields template - void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields = luisa::nullopt) const noexcept { + void write( + I &&index, const coroutine::CoroFrame &frame, + luisa::optional &> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field = fields[i]; - auto offset = _field_offsets[i]; - auto field_buffer = fb->buffer_binding(field, handle(), offset, field->size() * _size);// FIXME + auto offset = _field_offsets->at(i); + auto offset_var = offset + index * field->size(); auto f = fb->member(field, frame.expression(), i); auto s = fb->call( field, CallOp::BUFFER_WRITE, - {field_buffer, detail::extract_expression(std::forward(index)), f}); + {_expression, detail::extract_expression(offset_var), f}); } } @@ -226,7 +252,7 @@ struct Expr> { template<> struct Var> : public Expr> { explicit Var(detail::ArgumentCreation) noexcept - : Expr>{detail::FunctionBuilder::current()->buffer(SOA::element_type())} {} + : Expr{detail::FunctionBuilder::current()->buffer(SOA::element_type())} {} Var(Var &&) noexcept = default; Var(const Var &) noexcept = delete; }; diff --git a/include/luisa/dsl/resource.h b/include/luisa/dsl/resource.h index 47924cbaf..b7a3a6ae9 100644 --- a/include/luisa/dsl/resource.h +++ b/include/luisa/dsl/resource.h @@ -123,11 +123,15 @@ struct Expr { : Expr{ByteBufferView{buffer}} {} /// Construct from Var. - Expr(const Var &buffer) noexcept + template + requires std::same_as + Expr(const Var &buffer) noexcept : Expr{buffer.expression()} {} - /// Construct from Var. - Expr(const Var &buffer) noexcept + /// Construct from Var. + template + requires std::same_as + Expr(const Var &buffer) noexcept : Expr{buffer.expression()} {} /// Return RefExpr diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 489687bda..6cebb6344 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -10,8 +10,8 @@ LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; class ByteBufferExprProxy; }// namespace detail -template<> -class SOA; +template +class SOA; class ByteBufferView; @@ -38,10 +38,7 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { return *this; } ByteBuffer &operator=(ByteBuffer const &) noexcept = delete; - [[nodiscard]] auto view() const noexcept { - _check_is_valid(); - return ByteBufferView{this->native_handle(), this->handle(), 0u, size_bytes()}; - } + [[nodiscard]] ByteBufferView view() const noexcept; using Resource::operator bool; [[nodiscard]] auto copy_to(void *data) const noexcept { _check_is_valid(); @@ -112,8 +109,8 @@ class ByteBufferView { [[nodiscard]] auto handle() const noexcept { return _handle; } [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } [[nodiscard]] auto offset() const noexcept { return _offset_bytes; } - [[nodiscard]] auto size() const noexcept { return _size; } [[nodiscard]] auto size_bytes() const noexcept { return _size; } + [[nodiscard]] auto total_size() const noexcept { return _total_size; } [[nodiscard]] auto original() const noexcept { return ByteBufferView{_native_handle, _handle, 0u, _total_size, _total_size}; } diff --git a/src/runtime/byte_buffer.cpp b/src/runtime/byte_buffer.cpp index 33307129e..78c775beb 100644 --- a/src/runtime/byte_buffer.cpp +++ b/src/runtime/byte_buffer.cpp @@ -35,6 +35,11 @@ ByteBuffer::~ByteBuffer() noexcept { if (*this) { device()->destroy_buffer(handle()); } } +ByteBufferView ByteBuffer::view() const noexcept { + _check_is_valid(); + return ByteBufferView{this->native_handle(), this->handle(), 0u, size_bytes(), size_bytes()}; +} + ByteBuffer Device::create_byte_buffer(size_t byte_size) noexcept { return ByteBuffer{impl(), byte_size}; } From 0070b606b6afcc6465321abe915abb324d555562 Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 16:46:01 +0800 Subject: [PATCH 32/67] finish SOA, untested --- include/luisa/coro/v2/coro_frame_soa.h | 212 ++++++++++++------------- include/luisa/runtime/byte_buffer.h | 5 +- src/runtime/byte_buffer.cpp | 9 +- 3 files changed, 103 insertions(+), 123 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index da1c9ef5b..7163237e9 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -21,8 +21,7 @@ template class CoroFrameSOAExprProxy { private: - T _soa; - using CoroFrameSOA = T; + T _soa; public: LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(CoroFrameSOAExprProxy) @@ -32,68 +31,107 @@ class CoroFrameSOAExprProxy { requires is_integral_expr_v [[nodiscard]] auto read( I &&index, - luisa::optional &> active_fields = luisa::nullopt) const noexcept { - return Expr{_soa}.read(std::forward(index), std::move(active_fields)); + luisa::optional> active_fields = luisa::nullopt) const noexcept { + return Expr{_soa}.read(std::forward(index), std::move(active_fields)); } template requires is_integral_expr_v void write( I &&index, V &&value, - luisa::optional &> active_fields = luisa::nullopt) const noexcept { - Expr{_soa}.write(std::forward(index), + luisa::optional> active_fields = luisa::nullopt) const noexcept { + Expr{_soa}.write(std::forward(index), std::forward(value), std::move(active_fields)); } [[nodiscard]] Expr device_address() const noexcept { - return Expr{_soa}.device_address(); + return Expr{_soa}.device_address(); } }; }// namespace detail -template<> -class SOA { +template +class SOA; -private: +struct SOABase { +protected: luisa::shared_ptr _desc; - size_t _size{0u}; luisa::shared_ptr> _field_offsets; - size_t _buffer_element_count{0u}; - ByteBuffer _buffer; + uint _range_start{0u}, _size{0u}; + +public: + SOABase() noexcept = default; + SOABase(luisa::shared_ptr desc, + luisa::shared_ptr> field_offsets, + uint range_start, uint size) noexcept + : _desc{std::move(desc)}, + _field_offsets{std::move(field_offsets)}, + _range_start{range_start}, _size{size} {} +}; +template<> +class SOAView : public SOABase { private: - [[nodiscard]] auto _init(DeviceInterface *device) noexcept { + ByteBufferView _buffer_view; + +public: + SOAView() noexcept = default; + SOAView(luisa::shared_ptr desc, ByteBufferView buffer_view, + luisa::shared_ptr> field_offsets, + uint range_start, uint size) noexcept + : SOABase{std::move(desc), field_offsets, range_start, size}, + _buffer_view{buffer_view} { + LUISA_ASSERT((_buffer_view.offset() == 0u) && + (_buffer_view.size_bytes() == _buffer_view.total_size()), + "Invalid buffer view for SOA."); + } + template + requires std::same_as> + SOAView(const T &soa) noexcept + : SOAView{soa.view()} {} + ~SOAView() noexcept = default; + + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } + [[nodiscard]] auto range_start() const noexcept { return _range_start; } + [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } + [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + + + +template<> +class SOA : public SOABase { + +private: + ByteBuffer _buffer; + +public: + SOA(DeviceInterface *device, luisa::shared_ptr desc, uint n, bool soa) noexcept + : SOABase{std::move(desc), luisa::make_shared>(), 0u, n} { auto size_bytes = 0u; size_bytes = 0u; - const auto type = _desc->type(); - auto fields = type->members(); - LUISA_ASSERT(_field_offsets == nullptr, "CoroFrame SOA Buffer already initialized."); - _field_offsets = luisa::make_shared>(); + auto fields = _desc->type()->members(); _field_offsets->reserve(fields.size()); for (const auto field : fields) { size_bytes = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); _field_offsets->emplace_back(size_bytes); + if (field->size() % field->alignment() != 0u) [[unlikely]] { + detail::error_buffer_invalid_alignment(size_bytes + field->size(), field->alignment()); + } size_bytes += field->size() * _size; } - _buffer_element_count = (size_bytes + element_type()->size() - 1u) / element_type()->size(); - return device->create_buffer( - element_type(), - _buffer_element_count, + auto buffer_element_count = (size_bytes + sizeof(uint) - 1u) / sizeof(uint); + auto info = device->create_buffer( + Type::of(), + buffer_element_count, nullptr); - } - -public: - [[nodiscard]] static const Type *element_type() noexcept { - return Type::of(); - } - -public: - SOA(DeviceInterface *device, luisa::shared_ptr desc, size_t n, bool soa) noexcept - : _size{n}, _desc{std::move(desc)} { - auto info = _init(device); _buffer = std::move(ByteBuffer{device, info}); - LUISA_ASSERT(_buffer.size_bytes() == _buffer_element_count * element_type()->size(), - "CoroFrame SOA Buffer size mismatch."); } SOA() noexcept = default; SOA(const SOA &) = delete; @@ -103,7 +141,6 @@ class SOA { _desc = std::move(x._desc); _size = x._size; _field_offsets = std::move(x._field_offsets); - _buffer_element_count = x._buffer_element_count; _buffer = std::move(x._buffer); return *this; } @@ -111,43 +148,14 @@ class SOA { // properties [[nodiscard]] auto desc() const noexcept { return _desc.get(); } - [[nodiscard]] auto size() const noexcept { return _size; } - [[nodiscard]] auto handle() const noexcept { return _buffer.handle(); } - [[nodiscard]] auto size_bytes() const noexcept { return _buffer.size_bytes(); } - [[nodiscard]] auto view() const noexcept { return SOAView{_desc, _buffer.view()}; } - // DSL interface - [[nodiscard]] auto operator->() const noexcept { - return reinterpret_cast> *>(this); + [[nodiscard]] auto view() const noexcept { + return SOAView{ + _desc, + _buffer.view(), + _field_offsets, + 0u, + _size}; } -}; - -template<> -class SOAView { -private: - luisa::shared_ptr _desc; - ByteBufferView _buffer_view; - luisa::shared_ptr> _field_offsets; - -private: - friend class SOA; - -public: - SOAView() noexcept = default; - SOAView(luisa::shared_ptr desc, ByteBufferView buffer_view, - luisa::shared_ptr> field_offsets) noexcept - : _desc{std::move(desc)}, - _buffer_view{buffer_view}, - _field_offsets{field_offsets} {} - SOAView(const SOA &soa) noexcept - : SOAView{soa.view()} {} - ~SOAView() noexcept = default; - - [[nodiscard]] auto desc() const noexcept { return _desc.get(); } - [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } - [[nodiscard]] auto offset() const noexcept { return _buffer_view.offset(); } - [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } - [[nodiscard]] auto total_size() const noexcept { return _buffer_view.total_size(); } - [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } // DSL interface [[nodiscard]] auto operator->() const noexcept { return reinterpret_cast> *>(this); @@ -156,45 +164,24 @@ class SOAView { /// Class of Expr> template<> -struct Expr> { +struct Expr> : SOABase { private: - luisa::shared_ptr _desc; - luisa::shared_ptr> _field_offsets; const RefExpr *_expression{nullptr}; public: - /// Construct from RefExpr - explicit Expr(const RefExpr *expr) noexcept - : _expression{expr} {} - /// Construct from SOAView. Will call buffer_binding() to bind buffer Expr(const SOAView &soa_view) noexcept - : _desc{soa_view.desc()->shared_from_this()}, + : SOABase{soa_view.desc()->shared_from_this(), soa_view.field_offsets(), + soa_view.range_start(), soa_view.size()}, _expression{detail::FunctionBuilder::current()->buffer_binding( - Type::buffer(SOA::element_type()), soa_view.handle(), - soa_view.offset(), soa_view.size_bytes())}, - _field_offsets{soa_view.field_offsets()} { - LUISA_ASSERT((soa_view.offset() == 0u) && (soa_view.size_bytes() == soa_view.total_size()), - "The ByteBufferView in SOAView must be an original view."); - } + Type::buffer(Type::of()), soa_view.handle(), + 0u, soa_view.size_bytes())} {} /// Construct from SOA. Will call buffer_binding() to bind buffer Expr(const SOA &soa) noexcept : Expr{SOAView{soa}} {} - /// Construct from Var>. - template - requires std::same_as> - Expr(const T &soa) noexcept - : Expr{soa.desc()->shread, soa.expression()} {} - - /// Construct from Var>. - template - requires std::same_as>> - Expr(const T &soa_view) noexcept - : Expr{soa_view.expression()} {} - /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } @@ -202,18 +189,18 @@ struct Expr> { template [[nodiscard]] auto read( I &&index, - luisa::optional &> active_fields = luisa::nullopt) const noexcept { + luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } - auto field = fields[i]; + auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto f = fb->member(field, frame, i); - auto offset_var = offset + index * field->size(); + auto f = fb->member(field_type, frame, i); + auto offset_var = offset + index * field_type->size(); auto s = fb->call( - field, CallOp::BUFFER_READ, + field_type, CallOp::BYTE_BUFFER_READ, {_expression, detail::extract_expression(offset_var)}); fb->assign(f, s); } @@ -224,17 +211,17 @@ struct Expr> { template void write( I &&index, const coroutine::CoroFrame &frame, - luisa::optional &> active_fields = luisa::nullopt) const noexcept { + luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } - auto field = fields[i]; + auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + index * field->size(); - auto f = fb->member(field, frame.expression(), i); + auto offset_var = offset + index * field_type->size(); + auto f = fb->member(field_type, frame.expression(), i); auto s = fb->call( - field, CallOp::BUFFER_WRITE, + field_type, CallOp::BYTE_BUFFER_WRITE, {_expression, detail::extract_expression(offset_var), f}); } } @@ -248,13 +235,10 @@ struct Expr> { [[nodiscard]] auto operator->() const noexcept { return this; } }; -/// Class of Var> +/// Class of Expr> template<> -struct Var> : public Expr> { - explicit Var(detail::ArgumentCreation) noexcept - : Expr{detail::FunctionBuilder::current()->buffer(SOA::element_type())} {} - Var(Var &&) noexcept = default; - Var(const Var &) noexcept = delete; +struct Expr> : public Expr> { + using Expr>::Expr; }; }// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 6cebb6344..3e1eb10d0 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -7,6 +7,7 @@ namespace luisa::compute { namespace detail { LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; +template class ByteBufferExprProxy; }// namespace detail @@ -79,7 +80,7 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { // DSL interface [[nodiscard]] auto operator->() const noexcept { _check_is_valid(); - return reinterpret_cast(this); + return reinterpret_cast *>(this); } }; @@ -133,7 +134,7 @@ class ByteBufferView { } // DSL interface [[nodiscard]] auto operator->() const noexcept { - return reinterpret_cast(this); + return reinterpret_cast *>(this); } }; diff --git a/src/runtime/byte_buffer.cpp b/src/runtime/byte_buffer.cpp index 78c775beb..178178f96 100644 --- a/src/runtime/byte_buffer.cpp +++ b/src/runtime/byte_buffer.cpp @@ -25,10 +25,7 @@ ByteBuffer::ByteBuffer(DeviceInterface *device, size_t size_bytes) noexcept if ((size_bytes & 3) != 0) [[unlikely]] { detail::error_buffer_size_not_aligned(4); } - return device->create_buffer( - Type::of(), - (size_bytes + sizeof(uint) - 1u) / sizeof(uint), - nullptr); + return device->create_buffer(Type::of(), size_bytes, nullptr); }()} {} ByteBuffer::~ByteBuffer() noexcept { @@ -45,9 +42,7 @@ ByteBuffer Device::create_byte_buffer(size_t byte_size) noexcept { } ByteBuffer Device::import_external_byte_buffer(void *external_memory, size_t byte_size) noexcept { - auto info = impl()->create_buffer(Type::of(), - (byte_size + sizeof(uint) - 1u) / sizeof(uint), - external_memory); + auto info = impl()->create_buffer(Type::of(), byte_size, external_memory); return ByteBuffer{impl(), info}; } From de7064d494a20467cb16246aec359d5c074200e5 Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 17:13:36 +0800 Subject: [PATCH 33/67] fix SOA, untested --- include/luisa/coro/v2/coro_frame_soa.h | 54 +++++++++++--------- include/luisa/coro/v2/schedulers/wavefront.h | 2 +- include/luisa/runtime/device.h | 4 +- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 7163237e9..8fecd1d8c 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -40,8 +40,8 @@ class CoroFrameSOAExprProxy { I &&index, V &&value, luisa::optional> active_fields = luisa::nullopt) const noexcept { Expr{_soa}.write(std::forward(index), - std::forward(value), - std::move(active_fields)); + std::forward(value), + std::move(active_fields)); } [[nodiscard]] Expr device_address() const noexcept { return Expr{_soa}.device_address(); @@ -56,17 +56,18 @@ class SOA; struct SOABase { protected: luisa::shared_ptr _desc; - luisa::shared_ptr> _field_offsets; - uint _range_start{0u}, _size{0u}; + luisa::shared_ptr> _field_offsets; + size_t _offset_elements{0u}, _size_elements{0u}; public: SOABase() noexcept = default; + SOABase(SOABase &&) noexcept = default; SOABase(luisa::shared_ptr desc, - luisa::shared_ptr> field_offsets, - uint range_start, uint size) noexcept + luisa::shared_ptr> field_offsets, + size_t offset_elements, size_t size_elements) noexcept : _desc{std::move(desc)}, _field_offsets{std::move(field_offsets)}, - _range_start{range_start}, _size{size} {} + _offset_elements{offset_elements}, _size_elements{size_elements} {} }; template<> @@ -77,9 +78,9 @@ class SOAView : public SOABase { public: SOAView() noexcept = default; SOAView(luisa::shared_ptr desc, ByteBufferView buffer_view, - luisa::shared_ptr> field_offsets, - uint range_start, uint size) noexcept - : SOABase{std::move(desc), field_offsets, range_start, size}, + luisa::shared_ptr> field_offsets, + size_t offset_elements, size_t size_elements) noexcept + : SOABase{std::move(desc), std::move(field_offsets), offset_elements, size_elements}, _buffer_view{buffer_view} { LUISA_ASSERT((_buffer_view.offset() == 0u) && (_buffer_view.size_bytes() == _buffer_view.total_size()), @@ -91,10 +92,14 @@ class SOAView : public SOABase { : SOAView{soa.view()} {} ~SOAView() noexcept = default; + [[nodiscard]] auto subview(uint offset_elements, uint size_elements) noexcept { + return SOAView{_desc, _buffer_view.subview(offset_elements, size_elements), + _field_offsets, _offset_elements, _size_elements}; + } [[nodiscard]] auto desc() const noexcept { return _desc.get(); } [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } - [[nodiscard]] auto range_start() const noexcept { return _range_start; } - [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto offset_elements() const noexcept { return _offset_elements; } + [[nodiscard]] auto size_elements() const noexcept { return _size_elements; } [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } // DSL interface @@ -103,8 +108,6 @@ class SOAView : public SOABase { } }; - - template<> class SOA : public SOABase { @@ -112,9 +115,10 @@ class SOA : public SOABase { ByteBuffer _buffer; public: - SOA(DeviceInterface *device, luisa::shared_ptr desc, uint n, bool soa) noexcept - : SOABase{std::move(desc), luisa::make_shared>(), 0u, n} { - auto size_bytes = 0u; + SOA(DeviceInterface *device, luisa::shared_ptr desc, uint n) noexcept + : SOABase{std::move(desc), luisa::make_shared>(), + 0u, n} { + size_t size_bytes = 0u; size_bytes = 0u; auto fields = _desc->type()->members(); _field_offsets->reserve(fields.size()); @@ -124,11 +128,11 @@ class SOA : public SOABase { if (field->size() % field->alignment() != 0u) [[unlikely]] { detail::error_buffer_invalid_alignment(size_bytes + field->size(), field->alignment()); } - size_bytes += field->size() * _size; + size_bytes += field->size() * _size_elements; } - auto buffer_element_count = (size_bytes + sizeof(uint) - 1u) / sizeof(uint); + auto buffer_element_count = (size_bytes + 3u) & ~3u; auto info = device->create_buffer( - Type::of(), + Type::of(), buffer_element_count, nullptr); _buffer = std::move(ByteBuffer{device, info}); @@ -139,7 +143,7 @@ class SOA : public SOABase { SOA &operator=(const SOA &) = delete; SOA &operator=(SOA &&x) noexcept { _desc = std::move(x._desc); - _size = x._size; + _size_elements = x._size_elements; _field_offsets = std::move(x._field_offsets); _buffer = std::move(x._buffer); return *this; @@ -154,7 +158,7 @@ class SOA : public SOABase { _buffer.view(), _field_offsets, 0u, - _size}; + _size_elements}; } // DSL interface [[nodiscard]] auto operator->() const noexcept { @@ -173,7 +177,7 @@ struct Expr> : SOABase { /// Construct from SOAView. Will call buffer_binding() to bind buffer Expr(const SOAView &soa_view) noexcept : SOABase{soa_view.desc()->shared_from_this(), soa_view.field_offsets(), - soa_view.range_start(), soa_view.size()}, + soa_view.offset_elements(), soa_view.size_elements()}, _expression{detail::FunctionBuilder::current()->buffer_binding( Type::buffer(Type::of()), soa_view.handle(), 0u, soa_view.size_bytes())} {} @@ -198,7 +202,7 @@ struct Expr> : SOABase { auto field_type = fields[i]; auto offset = _field_offsets->at(i); auto f = fb->member(field_type, frame, i); - auto offset_var = offset + index * field_type->size(); + auto offset_var = offset + (_offset_elements + index) * field_type->size(); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, {_expression, detail::extract_expression(offset_var)}); @@ -218,7 +222,7 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + index * field_type->size(); + auto offset_var = offset + (_offset_elements + index) * field_type->size(); auto f = fb->member(field_type, frame.expression(), i); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_WRITE, diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index bc4f186cd..78bb727bc 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -58,7 +58,7 @@ class WavefrontCoroScheduler : public CoroScheduler { void _create_shader(Device &device, const Coroutine &coro, const WavefrontCoroSchedulerConfig &config) noexcept { if (config.soa) { - _frame = device.create_coro_frame_soa(coro.shared_frame(), config.max_instance_count, config.soa); + _frame = device.create_coro_frame_soa(coro.shared_frame(), config.max_instance_count); } else { } diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 49325765b..6e157b100 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -263,8 +263,8 @@ class LC_RUNTIME_API Device { template requires std::same_as, luisa::shared_ptr> || std::same_as, const coroutine::CoroFrameDesc *> - [[nodiscard]] auto create_coro_frame_soa(Desc &&desc, size_t size, bool soa) noexcept { - return SOA{impl(), std::forward(desc), size, soa}; + [[nodiscard]] auto create_coro_frame_soa(Desc &&desc, size_t size) noexcept { + return SOA{impl(), std::forward(desc), size}; } template From 1f6c4374924269c68c698b493b92e49c0217f7c7 Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 20:41:44 +0800 Subject: [PATCH 34/67] SOA minor fix, add read_field, untested --- include/luisa/coro/v2/coro_frame_soa.h | 72 ++++++--- .../luisa/coro/v2/schedulers/state_machine.h | 16 +- include/luisa/coro/v2/schedulers/wavefront.h | 138 ++++++++++++++++-- src/ast/type.cpp | 3 + src/coro/coro_frame.cpp | 2 +- 5 files changed, 193 insertions(+), 38 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 8fecd1d8c..cd967926a 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -17,34 +17,54 @@ namespace luisa::compute { namespace detail { -template +template class CoroFrameSOAExprProxy { - private: - T _soa; + SOAOrView _soa; public: LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(CoroFrameSOAExprProxy) public: + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { + return Expr{_soa}.template read_field(std::forward(index), field_index); + } + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { + return read_field(std::forward(index), _soa.desc()->designated_field(name)); + } + template requires is_integral_expr_v - [[nodiscard]] auto read( - I &&index, - luisa::optional> active_fields = luisa::nullopt) const noexcept { - return Expr{_soa}.read(std::forward(index), std::move(active_fields)); + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::nullopt); + } + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); } + template requires is_integral_expr_v - void write( - I &&index, V &&value, - luisa::optional> active_fields = luisa::nullopt) const noexcept { - Expr{_soa}.write(std::forward(index), + void write(I &&index, V &&value) const noexcept { + Expr{_soa}.write(std::forward(index), std::forward(value), - std::move(active_fields)); + luisa::nullopt); } + template + requires is_integral_expr_v + void write(I &&index, V &&value, luisa::span active_fields) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + luisa::make_optional(active_fields)); + } + [[nodiscard]] Expr device_address() const noexcept { - return Expr{_soa}.device_address(); + return Expr{_soa}.device_address(); } }; @@ -189,11 +209,26 @@ struct Expr> : SOABase { /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } + /// Read field at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto field_type = _desc->type()->members()[field_index]; + auto offset = _field_offsets->at(field_index); + auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto f = fb->local(field_type); + auto s = fb->call( + field_type, CallOp::BYTE_BUFFER_READ, + {_expression, detail::extract_expression(offset_var)}); + fb->assign(f, s); + return Var(f); + } + /// Read index with active fields template [[nodiscard]] auto read( - I &&index, - luisa::optional> active_fields = luisa::nullopt) const noexcept { + I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); auto fields = _desc->type()->members(); @@ -201,11 +236,11 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto f = fb->member(field_type, frame, i); auto offset_var = offset + (_offset_elements + index) * field_type->size(); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, {_expression, detail::extract_expression(offset_var)}); + auto f = fb->member(field_type, frame, i); fb->assign(f, s); } return coroutine::CoroFrame{_desc, frame}; @@ -213,9 +248,8 @@ struct Expr> : SOABase { /// Write index with active fields template - void write( - I &&index, const coroutine::CoroFrame &frame, - luisa::optional> active_fields = luisa::nullopt) const noexcept { + void write(I &&index, const coroutine::CoroFrame &frame, + luisa::optional> active_fields = luisa::nullopt) const noexcept { auto fb = detail::FunctionBuilder::current(); auto fields = _desc->type()->members(); for (auto i = 0u; i < fields.size(); i++) { diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index 1e5cecc56..b62825a99 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -33,24 +33,24 @@ class StateMachineCoroScheduler : public CoroScheduler { Shader3D _shader; private: - void _create_shader(Device &device, const Coroutine &coro, + void _create_shader(Device &device, const Coroutine &coroutine, const StateMachineCoroSchedulerConfig &config) noexcept { - Kernel3D kernel = [&coro, &config](Var... args) noexcept { + Kernel3D kernel = [&coroutine, &config](Var... args) noexcept { set_block_size(config.block_size); if (config.shared_memory) { auto n = config.block_size.x * config.block_size.y * config.block_size.z; - Shared sm{coro.shared_frame(), n, config.shared_memory_soa, std::array{0u, 1u}}; + Shared sm{coroutine.shared_frame(), n, config.shared_memory_soa, std::array{0u, 1u}}; detail::coro_scheduler_state_machine_smem_impl( - sm, coro.graph(), + sm, coroutine.graph(), [&](CoroToken token, CoroFrame &frame) noexcept { - coro.subroutine(token)(frame, args...); + coroutine.subroutine(token)(frame, args...); }); } else { - auto frame = coro.instantiate(dispatch_id()); + auto frame = coroutine.instantiate(dispatch_id()); detail::coro_scheduler_state_machine_impl( - frame, coro.subroutine_count(), + frame, coroutine.subroutine_count(), [&](CoroToken token) noexcept { - coro.subroutine(token)(frame, args...); + coroutine.subroutine(token)(frame, args...); }); } }; diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 78bb727bc..8cca49c5c 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace luisa::compute::coroutine { @@ -26,21 +27,20 @@ template class WavefrontCoroScheduler : public CoroScheduler { private: - using Container = SOA; - Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; - luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; + Shader1D, Buffer, uint, uint, uint, Args...> _gen_shader; + luisa::vector, Buffer, uint, Args...>> _resume_shaders; Shader1D, Buffer, uint> _count_prefix_shader; - Shader1D, Buffer, Container, uint> _gather_shader; - Shader1D, Container, uint> _initialize_shader; - Shader1D, Container, uint, uint> _compact_shader; + Shader1D, Buffer, uint> _gather_shader; + Shader1D, uint> _initialize_shader; + Shader1D, uint, uint> _compact_shader; Shader1D, uint> _clear_shader; - Container _frame; + SOA _frame_soa; + Buffer _frame_buffer; Buffer _resume_index; Buffer _resume_count; ///offset calculate from count, will be end after gathering Buffer _resume_offset; Buffer _global_buffer; - Buffer _debug_buffer; luisa::vector _host_count; luisa::vector _host_offset; bool _host_empty; @@ -55,13 +55,131 @@ class WavefrontCoroScheduler : public CoroScheduler { Buffer _temp_index; private: - void _create_shader(Device &device, const Coroutine &coro, + void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { + luisa::shared_ptr desc = coroutine.shared_frame(); if (config.soa) { - _frame = device.create_coro_frame_soa(coro.shared_frame(), config.max_instance_count); + _frame_soa = device.create_coro_frame_soa(coroutine.shared_frame(), config.max_instance_count); } else { + _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), config.max_instance_count); + } + bool use_sort = config.sort || !config.hint_fields.empty(); + _max_sub_coro = coroutine->suspend_count() + 1; + _resume_index = device.create_buffer(_max_frame_count); + if (use_sort) { + _temp_index = device.create_buffer(_max_frame_count); + _temp_key[0] = device.create_buffer(_max_frame_count); + _temp_key[1] = device.create_buffer(_max_frame_count); + } + _resume_count = device.create_buffer(_max_sub_coro); + _resume_offset = device.create_buffer(_max_sub_coro); + _global_buffer = device.create_buffer(1); + _host_empty = true; + _dispatch_counter = 0; + _host_offset.resize(_max_sub_coro); + _host_count.resize(_max_sub_coro); + _have_hint.resize(_max_sub_coro, false); + for (auto &token : config.hint_fields) { + auto id = coroutine->coro_tokens().find(token); + if (id != coroutine->coro_tokens().end()) { + LUISA_ASSERT(id->second < _max_sub_coro, + "coroutine token {} of id {} out of range {}", token, id->second, _max_sub_coro); + _have_hint[id->second] = true; + } else { + LUISA_WARNING("coroutine token {} not found, hint disabled", token); + } + } + for (auto i = 0u; i < _max_sub_coro; i++) { + if (i) { + _host_count[i] = 0; + _host_offset[i] = _max_frame_count; + } else { + _host_count[i] = _max_frame_count; + _host_offset[i] = 0; + } + } + Callable get_coro_token = [&](UInt index) { + $if (index > _max_frame_count) { + device_log("index {} out of range {}", index, _max_frame_count); + }; + if (config.soa) { + return _frame_soa->read_field(index, luisa::string_view("target_token")) & token_mask; + } else { + CoroFrame frame = _frame_buffer->read(index); + return frame.get("target_token") & token_mask; + } + }; + Callable identical = [](UInt index) { + return index; + }; + Callable keep_index = [](UInt index, BufferUInt val) { + return val.read(index); + }; + Callable get_coro_hint = [&](UInt index, BufferUInt val) { + if (!config.hint_fields.empty()) { + auto id = keep_index(index, val); + CoroFrame frame = CoroFrame::create(desc); + if (config.soa) { + frame = _frame_soa->read(id, std::array{desc->designated_field("coro_hint")}); + } else { + frame = _frame_buffer->read(id); + } + return frame.get("coro_hint"); + } + return def(0u); + }; + if (use_sort) { + _sort_temp_storage = radix_sort::temp_storage( + device, _max_frame_count, std::max(std::min(config.hint_range, 128u), _max_sub_coro)); } + if (config.sort) { + _sort_token = radix_sort::instance<>( + device, _max_frame_count, _sort_temp_storage, &get_coro_token, &identical, + &get_coro_token, 1, _max_sub_coro); + } + if (!config.hint_fields.empty()) { + if (config.hint_range <= 128) { + _sort_hint = radix_sort::instance>( + device, _max_frame_count, _sort_temp_storage, &get_coro_hint, &keep_index, + &get_coro_hint, 1, config.hint_range); + } else { + auto highbit = 0; + while ((config.hint_range >> highbit) != 1) { + highbit++; + } + _sort_hint = radix_sort::instance>( + device, _max_frame_count, _sort_temp_storage, &get_coro_hint, &keep_index, + &get_coro_hint, 0, 128, 0, highbit); + } + } + Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt offset, UInt st_task_id, UInt n, Var... args) { + auto x = dispatch_x(); + $if (x >= n) { + $return(); + }; + UInt frame_id; + if (!config.compact) { + frame_id = index->read(x); + } else { + frame_id = offset + x; + } + CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); + if (!config.sort) { + count.atomic(0u).fetch_add(-1u); + } + + coroutine.subroutine(0u)(frame, args...); + if (config.soa) { + _frame_soa->write(frame_id, frame, coroutine->graph().node(0u)->output_state_members); + } else { + _frame_buffer->write(frame_id, frame); + } + if (!config.sort) { + auto nxt = read_promise(frame, "coro_token") & token_mask; + count.atomic(nxt).fetch_add(1u); + } + }; } void _dispatch(Stream &stream, uint3 dispatch_size, diff --git a/src/ast/type.cpp b/src/ast/type.cpp index ceed43f72..52a6f29e4 100644 --- a/src/ast/type.cpp +++ b/src/ast/type.cpp @@ -810,18 +810,21 @@ void Type::update_from(const Type *type) { } size_t Type::add_member(const luisa::string &name) noexcept { + LUISA_ERROR_WITH_LOCATION("Deprecated."); LUISA_ASSERT(name != "coro_id" && name != "coro_token", "{} is a reserved name for coroframe type.", name); return detail::_add_member(this, name); } void Type::set_member_name(size_t index, luisa::string name) noexcept { + LUISA_ERROR_WITH_LOCATION("Deprecated."); LUISA_ASSERT(name != "coro_id" && name != "coro_token", "{} is a reserved name for coroframe type.", name); detail::_set_member_name(this, index, std::move(name)); } size_t Type::member(luisa::string_view name) const noexcept { + LUISA_ERROR_WITH_LOCATION("Deprecated."); if (name == "coro_id") return 0; if (name == "coro_token") return 1u; auto &map = static_cast(this)->member_names; diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 479ea0dbc..3f8661dfd 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -67,4 +67,4 @@ Var CoroFrame::is_terminated() const noexcept { return (target_token & coro_token_terminal) != 0u; } -}// namespace luisa::compute::coro_v2 +}// namespace luisa::compute::coroutine From 327faf164b1eab67d74159a8dd709025defb1eba Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 22:23:56 +0800 Subject: [PATCH 35/67] checkpoint for wavefront coro scheduler --- include/luisa/coro/v2/coro_frame_soa.h | 46 ++- include/luisa/coro/v2/schedulers/wavefront.h | 280 ++++++++++++++++++- 2 files changed, 310 insertions(+), 16 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index cd967926a..54d234158 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -26,41 +26,64 @@ class CoroFrameSOAExprProxy { LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(CoroFrameSOAExprProxy) public: + /// Read field with field_index at index template requires is_integral_expr_v [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { return Expr{_soa}.template read_field(std::forward(index), field_index); } + /// Read field named with "name" at index template requires is_integral_expr_v [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { return read_field(std::forward(index), _soa.desc()->designated_field(name)); } + /// Write field with field_index at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, uint field_index) const noexcept { + Expr{_soa}.write_field(std::forward(index), + std::forward(value), + field_index); + } + /// Write field named with "name" at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { + write_field(std::forward(index), + std::forward(value), + _soa.desc()->designated_field(name)); + } + + /// Read index template requires is_integral_expr_v [[nodiscard]] auto read(I &&index) const noexcept { return Expr{_soa}.read(std::forward(index), luisa::nullopt); } + /// Read index with active fields template requires is_integral_expr_v [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); } + /// Write index template requires is_integral_expr_v void write(I &&index, V &&value) const noexcept { Expr{_soa}.write(std::forward(index), - std::forward(value), - luisa::nullopt); + std::forward(value), + luisa::nullopt); } + /// Write index with active fields template requires is_integral_expr_v void write(I &&index, V &&value, luisa::span active_fields) const noexcept { Expr{_soa}.write(std::forward(index), - std::forward(value), - luisa::make_optional(active_fields)); + std::forward(value), + luisa::make_optional(active_fields)); } [[nodiscard]] Expr device_address() const noexcept { @@ -209,7 +232,7 @@ struct Expr> : SOABase { /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } - /// Read field at index + /// Read field with field_index at index template requires is_integral_expr_v [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { @@ -225,6 +248,19 @@ struct Expr> : SOABase { return Var(f); } + /// Write field with field_index at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, uint field_index) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto field_type = _desc->type()->members()[field_index]; + auto offset = _field_offsets->at(field_index); + auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto s = fb->call( + field_type, CallOp::BYTE_BUFFER_WRITE, + {_expression, detail::extract_expression(offset_var), detail::extract_expression(value)}); + } + /// Read index with active fields template [[nodiscard]] auto read( diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 8cca49c5c..717b4cbac 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -27,6 +27,7 @@ template class WavefrontCoroScheduler : public CoroScheduler { private: + WavefrontCoroSchedulerConfig _config; Shader1D, Buffer, uint, uint, uint, Args...> _gen_shader; luisa::vector, Buffer, uint, Args...>> _resume_shaders; Shader1D, Buffer, uint> _count_prefix_shader; @@ -55,8 +56,30 @@ class WavefrontCoroScheduler : public CoroScheduler { Buffer _temp_index; private: + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + LUISA_ASSERT(dispatch_size.y == 1u && dispatch_size.z == 1u, + "WavefrontCoroScheduler only supports 1D dispatch for now."); + _config.block_size.x = dispatch_size.x; + _dispatch_counter = 0; + _host_empty = true; + for (auto i = 0u; i < _max_sub_coro; i++) { + if (i) { + _host_count[i] = 0; + _host_offset[i] = _max_frame_count; + } else { + _host_count[i] = _max_frame_count; + _host_offset[i] = 0; + } + } + stream << _initialize_shader(_resume_count, _max_frame_count).dispatch(_max_frame_count); + + // TODO + } + void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { + _config = config; luisa::shared_ptr desc = coroutine.shared_frame(); if (config.soa) { _frame_soa = device.create_coro_frame_soa(coroutine.shared_frame(), config.max_instance_count); @@ -100,10 +123,10 @@ class WavefrontCoroScheduler : public CoroScheduler { } Callable get_coro_token = [&](UInt index) { $if (index > _max_frame_count) { - device_log("index {} out of range {}", index, _max_frame_count); + device_log("Index out of range {}/{}", index, _max_frame_count); }; if (config.soa) { - return _frame_soa->read_field(index, luisa::string_view("target_token")) & token_mask; + return _frame_soa->read_field(index, "target_token") & token_mask; } else { CoroFrame frame = _frame_buffer->read(index); return frame.get("target_token") & token_mask; @@ -119,13 +142,12 @@ class WavefrontCoroScheduler : public CoroScheduler { Callable get_coro_hint = [&](UInt index, BufferUInt val) { if (!config.hint_fields.empty()) { auto id = keep_index(index, val); - CoroFrame frame = CoroFrame::create(desc); if (config.soa) { - frame = _frame_soa->read(id, std::array{desc->designated_field("coro_hint")}); + return _frame_soa->read_field(id, "coro_hint"); } else { - frame = _frame_buffer->read(id); + CoroFrame frame = _frame_buffer->read(id); + return frame.get("coro_hint"); } - return frame.get("coro_hint"); } return def(0u); }; @@ -164,11 +186,11 @@ class WavefrontCoroScheduler : public CoroScheduler { } else { frame_id = offset + x; } - CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); if (!config.sort) { count.atomic(0u).fetch_add(-1u); } + CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); coroutine.subroutine(0u)(frame, args...); if (config.soa) { _frame_soa->write(frame_id, frame, coroutine->graph().node(0u)->output_state_members); @@ -176,15 +198,251 @@ class WavefrontCoroScheduler : public CoroScheduler { _frame_buffer->write(frame_id, frame); } if (!config.sort) { - auto nxt = read_promise(frame, "coro_token") & token_mask; + auto nxt = frame.get("coro_hint") & token_mask; count.atomic(nxt).fetch_add(1u); } }; + ShaderOption o{}; + _gen_shader = device.compile(gen_kernel, o); + _gen_shader.set_name("gen"); + _resume_shaders.resize(_max_sub_coro); + + for (auto i = 1u; i < _max_sub_coro; ++i) { + Kernel1D resume_kernel = [&](BufferUInt index, BufferUInt count, UInt n, Var... args) { + auto x = dispatch_x(); + $if (x >= n) { + $return(); + }; + auto frame_id = index.read(x); + CoroFrame frame = CoroFrame::create(desc); + if (config.soa) { + //frame = frame_buffer.read(frame_id); + frame = _frame_soa->read(frame_id, coroutine->graph().node(i)->input_state_members); + } else { + frame = _frame_buffer->read(frame_id); + } + if (!config.sort) { + count.atomic(i).fetch_add(-1u); + } + coroutine.subroutine(i)(frame, args...); + if (config.soa) { + _frame_soa->write(frame_id, frame, coroutine->graph().node(i)->output_state_members); + } else { + _frame_buffer->write(frame_id, frame); + } + + if (!config.sort) { + auto nxt = frame.get("target_token") & token_mask; + $if (nxt < _max_sub_coro) { + count.atomic(nxt).fetch_add(1u); + }; + } + }; + _resume_shaders[i] = device.compile(resume_kernel, o); + _resume_shaders[i].set_name("resume" + std::to_string(i)); + } + + Kernel1D _prefix_kernel = [&](BufferUInt count, BufferUInt prefix, UInt n) { + $if (dispatch_x() == 0) { + auto pre = def(0u); + for (auto i = 0u; i < _max_sub_coro; ++i) { + auto val = count.read(i); + prefix.write(def(i), pre); + pre = pre + val; + } + }; + }; + _count_prefix_shader = device.compile(_prefix_kernel); + + Kernel1D _gather_kernel = [&](BufferUInt index, BufferUInt prefix, UInt n) { + auto x = dispatch_x(); + auto r_id = def(0u); + if (config.soa) { + r_id = _frame_soa->read_field(x, "target_token") & token_mask; + } else { + auto frame = _frame_buffer->read(x); + r_id = frame.get("target_token") & token_mask; + } + auto q_id = prefix.atomic(r_id).fetch_add(1u); + index.write(q_id, x); + }; + _gather_shader = device.compile(_gather_kernel); + + Kernel1D _compact_kernel_2 = [&](BufferUInt index, UInt empty_offset, UInt n) { + //_global_buffer->write(0u, 0u); + auto x = dispatch_x(); + $if (empty_offset + x < n) { + auto token = def(0u); + if (config.soa) { + token = _frame_soa->read_field(empty_offset + x, "target_token"); + } else { + CoroFrame frame = _frame_buffer->read(empty_offset + x); + token = frame.get("target_token"); + } + $if ((token & token_mask) != 0u) { + auto res = _global_buffer->atomic(0u).fetch_add(1u); + auto slot = index.read(res); + if (!config.sort) { + $while (slot >= empty_offset) { + res = _global_buffer->atomic(0u).fetch_add(1u); + slot = index.read(res); + }; + } + if (config.soa) { + // TODO: active fields here? + auto frame = _frame_soa->read(empty_offset + x); + _frame_soa->write(slot, frame); + } else { + auto frame = _frame_buffer->read(empty_offset + x); + _frame_buffer->write(slot, frame); + } + if (config.soa) { + _frame_soa->write_field(empty_offset + x, 0u, "target_token"); + } else { + CoroFrame empty_frame = CoroFrame::create(desc); + _frame_buffer->write(empty_offset + x, empty_frame); + } + }; + }; + }; + _compact_shader = device.compile(_compact_kernel_2); + _compact_shader.set_name("compact"); + + Kernel1D _initialize_kernel = [&](BufferUInt count, UInt n) { + auto x = dispatch_x(); + $if (x < n) { + if (config.soa) { + CoroFrame frame = coroutine.instantiate(dispatch_id()); + _frame_soa->write(x, frame, std::array{0u, 1u}); + } else { + CoroFrame frame = coroutine.instantiate(dispatch_id()); + _frame_buffer->write(x, frame); + } + }; + $if (x < _max_sub_coro) { + count.write(x, ite(x == 0u, _max_frame_count, 0u)); + }; + }; + Kernel1D clear = [&](BufferUInt buffer, UInt n) { + auto x = dispatch_x(); + $if (x < n) { + buffer.write(x, 0u); + }; + }; + _clear_shader = device.compile(clear); + _initialize_shader = device.compile(_initialize_kernel); } - void _dispatch(Stream &stream, uint3 dispatch_size, - compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - LUISA_ERROR_WITH_LOCATION("Unimplemented"); + [[nodiscard]] bool _all_dispatched() const noexcept { + return _dispatch_counter == _max_frame_count; + } + [[nodiscard]] bool _all_done() const noexcept { + return _host_empty && _all_dispatched(); + } + + void _await_all(Stream &stream) noexcept { + while (!_all_done()) { + _await_step(stream); + } + } + void _await_step(Stream &stream) noexcept { + if (_config.sort) { + auto host_update = [&] { + _host_empty = true; + for (uint i = 0u; i < _max_sub_coro; i++) { + _host_count[i] = (i + 1u == _max_sub_coro ? _max_frame_count : _host_offset[i + 1u]) - _host_offset[i]; + _host_empty = _host_empty && (i == 0u || _host_count[i] == 0u); + } + }; + _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], + _resume_index, _max_frame_count); + + stream << _sort_temp_storage.hist_buffer.view(0u, _max_sub_coro).copy_to(_host_offset.data()) + << host_update + << synchronize(); + + if (_host_count[0] > _max_frame_count * (0.5) && !_all_dispatched()) { + auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); + if (_host_count[0] != _max_frame_count && _config.compact) { + stream << _clear_shader(_global_buffer, 1).dispatch(1u); + stream << _compact_shader(_resume_index, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); + } + stream << this->template call_shader<1, Buffer, Buffer, uint, uint, uint>( + _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, _max_frame_count) + .dispatch(gen_count); + _dispatch_counter += gen_count; + _host_empty = false; + } else { + for (uint i = 1; i < _max_sub_coro; i++) { + if (_host_count[i] > 0) { + if (_have_hint[i]) { + BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; + BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); + stream << this->template call_shader<1, Buffer, Buffer, uint>( + _resume_shaders[i], _index[out], _resume_count, _max_frame_count) + .dispatch(_host_count[i]); + } else { + + stream << this->template call_shader<1, Buffer, Buffer, uint>( + _resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), _resume_count, _max_frame_count) + .dispatch(_host_count[i]); + } + } + } + } + stream << synchronize(); + } else { + stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); + stream << _gather_shader(_resume_index, _resume_offset, _max_frame_count).dispatch(_max_frame_count); + if (_host_count[0] > _max_frame_count / 2 && !_all_dispatched()) { + auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); + if (_host_count[0] != _max_frame_count && _config.compact) { + stream << _clear_shader(_global_buffer, 1).dispatch(1u); + stream + << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), + _max_frame_count - _host_count[0], _max_frame_count) + .dispatch(_host_count[0]); + } + stream << this->template call_shader<1, Buffer, Buffer, uint, uint, uint>( + _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, + _max_frame_count - _host_count[0], _dispatch_counter, _max_frame_count) + .dispatch(gen_count); + _dispatch_counter += gen_count; + _host_empty = false; + } else { + for (uint i = 1; i < _max_sub_coro; i++) { + if (_host_count[i] > 0) { + if (_have_hint[i]) { + BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; + BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); + stream << this->template call_shader<1, Buffer, Buffer, uint>( + _resume_shaders[i], _index[out], _resume_count, _max_frame_count) + .dispatch(_host_count[i]); + } else { + stream << this->template call_shader<1, Buffer, Buffer, uint>( + _resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _max_frame_count) + .dispatch(_host_count[i]); + } + } + } + } + auto host_update = [&] { + _host_empty = true; + auto sum = 0u; + for (uint i = 0; i < _max_sub_coro; i++) { + _host_offset[i] = sum; + sum += _host_count[i]; + _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); + } + }; + stream << _resume_count.view(0, _max_sub_coro).copy_to(_host_count.data()) + << host_update; + stream << synchronize(); + } } public: From 3a051128bdce56a021c8f8568dffe3766d70500c Mon Sep 17 00:00:00 2001 From: chenxin Date: Mon, 13 May 2024 23:56:14 +0800 Subject: [PATCH 36/67] checkpoint for wavefront coro scheduler --- include/luisa/coro/coro_dispatcher.h | 120 +++++++++---------- include/luisa/coro/v2/coro_frame.h | 2 +- include/luisa/coro/v2/coro_frame_soa.h | 17 ++- include/luisa/coro/v2/schedulers/wavefront.h | 72 +++++------ include/luisa/dsl/resource.h | 4 +- include/luisa/runtime/byte_buffer.h | 2 + include/luisa/runtime/device.h | 11 +- src/tests/coro/path_tracing_wavefront_v2.cpp | 3 +- 8 files changed, 120 insertions(+), 111 deletions(-) diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h index 1e9f431b3..f4695637c 100644 --- a/include/luisa/coro/coro_dispatcher.h +++ b/include/luisa/coro/coro_dispatcher.h @@ -141,7 +141,7 @@ struct WavefrontCoroDispatcherConfig { bool sort = true;//use sort for coro token gathering bool compact = true; bool debug = false; - uint hint_range=0xffff'ffff; + uint hint_range = 0xffff'ffff; luisa::vector hint_fields; }; @@ -187,8 +187,8 @@ class WavefrontCoroDispatcher : public CoroDispatcherBase _temp_index; Stream &_stream; public: - bool all_dispatched() const noexcept; - bool all_done() const noexcept; + [[nodiscard]] bool all_dispatched() const noexcept; + [[nodiscard]] bool all_done() const noexcept; WavefrontCoroDispatcher(Coroutine *coroutine, Device &device, Stream &stream, @@ -205,7 +205,7 @@ class WavefrontCoroDispatcher : public CoroDispatcherBasesuspend_count() + 1; _max_sub_coro = max_sub_coro; _resume_index = device.create_buffer(_max_frame_count); @@ -228,10 +228,9 @@ class WavefrontCoroDispatcher : public CoroDispatcherBasecoro_tokens().find(token); if (id != coroutine->coro_tokens().end()) { - LUISA_ASSERT(id->secondsecond,max_sub_coro); + LUISA_ASSERT(id->second < max_sub_coro, "coroutine token {} of id {} out of range {}", token, id->second, max_sub_coro); _have_hint[id->second] = true; - } - else + } else LUISA_WARNING("coroutine token {} not found, hint disabled", token); } for (auto i = 0u; i < max_sub_coro; i++) { @@ -254,34 +253,33 @@ class WavefrontCoroDispatcher : public CoroDispatcherBaseread(id); - auto x=read_promise(frame, "coro_hint"); + auto x = read_promise(frame, "coro_hint"); return x; } else { return def(0u); } }; if (use_sort) { - _sort_temp_storage = radix_sort::temp_storage(device, _max_frame_count, std::max(std::min(config.hint_range,128u), max_sub_coro)); + _sort_temp_storage = radix_sort::temp_storage(device, _max_frame_count, std::max(std::min(config.hint_range, 128u), max_sub_coro)); } if (sort_base_gather) { _sort_token = radix_sort::instance<>(device, _max_frame_count, _sort_temp_storage, - &get_coro_token, &identical, &get_coro_token, 1, max_sub_coro); + &get_coro_token, &identical, &get_coro_token, 1, max_sub_coro); } if (!config.hint_fields.empty()) { - if(config.hint_range<=128){ + if (config.hint_range <= 128) { _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, - &get_coro_hint, &keep_index, &get_coro_hint, 1, config.hint_range); - } - else{ - auto highbit=0; - while((config.hint_range>>highbit)!=1){ + &get_coro_hint, &keep_index, &get_coro_hint, 1, config.hint_range); + } else { + auto highbit = 0; + while ((config.hint_range >> highbit) != 1) { highbit++; } _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, @@ -528,19 +526,19 @@ class PersistentCoroDispatcher : public CoroDispatcherBase{coroutine, device}, - _max_thread_count{(config.max_thread_count+config.block_size-1)/config.block_size*config.block_size}, + _max_thread_count{(config.max_thread_count + config.block_size - 1) / config.block_size * config.block_size}, _block_size{config.block_size}, _debug{config.debug}, _stream{stream} { - auto use_global=config.global; + auto use_global = config.global; _global = device.create_buffer(1); auto q_fac = 1u; uint max_sub_coro = coroutine->suspend_count() + 1; - auto g_fac=(uint)std::max((int)(max_sub_coro-q_fac),0); - auto global_queue_size= config.block_size * g_fac; - _global_size=0; - if(use_global) { - _global_frame = device.create_buffer(_max_thread_count*g_fac); - _global_size = _max_thread_count*g_fac; + auto g_fac = (uint)std::max((int)(max_sub_coro - q_fac), 0); + auto global_queue_size = config.block_size * g_fac; + _global_size = 0; + if (use_global) { + _global_frame = device.create_buffer(_max_thread_count * g_fac); + _global_size = _max_thread_count * g_fac; } _max_sub_coro = max_sub_coro; _dispatched = false; @@ -552,7 +550,7 @@ class PersistentCoroDispatcher : public CoroDispatcherBase path_id{shared_queue_size}; Shared work_counter{max_sub_coro}; Shared work_offset{2u}; - Shared all_token{use_global?(shared_queue_size+global_queue_size):shared_queue_size}; + Shared all_token{use_global ? (shared_queue_size + global_queue_size) : shared_queue_size}; Shared workload{2}; Shared work_stat{2};//0 max_count,1 max_id //Shared tag_counter{use_tag_sort ? pipeline().surfaces().size() : 0}; @@ -562,11 +560,11 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(0, 0, 0)); }; $for (index, 0u, g_fac) { - all_token[shared_queue_size+index * config.block_size + thread_x()] = 0u; + all_token[shared_queue_size + index * config.block_size + thread_x()] = 0u; }; $if (thread_x() < max_sub_coro) { $if (thread_x() == 0) { - work_counter[thread_x()] = use_global?(shared_queue_size+global_queue_size):shared_queue_size; + work_counter[thread_x()] = use_global ? (shared_queue_size + global_queue_size) : shared_queue_size; } $else { work_counter[thread_x()] = 0u; @@ -638,7 +636,7 @@ class PersistentCoroDispatcher : public CoroDispatcherBaseread(global_id); - _global_frame->write(global_id,frames[dst]); - frames[dst]=g_state; - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; - + _global_frame->write(global_id, frames[dst]); + frames[dst] = g_state; + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; } $else { auto g_state = _global_frame->read(global_id); - frames[dst]=g_state; - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; + frames[dst] = g_state; + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; }; } $else { - $if(frame_token != 0u) { - _global_frame->write(global_id,frames[dst]); - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; + $if (frame_token != 0u) { + _global_frame->write(global_id, frames[dst]); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; }; }; }; @@ -704,9 +701,8 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(gen_st + thread_x(), 0, 0)); (*coroutine)(frames[pid], args...);//only work when kernel 0s are continue auto nxt = read_promise(frames[pid], "coro_token") & token_mask; - all_token[pid]=nxt; + all_token[pid] = nxt; work_counter.atomic(nxt).fetch_add(1u); workload.atomic(0).fetch_add(1u); }; @@ -727,7 +723,7 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(frames[pid], "coro_token") & token_mask; - all_token[pid]=nxt; + all_token[pid] = nxt; work_counter.atomic(nxt).fetch_add(1u); }; } @@ -751,7 +747,7 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(); initialize_coroframe(frame, def(0, 0, 0)); - frame_buffer.write(x,frame); + frame_buffer.write(x, frame); }; }; _clear_shader = device.compile(clear); @@ -763,7 +759,7 @@ class PersistentCoroDispatcher : public CoroDispatcherBase::_await_step(Stream &stream) noe if (_have_hint[i]) { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i],_resume_index.view(_host_offset[i], _host_count[i])); + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], _resume_count, _frame, _max_frame_count) .dispatch(_host_count[i]); @@ -924,7 +920,7 @@ void WavefrontCoroDispatcher::_await_step(Stream &stream) noe if (_have_hint[i]) { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i],_resume_index.view(_host_offset[i], _host_count[i])); + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], _resume_count, _frame, _max_frame_count) .dispatch(_host_count[i]); diff --git a/include/luisa/coro/v2/coro_frame.h b/include/luisa/coro/v2/coro_frame.h index c258b11ff..0eb374bb4 100644 --- a/include/luisa/coro/v2/coro_frame.h +++ b/include/luisa/coro/v2/coro_frame.h @@ -42,7 +42,7 @@ class LC_CORO_API CoroFrame { template [[nodiscard]] Var &get(uint index) noexcept { _check_member_index(index); - auto fb = detail::FunctionBuilder::current(); + auto fb = luisa::compute::detail::FunctionBuilder::current(); auto member = fb->member(_desc->type()->members()[index], _expression, index); return *fb->create_temporary>(member); } diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index 54d234158..ef8cfb7af 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -9,7 +9,9 @@ #include #include +#include #include +#include #include @@ -105,6 +107,7 @@ struct SOABase { public: SOABase() noexcept = default; SOABase(SOABase &&) noexcept = default; + SOABase(const SOABase &) noexcept = default; SOABase(luisa::shared_ptr desc, luisa::shared_ptr> field_offsets, size_t offset_elements, size_t size_elements) noexcept @@ -133,6 +136,8 @@ class SOAView : public SOABase { requires std::same_as> SOAView(const T &soa) noexcept : SOAView{soa.view()} {} + SOAView(const SOAView &) noexcept = default; + SOAView(SOAView &&) noexcept = default; ~SOAView() noexcept = default; [[nodiscard]] auto subview(uint offset_elements, uint size_elements) noexcept { @@ -158,7 +163,7 @@ class SOA : public SOABase { ByteBuffer _buffer; public: - SOA(DeviceInterface *device, luisa::shared_ptr desc, uint n) noexcept + SOA(DeviceInterface *device, luisa::shared_ptr desc, size_t n) noexcept : SOABase{std::move(desc), luisa::make_shared>(), 0u, n} { size_t size_bytes = 0u; @@ -169,7 +174,7 @@ class SOA : public SOABase { size_bytes = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); _field_offsets->emplace_back(size_bytes); if (field->size() % field->alignment() != 0u) [[unlikely]] { - detail::error_buffer_invalid_alignment(size_bytes + field->size(), field->alignment()); + luisa::compute::detail::error_buffer_invalid_alignment(size_bytes + field->size(), field->alignment()); } size_bytes += field->size() * _size_elements; } @@ -239,7 +244,7 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); auto f = fb->local(field_type); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, @@ -255,7 +260,7 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_WRITE, {_expression, detail::extract_expression(offset_var), detail::extract_expression(value)}); @@ -272,7 +277,7 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, {_expression, detail::extract_expression(offset_var)}); @@ -292,7 +297,7 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + (_offset_elements + index) * field_type->size(); + auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); auto f = fb->member(field_type, frame.expression(), i); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_WRITE, diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 717b4cbac..ed8db68eb 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -28,6 +28,7 @@ class WavefrontCoroScheduler : public CoroScheduler { private: WavefrontCoroSchedulerConfig _config; + luisa::optional...>> _args; Shader1D, Buffer, uint, uint, uint, Args...> _gen_shader; luisa::vector, Buffer, uint, Args...>> _resume_shaders; Shader1D, Buffer, uint> _count_prefix_shader; @@ -73,21 +74,21 @@ class WavefrontCoroScheduler : public CoroScheduler { } } stream << _initialize_shader(_resume_count, _max_frame_count).dispatch(_max_frame_count); - - // TODO + _args = std::make_tuple(std::forward>(args)...); + _await_all(stream); } void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; - luisa::shared_ptr desc = coroutine.shared_frame(); + const luisa::shared_ptr desc = coroutine.shared_frame(); if (config.soa) { - _frame_soa = device.create_coro_frame_soa(coroutine.shared_frame(), config.max_instance_count); + _frame_soa = device.create_soa(coroutine.shared_frame(), config.max_instance_count); } else { _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), config.max_instance_count); } bool use_sort = config.sort || !config.hint_fields.empty(); - _max_sub_coro = coroutine->suspend_count() + 1; + _max_sub_coro = coroutine.subroutine_count() + 1; _resume_index = device.create_buffer(_max_frame_count); if (use_sort) { _temp_index = device.create_buffer(_max_frame_count); @@ -103,8 +104,8 @@ class WavefrontCoroScheduler : public CoroScheduler { _host_count.resize(_max_sub_coro); _have_hint.resize(_max_sub_coro, false); for (auto &token : config.hint_fields) { - auto id = coroutine->coro_tokens().find(token); - if (id != coroutine->coro_tokens().end()) { + auto id = coroutine.frame()->designated_fields().find(token); + if (id != coroutine.frame()->designated_fields().end()) { LUISA_ASSERT(id->second < _max_sub_coro, "coroutine token {} of id {} out of range {}", token, id->second, _max_sub_coro); _have_hint[id->second] = true; @@ -191,9 +192,9 @@ class WavefrontCoroScheduler : public CoroScheduler { } CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); - coroutine.subroutine(0u)(frame, args...); + coroutine[0u](frame, args...); if (config.soa) { - _frame_soa->write(frame_id, frame, coroutine->graph().node(0u)->output_state_members); + _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); } else { _frame_buffer->write(frame_id, frame); } @@ -217,16 +218,16 @@ class WavefrontCoroScheduler : public CoroScheduler { CoroFrame frame = CoroFrame::create(desc); if (config.soa) { //frame = frame_buffer.read(frame_id); - frame = _frame_soa->read(frame_id, coroutine->graph().node(i)->input_state_members); + frame = _frame_soa->read(frame_id, coroutine.graph()->node(i).input_fields()); } else { frame = _frame_buffer->read(frame_id); } if (!config.sort) { count.atomic(i).fetch_add(-1u); } - coroutine.subroutine(i)(frame, args...); + coroutine[i](frame, args...); if (config.soa) { - _frame_soa->write(frame_id, frame, coroutine->graph().node(i)->output_state_members); + _frame_soa->write(frame_id, frame, coroutine.graph()->node(i).output_fields()); } else { _frame_buffer->write(frame_id, frame); } @@ -367,10 +368,10 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); } - stream << this->template call_shader<1, Buffer, Buffer, uint, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, _max_frame_count) - .dispatch(gen_count); + auto invoke = _gen_shader(_resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, + _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; } else { @@ -380,14 +381,12 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, uint>( - _resume_shaders[i], _index[out], _resume_count, _max_frame_count) - .dispatch(_host_count[i]); + auto invoke = _resume_shaders[i](_index[out], _resume_count, _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); } else { - - stream << this->template call_shader<1, Buffer, Buffer, uint>( - _resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), _resume_count, _max_frame_count) - .dispatch(_host_count[i]); + auto invoke = _resume_shaders[i](_resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); } } } @@ -405,10 +404,10 @@ class WavefrontCoroScheduler : public CoroScheduler { _max_frame_count - _host_count[0], _max_frame_count) .dispatch(_host_count[0]); } - stream << this->template call_shader<1, Buffer, Buffer, uint, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, - _max_frame_count - _host_count[0], _dispatch_counter, _max_frame_count) - .dispatch(gen_count); + auto invoke = _gen_shader(_resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, + _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; } else { @@ -418,14 +417,12 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, uint>( - _resume_shaders[i], _index[out], _resume_count, _max_frame_count) - .dispatch(_host_count[i]); + auto invoke = _resume_shaders[i](_index[out], _resume_count, _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); } else { - stream << this->template call_shader<1, Buffer, Buffer, uint>( - _resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _max_frame_count) - .dispatch(_host_count[i]); + auto invoke = _resume_shaders[i](_resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _max_frame_count); + stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); } } } @@ -444,6 +441,13 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << synchronize(); } } + template + [[nodiscard]] auto _invoke_args(compute::detail::ShaderInvoke &invoke) noexcept { + std::apply([&](auto &&...args) { + static_cast((invoke << ... << args)); + }, *_args); + return invoke; + } public: WavefrontCoroScheduler(Device &device, const Coroutine &coro, diff --git a/include/luisa/dsl/resource.h b/include/luisa/dsl/resource.h index b7a3a6ae9..29fbe570e 100644 --- a/include/luisa/dsl/resource.h +++ b/include/luisa/dsl/resource.h @@ -558,10 +558,10 @@ class ByteBufferExprProxy { LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(ByteBufferExprProxy) public: - template + template requires is_integral_expr_v [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_buffer}.read(std::forward(index)); + return Expr{_buffer}.template read(std::forward(index)); } template requires is_integral_expr_v diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 3e1eb10d0..69e82d146 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -104,6 +104,8 @@ class ByteBufferView { _size{size}, _total_size{total_size} {} ByteBufferView(const ByteBuffer &buffer) noexcept : ByteBufferView{buffer.view()} {} + ByteBufferView(const ByteBufferView &) noexcept = default; + ByteBufferView(ByteBufferView &&) noexcept = default; ByteBufferView() noexcept : ByteBufferView{nullptr, invalid_resource_handle, 0u, 0u, 0u} {} [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 6e157b100..1d24f1861 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -260,11 +260,12 @@ class LC_RUNTIME_API Device { return SOA{*this, size}; } - template - requires std::same_as, luisa::shared_ptr> || - std::same_as, const coroutine::CoroFrameDesc *> - [[nodiscard]] auto create_coro_frame_soa(Desc &&desc, size_t size) noexcept { - return SOA{impl(), std::forward(desc), size}; + template + requires(std::same_as, luisa::shared_ptr> || + std::same_as, const coroutine::CoroFrameDesc *>) && + std::same_as + [[nodiscard]] auto create_soa(Desc &&desc, size_t size) noexcept { + return SOA{impl(), std::forward(desc), size}; } template diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 9badcba39..7fc44575f 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -301,7 +301,8 @@ int main(int argc, char *argv[]) { auto coro_buffer = device.create_coro_frame_buffer(coro.frame(), 1024u); - coroutine::StateMachineCoroScheduler scheduler{device, coro}; + // coroutine::StateMachineCoroScheduler scheduler{device, coro}; + coroutine::WavefrontCoroScheduler scheduler{device, coro}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); From 487aa7d3a3d22f3227471370cbd309589a95880e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 12:04:02 +0800 Subject: [PATCH 37/67] minor --- src/tests/coro/helloworld_v2.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp index 586eb10f7..ae3c70e4d 100644 --- a/src/tests/coro/helloworld_v2.cpp +++ b/src/tests/coro/helloworld_v2.cpp @@ -30,17 +30,24 @@ int main(int argc, char *argv[]) { } }; - coroutine::Coroutine coro = [](UInt n) { + coroutine::Coroutine nested2 = [](UInt n) { $for (i, n) { - device_log("{} / {}", i, n); + device_log("nested2: {} / {}", i, n); $suspend(); }; }; - coroutine::Coroutine awaiter = [&coro] { - $await coro(dispatch_x()); + coroutine::Coroutine nested1 = [&](UInt n) { + $for (i, n) { + $await nested2(i); + device_log("nested1: {} / {}", i, n); + }; + }; + + coroutine::Coroutine top_level = [&]() { + $await nested1(10u); }; - coroutine::StateMachineCoroScheduler sched{device, awaiter}; - stream << sched().dispatch(10u) << synchronize(); + coroutine::StateMachineCoroScheduler sched{device, top_level}; + stream << sched().dispatch(1u) << synchronize(); } From 6f2fc0a10a79cc1e454246f6f1fee28fbe58764c Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 12:08:54 +0800 Subject: [PATCH 38/67] minor fix --- include/luisa/dsl/resource.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/luisa/dsl/resource.h b/include/luisa/dsl/resource.h index b7a3a6ae9..b77884e4a 100644 --- a/include/luisa/dsl/resource.h +++ b/include/luisa/dsl/resource.h @@ -561,7 +561,7 @@ class ByteBufferExprProxy { template requires is_integral_expr_v [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_buffer}.read(std::forward(index)); + return Expr{_buffer}.template read(std::forward(index)); } template requires is_integral_expr_v From e4f6e225e8cbd1b22711205527c04d6a431d9c5c Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 14:03:17 +0800 Subject: [PATCH 39/67] debugging coro frame soa --- include/luisa/coro/v2/coro_frame_buffer.h | 27 ----- include/luisa/coro/v2/coro_frame_soa.h | 120 +++++-------------- include/luisa/coro/v2/schedulers/wavefront.h | 2 +- 3 files changed, 28 insertions(+), 121 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/v2/coro_frame_buffer.h index bf741e965..72d54f159 100644 --- a/include/luisa/coro/v2/coro_frame_buffer.h +++ b/include/luisa/coro/v2/coro_frame_buffer.h @@ -276,33 +276,6 @@ struct Expr> { namespace detail { -// template -// requires std::same_as> || -// std::same_as> -// class BufferExprProxy { -// -// private: -// T _buffer; -// -// public: -// LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) -// -// public: -// template -// requires is_integral_expr_v -// [[nodiscard]] auto read(I &&index) const noexcept { -// return Expr{_buffer}.read(std::forward(index)); -// } -// template -// requires is_integral_expr_v -// void write(I &&index, V &&value) const noexcept { -// Expr{_buffer}.write(std::forward(index), std::forward(value)); -// } -// [[nodiscard]] Expr device_address() const noexcept { -// return Expr{_buffer}.device_address(); -// } -// }; - template typename B> class BufferExprProxy> { diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index a4661d6d2..350b7bde4 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -95,21 +95,19 @@ class SOA : public SOABase { : SOABase{std::move(desc), luisa::make_shared>(), 0u, n} { size_t size_bytes = 0u; - size_bytes = 0u; auto fields = _desc->type()->members(); _field_offsets->reserve(fields.size()); - for (const auto field : fields) { - size_bytes = (size_bytes + field->alignment() - 1u) & ~(field->alignment() - 1u); + for (const auto field_type : fields) { + auto alignment = std::max(field_type->alignment(), 4ull); + size_bytes = (size_bytes + alignment - 1u) & ~(alignment - 1u); _field_offsets->emplace_back(size_bytes); - if (field->size() % field->alignment() != 0u) [[unlikely]] { - luisa::compute::detail::error_buffer_invalid_alignment(size_bytes + field->size(), field->alignment()); - } - size_bytes += field->size() * _size_elements; + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + size_bytes += aligned_size * _size_elements; } - auto buffer_element_count = (size_bytes + 3u) & ~3u; + size_bytes = (size_bytes + 3u) & ~3u; auto info = device->create_buffer( Type::of(), - buffer_element_count, + size_bytes, nullptr); _buffer = std::move(ByteBuffer{device, info}); } @@ -172,7 +170,9 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); + auto alignment = std::max(field_type->alignment(), 4ull); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto f = fb->local(field_type); auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, @@ -188,10 +188,13 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); - auto s = fb->call( - field_type, CallOp::BYTE_BUFFER_WRITE, - {_expression, detail::extract_expression(offset_var), detail::extract_expression(value)}); + auto alignment = std::max(field_type->alignment(), 4ull); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; + fb->call(CallOp::BYTE_BUFFER_WRITE, + {_expression, + detail::extract_expression(offset_var), + detail::extract_expression(value)}); } /// Read index with active fields @@ -205,7 +208,9 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); + auto alignment = std::max(field_type->alignment(), 4ull); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto s = fb->call( field_type, CallOp::BYTE_BUFFER_READ, {_expression, detail::extract_expression(offset_var)}); @@ -225,11 +230,14 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto offset_var = offset + (_offset_elements + ULong(index)) * field_type->size(); + auto alignment = std::max(field_type->alignment(), 4ull); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto f = fb->member(field_type, frame.expression(), i); - auto s = fb->call( - field_type, CallOp::BYTE_BUFFER_WRITE, - {_expression, detail::extract_expression(offset_var), f}); + fb->call(CallOp::BYTE_BUFFER_WRITE, + {_expression, + detail::extract_expression(offset_var), + f}); } } @@ -325,80 +333,6 @@ class SOAExprProxy> { } }; -// template -// class SOAExprProxy> { -// private: -// SOAOrView _soa; -// -// public: -// LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(SOAExprProxy>) -// -// public: -// /// Read field with field_index at index -// template -// requires is_integral_expr_v -// [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { -// return Expr{_soa}.template read_field(std::forward(index), field_index); -// } -// /// Read field named with "name" at index -// template -// requires is_integral_expr_v -// [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { -// return read_field(std::forward(index), _soa.desc()->designated_field(name)); -// } -// -// /// Write field with field_index at index -// template -// requires is_integral_expr_v -// void write_field(I &&index, V &&value, uint field_index) const noexcept { -// Expr{_soa}.write_field(std::forward(index), -// std::forward(value), -// field_index); -// } -// /// Write field named with "name" at index -// template -// requires is_integral_expr_v -// void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { -// write_field(std::forward(index), -// std::forward(value), -// _soa.desc()->designated_field(name)); -// } -// -// /// Read index -// template -// requires is_integral_expr_v -// [[nodiscard]] auto read(I &&index) const noexcept { -// return Expr{_soa}.read(std::forward(index), luisa::nullopt); -// } -// /// Read index with active fields -// template -// requires is_integral_expr_v -// [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { -// return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); -// } -// -// /// Write index -// template -// requires is_integral_expr_v -// void write(I &&index, V &&value) const noexcept { -// Expr{_soa}.write(std::forward(index), -// std::forward(value), -// luisa::nullopt); -// } -// /// Write index with active fields -// template -// requires is_integral_expr_v -// void write(I &&index, V &&value, luisa::span active_fields) const noexcept { -// Expr{_soa}.write(std::forward(index), -// std::forward(value), -// luisa::make_optional(active_fields)); -// } -// -// [[nodiscard]] Expr device_address() const noexcept { -// return Expr{_soa}.device_address(); -// } -// }; - }// namespace detail }// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index cd660ab98..70d96c672 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -140,7 +140,7 @@ class WavefrontCoroScheduler : public CoroScheduler { return _frame_soa->read_field(index, "target_token") & token_mask; } else { CoroFrame frame = _frame_buffer->read(index); - return frame.get("target_token") & token_mask; + return frame.target_token & token_mask; } }; Callable identical = [](UInt index) { From ca8ad883f157a3a18f9cc2199a3c21e7f18cdb14 Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 15:21:19 +0800 Subject: [PATCH 40/67] bad --- src/tests/coro/path_tracing_wavefront_v2.cpp | 23 ++++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 739fc6d92..9d8427e90 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -203,13 +203,17 @@ int main(int argc, char *argv[]) { UInt pixel_id = coro_id().x % pixel_count; UInt2 coord = make_uint2(pixel_id % resolution.x, pixel_id / resolution.x); + // UInt2 coord = dispatch_id().xy(); + Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; Float rx = lcg(state); Float ry = lcg(state); Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; Float3 radiance = def(make_float3(0.0f)); - $suspend("per_spp"); + + $suspend("per_dispatch"); + // $if (all(coord == make_uint2(50u, 500u))) { // device_log("coord: {}", coord); // }; @@ -223,10 +227,10 @@ int main(int argc, char *argv[]) { constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); Float light_area = length(cross(light_u, light_v)); Float3 light_normal = normalize(cross(light_u, light_v)); - $suspend("per_depth"); + $for (depth, 10u) { // trace - $suspend("before_tracing"); + $suspend("intersect"); Var hit = accel.intersect(ray, {}); reorder_shader_execution(); $if (hit->miss()) { $break; }; @@ -236,7 +240,6 @@ int main(int argc, char *argv[]) { Float3 p2 = vertex_buffer->read(triangle.i2); Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); Float3 n = normalize(cross(p1 - p0, p2 - p0)); - $suspend("after_tracing"); Float cos_wo = dot(-ray->direction(), n); $if (cos_wo < 1e-4f) { $break; }; @@ -263,10 +266,13 @@ int main(int argc, char *argv[]) { Float3 pp_light = offset_ray_origin(p_light, light_normal); Float d_light = distance(pp, pp_light); Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); Bool occluded = accel.intersect_any(shadow_ray, {}); Float cos_wi_light = dot(wi_light, n); Float cos_light = -dot(light_normal, wi_light); + + $suspend("evaluate_surface", std::make_pair(hit.inst, "coro_hint")); Float3 albedo = materials.read(hit.inst); $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { Float pdf_light = (d_light * d_light) / (light_area * cos_light); @@ -277,7 +283,6 @@ int main(int argc, char *argv[]) { }; // sample BSDF - $suspend("sample_bsdf"); Var onb = make_onb(n); Float ux = lcg(state); Float uy = lcg(state); @@ -297,6 +302,7 @@ int main(int argc, char *argv[]) { $if (r >= q) { $break; }; beta *= 1.0f / q; }; + $suspend("write_film"); seed_image.write(coord, make_uint4(state)); $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; @@ -304,11 +310,14 @@ int main(int argc, char *argv[]) { image.write(coord, color); }; + LUISA_INFO_WITH_LOCATION("askjdhfpahfdsalk;fd"); + coroutine::WavefrontCoroSchedulerConfig config{ .thread_count = 16_M, - .soa = true, + .soa = false, }; coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; + // coroutine::PersistentThreadsCoroScheduler scheduler{device, coro}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); @@ -373,7 +382,7 @@ int main(int argc, char *argv[]) { while (!window.should_close()) { stream << scheduler(framebuffer, seed_image, accel, resolution) - .dispatch(resolution.x * resolution.y * spp_per_dispatch) + .dispatch(resolution.x * resolution.y) << accumulate_shader(accum_image, framebuffer) .dispatch(resolution) << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) From 6b0eca3daf08a332dc120fb8d86faa11c888505b Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 15:43:21 +0800 Subject: [PATCH 41/67] fix wavefront coro scheduler --- src/tests/coro/path_tracing_wavefront_v2.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 9d8427e90..3eddcaf76 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -272,7 +272,7 @@ int main(int argc, char *argv[]) { Float cos_wi_light = dot(wi_light, n); Float cos_light = -dot(light_normal, wi_light); - $suspend("evaluate_surface", std::make_pair(hit.inst, "coro_hint")); + $suspend("evaluate_surface"); Float3 albedo = materials.read(hit.inst); $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { Float pdf_light = (d_light * d_light) / (light_area * cos_light); @@ -310,11 +310,10 @@ int main(int argc, char *argv[]) { image.write(coord, color); }; - LUISA_INFO_WITH_LOCATION("askjdhfpahfdsalk;fd"); - coroutine::WavefrontCoroSchedulerConfig config{ - .thread_count = 16_M, - .soa = false, + .thread_count = 256_k, + .soa = true, + .sort = true, }; coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; // coroutine::PersistentThreadsCoroScheduler scheduler{device, coro}; From b0e03814041ed6a71fccc3e67e1ce9a1b0386f35 Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 16:59:52 +0800 Subject: [PATCH 42/67] wavefront coro scheduler checkpoint --- include/luisa/coro/v2/coro_frame_buffer.h | 31 ++- include/luisa/coro/v2/coro_frame_soa.h | 239 +++++++++++++------ include/luisa/coro/v2/schedulers/wavefront.h | 120 +++++----- src/tests/coro/path_tracing_wavefront_v2.cpp | 4 + 4 files changed, 253 insertions(+), 141 deletions(-) diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/v2/coro_frame_buffer.h index d5619b9b0..bf741e965 100644 --- a/include/luisa/coro/v2/coro_frame_buffer.h +++ b/include/luisa/coro/v2/coro_frame_buffer.h @@ -14,8 +14,8 @@ namespace detail { [[noreturn]] LC_CORO_API void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept; -template typename B> -class BufferExprProxy>; +template +class BufferExprProxy; }// namespace detail @@ -276,6 +276,33 @@ struct Expr> { namespace detail { +// template +// requires std::same_as> || +// std::same_as> +// class BufferExprProxy { +// +// private: +// T _buffer; +// +// public: +// LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) +// +// public: +// template +// requires is_integral_expr_v +// [[nodiscard]] auto read(I &&index) const noexcept { +// return Expr{_buffer}.read(std::forward(index)); +// } +// template +// requires is_integral_expr_v +// void write(I &&index, V &&value) const noexcept { +// Expr{_buffer}.write(std::forward(index), std::forward(value)); +// } +// [[nodiscard]] Expr device_address() const noexcept { +// return Expr{_buffer}.device_address(); +// } +// }; + template typename B> class BufferExprProxy> { diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/v2/coro_frame_soa.h index ef8cfb7af..a4661d6d2 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/v2/coro_frame_soa.h @@ -18,81 +18,8 @@ namespace luisa::compute { namespace detail { - -template -class CoroFrameSOAExprProxy { -private: - SOAOrView _soa; - -public: - LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(CoroFrameSOAExprProxy) - -public: - /// Read field with field_index at index - template - requires is_integral_expr_v - [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { - return Expr{_soa}.template read_field(std::forward(index), field_index); - } - /// Read field named with "name" at index - template - requires is_integral_expr_v - [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { - return read_field(std::forward(index), _soa.desc()->designated_field(name)); - } - - /// Write field with field_index at index - template - requires is_integral_expr_v - void write_field(I &&index, V &&value, uint field_index) const noexcept { - Expr{_soa}.write_field(std::forward(index), - std::forward(value), - field_index); - } - /// Write field named with "name" at index - template - requires is_integral_expr_v - void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { - write_field(std::forward(index), - std::forward(value), - _soa.desc()->designated_field(name)); - } - - /// Read index - template - requires is_integral_expr_v - [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_soa}.read(std::forward(index), luisa::nullopt); - } - /// Read index with active fields - template - requires is_integral_expr_v - [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { - return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); - } - - /// Write index - template - requires is_integral_expr_v - void write(I &&index, V &&value) const noexcept { - Expr{_soa}.write(std::forward(index), - std::forward(value), - luisa::nullopt); - } - /// Write index with active fields - template - requires is_integral_expr_v - void write(I &&index, V &&value, luisa::span active_fields) const noexcept { - Expr{_soa}.write(std::forward(index), - std::forward(value), - luisa::make_optional(active_fields)); - } - - [[nodiscard]] Expr device_address() const noexcept { - return Expr{_soa}.device_address(); - } -}; - +template +class SOAExprProxy; }// namespace detail template @@ -148,11 +75,12 @@ class SOAView : public SOABase { [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } [[nodiscard]] auto offset_elements() const noexcept { return _offset_elements; } [[nodiscard]] auto size_elements() const noexcept { return _size_elements; } + [[nodiscard]] auto buffer_offset() const noexcept { return _buffer_view.offset(); } [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } // DSL interface [[nodiscard]] auto operator->() const noexcept { - return reinterpret_cast> *>(this); + return reinterpret_cast> *>(this); } }; @@ -210,7 +138,7 @@ class SOA : public SOABase { } // DSL interface [[nodiscard]] auto operator->() const noexcept { - return reinterpret_cast> *>(this); + return reinterpret_cast> *>(this); } }; @@ -227,8 +155,8 @@ struct Expr> : SOABase { : SOABase{soa_view.desc()->shared_from_this(), soa_view.field_offsets(), soa_view.offset_elements(), soa_view.size_elements()}, _expression{detail::FunctionBuilder::current()->buffer_binding( - Type::buffer(Type::of()), soa_view.handle(), - 0u, soa_view.size_bytes())} {} + Type::of(), soa_view.handle(), + soa_view.buffer_offset(), soa_view.size_bytes())} {} /// Construct from SOA. Will call buffer_binding() to bind buffer Expr(const SOA &soa) noexcept @@ -320,4 +248,157 @@ struct Expr> : public Expr>::Expr; }; +namespace detail { + +template typename B> +class SOAExprProxy> { +private: + using SOAOrView = B; + SOAOrView _soa; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(SOAExprProxy) + +public: + /// Read field with field_index at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { + return Expr{_soa}.template read_field(std::forward(index), field_index); + } + /// Read field named with "name" at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { + return read_field(std::forward(index), _soa.desc()->designated_field(name)); + } + + /// Write field with field_index at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, uint field_index) const noexcept { + Expr{_soa}.write_field(std::forward(index), + std::forward(value), + field_index); + } + /// Write field named with "name" at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { + write_field(std::forward(index), + std::forward(value), + _soa.desc()->designated_field(name)); + } + + /// Read index + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::nullopt); + } + /// Read index with active fields + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); + } + + /// Write index + template + requires is_integral_expr_v + void write(I &&index, V &&value) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + luisa::nullopt); + } + /// Write index with active fields + template + requires is_integral_expr_v + void write(I &&index, V &&value, luisa::span active_fields) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + luisa::make_optional(active_fields)); + } + + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_soa}.device_address(); + } +}; + +// template +// class SOAExprProxy> { +// private: +// SOAOrView _soa; +// +// public: +// LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(SOAExprProxy>) +// +// public: +// /// Read field with field_index at index +// template +// requires is_integral_expr_v +// [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { +// return Expr{_soa}.template read_field(std::forward(index), field_index); +// } +// /// Read field named with "name" at index +// template +// requires is_integral_expr_v +// [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { +// return read_field(std::forward(index), _soa.desc()->designated_field(name)); +// } +// +// /// Write field with field_index at index +// template +// requires is_integral_expr_v +// void write_field(I &&index, V &&value, uint field_index) const noexcept { +// Expr{_soa}.write_field(std::forward(index), +// std::forward(value), +// field_index); +// } +// /// Write field named with "name" at index +// template +// requires is_integral_expr_v +// void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { +// write_field(std::forward(index), +// std::forward(value), +// _soa.desc()->designated_field(name)); +// } +// +// /// Read index +// template +// requires is_integral_expr_v +// [[nodiscard]] auto read(I &&index) const noexcept { +// return Expr{_soa}.read(std::forward(index), luisa::nullopt); +// } +// /// Read index with active fields +// template +// requires is_integral_expr_v +// [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { +// return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); +// } +// +// /// Write index +// template +// requires is_integral_expr_v +// void write(I &&index, V &&value) const noexcept { +// Expr{_soa}.write(std::forward(index), +// std::forward(value), +// luisa::nullopt); +// } +// /// Write index with active fields +// template +// requires is_integral_expr_v +// void write(I &&index, V &&value, luisa::span active_fields) const noexcept { +// Expr{_soa}.write(std::forward(index), +// std::forward(value), +// luisa::make_optional(active_fields)); +// } +// +// [[nodiscard]] Expr device_address() const noexcept { +// return Expr{_soa}.device_address(); +// } +// }; + +}// namespace detail + }// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index ed8db68eb..09ba738e3 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -13,7 +13,7 @@ namespace luisa::compute::coroutine { struct WavefrontCoroSchedulerConfig { - uint3 block_size = luisa::make_uint3(128, 1, 1); + uint3 block_size = make_uint3(8u, 8u, 1u); uint max_instance_count = 2_M; bool soa = true; bool sort = true;// use sort for coro token gathering @@ -28,7 +28,10 @@ class WavefrontCoroScheduler : public CoroScheduler { private: WavefrontCoroSchedulerConfig _config; - luisa::optional...>> _args; + using ArgPack = std::tuple...>; + luisa::optional _args; + SOA _frame_soa; + Buffer _frame_buffer; Shader1D, Buffer, uint, uint, uint, Args...> _gen_shader; luisa::vector, Buffer, uint, Args...>> _resume_shaders; Shader1D, Buffer, uint> _count_prefix_shader; @@ -36,8 +39,6 @@ class WavefrontCoroScheduler : public CoroScheduler { Shader1D, uint> _initialize_shader; Shader1D, uint, uint> _compact_shader; Shader1D, uint> _clear_shader; - SOA _frame_soa; - Buffer _frame_buffer; Buffer _resume_index; Buffer _resume_count; ///offset calculate from count, will be end after gathering @@ -48,7 +49,6 @@ class WavefrontCoroScheduler : public CoroScheduler { bool _host_empty; uint _dispatch_counter; uint _max_sub_coro; - uint _max_frame_count; radix_sort::temp_storage _sort_temp_storage; radix_sort::instance<> _sort_token; radix_sort::instance> _sort_hint; @@ -59,25 +59,33 @@ class WavefrontCoroScheduler : public CoroScheduler { private: void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - LUISA_ASSERT(dispatch_size.y == 1u && dispatch_size.z == 1u, - "WavefrontCoroScheduler only supports 1D dispatch for now."); _config.block_size.x = dispatch_size.x; _dispatch_counter = 0; _host_empty = true; for (auto i = 0u; i < _max_sub_coro; i++) { if (i) { _host_count[i] = 0; - _host_offset[i] = _max_frame_count; + _host_offset[i] = _config.max_instance_count; } else { - _host_count[i] = _max_frame_count; + _host_count[i] = _config.max_instance_count; _host_offset[i] = 0; } } - stream << _initialize_shader(_resume_count, _max_frame_count).dispatch(_max_frame_count); - _args = std::make_tuple(std::forward>(args)...); + stream << _initialize_shader(_resume_count, _config.max_instance_count).dispatch(_config.max_instance_count); + _args.emplace(std::forward>(args)...); _await_all(stream); } + template + [[nodiscard]] auto _invoke(ShaderTytpe &shader, PrefixArgsType &&...prefix_args) const noexcept { + return std::apply( + [&](PostfixArgsType &&...postfix_args) { + return shader(std::forward(prefix_args)..., + std::forward(postfix_args)...); + }, + _args.value()); + } + void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; @@ -88,12 +96,12 @@ class WavefrontCoroScheduler : public CoroScheduler { _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), config.max_instance_count); } bool use_sort = config.sort || !config.hint_fields.empty(); - _max_sub_coro = coroutine.subroutine_count() + 1; - _resume_index = device.create_buffer(_max_frame_count); + _max_sub_coro = coroutine.subroutine_count(); + _resume_index = device.create_buffer(_config.max_instance_count); if (use_sort) { - _temp_index = device.create_buffer(_max_frame_count); - _temp_key[0] = device.create_buffer(_max_frame_count); - _temp_key[1] = device.create_buffer(_max_frame_count); + _temp_index = device.create_buffer(_config.max_instance_count); + _temp_key[0] = device.create_buffer(_config.max_instance_count); + _temp_key[1] = device.create_buffer(_config.max_instance_count); } _resume_count = device.create_buffer(_max_sub_coro); _resume_offset = device.create_buffer(_max_sub_coro); @@ -116,15 +124,15 @@ class WavefrontCoroScheduler : public CoroScheduler { for (auto i = 0u; i < _max_sub_coro; i++) { if (i) { _host_count[i] = 0; - _host_offset[i] = _max_frame_count; + _host_offset[i] = _config.max_instance_count; } else { - _host_count[i] = _max_frame_count; + _host_count[i] = _config.max_instance_count; _host_offset[i] = 0; } } Callable get_coro_token = [&](UInt index) { - $if (index > _max_frame_count) { - device_log("Index out of range {}/{}", index, _max_frame_count); + $if (index > _config.max_instance_count) { + device_log("Index out of range {}/{}", index, _config.max_instance_count); }; if (config.soa) { return _frame_soa->read_field(index, "target_token") & token_mask; @@ -154,17 +162,17 @@ class WavefrontCoroScheduler : public CoroScheduler { }; if (use_sort) { _sort_temp_storage = radix_sort::temp_storage( - device, _max_frame_count, std::max(std::min(config.hint_range, 128u), _max_sub_coro)); + device, _config.max_instance_count, std::max(std::min(config.hint_range, 128u), _max_sub_coro)); } if (config.sort) { _sort_token = radix_sort::instance<>( - device, _max_frame_count, _sort_temp_storage, &get_coro_token, &identical, + device, _config.max_instance_count, _sort_temp_storage, &get_coro_token, &identical, &get_coro_token, 1, _max_sub_coro); } if (!config.hint_fields.empty()) { if (config.hint_range <= 128) { _sort_hint = radix_sort::instance>( - device, _max_frame_count, _sort_temp_storage, &get_coro_hint, &keep_index, + device, _config.max_instance_count, _sort_temp_storage, &get_coro_hint, &keep_index, &get_coro_hint, 1, config.hint_range); } else { auto highbit = 0; @@ -172,7 +180,7 @@ class WavefrontCoroScheduler : public CoroScheduler { highbit++; } _sort_hint = radix_sort::instance>( - device, _max_frame_count, _sort_temp_storage, &get_coro_hint, &keep_index, + device, _config.max_instance_count, _sort_temp_storage, &get_coro_hint, &keep_index, &get_coro_hint, 0, 128, 0, highbit); } } @@ -321,7 +329,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } }; $if (x < _max_sub_coro) { - count.write(x, ite(x == 0u, _max_frame_count, 0u)); + count.write(x, ite(x == 0u, _config.max_instance_count, 0u)); }; }; Kernel1D clear = [&](BufferUInt buffer, UInt n) { @@ -335,7 +343,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } [[nodiscard]] bool _all_dispatched() const noexcept { - return _dispatch_counter == _max_frame_count; + return _dispatch_counter == _config.max_instance_count; } [[nodiscard]] bool _all_done() const noexcept { return _host_empty && _all_dispatched(); @@ -351,27 +359,27 @@ class WavefrontCoroScheduler : public CoroScheduler { auto host_update = [&] { _host_empty = true; for (uint i = 0u; i < _max_sub_coro; i++) { - _host_count[i] = (i + 1u == _max_sub_coro ? _max_frame_count : _host_offset[i + 1u]) - _host_offset[i]; + _host_count[i] = (i + 1u == _max_sub_coro ? _config.max_instance_count : _host_offset[i + 1u]) - _host_offset[i]; _host_empty = _host_empty && (i == 0u || _host_count[i] == 0u); } }; _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], - _resume_index, _max_frame_count); + _resume_index, _config.max_instance_count); stream << _sort_temp_storage.hist_buffer.view(0u, _max_sub_coro).copy_to(_host_offset.data()) << host_update << synchronize(); - if (_host_count[0] > _max_frame_count * (0.5) && !_all_dispatched()) { + if (_host_count[0] > _config.max_instance_count * (0.5) && !_all_dispatched()) { auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _max_frame_count && _config.compact) { + if (_host_count[0] != _config.max_instance_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream << _compact_shader(_resume_index, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); + stream << _compact_shader(_resume_index, _config.max_instance_count - _host_count[0], _config.max_instance_count).dispatch(_host_count[0]); } - auto invoke = _gen_shader(_resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, - _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(gen_count); + stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _config.max_instance_count - _host_count[0], _dispatch_counter, + _config.max_instance_count) + .dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; } else { @@ -381,12 +389,12 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - auto invoke = _resume_shaders[i](_index[out], _resume_count, _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.max_instance_count) + .dispatch(_host_count[i]); } else { - auto invoke = _resume_shaders[i](_resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); + stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _config.max_instance_count) + .dispatch(_host_count[i]); } } } @@ -394,20 +402,19 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << synchronize(); } else { stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); - stream << _gather_shader(_resume_index, _resume_offset, _max_frame_count).dispatch(_max_frame_count); - if (_host_count[0] > _max_frame_count / 2 && !_all_dispatched()) { + stream << _gather_shader(_resume_index, _resume_offset, _config.max_instance_count).dispatch(_config.max_instance_count); + if (_host_count[0] > _config.max_instance_count / 2 && !_all_dispatched()) { auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _max_frame_count && _config.compact) { + if (_host_count[0] != _config.max_instance_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), - _max_frame_count - _host_count[0], _max_frame_count) + _config.max_instance_count - _host_count[0], _config.max_instance_count) .dispatch(_host_count[0]); } - auto invoke = _gen_shader(_resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _max_frame_count - _host_count[0], _dispatch_counter, - _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(gen_count); + stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _config.max_instance_count - _host_count[0], _dispatch_counter, _config.max_instance_count) + .dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; } else { @@ -417,12 +424,12 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - auto invoke = _resume_shaders[i](_index[out], _resume_count, _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.max_instance_count) + .dispatch(_host_count[i]); } else { - auto invoke = _resume_shaders[i](_resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _max_frame_count); - stream << _invoke_args(invoke, _args).dispatch(_host_count[i]); + stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _config.max_instance_count) + .dispatch(_host_count[i]); } } } @@ -441,13 +448,6 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << synchronize(); } } - template - [[nodiscard]] auto _invoke_args(compute::detail::ShaderInvoke &invoke) noexcept { - std::apply([&](auto &&...args) { - static_cast((invoke << ... << args)); - }, *_args); - return invoke; - } public: WavefrontCoroScheduler(Device &device, const Coroutine &coro, diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 7fc44575f..bed5b72f7 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -302,6 +302,10 @@ int main(int argc, char *argv[]) { auto coro_buffer = device.create_coro_frame_buffer(coro.frame(), 1024u); // coroutine::StateMachineCoroScheduler scheduler{device, coro}; + coroutine::WavefrontCoroSchedulerConfig config{ + .block_size = make_uint3(1024u, 1u, 1u), + .max_instance_count = 16777216, + }; coroutine::WavefrontCoroScheduler scheduler{device, coro}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { From 19406613695f5da341e53ed8bc060bc96dac4b2c Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 17:48:52 +0800 Subject: [PATCH 43/67] wip: persistent-threads scheduler --- include/luisa/coro/coro_dispatcher.h | 21 -- .../coro/v2/schedulers/persistent_threads.h | 276 +++++++++++++- .../luisa/coro/v2/schedulers/state_machine.h | 15 +- src/tests/CMakeLists.txt | 1 + .../path_tracing_persistent_threads_v2.cpp | 350 ++++++++++++++++++ 5 files changed, 633 insertions(+), 30 deletions(-) create mode 100644 src/tests/coro/path_tracing_persistent_threads_v2.cpp diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h index 1e9f431b3..6e8dba0b0 100644 --- a/include/luisa/coro/coro_dispatcher.h +++ b/include/luisa/coro/coro_dispatcher.h @@ -583,30 +583,10 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(-1); $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { sync_block();//very important, synchronize for condition - rem_local[0] = 0u; count += 1; work_stat[0] = 0; work_stat[1] = -1; - /* $if(thread_x() < (uint)KERNEL_COUNT) {//clear counter - work_counter[thread_x()] = 0u; - }; - sync_block(); - $for(index, 0u, q_factor) { - for (auto i = 0u; i < KERNEL_COUNT; ++i) {//count the kernels - auto state = path_state[index*block_size+thread_x()]; - $if(state.kernel_index == i) { - if (i != (uint)INVALID) { - rem_local[0] = 1u; - } else { - $if(workload[0] < workload[1]) { - rem_local[0] = 1u; - }; - } - work_counter.atomic(i).fetch_add(1u); - }; - } - };*/ sync_block(); $if (thread_x() == config.block_size - 1) { $if ((workload[0] >= workload[1]) & (rem_global[0] == 1u)) {//fetch new workload @@ -764,7 +744,6 @@ class PersistentCoroDispatcher : public CoroDispatcherBase + +namespace luisa::compute::coroutine { + +struct PersistentThreadsCoroSchedulerConfig { + uint thread_count = 64_k; + uint block_size = 128; + uint fetch_size = 16; + bool shared_memory_soa = true; + bool global_ext_memory = false; +}; + +template +class PersistentThreadsCoroScheduler : public CoroScheduler { + +public: + using Coro = Coroutine; + using Config = PersistentThreadsCoroSchedulerConfig; + +private: + Coro _coro; + Config _config; + Shader1D, uint, uint2, Args...> _pt_shader; + Shader1D> _clear_shader; + Buffer _global; + Buffer _global_frames; + Shader1D _initialize_shader; + +private: + void _prepare(Device &device) noexcept { + _global = device.create_buffer(1); + auto q_fac = 1u; + auto g_fac = std::max(_coro.subroutine_count() - q_fac, 0); + auto global_queue_size = _config.block_size * g_fac; + if (_config.global_ext_memory) { + auto global_ext_size = _config.thread_count * g_fac; + _global_frames = device.create_buffer(global_ext_size); + } + Kernel1D main_kernel = [&](BufferUInt global, UInt dispatch_size, UInt2 dispatch_shape, Var... args) noexcept { + set_block_size(_config.block_size, 1u, 1u); + auto shared_queue_size = _config.block_size * q_fac; + Shared frames{_coro.shared_frame(), shared_queue_size, _config.shared_memory_soa}; + Shared path_id{shared_queue_size}; + Shared work_counter{_coro.subroutine_count()}; + Shared work_offset{2u}; + Shared all_token{_config.global_ext_memory ? + shared_queue_size + global_queue_size : + shared_queue_size}; + Shared workload{2}; + Shared work_stat{2};// work_state[0] for max_count, [1] for max_id + for (auto index : dsl::dynamic_range(q_fac)) { + auto s = index * _config.block_size + thread_x(); + all_token[s] = 0u; + frames.write(s, _coro.instantiate()); + } + for (auto index : dsl::dynamic_range(g_fac)) { + auto s = index * _config.block_size + thread_x(); + all_token[shared_queue_size + s] = 0u; + } + if_(thread_x() < _coro.subroutine_count(), [&] { + if_(thread_x() == 0u, [&] { + work_counter[thread_x()] = + _config.global_ext_memory ? + shared_queue_size + global_queue_size : + shared_queue_size; + }).else_([&] { + work_counter[thread_x()] = 0u; + }); + }); + workload[0] = 0u; + workload[1] = 0u; + Shared rem_global{1}; + Shared rem_local{1}; + rem_global[0] = 1u; + rem_local[0] = 0u; + sync_block(); + auto count = def(0u); + auto count_limit = def(-1); + loop([&] { + if_(!(rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit), [&] { break_(); }); + sync_block();//very important, synchronize for condition + rem_local[0] = 0u; + count += 1; + work_stat[0] = 0; + work_stat[1] = -1; + sync_block(); + if_(thread_x() == _config.block_size - 1, [&] { + if_(workload[0] >= workload[1] & rem_global[0] == 1u, [&] {//fetch new workload + workload[0] = global.atomic(0u).fetch_add(_config.block_size * _config.fetch_size); + workload[1] = min(workload[0] + _config.block_size * _config.fetch_size, dispatch_size); + if_(workload[0] >= dispatch_size, [&] { + rem_global[0] = 0u; + }); + }); + }); + sync_block(); + if_(thread_x() < _coro.subroutine_count(), [&] {//get max + if_(workload[0] < workload[1] | thread_x() != 0u, [&] { + if_(work_counter[thread_x()] != 0, [&] { + rem_local[0] = 1u; + work_stat.atomic(0).fetch_max(work_counter[thread_x()]); + }); + }); + }); + sync_block(); + if_(thread_x() < _coro.subroutine_count(), [&] {//get argmax + if_(work_stat[0] == work_counter[thread_x()] & (workload[0] < workload[1] | thread_x() != 0u), [&] { + work_stat[1] = thread_x(); + }); + }); + sync_block(); + work_offset[0] = 0; + work_offset[1] = 0; + sync_block(); + if (!_config.global_ext_memory) { + for (auto index : dsl::dynamic_range(q_fac)) {//collect indices + auto frame_token = all_token[index * _config.block_size + thread_x()]; + if_(frame_token == work_stat[1], [&] { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * _config.block_size + thread_x(); + }); + } + } else { + for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices + auto frame_token = all_token[index * _config.block_size + thread_x()]; + if_(frame_token != work_stat[1], [&] { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * _config.block_size + thread_x(); + }); + } + sync_block(); + if_(shared_queue_size - work_offset[0] < _config.block_size, [&] {//no enough work + for (auto index : dsl::dynamic_range(g_fac)) { //swap frames + auto global_id = block_x() * global_queue_size + index * _config.block_size + thread_x(); + auto g_queue_id = index * _config.block_size + thread_x(); + auto coro_token = all_token[shared_queue_size + g_queue_id]; + if_(coro_token == work_stat[1], [&] { + auto id = work_offset.atomic(1).fetch_add(1u); + if_(id < work_offset[0], [&] { + auto dst = path_id[id]; + auto frame_token = all_token[dst]; + if_(coro_token != 0u, [&] { + $if (frame_token != 0u) { + auto g_state = _global_frames->read(global_id); + _global_frames->write(global_id, frames.read(dst)); + frames.write(dst, g_state); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + } + $else { + auto g_state = _global_frames->read(global_id); + frames.write(dst, g_state); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + }; + }).else_([&] { + $if (frame_token != 0u) { + auto frame = frames.read(dst); + _global_frames->write(global_id, frame); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + }; + }); + }); + }); + } + }); + } + auto gen_st = workload[0]; + sync_block(); + auto pid = def(0u); + if (_config.global_ext_memory) { + pid = thread_x(); + } else { + pid = path_id[thread_x()]; + } + auto launch_condition = def(true); + if (!_config.global_ext_memory) { + launch_condition = (thread_x() < work_offset[0]); + } else { + launch_condition = (all_token[pid] == work_stat[1]); + } + if_(launch_condition, [&] { + auto switch_stmt = switch_(all_token[pid]); + std::move(switch_stmt).case_(0u, [&] { + if_(gen_st + thread_x() < workload[1], [&] { + work_counter.atomic(0u).fetch_sub(1u); + auto work_id = gen_st + thread_x(); + auto global_index = gen_st + thread_x(); + auto image_size = dispatch_shape.x * dispatch_shape.y; + auto index_z = global_index / image_size; + auto index_xy = global_index % image_size; + auto index_x = index_xy % dispatch_shape.x; + auto index_y = index_xy / dispatch_shape.x; + auto frame = _coro.instantiate(make_uint3(index_x, index_y, index_z)); + _coro.entry()(frame, args...); + auto next = frame.target_token; + frames.write(pid, frame); + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + workload.atomic(0).fetch_add(1u); + }); + }); + for (auto i = 1u; i < _coro.subroutine_count(); i++) { + std::move(switch_stmt).case_(i, [&] { + work_counter.atomic(i).fetch_sub(1u); + auto frame = frames.read(pid); + _coro[i](frame, args...); + auto next = frame.target_token; + frames.write(pid, frame); + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + }); + } + }); + sync_block(); +#ifndef NDEBUG + if_(count >= count_limit, [&] { + device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); + if_(thread_x() < _coro.subroutine_count(), [&] { + device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); + }); + }); +#endif + }); + _pt_shader = device.compile(main_kernel); + _clear_shader = device.compile<1>([](BufferUInt global) { + global->write(dispatch_x(), 0u); + }); + _initialize_shader = device.compile<1>([&](UInt n) noexcept { + auto x = dispatch_x(); + $if (x < n) { + auto frame = _coro.instantiate(); + _global_frames->write(x, frame); + }; + }); + }; + } + + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + stream << _clear_shader(_global).dispatch(1u); + if (_config.global_ext_memory) { + auto n = static_cast(_global_frames.size()); + stream << _initialize_shader(n).dispatch(n); + } + auto size = make_ulong3(dispatch_size); + auto n = size.x * size.y * size.z; + stream << _pt_shader(_global, static_cast(n), dispatch_size.xy(), args...).dispatch(static_cast(n)); + } + +public: + PersistentThreadsCoroScheduler(Device &device, const Coro &coro, const Config &config) noexcept + : _coro{coro}, _config{config} { + _config.thread_count = luisa::align(_config.thread_count, _config.block_size); + _prepare(device); + } + PersistentThreadsCoroScheduler(Device &device, const Coro &coro) noexcept + : PersistentThreadsCoroScheduler{device, coro, Config{}} {} +}; + +// User-defined CTAD guides +template +PersistentThreadsCoroScheduler(Device &device, const Coroutine &coro, + const PersistentThreadsCoroSchedulerConfig &config) noexcept + -> PersistentThreadsCoroScheduler; + +template +PersistentThreadsCoroScheduler(Device &device, const Coroutine &coro) noexcept + -> PersistentThreadsCoroScheduler; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/v2/schedulers/state_machine.h index 1e5cecc56..5398c00f3 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/v2/schedulers/state_machine.h @@ -29,12 +29,15 @@ struct StateMachineCoroSchedulerConfig { template class StateMachineCoroScheduler : public CoroScheduler { +public: + using Coro = Coroutine; + using Config = StateMachineCoroSchedulerConfig; + private: Shader3D _shader; private: - void _create_shader(Device &device, const Coroutine &coro, - const StateMachineCoroSchedulerConfig &config) noexcept { + void _create_shader(Device &device, const Coro &coro, const Config &config) noexcept { Kernel3D kernel = [&coro, &config](Var... args) noexcept { set_block_size(config.block_size); if (config.shared_memory) { @@ -63,14 +66,14 @@ class StateMachineCoroScheduler : public CoroScheduler { } public: - StateMachineCoroScheduler(Device &device, const Coroutine &coro, - const StateMachineCoroSchedulerConfig &config) noexcept { + StateMachineCoroScheduler(Device &device, const Coro &coro, const Config &config) noexcept { _create_shader(device, coro, config); } - StateMachineCoroScheduler(Device &device, const Coroutine &coro) noexcept - : StateMachineCoroScheduler{device, coro, StateMachineCoroSchedulerConfig{}} {} + StateMachineCoroScheduler(Device &device, const Coro &coro) noexcept + : StateMachineCoroScheduler{device, coro, Config{}} {} }; +// User-defined CTAD guides template StateMachineCoroScheduler(Device &, const Coroutine &) -> StateMachineCoroScheduler; diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 132e52a9e..b78f9dd07 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -202,6 +202,7 @@ luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_rende luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) luisa_compute_add_executable(test_coro_path_tracing_wavefront_v2 coro/path_tracing_wavefront_v2.cpp) +luisa_compute_add_executable(test_coro_path_tracing_persistent_threads_v2 coro/path_tracing_persistent_threads_v2.cpp) luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/coro/path_tracing_persistent_threads_v2.cpp b/src/tests/coro/path_tracing_persistent_threads_v2.cpp new file mode 100644 index 000000000..d60356a9f --- /dev/null +++ b/src/tests/coro/path_tracing_persistent_threads_v2.cpp @@ -0,0 +1,350 @@ +#include + +#include +#include + +#include +#include "../common/cornell_box.h" + +#define TINYOBJLOADER_IMPLEMENTATION +#include "../common/tiny_obj_loader.h" + +using namespace luisa; +using namespace luisa::compute; + +struct Onb { + float3 tangent; + float3 binormal; + float3 normal; +}; + +LUISA_STRUCT(Onb, tangent, binormal, normal) { + [[nodiscard]] Float3 to_world(Expr v) const noexcept { + return v.x * tangent + v.y * binormal + v.z * normal; + } +}; + +int main(int argc, char *argv[]) { + + log_level_verbose(); + + Context context{argv[0]}; + if (argc <= 1) { + LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); + exit(1); + } + Device device = context.create_device(argv[1]); + + // load the Cornell Box scene + tinyobj::ObjReaderConfig obj_reader_config; + obj_reader_config.triangulate = true; + obj_reader_config.vertex_color = false; + tinyobj::ObjReader obj_reader; + if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { + luisa::string_view error_message = "unknown error."; + if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } + LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); + } + if (auto &&e = obj_reader.Warning(); !e.empty()) { + LUISA_WARNING_WITH_LOCATION("{}", e); + } + + auto &&p = obj_reader.GetAttrib().vertices; + luisa::vector vertices; + vertices.reserve(p.size() / 3u); + for (uint i = 0u; i < p.size(); i += 3u) { + vertices.emplace_back(make_float3( + p[i + 0u], p[i + 1u], p[i + 2u])); + } + LUISA_INFO( + "Loaded mesh with {} shape(s) and {} vertices.", + obj_reader.GetShapes().size(), vertices.size()); + + BindlessArray heap = device.create_bindless_array(); + Stream stream = device.create_stream(StreamTag::GRAPHICS); + Buffer vertex_buffer = device.create_buffer(vertices.size()); + stream << vertex_buffer.copy_from(vertices.data()); + luisa::vector meshes; + luisa::vector> triangle_buffers; + for (auto &&shape : obj_reader.GetShapes()) { + uint index = static_cast(meshes.size()); + std::vector const &t = shape.mesh.indices; + uint triangle_count = t.size() / 3u; + LUISA_INFO( + "Processing shape '{}' at index {} with {} triangle(s).", + shape.name, index, triangle_count); + luisa::vector indices; + indices.reserve(t.size()); + for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } + Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); + Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); + heap.emplace_on_update(index, triangle_buffer); + stream << triangle_buffer.copy_from(indices.data()) + << mesh.build(); + } + + Accel accel = device.create_accel({}); + for (Mesh &m : meshes) { + accel.emplace_back(m, make_float4x4(1.0f)); + } + stream << heap.update() + << accel.build() + << synchronize(); + + Constant materials{ + make_float3(0.725f, 0.710f, 0.680f),// floor + make_float3(0.725f, 0.710f, 0.680f),// ceiling + make_float3(0.725f, 0.710f, 0.680f),// back wall + make_float3(0.140f, 0.450f, 0.091f),// right wall + make_float3(0.630f, 0.065f, 0.050f),// left wall + make_float3(0.725f, 0.710f, 0.680f),// short box + make_float3(0.725f, 0.710f, 0.680f),// tall box + make_float3(0.000f, 0.000f, 0.000f),// light + }; + + Callable linear_to_srgb = [&](Var x) noexcept { + return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, + 12.92f * x, + x <= 0.00031308f)); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + UInt s0 = def(0u); + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { + UInt2 p = dispatch_id().xy(); + UInt state = tea(p.x, p.y); + seed_image.write(p, make_uint4(state)); + }; + + Callable lcg = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state & 0x00ffffffu) * + (1.0f / static_cast(0x01000000u)); + }; + + Callable make_onb = [](const Float3 &normal) noexcept { + Float3 binormal = normalize(ite( + abs(normal.x) > abs(normal.z), + make_float3(-normal.y, normal.x, 0.0f), + make_float3(0.0f, -normal.z, normal.y))); + Float3 tangent = normalize(cross(binormal, normal)); + return def(tangent, binormal, normal); + }; + + Callable generate_ray = [](Float2 p) noexcept { + static constexpr float fov = radians(27.8f); + static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); + Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); + Float3 direction = normalize(pixel - origin); + return make_ray(origin, direction); + }; + + Callable cosine_sample_hemisphere = [](Float2 u) noexcept { + Float r = sqrt(u.x); + Float phi = 2.0f * constants::pi * u.y; + return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); + }; + + Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { + return pdf_a / max(pdf_a + pdf_b, 1e-4f); + }; + + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + UInt2 coord = dispatch_id().xy(); + Float frame_size = min(resolution.x, resolution.y).cast(); + UInt state = seed_image.read(coord).x; + Float rx = lcg(state); + Float ry = lcg(state); + Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; + Float3 radiance = def(make_float3(0.0f)); + $suspend("per_spp"); + $for (i, spp_per_dispatch) { + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + $suspend("per_depth"); + $for (depth, 10u) { + // trace + $suspend("before_tracing"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + $suspend("after_tracing"); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; + }; + $break; + }; + + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + }; + + // sample BSDF + $suspend("sample_bsdf"); + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; + }; + }; + $suspend("write_film"); + radiance /= static_cast(spp_per_dispatch); + seed_image.write(coord, make_uint4(state)); + $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; + image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); + }; + + coroutine::PersistentThreadsCoroSchedulerConfig config{}; + coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; + + Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { + UInt2 p = dispatch_id().xy(); + Float4 accum = accum_image.read(p); + Float3 curr = curr_image.read(p).xyz(); + accum_image.write(p, accum + make_float4(curr, 1.f)); + }; + + Callable aces_tonemapping = [](Float3 x) noexcept { + static constexpr float a = 2.51f; + static constexpr float b = 0.03f; + static constexpr float c = 2.43f; + static constexpr float d = 0.59f; + static constexpr float e = 0.14f; + return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); + }; + + Kernel2D clear_kernel = [](ImageFloat image) noexcept { + image.write(dispatch_id().xy(), make_float4(0.0f)); + }; + + Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { + UInt2 coord = dispatch_id().xy(); + Float4 hdr = hdr_image.read(coord); + Float3 ldr = hdr.xyz() / hdr.w * scale; + $if (!is_hdr) { + ldr = linear_to_srgb(ldr); + }; + ldr_image.write(coord, make_float4(ldr, 1.0f)); + }; + + ShaderOption o{.enable_debug_info = false}; + auto clear_shader = device.compile(clear_kernel, o); + auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); + auto accumulate_shader = device.compile(accumulate_kernel, o); + auto make_sampler_shader = device.compile(make_sampler_kernel, o); + + static constexpr uint2 resolution = make_uint2(1024u); + Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); + Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); + luisa::vector> host_image(resolution.x * resolution.y); + + Image seed_image = device.create_image(PixelStorage::INT1, resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); + + Window window{"path tracing", resolution}; + Swapchain swap_chain = device.create_swapchain( + stream, + SwapchainOption{ + .display = window.native_display(), + .window = window.native_handle(), + .size = make_uint2(resolution), + .wants_hdr = false, + .wants_vsync = false, + .back_buffer_count = 3, + }); + Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); + double last_time = 0.0; + uint frame_count = 0u; + Clock clock; + + while (!window.should_close()) { + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); + window.poll_events(); + double dt = clock.toc() - last_time; + last_time = clock.toc(); + frame_count += spp_per_dispatch; + LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", + frame_count, dt, spp_per_dispatch / dt * 1000); + } + stream + << ldr_image.copy_to(host_image.data()) + << synchronize(); + + LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); + stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); +} From a3a3dc0ea4055a36bc99e30dca9c5fb1f7e51d62 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 18:15:35 +0800 Subject: [PATCH 44/67] trying to debug --- include/luisa/coro/v2/coro_func.h | 2 +- .../coro/v2/schedulers/persistent_threads.h | 67 ++++++++++--------- .../common/hlsl/hlsl_codegen_util.cpp | 2 +- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/include/luisa/coro/v2/coro_func.h b/include/luisa/coro/v2/coro_func.h index 776484db8..8cceb9ceb 100644 --- a/include/luisa/coro/v2/coro_func.h +++ b/include/luisa/coro/v2/coro_func.h @@ -90,7 +90,7 @@ class Coroutine { public: [[nodiscard]] auto instantiate() const noexcept { return CoroFrame::create(_graph->shared_frame()); } [[nodiscard]] auto instantiate(Expr coro_id) const noexcept { return CoroFrame::create(_graph->shared_frame(), coro_id); } - [[nodiscard]] auto subroutine_count() const noexcept { return _graph->nodes().size(); } + [[nodiscard]] auto subroutine_count() const noexcept { return static_cast(_graph->nodes().size()); } [[nodiscard]] auto operator[](CoroToken token) const noexcept { return Subroutine{_graph->node(token).cc()}; } [[nodiscard]] auto operator[](luisa::string_view name) const noexcept { return Subroutine{_graph->node(name).cc()}; } [[nodiscard]] auto entry() const noexcept { return (*this)[coro_token_entry]; } diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index c8bd9869e..b93535684 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -24,7 +24,6 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { using Config = PersistentThreadsCoroSchedulerConfig; private: - Coro _coro; Config _config; Shader1D, uint, uint2, Args...> _pt_shader; Shader1D> _clear_shader; @@ -33,21 +32,21 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { Shader1D _initialize_shader; private: - void _prepare(Device &device) noexcept { + void _prepare(Device &device, const Coro &coro) noexcept { _global = device.create_buffer(1); auto q_fac = 1u; - auto g_fac = std::max(_coro.subroutine_count() - q_fac, 0); + auto g_fac = std::max(coro.subroutine_count() - q_fac, 0); auto global_queue_size = _config.block_size * g_fac; if (_config.global_ext_memory) { auto global_ext_size = _config.thread_count * g_fac; - _global_frames = device.create_buffer(global_ext_size); + _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D main_kernel = [&](BufferUInt global, UInt dispatch_size, UInt2 dispatch_shape, Var... args) noexcept { + Kernel1D, uint, uint2, Args...> main_kernel = [&](BufferUInt global, UInt dispatch_size, UInt2 dispatch_shape, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); auto shared_queue_size = _config.block_size * q_fac; - Shared frames{_coro.shared_frame(), shared_queue_size, _config.shared_memory_soa}; + Shared frames{coro.shared_frame(), shared_queue_size, _config.shared_memory_soa}; Shared path_id{shared_queue_size}; - Shared work_counter{_coro.subroutine_count()}; + Shared work_counter{coro.subroutine_count()}; Shared work_offset{2u}; Shared all_token{_config.global_ext_memory ? shared_queue_size + global_queue_size : @@ -57,13 +56,13 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { for (auto index : dsl::dynamic_range(q_fac)) { auto s = index * _config.block_size + thread_x(); all_token[s] = 0u; - frames.write(s, _coro.instantiate()); + frames.write(s, coro.instantiate()); } for (auto index : dsl::dynamic_range(g_fac)) { auto s = index * _config.block_size + thread_x(); all_token[shared_queue_size + s] = 0u; } - if_(thread_x() < _coro.subroutine_count(), [&] { + if_(thread_x() < coro.subroutine_count(), [&] { if_(thread_x() == 0u, [&] { work_counter[thread_x()] = _config.global_ext_memory ? @@ -100,7 +99,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { }); }); sync_block(); - if_(thread_x() < _coro.subroutine_count(), [&] {//get max + if_(thread_x() < coro.subroutine_count(), [&] {//get max if_(workload[0] < workload[1] | thread_x() != 0u, [&] { if_(work_counter[thread_x()] != 0, [&] { rem_local[0] = 1u; @@ -109,7 +108,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { }); }); sync_block(); - if_(thread_x() < _coro.subroutine_count(), [&] {//get argmax + if_(thread_x() < coro.subroutine_count(), [&] {//get argmax if_(work_stat[0] == work_counter[thread_x()] & (workload[0] < workload[1] | thread_x() != 0u), [&] { work_stat[1] = thread_x(); }); @@ -198,8 +197,8 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto index_xy = global_index % image_size; auto index_x = index_xy % dispatch_shape.x; auto index_y = index_xy / dispatch_shape.x; - auto frame = _coro.instantiate(make_uint3(index_x, index_y, index_z)); - _coro.entry()(frame, args...); + auto frame = coro.instantiate(make_uint3(index_x, index_y, index_z)); + coro.entry()(frame, args...); auto next = frame.target_token; frames.write(pid, frame); all_token[pid] = next; @@ -207,11 +206,11 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { workload.atomic(0).fetch_add(1u); }); }); - for (auto i = 1u; i < _coro.subroutine_count(); i++) { + for (auto i = 1u; i < coro.subroutine_count(); i++) { std::move(switch_stmt).case_(i, [&] { work_counter.atomic(i).fetch_sub(1u); auto frame = frames.read(pid); - _coro[i](frame, args...); + coro[i](frame, args...); auto next = frame.target_token; frames.write(pid, frame); all_token[pid] = next; @@ -220,27 +219,29 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { } }); sync_block(); + }); #ifndef NDEBUG - if_(count >= count_limit, [&] { - device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); - if_(thread_x() < _coro.subroutine_count(), [&] { - device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); - }); + if_(count >= count_limit, [&] { + device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); + if_(thread_x() < coro.subroutine_count(), [&] { + device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); }); -#endif - }); - _pt_shader = device.compile(main_kernel); - _clear_shader = device.compile<1>([](BufferUInt global) { - global->write(dispatch_x(), 0u); }); +#endif + }; + _pt_shader = device.compile(main_kernel); + _clear_shader = device.compile<1>([](BufferUInt global) { + global->write(dispatch_x(), 0u); + }); + if (_config.global_ext_memory) { _initialize_shader = device.compile<1>([&](UInt n) noexcept { - auto x = dispatch_x(); - $if (x < n) { - auto frame = _coro.instantiate(); - _global_frames->write(x, frame); - }; + auto x = dispatch_x(); + $if (x < n) { + auto frame = coro.instantiate(); + _global_frames->write(x, frame); + }; }); - }; + } } void _dispatch(Stream &stream, uint3 dispatch_size, @@ -257,9 +258,9 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { public: PersistentThreadsCoroScheduler(Device &device, const Coro &coro, const Config &config) noexcept - : _coro{coro}, _config{config} { + : _config{config} { _config.thread_count = luisa::align(_config.thread_count, _config.block_size); - _prepare(device); + _prepare(device, coro); } PersistentThreadsCoroScheduler(Device &device, const Coro &coro) noexcept : PersistentThreadsCoroScheduler{device, coro, Config{}} {} diff --git a/src/backends/common/hlsl/hlsl_codegen_util.cpp b/src/backends/common/hlsl/hlsl_codegen_util.cpp index 624c06f1d..2a5368444 100644 --- a/src/backends/common/hlsl/hlsl_codegen_util.cpp +++ b/src/backends/common/hlsl/hlsl_codegen_util.cpp @@ -2106,7 +2106,7 @@ uint4 dsp_c; LUISA_ERROR("Arguments binding size: {} exceeds 64 32-bit units not supported by hardware device. Try to use bindless instead.", bind_count); } else if (bind_count > 16) [[unlikely]] { if (!rootsig_exceed_warned.exchange(true)) { - LUISA_WARNING("Arguments binding size exceeds 16 32-bit unit (max 64 allowed). This may cause extra performance cost, try to use bindless instead."); + LUISA_WARNING("Arguments binding size: {} exceeds 16 32-bit unit (max 64 allowed). This may cause extra performance cost, try to use bindless instead.", bind_count); } } return { From 4433630f37141a1c2af09964b7df02a5b53e8657 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 18:16:05 +0800 Subject: [PATCH 45/67] minor --- include/luisa/coro/v2/schedulers/persistent_threads.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index b93535684..e8c0c3326 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -235,11 +235,11 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { }); if (_config.global_ext_memory) { _initialize_shader = device.compile<1>([&](UInt n) noexcept { - auto x = dispatch_x(); - $if (x < n) { - auto frame = coro.instantiate(); - _global_frames->write(x, frame); - }; + auto x = dispatch_x(); + $if (x < n) { + auto frame = coro.instantiate(); + _global_frames->write(x, frame); + }; }); } } From 42c77ba5be8aca6680f56d101a43f4c8b1c0f831 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 20:06:12 +0800 Subject: [PATCH 46/67] persistent-threads scheduler good --- include/luisa/coro/coro_dispatcher.h | 1 - .../coro/v2/schedulers/persistent_threads.h | 29 ++- src/tests/CMakeLists.txt | 1 + src/tests/coro/sdf_renderer_v2.cpp | 212 ++++++++++++++++++ 4 files changed, 226 insertions(+), 17 deletions(-) create mode 100644 src/tests/coro/sdf_renderer_v2.cpp diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h index 946420e29..19abb23b2 100644 --- a/include/luisa/coro/coro_dispatcher.h +++ b/include/luisa/coro/coro_dispatcher.h @@ -689,7 +689,6 @@ class PersistentCoroDispatcher : public CoroDispatcherBase(gen_st + thread_x(), 0, 0)); (*coroutine)(frames[pid], args...);//only work when kernel 0s are continue auto nxt = read_promise(frames[pid], "coro_token") & token_mask; diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index e8c0c3326..968535812 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -25,7 +25,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { private: Config _config; - Shader1D, uint, uint2, Args...> _pt_shader; + Shader1D, uint3, Args...> _pt_shader; Shader1D> _clear_shader; Buffer _global; Buffer _global_frames; @@ -35,13 +35,13 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { void _prepare(Device &device, const Coro &coro) noexcept { _global = device.create_buffer(1); auto q_fac = 1u; - auto g_fac = std::max(coro.subroutine_count() - q_fac, 0); + auto g_fac = coro.subroutine_count() - q_fac; auto global_queue_size = _config.block_size * g_fac; if (_config.global_ext_memory) { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D, uint, uint2, Args...> main_kernel = [&](BufferUInt global, UInt dispatch_size, UInt2 dispatch_shape, Var... args) noexcept { + Kernel1D main_kernel = [&](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); auto shared_queue_size = _config.block_size * q_fac; Shared frames{coro.shared_frame(), shared_queue_size, _config.shared_memory_soa}; @@ -56,7 +56,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { for (auto index : dsl::dynamic_range(q_fac)) { auto s = index * _config.block_size + thread_x(); all_token[s] = 0u; - frames.write(s, coro.instantiate()); + // frames.write(s, coro.instantiate(), std::array{0u, 1u}); } for (auto index : dsl::dynamic_range(g_fac)) { auto s = index * _config.block_size + thread_x(); @@ -81,8 +81,9 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { sync_block(); auto count = def(0u); auto count_limit = def(-1); + auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; loop([&] { - if_(!(rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit), [&] { break_(); }); + if_(!((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)), [&] { break_(); }); sync_block();//very important, synchronize for condition rem_local[0] = 0u; count += 1; @@ -190,7 +191,6 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { std::move(switch_stmt).case_(0u, [&] { if_(gen_st + thread_x() < workload[1], [&] { work_counter.atomic(0u).fetch_sub(1u); - auto work_id = gen_st + thread_x(); auto global_index = gen_st + thread_x(); auto image_size = dispatch_shape.x * dispatch_shape.y; auto index_z = global_index / image_size; @@ -199,8 +199,8 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto index_y = index_xy / dispatch_shape.x; auto frame = coro.instantiate(make_uint3(index_x, index_y, index_z)); coro.entry()(frame, args...); - auto next = frame.target_token; - frames.write(pid, frame); + auto next = frame.target_token & token_mask; + frames.write(pid, frame, coro.graph()->entry().output_fields()); all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); workload.atomic(0).fetch_add(1u); @@ -209,10 +209,10 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { for (auto i = 1u; i < coro.subroutine_count(); i++) { std::move(switch_stmt).case_(i, [&] { work_counter.atomic(i).fetch_sub(1u); - auto frame = frames.read(pid); + auto frame = frames.read(pid, coro.graph()->node(i).input_fields()); coro[i](frame, args...); - auto next = frame.target_token; - frames.write(pid, frame); + auto next = frame.target_token & token_mask; + frames.write(pid, frame, coro.graph()->node(i).output_fields()); all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); }); @@ -237,8 +237,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { _initialize_shader = device.compile<1>([&](UInt n) noexcept { auto x = dispatch_x(); $if (x < n) { - auto frame = coro.instantiate(); - _global_frames->write(x, frame); + _global_frames->write(x, coro.instantiate()); }; }); } @@ -251,9 +250,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto n = static_cast(_global_frames.size()); stream << _initialize_shader(n).dispatch(n); } - auto size = make_ulong3(dispatch_size); - auto n = size.x * size.y * size.z; - stream << _pt_shader(_global, static_cast(n), dispatch_size.xy(), args...).dispatch(static_cast(n)); + stream << _pt_shader(_global, dispatch_size, args...).dispatch(_config.thread_count); } public: diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index b78f9dd07..4df8b0638 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -198,6 +198,7 @@ endif () # coroutine tests luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) +luisa_compute_add_executable(test_coro_sdf_renderer_v2 coro/sdf_renderer_v2.cpp) luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) diff --git a/src/tests/coro/sdf_renderer_v2.cpp b/src/tests/coro/sdf_renderer_v2.cpp new file mode 100644 index 000000000..0332ebe6a --- /dev/null +++ b/src/tests/coro/sdf_renderer_v2.cpp @@ -0,0 +1,212 @@ +#include +#include +#include + +using namespace luisa; +using namespace luisa::compute; + +int main(int argc, char *argv[]) { + + Context context{argv[0]}; + if (argc <= 1) { exit(1); } + + static constexpr uint width = 1280u; + static constexpr uint height = 720u; + + static constexpr int max_ray_depth = 6; + static constexpr float eps = 1e-4f; + static constexpr float inf = 1e10f; + static constexpr float fov = 0.23f; + static constexpr float dist_limit = 100.0f; + static constexpr float3 camera_pos = make_float3(0.0f, 0.32f, 3.7f); + static constexpr float3 light_pos = make_float3(-1.5f, 0.6f, 0.3f); + static constexpr float3 light_normal = make_float3(1.0f, 0.0f, 0.0f); + static constexpr float light_radius = 2.0f; + + Clock clock; + + Callable intersect_light = [](Float3 pos, Float3 d) noexcept { + Float cos_w = dot(-d, light_normal); + Float dist = dot(d, light_pos - pos); + Float D = dist / cos_w; + Float dist_to_center = distance_squared(light_pos, pos + D * d); + Bool valid = cos_w > 0.0f & dist > 0.0f & dist_to_center < light_radius * light_radius; + return ite(valid, D, inf); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + Var s0 = 0u; + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Callable rand = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state) / cast(std::numeric_limits::max()); + }; + + Callable out_dir = [&rand](Float3 n, UInt &seed) noexcept { + Float3 u = ite( + abs(n.y) < 1.0f - eps, + normalize(cross(n, make_float3(0.0f, 1.0f, 0.0f))), + make_float3(1.f, 0.f, 0.f)); + Float3 v = cross(n, u); + Float phi = 2.0f * pi * rand(seed); + Float ay = sqrt(rand(seed)); + Float ax = sqrt(1.0f - ay * ay); + return ax * (cos(phi) * u + sin(phi) * v) + ay * n; + }; + + Callable make_nested = [](Float f) noexcept { + static constexpr float freq = 40.0f; + f *= freq; + f = ite(f < 0.f, ite(cast(f) % 2 == 0, 1.f - fract(f), fract(f)), f); + return (f - 0.2f) * (1.0f / freq); + }; + + Callable sdf = [&make_nested](Float3 o) noexcept { + Float wall = min(o.y + 0.1f, o.z + 0.4f); + Float sphere = distance(o, make_float3(0.0f, 0.35f, 0.0f)) - 0.36f; + Float3 q = abs(o - make_float3(0.8f, 0.3f, 0.0f)) - 0.3f; + Float box = length(max(q, 0.0f)) + min(max(max(q.x, q.y), q.z), 0.0f); + Float3 O = o - make_float3(-0.8f, 0.3f, 0.0f); + Float2 d = make_float2(length(make_float2(O.x, O.z)) - 0.3f, abs(O.y) - 0.3f); + Float cylinder = min(max(d.x, d.y), 0.0f) + length(max(d, 0.0f)); + Float geometry = make_nested(min(min(sphere, box), cylinder)); + Float g = max(geometry, -(0.32f - (o.y * 0.6f + o.z * 0.8f))); + return min(wall, g); + }; + + Callable ray_march = [&sdf](Float3 p, Float3 d) noexcept { + Float dist = def(0.0f); + $for (j, 100) { + Float s = sdf(p + dist * d); + $if (s <= 1e-6f | dist >= inf) { $break; }; + dist += s; + }; + return min(dist, inf); + }; + + Callable sdf_normal = [&sdf](Float3 p) noexcept { + static constexpr float d = 1e-3f; + Float3 n = def(make_float3()); + Float sdf_center = sdf(p); + for (uint i = 0; i < 3; i++) { + Float3 inc = p; + inc[i] += d; + n[i] = (1.0f / d) * (sdf(inc) - sdf_center); + } + return normalize(n); + }; + + Callable next_hit = [&ray_march, &sdf_normal](Float &closest, Float3 &normal, Float3 &c, Float3 pos, Float3 d) noexcept { + closest = inf; + normal = make_float3(); + c = make_float3(); + Float ray_march_dist = ray_march(pos, d); + $if (ray_march_dist < min(dist_limit, closest)) { + closest = ray_march_dist; + Float3 hit_pos = pos + d * closest; + normal = sdf_normal(hit_pos); + Int t = cast((hit_pos.x + 10.0f) * 1.1f + 0.5f) % 3; + c = make_float3(0.4f) + make_float3(0.3f, 0.2f, 0.3f) * ite(t == make_int3(0, 1, 2), 1.0f, 0.0f); + }; + }; + + coroutine::Coroutine coro = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { + Float2 resolution = make_float2(width, height); + UInt2 coord = dispatch_id().xy(); + + $if (frame_index == 0u) { + seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); + accum_image.write(coord, make_float4(make_float3(0.0f), 1.0f)); + }; + + Float aspect_ratio = resolution.x / resolution.y; + Float3 pos = def(camera_pos); + UInt seed = seed_image.read(coord).x; + Float ux = rand(seed); + Float uy = rand(seed); + Float2 uv = make_float2(coord.x + ux, height - 1u - coord.y + uy); + Float3 d = make_float3( + 2.0f * fov * uv / resolution.y - fov * make_float2(aspect_ratio, 1.0f) - 1e-5f, -1.0f); + d = normalize(d); + Float3 throughput = def(make_float3(1.0f, 1.0f, 1.0f)); + Float hit_light = def(0.0f); + $for (depth, max_ray_depth) { + + $suspend("2"); + Float closest = def(0.0f); + Float3 normal = def(make_float3()); + Float3 c = def(make_float3()); + next_hit(closest, normal, c, pos, d); + Float dist_to_light = intersect_light(pos, d); + $if (dist_to_light < closest) { + // $if (depth == 0) { + // device_log("xxxxxcoord {} {}", coord.x, coord.y); + // }; + hit_light = 1.0f; + $break; + }; + $if (length_squared(normal) == 0.0f) { + // $if (depth == 0) { + // device_log("coord {} {}", coord.x, coord.y); + // }; + $break; + }; + Float3 hit_pos = pos + closest * d; + d = out_dir(normal, seed); + pos = hit_pos + 1e-4f * d; + throughput *= c; + }; + $suspend("3"); + Float3 accum_color = lerp(accum_image.read(coord).xyz(), throughput.xyz() * hit_light, 1.0f / (frame_index + 1.0f)); + accum_image.write(coord, make_float4(accum_color, 1.0f)); + //$suspend("4"); + seed_image.write(coord, make_uint4(seed)); + }; + + auto device = context.create_device(argv[1]); + auto stream = device.create_stream(); + constexpr auto resolution = make_uint2(width, height); + + Image seed_image = device.create_image(PixelStorage::INT1, width, height); + Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); + + coroutine::PersistentThreadsCoroSchedulerConfig config{ + .thread_count = 64_k, + .block_size = 64u, + .fetch_size = 3u, + .shared_memory_soa = false, + .global_ext_memory = false}; + coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; + + auto clear_shader = device.compile<2>([&] { + auto coord = dispatch_id().xy(); + accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); + seed_image->write(coord, make_uint4(coord.y * resolution.y + coord.x)); + }); + + luisa::vector host_image(accum_image.view().size_bytes()); + stream << clear_shader().dispatch(resolution) << synchronize(); + Clock clk; + auto samples = 1024; + for (auto i = 0u; i < samples; ++i) { + LUISA_INFO("spp {}", i); + stream << scheduler(seed_image, accum_image, i).dispatch(width, height) + << synchronize(); + } + stream << synchronize(); + auto dt = clk.toc(); + LUISA_INFO("Time: {} ms ({} spp/s)", dt, samples * 1e3 / dt); + + stream << accum_image.copy_to(host_image.data()) + << synchronize(); + stbi_write_hdr("test_sdf.hdr", resolution.x, resolution.y, 4, reinterpret_cast(host_image.data())); +} From 37b899f3ecd88643a5f7935ecf106a63d13210dc Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 20:13:01 +0800 Subject: [PATCH 47/67] wavefront checkpoint --- include/luisa/coro/v2/schedulers/wavefront.h | 79 ++++++++++---------- src/tests/coro/path_tracing_v2.cpp | 11 +-- src/tests/coro/path_tracing_wavefront_v2.cpp | 20 +++-- 3 files changed, 53 insertions(+), 57 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 09ba738e3..71baa5e49 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -18,7 +18,6 @@ struct WavefrontCoroSchedulerConfig { bool soa = true; bool sort = true;// use sort for coro token gathering bool compact = true; - bool debug = false; uint hint_range = 0xffff'ffff; luisa::vector hint_fields; }; @@ -49,6 +48,7 @@ class WavefrontCoroScheduler : public CoroScheduler { bool _host_empty; uint _dispatch_counter; uint _max_sub_coro; + uint _dispatch_size; radix_sort::temp_storage _sort_temp_storage; radix_sort::instance<> _sort_token; radix_sort::instance> _sort_hint; @@ -59,7 +59,7 @@ class WavefrontCoroScheduler : public CoroScheduler { private: void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - _config.block_size.x = dispatch_size.x; + _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z; // TODO _dispatch_counter = 0; _host_empty = true; for (auto i = 0u; i < _max_sub_coro; i++) { @@ -73,7 +73,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } stream << _initialize_shader(_resume_count, _config.max_instance_count).dispatch(_config.max_instance_count); _args.emplace(std::forward>(args)...); - _await_all(stream); + this->_await_all(stream); } template @@ -90,12 +90,12 @@ class WavefrontCoroScheduler : public CoroScheduler { const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; const luisa::shared_ptr desc = coroutine.shared_frame(); - if (config.soa) { - _frame_soa = device.create_soa(coroutine.shared_frame(), config.max_instance_count); + if (_config.soa) { + _frame_soa = device.create_soa(coroutine.shared_frame(), _config.max_instance_count); } else { - _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), config.max_instance_count); + _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), _config.max_instance_count); } - bool use_sort = config.sort || !config.hint_fields.empty(); + bool use_sort = _config.sort || !_config.hint_fields.empty(); _max_sub_coro = coroutine.subroutine_count(); _resume_index = device.create_buffer(_config.max_instance_count); if (use_sort) { @@ -111,7 +111,7 @@ class WavefrontCoroScheduler : public CoroScheduler { _host_offset.resize(_max_sub_coro); _host_count.resize(_max_sub_coro); _have_hint.resize(_max_sub_coro, false); - for (auto &token : config.hint_fields) { + for (auto &token : _config.hint_fields) { auto id = coroutine.frame()->designated_fields().find(token); if (id != coroutine.frame()->designated_fields().end()) { LUISA_ASSERT(id->second < _max_sub_coro, @@ -134,7 +134,7 @@ class WavefrontCoroScheduler : public CoroScheduler { $if (index > _config.max_instance_count) { device_log("Index out of range {}/{}", index, _config.max_instance_count); }; - if (config.soa) { + if (_config.soa) { return _frame_soa->read_field(index, "target_token") & token_mask; } else { CoroFrame frame = _frame_buffer->read(index); @@ -149,9 +149,9 @@ class WavefrontCoroScheduler : public CoroScheduler { return val.read(index); }; Callable get_coro_hint = [&](UInt index, BufferUInt val) { - if (!config.hint_fields.empty()) { + if (!_config.hint_fields.empty()) { auto id = keep_index(index, val); - if (config.soa) { + if (_config.soa) { return _frame_soa->read_field(id, "coro_hint"); } else { CoroFrame frame = _frame_buffer->read(id); @@ -162,21 +162,21 @@ class WavefrontCoroScheduler : public CoroScheduler { }; if (use_sort) { _sort_temp_storage = radix_sort::temp_storage( - device, _config.max_instance_count, std::max(std::min(config.hint_range, 128u), _max_sub_coro)); + device, _config.max_instance_count, std::max(std::min(_config.hint_range, 128u), _max_sub_coro)); } - if (config.sort) { + if (_config.sort) { _sort_token = radix_sort::instance<>( device, _config.max_instance_count, _sort_temp_storage, &get_coro_token, &identical, &get_coro_token, 1, _max_sub_coro); } - if (!config.hint_fields.empty()) { - if (config.hint_range <= 128) { + if (!_config.hint_fields.empty()) { + if (_config.hint_range <= 128) { _sort_hint = radix_sort::instance>( device, _config.max_instance_count, _sort_temp_storage, &get_coro_hint, &keep_index, - &get_coro_hint, 1, config.hint_range); + &get_coro_hint, 1, _config.hint_range); } else { auto highbit = 0; - while ((config.hint_range >> highbit) != 1) { + while ((_config.hint_range >> highbit) != 1) { highbit++; } _sort_hint = radix_sort::instance>( @@ -190,23 +190,23 @@ class WavefrontCoroScheduler : public CoroScheduler { $return(); }; UInt frame_id; - if (!config.compact) { + if (!_config.compact) { frame_id = index->read(x); } else { frame_id = offset + x; } - if (!config.sort) { + if (!_config.sort) { count.atomic(0u).fetch_add(-1u); } CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); coroutine[0u](frame, args...); - if (config.soa) { + if (_config.soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); } else { _frame_buffer->write(frame_id, frame); } - if (!config.sort) { + if (!_config.sort) { auto nxt = frame.get("coro_hint") & token_mask; count.atomic(nxt).fetch_add(1u); } @@ -224,23 +224,23 @@ class WavefrontCoroScheduler : public CoroScheduler { }; auto frame_id = index.read(x); CoroFrame frame = CoroFrame::create(desc); - if (config.soa) { + if (_config.soa) { //frame = frame_buffer.read(frame_id); frame = _frame_soa->read(frame_id, coroutine.graph()->node(i).input_fields()); } else { frame = _frame_buffer->read(frame_id); } - if (!config.sort) { + if (!_config.sort) { count.atomic(i).fetch_add(-1u); } coroutine[i](frame, args...); - if (config.soa) { + if (_config.soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(i).output_fields()); } else { _frame_buffer->write(frame_id, frame); } - if (!config.sort) { + if (!_config.sort) { auto nxt = frame.get("target_token") & token_mask; $if (nxt < _max_sub_coro) { count.atomic(nxt).fetch_add(1u); @@ -266,7 +266,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Kernel1D _gather_kernel = [&](BufferUInt index, BufferUInt prefix, UInt n) { auto x = dispatch_x(); auto r_id = def(0u); - if (config.soa) { + if (_config.soa) { r_id = _frame_soa->read_field(x, "target_token") & token_mask; } else { auto frame = _frame_buffer->read(x); @@ -282,7 +282,7 @@ class WavefrontCoroScheduler : public CoroScheduler { auto x = dispatch_x(); $if (empty_offset + x < n) { auto token = def(0u); - if (config.soa) { + if (_config.soa) { token = _frame_soa->read_field(empty_offset + x, "target_token"); } else { CoroFrame frame = _frame_buffer->read(empty_offset + x); @@ -291,13 +291,13 @@ class WavefrontCoroScheduler : public CoroScheduler { $if ((token & token_mask) != 0u) { auto res = _global_buffer->atomic(0u).fetch_add(1u); auto slot = index.read(res); - if (!config.sort) { + if (!_config.sort) { $while (slot >= empty_offset) { res = _global_buffer->atomic(0u).fetch_add(1u); slot = index.read(res); }; } - if (config.soa) { + if (_config.soa) { // TODO: active fields here? auto frame = _frame_soa->read(empty_offset + x); _frame_soa->write(slot, frame); @@ -305,7 +305,7 @@ class WavefrontCoroScheduler : public CoroScheduler { auto frame = _frame_buffer->read(empty_offset + x); _frame_buffer->write(slot, frame); } - if (config.soa) { + if (_config.soa) { _frame_soa->write_field(empty_offset + x, 0u, "target_token"); } else { CoroFrame empty_frame = CoroFrame::create(desc); @@ -320,7 +320,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Kernel1D _initialize_kernel = [&](BufferUInt count, UInt n) { auto x = dispatch_x(); $if (x < n) { - if (config.soa) { + if (_config.soa) { CoroFrame frame = coroutine.instantiate(dispatch_id()); _frame_soa->write(x, frame, std::array{0u, 1u}); } else { @@ -332,6 +332,8 @@ class WavefrontCoroScheduler : public CoroScheduler { count.write(x, ite(x == 0u, _config.max_instance_count, 0u)); }; }; + _initialize_shader = device.compile(_initialize_kernel); + Kernel1D clear = [&](BufferUInt buffer, UInt n) { auto x = dispatch_x(); $if (x < n) { @@ -339,19 +341,18 @@ class WavefrontCoroScheduler : public CoroScheduler { }; }; _clear_shader = device.compile(clear); - _initialize_shader = device.compile(_initialize_kernel); } [[nodiscard]] bool _all_dispatched() const noexcept { - return _dispatch_counter == _config.max_instance_count; + return _dispatch_counter == _dispatch_size; } [[nodiscard]] bool _all_done() const noexcept { - return _host_empty && _all_dispatched(); + return this->_all_dispatched() && _host_empty; } void _await_all(Stream &stream) noexcept { - while (!_all_done()) { - _await_step(stream); + while (!this->_all_done()) { + this->_await_step(stream); } } void _await_step(Stream &stream) noexcept { @@ -370,8 +371,8 @@ class WavefrontCoroScheduler : public CoroScheduler { << host_update << synchronize(); - if (_host_count[0] > _config.max_instance_count * (0.5) && !_all_dispatched()) { - auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); + if (_host_count[0] > _config.max_instance_count * 0.5f && !this->_all_dispatched()) { + auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); if (_host_count[0] != _config.max_instance_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index, _config.max_instance_count - _host_count[0], _config.max_instance_count).dispatch(_host_count[0]); @@ -404,7 +405,7 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); stream << _gather_shader(_resume_index, _resume_offset, _config.max_instance_count).dispatch(_config.max_instance_count); if (_host_count[0] > _config.max_instance_count / 2 && !_all_dispatched()) { - auto gen_count = std::min(_config.block_size.x - _dispatch_counter, _host_count[0]); + auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); if (_host_count[0] != _config.max_instance_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_v2.cpp index 8d09cb1fe..a5821c266 100644 --- a/src/tests/coro/path_tracing_v2.cpp +++ b/src/tests/coro/path_tracing_v2.cpp @@ -262,19 +262,10 @@ int main(int argc, char *argv[]) { image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); }; - coroutine::Coroutine raytrace_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution, UInt2 pixel_id) noexcept { - auto coro_id = make_uint3(pixel_id, 0u); - coro(image, seed_image, accel, resolution).set_id(coro_id).await(); - }; - - coroutine::Coroutine raytracing_coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - raytrace_coro(image, seed_image, accel, resolution, dispatch_id().xy()).await(); - }; - coroutine::StateMachineCoroSchedulerConfig config{.block_size = make_uint3(8u, 8u, 1u), .shared_memory = false, .shared_memory_soa = true}; - coroutine::StateMachineCoroScheduler scheduler{device, raytracing_coro, config}; + coroutine::StateMachineCoroScheduler scheduler{device, coro, config}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index bed5b72f7..b73e0a716 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -196,10 +196,13 @@ int main(int argc, char *argv[]) { return pdf_a / max(pdf_a + pdf_b, 1e-4f); }; - auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 16u; coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - UInt2 coord = dispatch_id().xy(); + UInt pixel_count = resolution.x * resolution.y; + UInt pixel_id = coro_id().x % pixel_count; + UInt2 coord = make_uint2(pixel_id % resolution.x, pixel_id / resolution.x); + Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; Float rx = lcg(state); @@ -207,6 +210,9 @@ int main(int argc, char *argv[]) { Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; Float3 radiance = def(make_float3(0.0f)); $suspend("per_spp"); + // $if (all(coord == make_uint2(50u, 500u))) { + // device_log("coord: {}", coord); + // }; $for (i, spp_per_dispatch) { Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); Float3 beta = def(make_float3(1.0f)); @@ -296,17 +302,15 @@ int main(int argc, char *argv[]) { radiance /= static_cast(spp_per_dispatch); seed_image.write(coord, make_uint4(state)); $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; - image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); + auto color = make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f); + image.write(coord, color); }; - auto coro_buffer = device.create_coro_frame_buffer(coro.frame(), 1024u); - - // coroutine::StateMachineCoroScheduler scheduler{device, coro}; coroutine::WavefrontCoroSchedulerConfig config{ - .block_size = make_uint3(1024u, 1u, 1u), .max_instance_count = 16777216, + .soa = false, }; - coroutine::WavefrontCoroScheduler scheduler{device, coro}; + coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); From 76bd43a1949bb9359580d862c97235cfa79f8f6e Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 20:25:10 +0800 Subject: [PATCH 48/67] minor --- src/tests/coro/path_tracing_wavefront_v2.cpp | 162 +++++++++---------- 1 file changed, 80 insertions(+), 82 deletions(-) diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index b73e0a716..0fec30994 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -213,93 +213,91 @@ int main(int argc, char *argv[]) { // $if (all(coord == make_uint2(50u, 500u))) { // device_log("coord: {}", coord); // }; - $for (i, spp_per_dispatch) { - Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); - Float3 beta = def(make_float3(1.0f)); - Float pdf_bsdf = def(0.0f); - constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); - constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; - constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; - constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); - Float light_area = length(cross(light_u, light_v)); - Float3 light_normal = normalize(cross(light_u, light_v)); - $suspend("per_depth"); - $for (depth, 10u) { - // trace - $suspend("before_tracing"); - Var hit = accel.intersect(ray, {}); - reorder_shader_execution(); - $if (hit->miss()) { $break; }; - Var triangle = heap->buffer(hit.inst).read(hit.prim); - Float3 p0 = vertex_buffer->read(triangle.i0); - Float3 p1 = vertex_buffer->read(triangle.i1); - Float3 p2 = vertex_buffer->read(triangle.i2); - Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); - Float3 n = normalize(cross(p1 - p0, p2 - p0)); - $suspend("after_tracing"); - - Float cos_wo = dot(-ray->direction(), n); - $if (cos_wo < 1e-4f) { $break; }; - - // hit light - $if (hit.inst == static_cast(meshes.size() - 1u)) { - $if (depth == 0u) { - radiance += light_emission; - } - $else { - Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); - Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); - radiance += mis_weight * beta * light_emission; - }; - $break; - }; - // sample light - $suspend("sample_light"); - Float ux_light = lcg(state); - Float uy_light = lcg(state); - Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; - Float3 pp = offset_ray_origin(p, n); - Float3 pp_light = offset_ray_origin(p_light, light_normal); - Float d_light = distance(pp, pp_light); - Float3 wi_light = normalize(pp_light - pp); - Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); - Bool occluded = accel.intersect_any(shadow_ray, {}); - Float cos_wi_light = dot(wi_light, n); - Float cos_light = -dot(light_normal, wi_light); - Float3 albedo = materials.read(hit.inst); - $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { - Float pdf_light = (d_light * d_light) / (light_area * cos_light); - Float pdf_bsdf = cos_wi_light * inv_pi; - Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); - Float3 bsdf = albedo * inv_pi * cos_wi_light; - radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + $suspend("per_depth"); + $for (depth, 10u) { + // trace + $suspend("before_tracing"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + $suspend("after_tracing"); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; }; + $break; + }; - // sample BSDF - $suspend("sample_bsdf"); - Var onb = make_onb(n); - Float ux = lcg(state); - Float uy = lcg(state); - Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); - Float cos_wi = abs(wi_local.z); - Float3 new_direction = onb->to_world(wi_local); - ray = make_ray(pp, new_direction); - pdf_bsdf = cos_wi * inv_pi; - beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f - - // rr - $suspend("rr"); - Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); - $if (l == 0.0f) { $break; }; - Float q = max(l, 0.05f); - Float r = lcg(state); - $if (r >= q) { $break; }; - beta *= 1.0f / q; + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); }; + + // sample BSDF + $suspend("sample_bsdf"); + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; }; $suspend("write_film"); - radiance /= static_cast(spp_per_dispatch); seed_image.write(coord, make_uint4(state)); $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; auto color = make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f); @@ -375,7 +373,7 @@ int main(int argc, char *argv[]) { while (!window.should_close()) { stream << scheduler(framebuffer, seed_image, accel, resolution) - .dispatch(resolution) + .dispatch(resolution.x * resolution.y * spp_per_dispatch) << accumulate_shader(accum_image, framebuffer) .dispatch(resolution) << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) From 2d8f72ab4319eac497b54b1fb4c6ad46a31ab16b Mon Sep 17 00:00:00 2001 From: chenxin Date: Tue, 14 May 2024 20:40:40 +0800 Subject: [PATCH 49/67] minor --- include/luisa/coro/v2/schedulers/wavefront.h | 17 ++++++++--------- src/tests/coro/path_tracing_wavefront_v2.cpp | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 71baa5e49..ac46aeb41 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -59,7 +59,7 @@ class WavefrontCoroScheduler : public CoroScheduler { private: void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z; // TODO + _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z;// TODO _dispatch_counter = 0; _host_empty = true; for (auto i = 0u; i < _max_sub_coro; i++) { @@ -89,7 +89,6 @@ class WavefrontCoroScheduler : public CoroScheduler { void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; - const luisa::shared_ptr desc = coroutine.shared_frame(); if (_config.soa) { _frame_soa = device.create_soa(coroutine.shared_frame(), _config.max_instance_count); } else { @@ -199,7 +198,7 @@ class WavefrontCoroScheduler : public CoroScheduler { count.atomic(0u).fetch_add(-1u); } - CoroFrame frame = CoroFrame::create(desc, def(st_task_id + x, 0, 0)); + CoroFrame frame = coroutine.instantiate(def(st_task_id + x, 0, 0)); coroutine[0u](frame, args...); if (_config.soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); @@ -223,7 +222,7 @@ class WavefrontCoroScheduler : public CoroScheduler { $return(); }; auto frame_id = index.read(x); - CoroFrame frame = CoroFrame::create(desc); + CoroFrame frame = coroutine.instantiate(); if (_config.soa) { //frame = frame_buffer.read(frame_id); frame = _frame_soa->read(frame_id, coroutine.graph()->node(i).input_fields()); @@ -265,7 +264,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Kernel1D _gather_kernel = [&](BufferUInt index, BufferUInt prefix, UInt n) { auto x = dispatch_x(); - auto r_id = def(0u); + UInt r_id; if (_config.soa) { r_id = _frame_soa->read_field(x, "target_token") & token_mask; } else { @@ -281,7 +280,7 @@ class WavefrontCoroScheduler : public CoroScheduler { //_global_buffer->write(0u, 0u); auto x = dispatch_x(); $if (empty_offset + x < n) { - auto token = def(0u); + UInt token; if (_config.soa) { token = _frame_soa->read_field(empty_offset + x, "target_token"); } else { @@ -308,7 +307,7 @@ class WavefrontCoroScheduler : public CoroScheduler { if (_config.soa) { _frame_soa->write_field(empty_offset + x, 0u, "target_token"); } else { - CoroFrame empty_frame = CoroFrame::create(desc); + CoroFrame empty_frame = coroutine.instantiate(); _frame_buffer->write(empty_offset + x, empty_frame); } }; @@ -321,10 +320,10 @@ class WavefrontCoroScheduler : public CoroScheduler { auto x = dispatch_x(); $if (x < n) { if (_config.soa) { - CoroFrame frame = coroutine.instantiate(dispatch_id()); + CoroFrame frame = coroutine.instantiate(); _frame_soa->write(x, frame, std::array{0u, 1u}); } else { - CoroFrame frame = coroutine.instantiate(dispatch_id()); + CoroFrame frame = coroutine.instantiate(); _frame_buffer->write(x, frame); } }; diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 0fec30994..6e710219a 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -306,7 +306,7 @@ int main(int argc, char *argv[]) { coroutine::WavefrontCoroSchedulerConfig config{ .max_instance_count = 16777216, - .soa = false, + .soa = true, }; coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; From 09cbadee9518e8b4fdf7237bfc4f10f652391c1e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 21:32:00 +0800 Subject: [PATCH 50/67] debugging wavefront scheduler --- .../coro/v2/schedulers/persistent_threads.h | 18 +-- include/luisa/coro/v2/schedulers/wavefront.h | 103 +++++++++--------- src/tests/coro/path_tracing_wavefront_v2.cpp | 2 +- src/tests/coro/sdf_renderer_v2.cpp | 2 +- 4 files changed, 63 insertions(+), 62 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index 968535812..4fef7fa55 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -13,7 +13,7 @@ struct PersistentThreadsCoroSchedulerConfig { uint block_size = 128; uint fetch_size = 16; bool shared_memory_soa = true; - bool global_ext_memory = false; + bool global_memory_ext = false; }; template @@ -37,7 +37,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto q_fac = 1u; auto g_fac = coro.subroutine_count() - q_fac; auto global_queue_size = _config.block_size * g_fac; - if (_config.global_ext_memory) { + if (_config.global_memory_ext) { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } @@ -48,7 +48,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { Shared path_id{shared_queue_size}; Shared work_counter{coro.subroutine_count()}; Shared work_offset{2u}; - Shared all_token{_config.global_ext_memory ? + Shared all_token{_config.global_memory_ext ? shared_queue_size + global_queue_size : shared_queue_size}; Shared workload{2}; @@ -65,7 +65,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { if_(thread_x() < coro.subroutine_count(), [&] { if_(thread_x() == 0u, [&] { work_counter[thread_x()] = - _config.global_ext_memory ? + _config.global_memory_ext ? shared_queue_size + global_queue_size : shared_queue_size; }).else_([&] { @@ -118,7 +118,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { work_offset[0] = 0; work_offset[1] = 0; sync_block(); - if (!_config.global_ext_memory) { + if (!_config.global_memory_ext) { for (auto index : dsl::dynamic_range(q_fac)) {//collect indices auto frame_token = all_token[index * _config.block_size + thread_x()]; if_(frame_token == work_stat[1], [&] { @@ -175,13 +175,13 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto gen_st = workload[0]; sync_block(); auto pid = def(0u); - if (_config.global_ext_memory) { + if (_config.global_memory_ext) { pid = thread_x(); } else { pid = path_id[thread_x()]; } auto launch_condition = def(true); - if (!_config.global_ext_memory) { + if (!_config.global_memory_ext) { launch_condition = (thread_x() < work_offset[0]); } else { launch_condition = (all_token[pid] == work_stat[1]); @@ -233,7 +233,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { _clear_shader = device.compile<1>([](BufferUInt global) { global->write(dispatch_x(), 0u); }); - if (_config.global_ext_memory) { + if (_config.global_memory_ext) { _initialize_shader = device.compile<1>([&](UInt n) noexcept { auto x = dispatch_x(); $if (x < n) { @@ -246,7 +246,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { stream << _clear_shader(_global).dispatch(1u); - if (_config.global_ext_memory) { + if (_config.global_memory_ext) { auto n = static_cast(_global_frames.size()); stream << _initialize_shader(n).dispatch(n); } diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index ac46aeb41..cd660ab98 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -13,8 +13,8 @@ namespace luisa::compute::coroutine { struct WavefrontCoroSchedulerConfig { - uint3 block_size = make_uint3(8u, 8u, 1u); - uint max_instance_count = 2_M; + // uint3 block_size = make_uint3(8u, 8u, 1u); + uint thread_count = 2_M; bool soa = true; bool sort = true;// use sort for coro token gathering bool compact = true; @@ -25,8 +25,11 @@ struct WavefrontCoroSchedulerConfig { template class WavefrontCoroScheduler : public CoroScheduler { +public: + using Config = WavefrontCoroSchedulerConfig; + private: - WavefrontCoroSchedulerConfig _config; + Config _config; using ArgPack = std::tuple...>; luisa::optional _args; SOA _frame_soa; @@ -59,20 +62,20 @@ class WavefrontCoroScheduler : public CoroScheduler { private: void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + _args.emplace(std::forward>(args)...); _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z;// TODO _dispatch_counter = 0; _host_empty = true; for (auto i = 0u; i < _max_sub_coro; i++) { if (i) { _host_count[i] = 0; - _host_offset[i] = _config.max_instance_count; + _host_offset[i] = _config.thread_count; } else { - _host_count[i] = _config.max_instance_count; + _host_count[i] = _config.thread_count; _host_offset[i] = 0; } } - stream << _initialize_shader(_resume_count, _config.max_instance_count).dispatch(_config.max_instance_count); - _args.emplace(std::forward>(args)...); + stream << _initialize_shader(_resume_count, _config.thread_count).dispatch(_config.thread_count); this->_await_all(stream); } @@ -90,17 +93,17 @@ class WavefrontCoroScheduler : public CoroScheduler { const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; if (_config.soa) { - _frame_soa = device.create_soa(coroutine.shared_frame(), _config.max_instance_count); + _frame_soa = device.create_soa(coroutine.shared_frame(), _config.thread_count); } else { - _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), _config.max_instance_count); + _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), _config.thread_count); } bool use_sort = _config.sort || !_config.hint_fields.empty(); _max_sub_coro = coroutine.subroutine_count(); - _resume_index = device.create_buffer(_config.max_instance_count); + _resume_index = device.create_buffer(_config.thread_count); if (use_sort) { - _temp_index = device.create_buffer(_config.max_instance_count); - _temp_key[0] = device.create_buffer(_config.max_instance_count); - _temp_key[1] = device.create_buffer(_config.max_instance_count); + _temp_index = device.create_buffer(_config.thread_count); + _temp_key[0] = device.create_buffer(_config.thread_count); + _temp_key[1] = device.create_buffer(_config.thread_count); } _resume_count = device.create_buffer(_max_sub_coro); _resume_offset = device.create_buffer(_max_sub_coro); @@ -111,8 +114,8 @@ class WavefrontCoroScheduler : public CoroScheduler { _host_count.resize(_max_sub_coro); _have_hint.resize(_max_sub_coro, false); for (auto &token : _config.hint_fields) { - auto id = coroutine.frame()->designated_fields().find(token); - if (id != coroutine.frame()->designated_fields().end()) { + auto id = coroutine.graph()->named_tokens().find(token); + if (id != coroutine.graph()->named_tokens().end()) { LUISA_ASSERT(id->second < _max_sub_coro, "coroutine token {} of id {} out of range {}", token, id->second, _max_sub_coro); _have_hint[id->second] = true; @@ -123,15 +126,15 @@ class WavefrontCoroScheduler : public CoroScheduler { for (auto i = 0u; i < _max_sub_coro; i++) { if (i) { _host_count[i] = 0; - _host_offset[i] = _config.max_instance_count; + _host_offset[i] = _config.thread_count; } else { - _host_count[i] = _config.max_instance_count; + _host_count[i] = _config.thread_count; _host_offset[i] = 0; } } Callable get_coro_token = [&](UInt index) { - $if (index > _config.max_instance_count) { - device_log("Index out of range {}/{}", index, _config.max_instance_count); + $if (index > _config.thread_count) { + device_log("Index out of range {}/{}", index, _config.thread_count); }; if (_config.soa) { return _frame_soa->read_field(index, "target_token") & token_mask; @@ -161,17 +164,17 @@ class WavefrontCoroScheduler : public CoroScheduler { }; if (use_sort) { _sort_temp_storage = radix_sort::temp_storage( - device, _config.max_instance_count, std::max(std::min(_config.hint_range, 128u), _max_sub_coro)); + device, _config.thread_count, std::max(std::min(_config.hint_range, 128u), _max_sub_coro)); } if (_config.sort) { _sort_token = radix_sort::instance<>( - device, _config.max_instance_count, _sort_temp_storage, &get_coro_token, &identical, + device, _config.thread_count, _sort_temp_storage, &get_coro_token, &identical, &get_coro_token, 1, _max_sub_coro); } if (!_config.hint_fields.empty()) { if (_config.hint_range <= 128) { _sort_hint = radix_sort::instance>( - device, _config.max_instance_count, _sort_temp_storage, &get_coro_hint, &keep_index, + device, _config.thread_count, _sort_temp_storage, &get_coro_hint, &keep_index, &get_coro_hint, 1, _config.hint_range); } else { auto highbit = 0; @@ -179,7 +182,7 @@ class WavefrontCoroScheduler : public CoroScheduler { highbit++; } _sort_hint = radix_sort::instance>( - device, _config.max_instance_count, _sort_temp_storage, &get_coro_hint, &keep_index, + device, _config.thread_count, _sort_temp_storage, &get_coro_hint, &keep_index, &get_coro_hint, 0, 128, 0, highbit); } } @@ -198,15 +201,15 @@ class WavefrontCoroScheduler : public CoroScheduler { count.atomic(0u).fetch_add(-1u); } - CoroFrame frame = coroutine.instantiate(def(st_task_id + x, 0, 0)); - coroutine[0u](frame, args...); + CoroFrame frame = coroutine.instantiate(make_uint3(st_task_id + x, 0u, 0u)); + coroutine.entry()(frame, args...); if (_config.soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); } else { _frame_buffer->write(frame_id, frame); } if (!_config.sort) { - auto nxt = frame.get("coro_hint") & token_mask; + auto nxt = frame.target_token & token_mask; count.atomic(nxt).fetch_add(1u); } }; @@ -230,7 +233,7 @@ class WavefrontCoroScheduler : public CoroScheduler { frame = _frame_buffer->read(frame_id); } if (!_config.sort) { - count.atomic(i).fetch_add(-1u); + count.atomic(i).fetch_sub(1u); } coroutine[i](frame, args...); if (_config.soa) { @@ -240,7 +243,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } if (!_config.sort) { - auto nxt = frame.get("target_token") & token_mask; + auto nxt = frame.target_token & token_mask; $if (nxt < _max_sub_coro) { count.atomic(nxt).fetch_add(1u); }; @@ -255,7 +258,7 @@ class WavefrontCoroScheduler : public CoroScheduler { auto pre = def(0u); for (auto i = 0u; i < _max_sub_coro; ++i) { auto val = count.read(i); - prefix.write(def(i), pre); + prefix.write(i, pre); pre = pre + val; } }; @@ -269,7 +272,7 @@ class WavefrontCoroScheduler : public CoroScheduler { r_id = _frame_soa->read_field(x, "target_token") & token_mask; } else { auto frame = _frame_buffer->read(x); - r_id = frame.get("target_token") & token_mask; + r_id = frame.target_token & token_mask; } auto q_id = prefix.atomic(r_id).fetch_add(1u); index.write(q_id, x); @@ -285,7 +288,7 @@ class WavefrontCoroScheduler : public CoroScheduler { token = _frame_soa->read_field(empty_offset + x, "target_token"); } else { CoroFrame frame = _frame_buffer->read(empty_offset + x); - token = frame.get("target_token"); + token = frame.target_token; } $if ((token & token_mask) != 0u) { auto res = _global_buffer->atomic(0u).fetch_add(1u); @@ -297,7 +300,6 @@ class WavefrontCoroScheduler : public CoroScheduler { }; } if (_config.soa) { - // TODO: active fields here? auto frame = _frame_soa->read(empty_offset + x); _frame_soa->write(slot, frame); } else { @@ -319,16 +321,15 @@ class WavefrontCoroScheduler : public CoroScheduler { Kernel1D _initialize_kernel = [&](BufferUInt count, UInt n) { auto x = dispatch_x(); $if (x < n) { + CoroFrame frame = coroutine.instantiate(); if (_config.soa) { - CoroFrame frame = coroutine.instantiate(); _frame_soa->write(x, frame, std::array{0u, 1u}); } else { - CoroFrame frame = coroutine.instantiate(); _frame_buffer->write(x, frame); } }; $if (x < _max_sub_coro) { - count.write(x, ite(x == 0u, _config.max_instance_count, 0u)); + count.write(x, ite(x == 0u, _config.thread_count, 0u)); }; }; _initialize_shader = device.compile(_initialize_kernel); @@ -359,26 +360,26 @@ class WavefrontCoroScheduler : public CoroScheduler { auto host_update = [&] { _host_empty = true; for (uint i = 0u; i < _max_sub_coro; i++) { - _host_count[i] = (i + 1u == _max_sub_coro ? _config.max_instance_count : _host_offset[i + 1u]) - _host_offset[i]; + _host_count[i] = (i + 1u == _max_sub_coro ? _config.thread_count : _host_offset[i + 1u]) - _host_offset[i]; _host_empty = _host_empty && (i == 0u || _host_count[i] == 0u); } }; _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], - _resume_index, _config.max_instance_count); + _resume_index, _config.thread_count); stream << _sort_temp_storage.hist_buffer.view(0u, _max_sub_coro).copy_to(_host_offset.data()) << host_update << synchronize(); - if (_host_count[0] > _config.max_instance_count * 0.5f && !this->_all_dispatched()) { + if (_host_count[0] > _config.thread_count * 0.5f && !this->_all_dispatched()) { auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _config.max_instance_count && _config.compact) { + if (_host_count[0] != _config.thread_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream << _compact_shader(_resume_index, _config.max_instance_count - _host_count[0], _config.max_instance_count).dispatch(_host_count[0]); + stream << _compact_shader(_resume_index, _config.thread_count - _host_count[0], _config.thread_count).dispatch(_host_count[0]); } stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _config.max_instance_count - _host_count[0], _dispatch_counter, - _config.max_instance_count) + _resume_count, _config.thread_count - _host_count[0], _dispatch_counter, + _config.thread_count) .dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; @@ -389,11 +390,11 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.max_instance_count) + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.thread_count) .dispatch(_host_count[i]); } else { stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _config.max_instance_count) + _resume_count, _config.thread_count) .dispatch(_host_count[i]); } } @@ -402,18 +403,18 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << synchronize(); } else { stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); - stream << _gather_shader(_resume_index, _resume_offset, _config.max_instance_count).dispatch(_config.max_instance_count); - if (_host_count[0] > _config.max_instance_count / 2 && !_all_dispatched()) { + stream << _gather_shader(_resume_index, _resume_offset, _config.thread_count).dispatch(_config.thread_count); + if (_host_count[0] > _config.thread_count / 2 && !_all_dispatched()) { auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _config.max_instance_count && _config.compact) { + if (_host_count[0] != _config.thread_count && _config.compact) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), - _config.max_instance_count - _host_count[0], _config.max_instance_count) + _config.thread_count - _host_count[0], _config.thread_count) .dispatch(_host_count[0]); } stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _config.max_instance_count - _host_count[0], _dispatch_counter, _config.max_instance_count) + _resume_count, _config.thread_count - _host_count[0], _dispatch_counter, _config.thread_count) .dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; @@ -424,11 +425,11 @@ class WavefrontCoroScheduler : public CoroScheduler { BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.max_instance_count) + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.thread_count) .dispatch(_host_count[i]); } else { stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _config.max_instance_count) + _resume_count, _config.thread_count) .dispatch(_host_count[i]); } } diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront_v2.cpp index 6e710219a..739fc6d92 100644 --- a/src/tests/coro/path_tracing_wavefront_v2.cpp +++ b/src/tests/coro/path_tracing_wavefront_v2.cpp @@ -305,7 +305,7 @@ int main(int argc, char *argv[]) { }; coroutine::WavefrontCoroSchedulerConfig config{ - .max_instance_count = 16777216, + .thread_count = 16_M, .soa = true, }; coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; diff --git a/src/tests/coro/sdf_renderer_v2.cpp b/src/tests/coro/sdf_renderer_v2.cpp index 0332ebe6a..cc8aba1f2 100644 --- a/src/tests/coro/sdf_renderer_v2.cpp +++ b/src/tests/coro/sdf_renderer_v2.cpp @@ -184,7 +184,7 @@ int main(int argc, char *argv[]) { .block_size = 64u, .fetch_size = 3u, .shared_memory_soa = false, - .global_ext_memory = false}; + .global_memory_ext = false}; coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; auto clear_shader = device.compile<2>([&] { From d0d5eb4647c2f1a7a5d8849d706dcc625d66a6bc Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 22:06:36 +0800 Subject: [PATCH 51/67] refactor --- .../coro/v2/schedulers/persistent_threads.h | 204 ++--------------- src/coro/schedulers/persistent_threads.cpp | 207 ++++++++++++++++++ 2 files changed, 225 insertions(+), 186 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index 4fef7fa55..270a4b27b 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -4,6 +4,8 @@ #pragma once +#include +#include #include namespace luisa::compute::coroutine { @@ -16,6 +18,15 @@ struct PersistentThreadsCoroSchedulerConfig { bool global_memory_ext = false; }; +namespace detail { +LC_CORO_API void create_persistent_threads_scheduler_main_kernel( + const PersistentThreadsCoroSchedulerConfig &config, + uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, + const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + Expr> global, const Buffer &global_frames, + luisa::move_only_function call_subroutine) noexcept; +}// namespace detail + template class PersistentThreadsCoroScheduler : public CoroScheduler { @@ -36,198 +47,19 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { _global = device.create_buffer(1); auto q_fac = 1u; auto g_fac = coro.subroutine_count() - q_fac; - auto global_queue_size = _config.block_size * g_fac; if (_config.global_memory_ext) { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D main_kernel = [&](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { + Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); + auto global_queue_size = _config.block_size * g_fac; auto shared_queue_size = _config.block_size * q_fac; - Shared frames{coro.shared_frame(), shared_queue_size, _config.shared_memory_soa}; - Shared path_id{shared_queue_size}; - Shared work_counter{coro.subroutine_count()}; - Shared work_offset{2u}; - Shared all_token{_config.global_memory_ext ? - shared_queue_size + global_queue_size : - shared_queue_size}; - Shared workload{2}; - Shared work_stat{2};// work_state[0] for max_count, [1] for max_id - for (auto index : dsl::dynamic_range(q_fac)) { - auto s = index * _config.block_size + thread_x(); - all_token[s] = 0u; - // frames.write(s, coro.instantiate(), std::array{0u, 1u}); - } - for (auto index : dsl::dynamic_range(g_fac)) { - auto s = index * _config.block_size + thread_x(); - all_token[shared_queue_size + s] = 0u; - } - if_(thread_x() < coro.subroutine_count(), [&] { - if_(thread_x() == 0u, [&] { - work_counter[thread_x()] = - _config.global_memory_ext ? - shared_queue_size + global_queue_size : - shared_queue_size; - }).else_([&] { - work_counter[thread_x()] = 0u; - }); - }); - workload[0] = 0u; - workload[1] = 0u; - Shared rem_global{1}; - Shared rem_local{1}; - rem_global[0] = 1u; - rem_local[0] = 0u; - sync_block(); - auto count = def(0u); - auto count_limit = def(-1); - auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; - loop([&] { - if_(!((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)), [&] { break_(); }); - sync_block();//very important, synchronize for condition - rem_local[0] = 0u; - count += 1; - work_stat[0] = 0; - work_stat[1] = -1; - sync_block(); - if_(thread_x() == _config.block_size - 1, [&] { - if_(workload[0] >= workload[1] & rem_global[0] == 1u, [&] {//fetch new workload - workload[0] = global.atomic(0u).fetch_add(_config.block_size * _config.fetch_size); - workload[1] = min(workload[0] + _config.block_size * _config.fetch_size, dispatch_size); - if_(workload[0] >= dispatch_size, [&] { - rem_global[0] = 0u; - }); - }); - }); - sync_block(); - if_(thread_x() < coro.subroutine_count(), [&] {//get max - if_(workload[0] < workload[1] | thread_x() != 0u, [&] { - if_(work_counter[thread_x()] != 0, [&] { - rem_local[0] = 1u; - work_stat.atomic(0).fetch_max(work_counter[thread_x()]); - }); - }); - }); - sync_block(); - if_(thread_x() < coro.subroutine_count(), [&] {//get argmax - if_(work_stat[0] == work_counter[thread_x()] & (workload[0] < workload[1] | thread_x() != 0u), [&] { - work_stat[1] = thread_x(); - }); - }); - sync_block(); - work_offset[0] = 0; - work_offset[1] = 0; - sync_block(); - if (!_config.global_memory_ext) { - for (auto index : dsl::dynamic_range(q_fac)) {//collect indices - auto frame_token = all_token[index * _config.block_size + thread_x()]; - if_(frame_token == work_stat[1], [&] { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * _config.block_size + thread_x(); - }); - } - } else { - for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices - auto frame_token = all_token[index * _config.block_size + thread_x()]; - if_(frame_token != work_stat[1], [&] { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * _config.block_size + thread_x(); - }); - } - sync_block(); - if_(shared_queue_size - work_offset[0] < _config.block_size, [&] {//no enough work - for (auto index : dsl::dynamic_range(g_fac)) { //swap frames - auto global_id = block_x() * global_queue_size + index * _config.block_size + thread_x(); - auto g_queue_id = index * _config.block_size + thread_x(); - auto coro_token = all_token[shared_queue_size + g_queue_id]; - if_(coro_token == work_stat[1], [&] { - auto id = work_offset.atomic(1).fetch_add(1u); - if_(id < work_offset[0], [&] { - auto dst = path_id[id]; - auto frame_token = all_token[dst]; - if_(coro_token != 0u, [&] { - $if (frame_token != 0u) { - auto g_state = _global_frames->read(global_id); - _global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - } - $else { - auto g_state = _global_frames->read(global_id); - frames.write(dst, g_state); - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; - }).else_([&] { - $if (frame_token != 0u) { - auto frame = frames.read(dst); - _global_frames->write(global_id, frame); - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; - }); - }); - }); - } - }); - } - auto gen_st = workload[0]; - sync_block(); - auto pid = def(0u); - if (_config.global_memory_ext) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!_config.global_memory_ext) { - launch_condition = (thread_x() < work_offset[0]); - } else { - launch_condition = (all_token[pid] == work_stat[1]); - } - if_(launch_condition, [&] { - auto switch_stmt = switch_(all_token[pid]); - std::move(switch_stmt).case_(0u, [&] { - if_(gen_st + thread_x() < workload[1], [&] { - work_counter.atomic(0u).fetch_sub(1u); - auto global_index = gen_st + thread_x(); - auto image_size = dispatch_shape.x * dispatch_shape.y; - auto index_z = global_index / image_size; - auto index_xy = global_index % image_size; - auto index_x = index_xy % dispatch_shape.x; - auto index_y = index_xy / dispatch_shape.x; - auto frame = coro.instantiate(make_uint3(index_x, index_y, index_z)); - coro.entry()(frame, args...); - auto next = frame.target_token & token_mask; - frames.write(pid, frame, coro.graph()->entry().output_fields()); - all_token[pid] = next; - work_counter.atomic(next).fetch_add(1u); - workload.atomic(0).fetch_add(1u); - }); - }); - for (auto i = 1u; i < coro.subroutine_count(); i++) { - std::move(switch_stmt).case_(i, [&] { - work_counter.atomic(i).fetch_sub(1u); - auto frame = frames.read(pid, coro.graph()->node(i).input_fields()); - coro[i](frame, args...); - auto next = frame.target_token & token_mask; - frames.write(pid, frame, coro.graph()->node(i).output_fields()); - all_token[pid] = next; - work_counter.atomic(next).fetch_add(1u); - }); - } - }); - sync_block(); - }); -#ifndef NDEBUG - if_(count >= count_limit, [&] { - device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); - if_(thread_x() < coro.subroutine_count(), [&] { - device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); - }); - }); -#endif + auto call_subroutine = [&](CoroFrame &frame, CoroToken token) noexcept { coro[token](frame, args...); }; + Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; + detail::create_persistent_threads_scheduler_main_kernel( + _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, + frames, dispatch_shape, global, _global_frames, call_subroutine); }; _pt_shader = device.compile(main_kernel); _clear_shader = device.compile<1>([](BufferUInt global) { diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index cc81a9278..fd800f363 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -1,3 +1,210 @@ // // Created by Mike on 2024/5/10. // + +#include +#include +#include +#include +#include + +namespace luisa::compute::coroutine::detail { + +void create_persistent_threads_scheduler_main_kernel( + const PersistentThreadsCoroSchedulerConfig &config, + uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, + const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + Expr> global, const Buffer &global_frames, + luisa::move_only_function call_subroutine) noexcept { + + auto subroutine_count = static_cast(graph->nodes().size()); + Shared path_id{shared_queue_size}; + Shared work_counter{subroutine_count}; + Shared work_offset{2u}; + Shared all_token{config.global_memory_ext ? + shared_queue_size + global_queue_size : + shared_queue_size}; + Shared workload{2}; + Shared work_stat{2};// work_state[0] for max_count, [1] for max_id + for (auto index : dsl::dynamic_range(q_fac)) { + auto s = index * config.block_size + thread_x(); + all_token[s] = 0u; + // frames.write(s, coro.instantiate(), std::array{0u, 1u}); + } + for (auto index : dsl::dynamic_range(g_fac)) { + auto s = index * config.block_size + thread_x(); + all_token[shared_queue_size + s] = 0u; + } + $if (thread_x() < subroutine_count) { + $if (thread_x() == 0u) { + work_counter[thread_x()] = + config.global_memory_ext ? + shared_queue_size + global_queue_size : + shared_queue_size; + } + $else { + work_counter[thread_x()] = 0u; + }; + }; + workload[0] = 0u; + workload[1] = 0u; + Shared rem_global{1}; + Shared rem_local{1}; + rem_global[0] = 1u; + rem_local[0] = 0u; + sync_block(); + auto count = def(0u); + auto count_limit = def(-1); + auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; + $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { + sync_block();//very important, synchronize for condition + rem_local[0] = 0u; + count += 1; + work_stat[0] = 0; + work_stat[1] = -1; + sync_block(); + $if (thread_x() == config.block_size - 1) { + $if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload + workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); + workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); + $if (workload[0] >= dispatch_size) { + rem_global[0] = 0u; + }; + }; + }; + sync_block(); + $if (thread_x() < subroutine_count) {//get max + $if (workload[0] < workload[1] | thread_x() != 0u) { + $if (work_counter[thread_x()] != 0) { + rem_local[0] = 1u; + work_stat.atomic(0).fetch_max(work_counter[thread_x()]); + }; + }; + }; + sync_block(); + $if (thread_x() < subroutine_count) {//get argmax + $if (work_stat[0] == work_counter[thread_x()] & (workload[0] < workload[1] | thread_x() != 0u)) { + work_stat[1] = thread_x(); + }; + }; + sync_block(); + work_offset[0] = 0; + work_offset[1] = 0; + sync_block(); + if (!config.global_memory_ext) { + for (auto index : dsl::dynamic_range(q_fac)) {//collect indices + auto frame_token = all_token[index * config.block_size + thread_x()]; + $if (frame_token == work_stat[1]) { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * config.block_size + thread_x(); + }; + } + } else { + for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices + auto frame_token = all_token[index * config.block_size + thread_x()]; + $if (frame_token != work_stat[1]) { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * config.block_size + thread_x(); + }; + } + sync_block(); + $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work + for (auto index : dsl::dynamic_range(g_fac)) { //swap frames + auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); + auto g_queue_id = index * config.block_size + thread_x(); + auto coro_token = all_token[shared_queue_size + g_queue_id]; + $if (coro_token == work_stat[1]) { + auto id = work_offset.atomic(1).fetch_add(1u); + $if (id < work_offset[0]) { + auto dst = path_id[id]; + auto frame_token = all_token[dst]; + $if (coro_token != 0u) { + $if (frame_token != 0u) { + auto g_state = global_frames->read(global_id); + global_frames->write(global_id, frames.read(dst)); + frames.write(dst, g_state); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + } + $else { + auto g_state = global_frames->read(global_id); + frames.write(dst, g_state); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + }; + } + $else { + $if (frame_token != 0u) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + }; + }; + }; + }; + } + }; + } + auto gen_st = workload[0]; + sync_block(); + auto pid = def(0u); + if (config.global_memory_ext) { + pid = thread_x(); + } else { + pid = path_id[thread_x()]; + } + auto launch_condition = def(true); + if (!config.global_memory_ext) { + launch_condition = (thread_x() < work_offset[0]); + } else { + launch_condition = (all_token[pid] == work_stat[1]); + } + $if (launch_condition) { + constexpr auto valid_token_mask = coro_token_terminal - 1u; + static_assert(valid_token_mask == 0x7fff'ffffu); + $switch (all_token[pid]) { + $case (0u) { + $if (gen_st + thread_x() < workload[1]) { + work_counter.atomic(0u).fetch_sub(1u); + auto global_index = gen_st + thread_x(); + auto image_size = dispatch_shape.x * dispatch_shape.y; + auto index_z = global_index / image_size; + auto index_xy = global_index % image_size; + auto index_x = index_xy % dispatch_shape.x; + auto index_y = index_xy / dispatch_shape.x; + auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); + call_subroutine(frame, coro_token_entry); + auto next = frame.target_token & valid_token_mask; + frames.write(pid, frame, graph->entry().output_fields()); + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + workload.atomic(0).fetch_add(1u); + }; + }; + for (auto i = 1u; i < subroutine_count; i++) { + $case (i) { + work_counter.atomic(i).fetch_sub(1u); + auto frame = frames.read(pid, graph->node(i).input_fields()); + call_subroutine(frame, i); + auto next = frame.target_token & valid_token_mask; + frames.write(pid, frame, graph->node(i).output_fields()); + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + }; + } + }; + }; + sync_block(); + }; +#ifndef NDEBUG + $if (count >= count_limit) { + device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); + $if (thread_x() < subroutine_count) { + device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); + }; + }; +#endif +} + +}// namespace luisa::compute::coroutine::detail From 9d1a62ec00be0de209825e7db0d1e6e758286c18 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 22:07:25 +0800 Subject: [PATCH 52/67] minor --- include/luisa/coro/v2/schedulers/persistent_threads.h | 4 ++-- src/coro/schedulers/persistent_threads.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index 270a4b27b..f806d537c 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -19,7 +19,7 @@ struct PersistentThreadsCoroSchedulerConfig { }; namespace detail { -LC_CORO_API void create_persistent_threads_scheduler_main_kernel( +LC_CORO_API void persistent_threads_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, const CoroGraph *graph, Shared &frames, Expr dispatch_shape, @@ -57,7 +57,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto shared_queue_size = _config.block_size * q_fac; auto call_subroutine = [&](CoroFrame &frame, CoroToken token) noexcept { coro[token](frame, args...); }; Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; - detail::create_persistent_threads_scheduler_main_kernel( + detail::persistent_threads_scheduler_main_kernel_impl( _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, frames, dispatch_shape, global, _global_frames, call_subroutine); }; diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index fd800f363..f5f848200 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -10,7 +10,7 @@ namespace luisa::compute::coroutine::detail { -void create_persistent_threads_scheduler_main_kernel( +void persistent_threads_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, const CoroGraph *graph, Shared &frames, Expr dispatch_shape, From 48413d0172e159248bce1c692e72d908b8c883e6 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 23:09:36 +0800 Subject: [PATCH 53/67] minor --- include/luisa/coro/v2/coro_scheduler.h | 2 +- include/luisa/coro/v2/schedulers/persistent_threads.h | 4 ++-- src/coro/schedulers/persistent_threads.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/luisa/coro/v2/coro_scheduler.h b/include/luisa/coro/v2/coro_scheduler.h index e6e9ccde5..82b9444d0 100644 --- a/include/luisa/coro/v2/coro_scheduler.h +++ b/include/luisa/coro/v2/coro_scheduler.h @@ -41,7 +41,7 @@ class CoroSchedulerInvoke : public concepts::Noncopyable { private: friend Scheduler; - CoroSchedulerInvoke(Scheduler *scheduler, compute::detail::prototype_to_shader_invocation_t... args) noexcept + explicit CoroSchedulerInvoke(Scheduler *scheduler, compute::detail::prototype_to_shader_invocation_t... args) noexcept : _scheduler{scheduler}, _args{args...} {} public: diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/v2/schedulers/persistent_threads.h index f806d537c..e233187e5 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/v2/schedulers/persistent_threads.h @@ -19,7 +19,7 @@ struct PersistentThreadsCoroSchedulerConfig { }; namespace detail { -LC_CORO_API void persistent_threads_scheduler_main_kernel_impl( +LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, const CoroGraph *graph, Shared &frames, Expr dispatch_shape, @@ -57,7 +57,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto shared_queue_size = _config.block_size * q_fac; auto call_subroutine = [&](CoroFrame &frame, CoroToken token) noexcept { coro[token](frame, args...); }; Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; - detail::persistent_threads_scheduler_main_kernel_impl( + detail::persistent_threads_coro_scheduler_main_kernel_impl( _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, frames, dispatch_shape, global, _global_frames, call_subroutine); }; diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index f5f848200..624e562f4 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -10,7 +10,7 @@ namespace luisa::compute::coroutine::detail { -void persistent_threads_scheduler_main_kernel_impl( +void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, const CoroGraph *graph, Shared &frames, Expr dispatch_shape, From 1e724c702c7872659453def0f66e783a71fd23a9 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 14 May 2024 23:20:34 +0800 Subject: [PATCH 54/67] add support for 3d dispatch --- include/luisa/coro/v2/schedulers/wavefront.h | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/v2/schedulers/wavefront.h index 70d96c672..9a999a264 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/v2/schedulers/wavefront.h @@ -34,7 +34,7 @@ class WavefrontCoroScheduler : public CoroScheduler { luisa::optional _args; SOA _frame_soa; Buffer _frame_buffer; - Shader1D, Buffer, uint, uint, uint, Args...> _gen_shader; + Shader1D, Buffer, uint3, uint, uint, uint, Args...> _gen_shader; luisa::vector, Buffer, uint, Args...>> _resume_shaders; Shader1D, Buffer, uint> _count_prefix_shader; Shader1D, Buffer, uint> _gather_shader; @@ -48,6 +48,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Buffer _global_buffer; luisa::vector _host_count; luisa::vector _host_offset; + uint3 _dispatch_shape; bool _host_empty; uint _dispatch_counter; uint _max_sub_coro; @@ -63,6 +64,7 @@ class WavefrontCoroScheduler : public CoroScheduler { void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { _args.emplace(std::forward>(args)...); + _dispatch_shape = dispatch_size; _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z;// TODO _dispatch_counter = 0; _host_empty = true; @@ -186,7 +188,7 @@ class WavefrontCoroScheduler : public CoroScheduler { &get_coro_hint, 0, 128, 0, highbit); } } - Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt offset, UInt st_task_id, UInt n, Var... args) { + Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt3 dispatch_shape, UInt offset, UInt st_task_id, UInt n, Var... args) { auto x = dispatch_x(); $if (x >= n) { $return(); @@ -200,8 +202,13 @@ class WavefrontCoroScheduler : public CoroScheduler { if (!_config.sort) { count.atomic(0u).fetch_add(-1u); } - - CoroFrame frame = coroutine.instantiate(make_uint3(st_task_id + x, 0u, 0u)); + auto global_id = st_task_id + x; + auto image_size = dispatch_shape.x * dispatch_shape.y; + auto global_id_z = global_id / image_size; + auto global_id_xy = global_id % image_size; + auto global_id_x = global_id_xy % dispatch_shape.x; + auto global_id_y = global_id_xy / dispatch_shape.x; + CoroFrame frame = coroutine.instantiate(make_uint3(global_id_x, global_id_y, global_id_z)); coroutine.entry()(frame, args...); if (_config.soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); @@ -378,7 +385,7 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << _compact_shader(_resume_index, _config.thread_count - _host_count[0], _config.thread_count).dispatch(_host_count[0]); } stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _config.thread_count - _host_count[0], _dispatch_counter, + _resume_count, _dispatch_shape, _config.thread_count - _host_count[0], _dispatch_counter, _config.thread_count) .dispatch(gen_count); _dispatch_counter += gen_count; @@ -414,7 +421,7 @@ class WavefrontCoroScheduler : public CoroScheduler { .dispatch(_host_count[0]); } stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), - _resume_count, _config.thread_count - _host_count[0], _dispatch_counter, _config.thread_count) + _resume_count, _dispatch_shape, _config.thread_count - _host_count[0], _dispatch_counter, _config.thread_count) .dispatch(gen_count); _dispatch_counter += gen_count; _host_empty = false; From 28ec6e2eaa4e9a0ba019e3eef428faf0c60e0991 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 00:25:36 +0800 Subject: [PATCH 55/67] wip: cleanup outdated coro api --- include/luisa/ast/function_builder.h | 6 - include/luisa/ast/type.h | 19 - include/luisa/ast/type_registry.h | 19 - include/luisa/coro/coro_dispatcher.h | 956 ------------------ include/luisa/coro/{v2 => }/coro_frame.h | 2 +- .../luisa/coro/{v2 => }/coro_frame_buffer.h | 2 +- include/luisa/coro/{v2 => }/coro_frame_desc.h | 0 include/luisa/coro/{v2 => }/coro_frame_smem.h | 4 +- include/luisa/coro/{v2 => }/coro_frame_soa.h | 6 +- include/luisa/coro/{v2 => }/coro_func.h | 4 +- include/luisa/coro/coro_graph.h | 95 +- include/luisa/coro/coro_node.h | 35 - include/luisa/coro/{v2 => }/coro_scheduler.h | 0 include/luisa/coro/{v2 => }/coro_token.h | 1 + include/luisa/coro/coro_transition.h | 12 - include/luisa/coro/radix_sort.h | 2 +- .../{v2 => }/schedulers/persistent_threads.h | 3 +- .../coro/{v2 => }/schedulers/state_machine.h | 8 +- .../coro/{v2 => }/schedulers/wavefront.h | 24 +- include/luisa/coro/shader_scheduler.h | 0 include/luisa/coro/v2/coro_graph.h | 72 -- include/luisa/dsl/builtin.h | 18 - include/luisa/dsl/func.h | 117 +-- include/luisa/dsl/struct.h | 364 ------- include/luisa/luisa-compute.h | 26 +- include/luisa/runtime/byte_buffer.h | 5 +- include/luisa/runtime/shader.h | 9 - src/ast/ast2json.cpp | 3 - src/ast/function_builder.cpp | 23 - src/ast/type.cpp | 95 +- .../common/hlsl/hlsl_codegen_util.cpp | 4 - src/backends/common/shader_print_formatter.h | 3 +- src/backends/cuda/cuda_codegen_ast.cpp | 6 - src/backends/metal/metal_codegen_ast.cpp | 2 - src/coro/coro_frame.cpp | 5 +- src/coro/coro_frame_buffer.cpp | 2 +- src/coro/coro_frame_desc.cpp | 2 +- src/coro/coro_func.cpp | 2 +- src/coro/coro_graph.cpp | 4 +- src/coro/schedulers/persistent_threads.cpp | 14 +- src/coro/schedulers/state_machine.cpp | 2 +- src/dsl/func.cpp | 112 -- src/tests/CMakeLists.txt | 12 +- src/tests/test_skyline.cpp | 9 +- src/tests/test_sort.cpp | 1 - 45 files changed, 130 insertions(+), 1980 deletions(-) delete mode 100644 include/luisa/coro/coro_dispatcher.h rename include/luisa/coro/{v2 => }/coro_frame.h (97%) rename include/luisa/coro/{v2 => }/coro_frame_buffer.h (99%) rename include/luisa/coro/{v2 => }/coro_frame_desc.h (100%) rename include/luisa/coro/{v2 => }/coro_frame_smem.h (98%) rename include/luisa/coro/{v2 => }/coro_frame_soa.h (99%) rename include/luisa/coro/{v2 => }/coro_func.h (99%) delete mode 100644 include/luisa/coro/coro_node.h rename include/luisa/coro/{v2 => }/coro_scheduler.h (100%) rename include/luisa/coro/{v2 => }/coro_token.h (81%) delete mode 100644 include/luisa/coro/coro_transition.h rename include/luisa/coro/{v2 => }/schedulers/persistent_threads.h (98%) rename include/luisa/coro/{v2 => }/schedulers/state_machine.h (94%) rename include/luisa/coro/{v2 => }/schedulers/wavefront.h (97%) delete mode 100644 include/luisa/coro/shader_scheduler.h delete mode 100644 include/luisa/coro/v2/coro_graph.h diff --git a/include/luisa/ast/function_builder.h b/include/luisa/ast/function_builder.h index a1e89e1b8..a0fd25297 100644 --- a/include/luisa/ast/function_builder.h +++ b/include/luisa/ast/function_builder.h @@ -242,8 +242,6 @@ class LC_AST_API FunctionBuilder : public luisa::enable_shared_from_this &coro_tokens() const noexcept { return _coro_tokens; } - void coroframe_replace(const Type *type) noexcept; - // build primitives /// Define a kernel function with given definition template static auto define_kernel(Def &&def) { @@ -407,12 +405,8 @@ class LC_AST_API FunctionBuilder : public luisa::enable_shared_from_this::value; template struct is_custom_struct : std::false_type {}; -template -struct is_coroframe_struct : std::false_type {}; - template constexpr auto is_custom_struct_v = is_custom_struct::value; -template -constexpr auto is_coroframe_struct_v = is_coroframe_struct::value; - namespace detail { template @@ -379,9 +373,6 @@ class LC_AST_API Type { /// Return custom type with the specified name [[nodiscard]] static const Type *custom(luisa::string_view name) noexcept; - /// Return custom type for coroframe - [[nodiscard]] static const Type *coroframe(luisa::string_view name) noexcept; - /// Construct Type object from description /// @param description Type description in the following syntax: \n /// TYPE := DATA | RESOURCE | CUSTOM \n @@ -427,14 +418,6 @@ class LC_AST_API Type { [[nodiscard]] luisa::span members() const noexcept; [[nodiscard]] luisa::span member_attributes() const noexcept; [[nodiscard]] const Type *element() const noexcept; - [[nodiscard]] const Type *corotype() const noexcept; - ///change the corresponding Type to another Type in registry - void update_from(const Type *type); - /// add a member with name and type - size_t add_member(const luisa::string &name) noexcept; - void set_member_name(size_t index, luisa::string name) noexcept; - /// get member index with string - [[nodiscard]] size_t member(luisa::string_view name) const noexcept; /// Scalar = bool || float || int || uint [[nodiscard]] bool is_scalar() const noexcept; [[nodiscard]] bool is_bool() const noexcept; @@ -470,8 +453,6 @@ class LC_AST_API Type { [[nodiscard]] bool is_accel() const noexcept; [[nodiscard]] bool is_resource() const noexcept; [[nodiscard]] bool is_custom() const noexcept; - [[nodiscard]] bool is_coroframe() const noexcept; - [[nodiscard]] bool is_materialized_coroframe() const noexcept; }; }// namespace luisa::compute diff --git a/include/luisa/ast/type_registry.h b/include/luisa/ast/type_registry.h index 3d5fa6cb3..b5f7976b8 100644 --- a/include/luisa/ast/type_registry.h +++ b/include/luisa/ast/type_registry.h @@ -256,9 +256,6 @@ const Type *Type::of() noexcept { if constexpr (is_custom_struct_v) { static thread_local auto t = Type::custom(desc); return t; - } else if constexpr (is_coroframe_struct_v) { - static thread_local auto t = Type::coroframe(desc); - return t; } else { static thread_local auto t = Type::from(desc); return t; @@ -275,8 +272,6 @@ template struct is_valid_reflection, std::integer_sequence> { static_assert(alignof(S) >= 4u, "Structs must be aligned to at least 4 bytes."); - static_assert(std::negation_v...>>, - "Structs cannot contain CoroFrame Type"); private: [[nodiscard]] constexpr static auto _check() noexcept { @@ -369,17 +364,3 @@ constexpr auto is_valid_reflection_v = is_valid_reflection::value; return name; \ } \ }; - -#define LUISA_COROFRAME_STRUCT_REFLECT(S, name) \ - template<> \ - struct luisa::compute::canonical_layout { \ - using type = std::tuple; \ - }; \ - template<> \ - struct luisa::compute::is_coroframe_struct : std::true_type {}; \ - template<> \ - struct luisa::compute::detail::TypeDesc { \ - static constexpr luisa::string_view description() noexcept { \ - return name; \ - } \ - }; diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h deleted file mode 100644 index 19abb23b2..000000000 --- a/include/luisa/coro/coro_dispatcher.h +++ /dev/null @@ -1,956 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -namespace luisa::compute { - -class Stream; -inline namespace coro { -template -struct prototype_to_coro_dispatcher { - using type = T; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = BufferView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = ImageView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = VolumeView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = SOAView; -}; - -template -using prototype_to_coro_dispatcher_t = typename prototype_to_coro_dispatcher::type; - -template -class CoroAwait; -template -class CoroDispatcherBase { - static_assert(always_false_v); -}; -const uint token_mask = 0x7fffffff; -template -class CoroDispatcherBase : public concepts::Noncopyable { - using FuncType = void(FrameRef, Args...); - friend class CoroAwait; -public: - - luisa::queue> _dispatcher; -protected: - std::tuple...> _args; - Coroutine *_coro; - virtual void _await_step(Stream &stream) noexcept = 0; - virtual void _await_all(Stream &stream) noexcept; - uint _dispatch_size; - Device _device; - ///helper function for calling a shader by omitting suffix coroutine args - template - detail::ShaderInvoke call_shader(Shader &shader, prototype_to_coro_dispatcher_t... prefix_args); - -protected: -public: - CoroDispatcherBase(Coroutine *coro_ptr, Device &device) noexcept - : _coro{std::move(coro_ptr)}, _device{device}, _dispatch_size{} {} - virtual ~CoroDispatcherBase() noexcept = default; - [[nodiscard]] virtual bool all_dispatched() const noexcept = 0; - [[nodiscard]] virtual bool all_done() const noexcept = 0; - [[nodiscard]] CoroAwait await_step() noexcept; - [[nodiscard]] CoroAwait await_all() noexcept; - virtual void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept {//initialize the dispatcher - _args = std::make_tuple(std::forward>(args)...); - _dispatch_size = dispatch_size; - } -}; - -template -class SimpleCoroDispatcher : public CoroDispatcherBase { -private: - using FrameType = std::remove_reference_t; - Shader1D _shader; - bool _done; - void _await_step(Stream &stream) noexcept override { - stream << this->template call_shader<1, uint>(_shader, this->_dispatch_size).dispatch(this->_dispatch_size); - _done = true; - stream << synchronize(); - } -public: - SimpleCoroDispatcher(Coroutine *coroutine, Device &device, - uint max_frame_count) : CoroDispatcherBase{coroutine, device}, - _done{false} { - uint max_sub_coro = coroutine->suspend_count() + 1; - Kernel1D run_kernel = [&](UInt n, Var... args) { - set_block_size(128, 1, 1); - auto x = dispatch_x(); - $if (x < n) { - Var frame; - initialize_coroframe(frame, def(x, 0, 0)); - (*coroutine)(frame, args...); - $loop { - auto token = read_promise(frame, "coro_token"); - $if (token == 0x8000'0000) { - $break; - }; - $switch (token) { - for (auto i = 1u; i < max_sub_coro; ++i) { - $case (i) { - (*coroutine)[i](frame, args...); - }; - } - }; - }; - }; - }; - _shader = device.compile(run_kernel); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept { - CoroDispatcherBase::operator()(args..., dispatch_size); - _done = false; - } - bool all_dispatched() const noexcept override { - return _done; - } - bool all_done() const noexcept override { - return _done; - } -}; - -struct WavefrontCoroDispatcherConfig { - uint max_instance_count = 2_M; - bool soa = true; - bool sort = true;//use sort for coro token gathering - bool compact = true; - bool debug = false; - uint hint_range = 0xffff'ffff; - luisa::vector hint_fields; -}; - -template -class WavefrontCoroDispatcher : public CoroDispatcherBase { - -public: - using Config = WavefrontCoroDispatcherConfig; - -private: - using FrameType = std::remove_reference_t; - using Container = compute::SOA; - bool is_soa = true; - bool sort_base_gather = false; - bool use_compact = true; - Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; - luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; - Shader1D, Buffer, uint> _count_prefix_shader; - Shader1D, Buffer, Container, uint> _gather_shader; - Shader1D, Container, uint> _initialize_shader; - Shader1D, Container, uint, uint> _compact_shader; - Shader1D, uint> _clear_shader; - Container _frame; - compute::Buffer _resume_index; - compute::Buffer _resume_count; - ///offset calculate from count, will be end after gathering - compute::Buffer _resume_offset; - compute::Buffer _global_buffer; - compute::Buffer _debug_buffer; - luisa::vector _host_count; - luisa::vector _host_offset; - bool _host_empty; - uint _dispatch_counter; - uint _max_sub_coro; - uint _max_frame_count; - bool _debug; - void _await_step(Stream &stream) noexcept override; - radix_sort::temp_storage _sort_temp_storage; - radix_sort::instance<> _sort_token; - radix_sort::instance> _sort_hint; - luisa::vector _have_hint; - compute::Buffer _temp_key[2]; - compute::Buffer _temp_index; - Stream &_stream; -public: - [[nodiscard]] bool all_dispatched() const noexcept; - [[nodiscard]] bool all_done() const noexcept; - - WavefrontCoroDispatcher(Coroutine *coroutine, - Device &device, Stream &stream, - const WavefrontCoroDispatcherConfig &config) noexcept - : CoroDispatcherBase{coroutine, device}, - is_soa{config.soa}, - sort_base_gather{config.sort}, - use_compact{config.compact}, - _max_frame_count{config.max_instance_count}, - _stream{stream}, _debug{config.debug}, - _frame{device.create_soa(config.max_instance_count)} { - /*if (device.backend_name() != "cuda") {//only cuda can sort - sort_base_gather = false; - hint_token = {}; - LUISA_INFO("Using wavefront dispatcher without cuda, the sorting will be disabled!"); - }*/ - bool use_sort = sort_base_gather || !config.hint_fields.empty(); - uint max_sub_coro = coroutine->suspend_count() + 1; - _max_sub_coro = max_sub_coro; - _resume_index = device.create_buffer(_max_frame_count); - if (use_sort) { - _temp_index = device.create_buffer(_max_frame_count); - _temp_key[0] = device.create_buffer(_max_frame_count); - _temp_key[1] = device.create_buffer(_max_frame_count); - } - _resume_count = device.create_buffer(max_sub_coro); - _resume_offset = device.create_buffer(max_sub_coro); - _global_buffer = device.create_buffer(1); - if (_debug) { - _debug_buffer = device.create_buffer(max_sub_coro); - } - _host_empty = true; - _dispatch_counter = 0; - _host_offset.resize(max_sub_coro); - _host_count.resize(max_sub_coro); - _have_hint.resize(max_sub_coro, false); - for (auto &token : config.hint_fields) { - auto id = coroutine->coro_tokens().find(token); - if (id != coroutine->coro_tokens().end()) { - LUISA_ASSERT(id->second < max_sub_coro, "coroutine token {} of id {} out of range {}", token, id->second, max_sub_coro); - _have_hint[id->second] = true; - } else - LUISA_WARNING("coroutine token {} not found, hint disabled", token); - } - for (auto i = 0u; i < max_sub_coro; i++) { - if (i) { - _host_count[i] = 0; - _host_offset[i] = _max_frame_count; - } else { - _host_count[i] = _max_frame_count; - _host_offset[i] = 0; - } - } - Callable get_coro_token = [&](UInt index) { - $if (index > _max_frame_count) { - device_log("index {} out of range {}", index, _max_frame_count); - }; - auto frame = _frame->read(index); - return read_promise(frame, "coro_token") & token_mask; - }; - Callable identical = [&](UInt index) { - return index; - }; - - Callable keep_index = [&](UInt index, BufferUInt val) { - return val.read(index); - }; - Callable get_coro_hint = [&](UInt index, BufferUInt val) { - if (!config.hint_fields.empty()) { - auto id = keep_index(index, val); - auto frame = _frame->read(id); - auto x = read_promise(frame, "coro_hint"); - return x; - } else { - return def(0u); - } - }; - if (use_sort) { - _sort_temp_storage = radix_sort::temp_storage(device, _max_frame_count, std::max(std::min(config.hint_range, 128u), max_sub_coro)); - } - if (sort_base_gather) { - _sort_token = radix_sort::instance<>(device, _max_frame_count, _sort_temp_storage, - &get_coro_token, &identical, &get_coro_token, 1, max_sub_coro); - } - if (!config.hint_fields.empty()) { - if (config.hint_range <= 128) { - _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, - &get_coro_hint, &keep_index, &get_coro_hint, 1, config.hint_range); - } else { - auto highbit = 0; - while ((config.hint_range >> highbit) != 1) { - highbit++; - } - _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, - &get_coro_hint, &keep_index, &get_coro_hint, 0, 128, 0, highbit); - } - } - - Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt offset, Var frame_buffer, UInt st_task_id, UInt n, Var... args) { - auto x = dispatch_x(); - $if (x >= n) { - $return(); - }; - UInt frame_id; - if (!use_compact) { - frame_id = index->read(x); - } else { - frame_id = offset + x; - } - Var frame; - if (_debug) { - frame = frame_buffer.read(frame_id); - $if ((read_promise(frame, "coro_token") & token_mask) != 0u) { - device_log("wrong gen for frame {} at kernel {}when dispatch {}", frame_id, read_promise(frame, "coro_token"), dispatch_x()); - }; - } - initialize_coroframe(frame, def(st_task_id + x, 0, 0)); - if (!sort_base_gather) { - count.atomic(0).fetch_add(-1u); - } - - (*coroutine)(frame, args...); - if (is_soa) { - frame_buffer.write(frame_id, frame, coroutine->graph().node(0u)->output_state_members); - } else { - frame_buffer.write(frame_id, frame); - } - if (!sort_base_gather) { - auto nxt = read_promise(frame, "coro_token") & token_mask; - count.atomic(nxt).fetch_add(1u); - } - }; - ShaderOption o{}; - _gen_shader = device.compile(gen_kernel, o); - _gen_shader.set_name("gen"); - _resume_shaders.resize(max_sub_coro); - for (int i = 1; i < max_sub_coro; ++i) { - Kernel1D resume_kernel = [&](BufferUInt index, BufferUInt count, Var frame_buffer, UInt n, Var... args) { - auto x = dispatch_x(); - $if (x >= n) { - $return(); - }; - auto frame_id = index.read(x); - Var frame; - if (is_soa) { - //frame = frame_buffer.read(frame_id); - frame = frame_buffer.read(frame_id, coroutine->graph().node(i)->input_state_members); - } else { - frame = frame_buffer.read(frame_id); - } - if (!sort_base_gather) { - count.atomic(i).fetch_add(-1u); - } - if (_debug) { - auto token = frame_buffer.read(frame_id); - auto check = read_promise(token, "coro_token") & token_mask; - /*$if (check != i) { - device_log("wrong launch for frame {} at kernel {} as kernel {} when dispatch {}", frame_id, check, i, dispatch_x()); - };*/ - } - /*if(_have_hint[i]) { - auto token = frame_buffer.read(frame_id); - device_log("dispatch:{}, index:{}, hint:{}",dispatch_x(),frame_id, read_promise(token,"coro_hint")&token_mask); - }*/ - (*coroutine)[i](frame, args...); - if (is_soa) { - frame_buffer.write(frame_id, frame, coroutine->graph().node(i)->output_state_members); - } else { - frame_buffer.write(frame_id, frame); - } - //if (debug) - // device_log("resume kernel {} : id {} goto kernel {}", i, frame_id, nxt); - - if (!sort_base_gather) { - auto nxt = read_promise(frame, "coro_token") & token_mask; - $switch (nxt) { - for (int i = 0; i < max_sub_coro; ++i) { - $case (i) { - count.atomic(i).fetch_add(1u); - }; - } - }; - //count.atomic(nxt).fetch_add(1u); - } - }; - if (_debug) o.name = "resume" + std::to_string(i); - _resume_shaders[i] = device.compile(resume_kernel, o); - _resume_shaders[i].set_name("resume" + std::to_string(i)); - } - Kernel1D _prefix_kernel = [&](BufferUInt count, BufferUInt prefix, UInt n) { - $if (dispatch_x() == 0) { - auto pre = def(0u); - $for (i, 0u, _max_sub_coro) { - auto val = count.read(i); - prefix.write(i, pre); - pre = pre + val; - if (_debug) { - _debug_buffer->write(i, pre); - } - }; - }; - }; - _count_prefix_shader = device.compile(_prefix_kernel); - Kernel1D _collect_kernel = [&](BufferUInt index, BufferUInt prefix, Var frame_buffer, UInt n) { - auto x = dispatch_x(); - auto frame = frame_buffer.read(x); - auto r_id = read_promise(frame, "coro_token") & token_mask; - auto q_id = prefix.atomic(r_id).fetch_add(1u); - if (_debug) { - $if (q_id >= _debug_buffer->read(r_id)) { - device_log("collect: indices overflow!!!! frame_id:{}>buffer_offset:{}", q_id, _debug_buffer->read(r_id)); - }; - /* - $if (q_id == _debug_buffer->read(r_id) - 1) { - device_log("finish gather: kernel {}, tot :{}", r_id, _debug_buffer->read(r_id)); - }; - */ - } - index.write(q_id, x); - }; - _gather_shader = device.compile(_collect_kernel); - Kernel1D _compact_kernel_2 = [&](BufferUInt index, Var frame_buffer, UInt empty_offset, UInt n) { - //_global_buffer->write(0u, 0u); - auto x = dispatch_x(); - $if (empty_offset + x < n) { - auto token = frame_buffer.read_coro_token(empty_offset + x); - $if ((token & token_mask) != 0u) { - - auto res = _global_buffer->atomic(0).fetch_add(1u); - auto slot = index.read(res); - if (!sort_base_gather) { - $while (slot >= empty_offset) { - res = _global_buffer->atomic(0).fetch_add(1u); - slot = index.read(res); - }; - } - /*if (_debug) { - $if (slot >= empty_offset) { - device_log("compact: new slot is in empty set!!!! slot:{}>empty_offset:{}", slot, empty_offset); - }; - }*/ - if (_debug) { - auto empty = frame_buffer.read(slot); - $if ((read_promise(empty, "coro_token") & token_mask) != 0u) { - device_log("wrong compact for frame {} at kernel {} when dispatch {}", slot, (read_promise(empty, "coro_token") & token_mask), dispatch_x()); - }; - } - auto frame = frame_buffer.read(empty_offset + x); - frame_buffer.write(slot, frame); - Var empty_frame; - initialize_coroframe(empty_frame, def(0, 0, 0)); - if (is_soa) { - frame_buffer.write(empty_offset + x, empty_frame, {1}); - } else { - frame_buffer.write(empty_offset + x, empty_frame); - } - }; - }; - }; - _compact_shader = device.compile(_compact_kernel_2); - _compact_shader.set_name("compact"); - - Kernel1D _initialize_kernel = [&](BufferUInt count, Var frame_buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - auto frame = frame_buffer.read(x); - initialize_coroframe(frame, def(0, 0, 0)); - }; - $if (x < max_sub_coro) { - count.write(x, ite(x == 0, _max_frame_count, 0u)); - }; - }; - Kernel1D clear = [&](BufferUInt buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - buffer.write(x, 0u); - }; - }; - _clear_shader = device.compile(clear); - _initialize_shader = device.compile(_initialize_kernel); - stream << _initialize_shader(_resume_count, _frame, _max_frame_count).dispatch(_max_frame_count); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept override { - CoroDispatcherBase::operator()(args..., dispatch_size); - _dispatch_counter = 0; - _host_empty = true; - for (auto i = 0u; i < _max_sub_coro; i++) { - if (i) { - _host_count[i] = 0; - _host_offset[i] = _max_frame_count; - } else { - _host_count[i] = _max_frame_count; - _host_offset[i] = 0; - } - } - _stream << _initialize_shader(_resume_count, _frame, _max_frame_count).dispatch(_max_frame_count); - } -}; - -struct PersistentCoroDispatcherConfig { - uint max_thread_count = 64_k; - uint block_size = 128; - uint fetch_size = 16; - bool global = false; - bool debug = false; -}; - -template -class PersistentCoroDispatcher : public CoroDispatcherBase { - -public: - using Config = PersistentCoroDispatcherConfig; - -private: - using FrameType = std::remove_reference_t; - Shader1D, uint, Args...> _pt_shader; - Shader1D> _clear_shader; - Buffer _global; - Buffer _global_frame; - Shader1D, uint> _initialize_shader; - uint _max_sub_coro; - uint _max_thread_count; - uint _block_size; - uint _global_size; - bool _dispatched; - bool _done; - void _await_step(Stream &stream) noexcept; - void _await_all(Stream &stream) noexcept; - bool _debug; - Stream &_stream; -public: - bool all_dispatched() const noexcept; - bool all_done() const noexcept; - PersistentCoroDispatcher(Coroutine *coroutine, - Device &device, Stream &stream, - const PersistentCoroDispatcherConfig &config) noexcept - : CoroDispatcherBase{coroutine, device}, - _max_thread_count{(config.max_thread_count + config.block_size - 1) / config.block_size * config.block_size}, - _block_size{config.block_size}, - _debug{config.debug}, _stream{stream} { - auto use_global = config.global; - _global = device.create_buffer(1); - auto q_fac = 1u; - uint max_sub_coro = coroutine->suspend_count() + 1; - auto g_fac = (uint)std::max((int)(max_sub_coro - q_fac), 0); - auto global_queue_size = config.block_size * g_fac; - _global_size = 0; - if (use_global) { - _global_frame = device.create_buffer(_max_thread_count * g_fac); - _global_size = _max_thread_count * g_fac; - } - _max_sub_coro = max_sub_coro; - _dispatched = false; - _done = false; - Kernel1D main_kernel = [&](BufferUInt global, UInt dispatch_size, Var... args) { - set_block_size(config.block_size, 1, 1); - auto shared_queue_size = config.block_size * q_fac; - Shared frames{shared_queue_size}; - Shared path_id{shared_queue_size}; - Shared work_counter{max_sub_coro}; - Shared work_offset{2u}; - Shared all_token{use_global ? (shared_queue_size + global_queue_size) : shared_queue_size}; - Shared workload{2}; - Shared work_stat{2};//0 max_count,1 max_id - //Shared tag_counter{use_tag_sort ? pipeline().surfaces().size() : 0}; - //Shared tag_offset{pipeline().surfaces().size()}; - $for (index, 0u, q_fac) { - all_token[index * config.block_size + thread_x()] = 0u; - initialize_coroframe(frames[index * config.block_size + thread_x()], def(0, 0, 0)); - }; - $for (index, 0u, g_fac) { - all_token[shared_queue_size + index * config.block_size + thread_x()] = 0u; - }; - $if (thread_x() < max_sub_coro) { - $if (thread_x() == 0) { - work_counter[thread_x()] = use_global ? (shared_queue_size + global_queue_size) : shared_queue_size; - } - $else { - work_counter[thread_x()] = 0u; - }; - }; - workload[0] = 0; - workload[1] = 0; - Shared rem_global{1}; - Shared rem_local{1}; - rem_global[0] = 1u; - rem_local[0] = 0u; - sync_block(); - auto count = def(0u); - auto count_limit = def(-1); - $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { - sync_block();//very important, synchronize for condition - rem_local[0] = 0u; - count += 1; - work_stat[0] = 0; - work_stat[1] = -1; - sync_block(); - $if (thread_x() == config.block_size - 1) { - $if ((workload[0] >= workload[1]) & (rem_global[0] == 1u)) {//fetch new workload - workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); - if (_debug) - device_log("block {}, fetch workload: {}", block_x(), workload[0]); - workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); - $if (workload[0] >= dispatch_size) { - rem_global[0] = 0u; - }; - }; - }; - sync_block(); - $if (thread_x() < max_sub_coro) {//get max - $if ((workload[0] < workload[1]) | (thread_x() != 0u)) { - $if (work_counter[thread_x()] != 0) { - rem_local[0] = 1u; - work_stat.atomic(0).fetch_max(work_counter[thread_x()]); - }; - }; - }; - sync_block(); - $if (thread_x() < max_sub_coro) {//get argmax - $if ((work_stat[0] == work_counter[thread_x()]) & ((workload[0] < workload[1]) | (thread_x() != 0u))) { - work_stat[1] = thread_x(); - }; - }; - sync_block(); - work_offset[0] = 0; - work_offset[1] = 0; - sync_block(); - if (!use_global) { - $for (index, 0u, q_fac) {//collect indices - auto frame_token = all_token[index * config.block_size + thread_x()]; - $if (frame_token == work_stat[1]) { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * config.block_size + thread_x(); - }; - }; - } else { - $for (index, 0u, q_fac) {//collect switch out indices - auto frame_token = all_token[index * config.block_size + thread_x()]; - $if (frame_token != work_stat[1]) { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * config.block_size + thread_x(); - }; - }; - sync_block(); - $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work - $for (index, 0u, g_fac) { //swap frames - auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); - auto g_queue_id = index * config.block_size + thread_x(); - auto coro_token = all_token[shared_queue_size + g_queue_id]; - $if (coro_token == work_stat[1]) { - auto id = work_offset.atomic(1).fetch_add(1u); - $if (id < work_offset[0]) { - auto dst = path_id[id]; - auto frame_token = all_token[dst]; - $if (coro_token != 0u) { - $if (frame_token != 0u) { - auto g_state = _global_frame->read(global_id); - _global_frame->write(global_id, frames[dst]); - frames[dst] = g_state; - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - } - $else { - auto g_state = _global_frame->read(global_id); - frames[dst] = g_state; - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; - } - $else { - $if (frame_token != 0u) { - _global_frame->write(global_id, frames[dst]); - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; - }; - }; - }; - }; - }; - } - auto gen_st = workload[0]; - sync_block(); - auto pid = def(0u); - if (use_global) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!use_global) { - launch_condition = (thread_x() < work_offset[0]); - } else { - launch_condition = (all_token[pid] == work_stat[1]); - } - $if (launch_condition) { - $switch (all_token[pid]) { - $case (0u) { - $if (gen_st + thread_x() < workload[1]) { - work_counter.atomic(0u).fetch_sub(1u); - initialize_coroframe(frames[pid], def(gen_st + thread_x(), 0, 0)); - (*coroutine)(frames[pid], args...);//only work when kernel 0s are continue - auto nxt = read_promise(frames[pid], "coro_token") & token_mask; - all_token[pid] = nxt; - work_counter.atomic(nxt).fetch_add(1u); - workload.atomic(0).fetch_add(1u); - }; - }; - for (auto i = 1u; i < max_sub_coro; ++i) { - $case (i) { - work_counter.atomic(i).fetch_sub(1u); - (*coroutine)[i](frames[pid], args...); - auto nxt = read_promise(frames[pid], "coro_token") & token_mask; - all_token[pid] = nxt; - work_counter.atomic(nxt).fetch_add(1u); - }; - } - }; - }; - sync_block(); - }; - $if (count >= count_limit) { - device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); - $if (thread_x() < max_sub_coro) { - device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); - }; - }; - }; - _pt_shader = device.compile(main_kernel); - Kernel1D clear = [&](BufferUInt global) { - global->write(dispatch_x(), 0u); - }; - Kernel1D initialize_frame = [&](Var> frame_buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - auto frame = def(); - initialize_coroframe(frame, def(0, 0, 0)); - frame_buffer.write(x, frame); - }; - }; - _clear_shader = device.compile(clear); - _initialize_shader = device.compile(initialize_frame); - stream << _clear_shader(_global).dispatch(1u); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept override { - CoroDispatcherBase::operator()(args..., dispatch_size); - _dispatched = false; - _done = false; - _stream << _clear_shader(_global).dispatch(1u); - if(_global_size) { - _stream << _initialize_shader(_global_frame, _global_size).dispatch(_global_size); - } - } -}; - -// a simple wrap that helps submit a coroutine dispatch to the stream -enum struct CmdTag { - AWAIT_STEP, - AWAIT_ALL -}; -template -class CoroAwait { - - friend class CoroDispatcherBase; - -public: - - -private: - CmdTag _tag; - CoroDispatcherBase *_dispatcher; - -private: - CoroAwait(CmdTag tag, - CoroDispatcherBase *dispatcher) noexcept - : _tag{tag}, _dispatcher{dispatcher} {} -public: - void operator()(Stream &stream) && noexcept; -}; - -}// namespace coro - -template -struct luisa::compute::detail::is_stream_event_impl> : std::true_type {}; -}// namespace luisa::compute - -namespace luisa::compute::inline coro { -template -void CoroDispatcherBase::_await_all(Stream &stream) noexcept { - while (!this->all_done()) { - this->_await_step(stream); - } -} -template -CoroAwait CoroDispatcherBase::await_step() noexcept { - return CoroAwait{CmdTag::AWAIT_STEP, this}; -} -template -CoroAwait CoroDispatcherBase::await_all() noexcept { - return CoroAwait{CmdTag::AWAIT_ALL, this}; -} -template -template -detail::ShaderInvoke CoroDispatcherBase::call_shader(Shader &shader, - prototype_to_coro_dispatcher_t... prefix_args) { - auto invoke = shader.partial_invoke(prefix_args...); - std::apply([&invoke](prototype_to_coro_dispatcher_t... args) { - static_cast((invoke << ... << args)); - }, - _args); - return invoke; -} -template -void CoroAwait::operator()(Stream &stream) && noexcept { - switch (_tag) { - case CmdTag::AWAIT_STEP: this->_dispatcher->_await_step(stream); break; - case CmdTag::AWAIT_ALL: this->_dispatcher->_await_all(stream); break; - } -} -/*template -void WavefrontCoroDispatcher::_await_step(Stream &stream) noexcept { - -}*/ -template -void WavefrontCoroDispatcher::_await_step(Stream &stream) noexcept { - if (sort_base_gather) { - auto host_update = [&] { - _host_empty = true; - auto sum = 0u; - for (uint i = 0; i < _max_sub_coro; i++) { - _host_count[i] = (i + 1 == _max_sub_coro ? _max_frame_count : _host_offset[i + 1]) - _host_offset[i]; - _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); - } - }; - _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], _resume_index, _max_frame_count); - - stream << _sort_temp_storage.hist_buffer.view(0, _max_sub_coro).copy_to(_host_offset.data()) - << host_update - << synchronize(); - if (_debug) - for (int i = 0; i < _max_sub_coro; ++i) { - LUISA_INFO("kernel {}: total {}", i, _host_count[i]); - } - if (_host_count[0] > _max_frame_count * (0.5) && !all_dispatched()) { - auto gen_count = std::min(this->_dispatch_size - this->_dispatch_counter, _host_count[0]); - if (_debug) { - LUISA_INFO("Gen {} new frame", gen_count); - } - if (_host_count[0] != _max_frame_count && use_compact) { - stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream << _compact_shader(_resume_index, _frame, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); - } - stream << this->template call_shader<1, Buffer, Buffer, uint, Container, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, _max_frame_count - _host_count[0], _frame, _dispatch_counter, _max_frame_count) - .dispatch(gen_count); - _dispatch_counter += gen_count; - _host_empty = false; - } else { - for (uint i = 1; i < _max_sub_coro; i++) { - if (_debug) - LUISA_INFO("launch kernel {} for count {}", i, _host_count[i]); - if (_host_count[i] > 0) { - if (_have_hint[i]) { - BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; - BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } else { - - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } - } - } - } - stream << synchronize(); - } else { - if (_debug) - for (int i = 0; i < _max_sub_coro; ++i) { - LUISA_INFO("kernel {}: total {}", i, _host_count[i]); - } - stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); - stream << _gather_shader(_resume_index, _resume_offset, _frame, _max_frame_count).dispatch(_max_frame_count); - if (_host_count[0] > _max_frame_count / 2 && !all_dispatched()) { - auto gen_count = std::min(this->_dispatch_size - this->_dispatch_counter, _host_count[0]); - if (_debug) { - LUISA_INFO("Gen {} new frame", gen_count); - } - if (_host_count[0] != _max_frame_count && use_compact) { - stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream - << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), _frame, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); - } - stream << this->template call_shader<1, Buffer, Buffer, uint, Container, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, _max_frame_count - _host_count[0], _frame, _dispatch_counter, _max_frame_count) - .dispatch(gen_count); - _dispatch_counter += gen_count; - _host_empty = false; - } else { - for (uint i = 1; i < _max_sub_coro; i++) { - if (_debug) - LUISA_INFO("launch kernel {} for count {}", i, _host_count[i]); - if (_host_count[i] > 0) { - if (_have_hint[i]) { - BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; - BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } else { - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } - } - } - } - auto host_update = [&] { - _host_empty = true; - auto sum = 0u; - for (uint i = 0; i < _max_sub_coro; i++) { - _host_offset[i] = sum; - sum += _host_count[i]; - _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); - } - }; - stream << _resume_count.view(0, _max_sub_coro).copy_to(_host_count.data()) - << host_update; - stream << synchronize(); - } -} -template -bool WavefrontCoroDispatcher::all_dispatched() const noexcept { - return this->_dispatch_size == this->_dispatch_counter; -}; -template -bool WavefrontCoroDispatcher::all_done() const noexcept { - return all_dispatched() && _host_empty; -}; - -template -void PersistentCoroDispatcher::_await_step(Stream &stream) noexcept { - LUISA_ERROR("PersistentCoroDispatcher can only be used with await_all!"); -} -template -void PersistentCoroDispatcher::_await_all(Stream &stream) noexcept { - _dispatched = true; - stream << this->template call_shader<1, Buffer, uint>(_pt_shader, _global, this->_dispatch_size).dispatch(_max_thread_count) - << [&] { _done = true; }; - stream << synchronize(); -} - -template -bool PersistentCoroDispatcher::all_dispatched() const noexcept { - return _dispatched; -} -template -bool PersistentCoroDispatcher::all_done() const noexcept { - return _done; -} -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/v2/coro_frame.h b/include/luisa/coro/coro_frame.h similarity index 97% rename from include/luisa/coro/v2/coro_frame.h rename to include/luisa/coro/coro_frame.h index 0eb374bb4..bac5580f8 100644 --- a/include/luisa/coro/v2/coro_frame.h +++ b/include/luisa/coro/coro_frame.h @@ -5,7 +5,7 @@ #pragma once #include -#include +#include #include namespace luisa::compute::coroutine { diff --git a/include/luisa/coro/v2/coro_frame_buffer.h b/include/luisa/coro/coro_frame_buffer.h similarity index 99% rename from include/luisa/coro/v2/coro_frame_buffer.h rename to include/luisa/coro/coro_frame_buffer.h index 72d54f159..d2d15c9ac 100644 --- a/include/luisa/coro/v2/coro_frame_buffer.h +++ b/include/luisa/coro/coro_frame_buffer.h @@ -6,7 +6,7 @@ #include #include -#include +#include namespace luisa::compute { diff --git a/include/luisa/coro/v2/coro_frame_desc.h b/include/luisa/coro/coro_frame_desc.h similarity index 100% rename from include/luisa/coro/v2/coro_frame_desc.h rename to include/luisa/coro/coro_frame_desc.h diff --git a/include/luisa/coro/v2/coro_frame_smem.h b/include/luisa/coro/coro_frame_smem.h similarity index 98% rename from include/luisa/coro/v2/coro_frame_smem.h rename to include/luisa/coro/coro_frame_smem.h index 08a9b849b..42159b6a4 100644 --- a/include/luisa/coro/v2/coro_frame_smem.h +++ b/include/luisa/coro/coro_frame_smem.h @@ -4,10 +4,8 @@ #pragma once -#include "coro_token.h" - #include -#include +#include namespace luisa::compute { diff --git a/include/luisa/coro/v2/coro_frame_soa.h b/include/luisa/coro/coro_frame_soa.h similarity index 99% rename from include/luisa/coro/v2/coro_frame_soa.h rename to include/luisa/coro/coro_frame_soa.h index 350b7bde4..c1bc45d68 100644 --- a/include/luisa/coro/v2/coro_frame_soa.h +++ b/include/luisa/coro/coro_frame_soa.h @@ -4,13 +4,11 @@ #pragma once -#include "luisa/runtime/device.h" -#include "spdlog/fmt/bundled/compile.h" - +#include #include #include #include -#include +#include #include #include diff --git a/include/luisa/coro/v2/coro_func.h b/include/luisa/coro/coro_func.h similarity index 99% rename from include/luisa/coro/v2/coro_func.h rename to include/luisa/coro/coro_func.h index 8cceb9ceb..5adc60712 100644 --- a/include/luisa/coro/v2/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -5,8 +5,8 @@ #pragma once #include -#include -#include +#include +#include #include namespace luisa::compute::coroutine { diff --git a/include/luisa/coro/coro_graph.h b/include/luisa/coro/coro_graph.h index 0bf726033..99a26f066 100644 --- a/include/luisa/coro/coro_graph.h +++ b/include/luisa/coro/coro_graph.h @@ -1,51 +1,72 @@ +// +// Created by Mike on 2024/5/8. +// + #pragma once +#include +#include #include -#include +#include +#include + +namespace luisa::compute::coroutine { + +class CoroFrameDesc; -namespace luisa::compute::inline coro { +class LC_CORO_API CoroGraph { -class CoroGraph { +public: + using CC = luisa::shared_ptr;// current continuation function + +public: + class LC_CORO_API Node { + + private: + luisa::vector _input_fields; + luisa::vector _output_fields; + luisa::vector _targets; + CC _cc; + + public: + Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept; + ~Node() noexcept; + + public: + [[nodiscard]] auto input_fields() const noexcept { return luisa::span{_input_fields}; } + [[nodiscard]] auto output_fields() const noexcept { return luisa::span{_output_fields}; } + [[nodiscard]] auto targets() const noexcept { return luisa::span{_targets}; } + [[nodiscard]] Function cc() const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; + }; private: - luisa::unordered_map _nodes; - uint _entry; - const Type *_state_type; - luisa::unordered_map _designated_state_members; + luisa::shared_ptr _frame; + luisa::unordered_map _nodes; + luisa::unordered_map _named_tokens; + +public: + CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept; + ~CoroGraph() noexcept; public: - // for construction only - CoroGraph(uint entry, const Type *state_type) noexcept : _entry{entry}, _state_type{state_type} {} - [[nodiscard]] CoroNode *add_node(uint token, CoroNode::Func f) noexcept { - auto node = CoroNode{this, std::move(f)}; - auto [iter, success] = _nodes.emplace(token, std::move(node)); - LUISA_ASSERT(success, "Coroutine node (token = {}) already exists.", token); - return &(iter->second); - } - void designate_state_member(luisa::string name, uint index) noexcept { - auto [iter, success] = _designated_state_members.emplace(name, index); - LUISA_ASSERT(success, "State member '{}' already designated.", name); - } + // create a coroutine graph from a coroutine function definition + [[nodiscard]] static luisa::shared_ptr create(Function coroutine) noexcept; public: - [[nodiscard]] const CoroNode *entry() const noexcept { - return this->node(_entry); - } - [[nodiscard]] const CoroNode *node(uint token) const noexcept { - auto iter = _nodes.find(token); - LUISA_ASSERT(iter != _nodes.cend(), - "Coroutine node (token = {}) not found.", token); - return &(iter->second); - } + [[nodiscard]] auto frame() const noexcept { return _frame.get(); } + [[nodiscard]] auto &shared_frame() const noexcept { return _frame; } [[nodiscard]] auto &nodes() const noexcept { return _nodes; } - [[nodiscard]] auto state_type() const noexcept { return _state_type; } - [[nodiscard]] auto &designated_state_members() const noexcept { return _designated_state_members; } - [[nodiscard]] uint designated_state_member(luisa::string_view name) const noexcept { - auto iter = _designated_state_members.find(name); - LUISA_ASSERT(iter != _designated_state_members.cend(), - "State member '{}' not designated.", name); - return iter->second; - } + [[nodiscard]] auto &named_tokens() const noexcept { return _named_tokens; } + [[nodiscard]] const Node &entry() const noexcept; + [[nodiscard]] const Node &node(CoroToken index) const noexcept; + [[nodiscard]] const Node &node(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; }; -}// namespace luisa::compute::inline coro +}// namespace luisa::compute::co diff --git a/include/luisa/coro/coro_node.h b/include/luisa/coro/coro_node.h deleted file mode 100644 index b91d05201..000000000 --- a/include/luisa/coro/coro_node.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace luisa::compute::inline coro { - -class CoroGraph; - -class CoroNode { - - friend class CoroGraph; - -public: - using Func = luisa::shared_ptr; - -private: - const CoroGraph *_graph; - Func _function; - -protected: - CoroNode(const CoroGraph *graph, Func function) noexcept - : _graph{graph}, _function{std::move(function)} {} - -public: - luisa::vector input_state_members; - luisa::vector output_state_members; - -public: - [[nodiscard]] auto graph() const noexcept { return _graph; } - [[nodiscard]] auto function() const noexcept { return _function->function(); } -}; - -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/v2/coro_scheduler.h b/include/luisa/coro/coro_scheduler.h similarity index 100% rename from include/luisa/coro/v2/coro_scheduler.h rename to include/luisa/coro/coro_scheduler.h diff --git a/include/luisa/coro/v2/coro_token.h b/include/luisa/coro/coro_token.h similarity index 81% rename from include/luisa/coro/v2/coro_token.h rename to include/luisa/coro/coro_token.h index 258aaa7cd..15032d4e1 100644 --- a/include/luisa/coro/v2/coro_token.h +++ b/include/luisa/coro/coro_token.h @@ -8,4 +8,5 @@ namespace luisa::compute::coroutine { using CoroToken = unsigned int; constexpr CoroToken coro_token_entry = 0u; constexpr CoroToken coro_token_terminal = 0x8000'0000u; +constexpr CoroToken coro_token_valid_mask = 0x7fff'ffffu; }// namespace luisa::compute::coro_v2 diff --git a/include/luisa/coro/coro_transition.h b/include/luisa/coro/coro_transition.h deleted file mode 100644 index f63815461..000000000 --- a/include/luisa/coro/coro_transition.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -namespace luisa::compute::inline coro { - -struct CoroTransition { - uint destination; -}; - -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/radix_sort.h b/include/luisa/coro/radix_sort.h index 96d8ed046..0072123ac 100644 --- a/include/luisa/coro/radix_sort.h +++ b/include/luisa/coro/radix_sort.h @@ -6,7 +6,7 @@ #include #include #include -#include +#include namespace luisa::compute { namespace radix_sort { const uint HIST_BLOCK_SIZE = 128; diff --git a/include/luisa/coro/v2/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h similarity index 98% rename from include/luisa/coro/v2/schedulers/persistent_threads.h rename to include/luisa/coro/schedulers/persistent_threads.h index e233187e5..3972b77fc 100644 --- a/include/luisa/coro/v2/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -6,7 +6,8 @@ #include #include -#include +#include +#include namespace luisa::compute::coroutine { diff --git a/include/luisa/coro/v2/schedulers/state_machine.h b/include/luisa/coro/schedulers/state_machine.h similarity index 94% rename from include/luisa/coro/v2/schedulers/state_machine.h rename to include/luisa/coro/schedulers/state_machine.h index d8136c8c9..a2045b69f 100644 --- a/include/luisa/coro/v2/schedulers/state_machine.h +++ b/include/luisa/coro/schedulers/state_machine.h @@ -4,10 +4,10 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include namespace luisa::compute::coroutine { diff --git a/include/luisa/coro/v2/schedulers/wavefront.h b/include/luisa/coro/schedulers/wavefront.h similarity index 97% rename from include/luisa/coro/v2/schedulers/wavefront.h rename to include/luisa/coro/schedulers/wavefront.h index 9a999a264..5602b2012 100644 --- a/include/luisa/coro/v2/schedulers/wavefront.h +++ b/include/luisa/coro/schedulers/wavefront.h @@ -4,11 +4,11 @@ #pragma once -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace luisa::compute::coroutine { @@ -139,10 +139,10 @@ class WavefrontCoroScheduler : public CoroScheduler { device_log("Index out of range {}/{}", index, _config.thread_count); }; if (_config.soa) { - return _frame_soa->read_field(index, "target_token") & token_mask; + return _frame_soa->read_field(index, "target_token") & coro_token_valid_mask; } else { CoroFrame frame = _frame_buffer->read(index); - return frame.target_token & token_mask; + return frame.target_token & coro_token_valid_mask; } }; Callable identical = [](UInt index) { @@ -216,7 +216,7 @@ class WavefrontCoroScheduler : public CoroScheduler { _frame_buffer->write(frame_id, frame); } if (!_config.sort) { - auto nxt = frame.target_token & token_mask; + auto nxt = frame.target_token & coro_token_valid_mask; count.atomic(nxt).fetch_add(1u); } }; @@ -250,7 +250,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } if (!_config.sort) { - auto nxt = frame.target_token & token_mask; + auto nxt = frame.target_token & coro_token_valid_mask; $if (nxt < _max_sub_coro) { count.atomic(nxt).fetch_add(1u); }; @@ -276,10 +276,10 @@ class WavefrontCoroScheduler : public CoroScheduler { auto x = dispatch_x(); UInt r_id; if (_config.soa) { - r_id = _frame_soa->read_field(x, "target_token") & token_mask; + r_id = _frame_soa->read_field(x, "target_token") & coro_token_valid_mask; } else { auto frame = _frame_buffer->read(x); - r_id = frame.target_token & token_mask; + r_id = frame.target_token & coro_token_valid_mask; } auto q_id = prefix.atomic(r_id).fetch_add(1u); index.write(q_id, x); @@ -297,7 +297,7 @@ class WavefrontCoroScheduler : public CoroScheduler { CoroFrame frame = _frame_buffer->read(empty_offset + x); token = frame.target_token; } - $if ((token & token_mask) != 0u) { + $if ((token & coro_token_valid_mask) != 0u) { auto res = _global_buffer->atomic(0u).fetch_add(1u); auto slot = index.read(res); if (!_config.sort) { diff --git a/include/luisa/coro/shader_scheduler.h b/include/luisa/coro/shader_scheduler.h deleted file mode 100644 index e69de29bb..000000000 diff --git a/include/luisa/coro/v2/coro_graph.h b/include/luisa/coro/v2/coro_graph.h deleted file mode 100644 index 3f2b765f9..000000000 --- a/include/luisa/coro/v2/coro_graph.h +++ /dev/null @@ -1,72 +0,0 @@ -// -// Created by Mike on 2024/5/8. -// - -#pragma once - -#include -#include -#include -#include -#include - -namespace luisa::compute::coroutine { - -class CoroFrameDesc; - -class LC_CORO_API CoroGraph { - -public: - using CC = luisa::shared_ptr;// current continuation function - -public: - class LC_CORO_API Node { - - private: - luisa::vector _input_fields; - luisa::vector _output_fields; - luisa::vector _targets; - CC _cc; - - public: - Node(luisa::vector input_fields, - luisa::vector output_fields, - luisa::vector targets, - CC current_continuation) noexcept; - ~Node() noexcept; - - public: - [[nodiscard]] auto input_fields() const noexcept { return luisa::span{_input_fields}; } - [[nodiscard]] auto output_fields() const noexcept { return luisa::span{_output_fields}; } - [[nodiscard]] auto targets() const noexcept { return luisa::span{_targets}; } - [[nodiscard]] Function cc() const noexcept; - [[nodiscard]] luisa::string dump() const noexcept; - }; - -private: - luisa::shared_ptr _frame; - luisa::unordered_map _nodes; - luisa::unordered_map _named_tokens; - -public: - CoroGraph(luisa::shared_ptr frame_desc, - luisa::unordered_map nodes, - luisa::unordered_map named_tokens) noexcept; - ~CoroGraph() noexcept; - -public: - // create a coroutine graph from a coroutine function definition - [[nodiscard]] static luisa::shared_ptr create(Function coroutine) noexcept; - -public: - [[nodiscard]] auto frame() const noexcept { return _frame.get(); } - [[nodiscard]] auto &shared_frame() const noexcept { return _frame; } - [[nodiscard]] auto &nodes() const noexcept { return _nodes; } - [[nodiscard]] auto &named_tokens() const noexcept { return _named_tokens; } - [[nodiscard]] const Node &entry() const noexcept; - [[nodiscard]] const Node &node(CoroToken index) const noexcept; - [[nodiscard]] const Node &node(luisa::string_view name) const noexcept; - [[nodiscard]] luisa::string dump() const noexcept; -}; - -}// namespace luisa::compute::co diff --git a/include/luisa/dsl/builtin.h b/include/luisa/dsl/builtin.h index 00b5e4560..a9382b6cb 100644 --- a/include/luisa/dsl/builtin.h +++ b/include/luisa/dsl/builtin.h @@ -1771,24 +1771,6 @@ template {LUISA_EXPR(x)})); } -/// Coroutines -/// coroutine initialization -template - requires is_coroframe_struct_v> && - is_dsl_v && - is_same_expr_v -inline void initialize_coroframe(F &&frame, V &&coro_id) noexcept { - detail::FunctionBuilder::current()->initialize_coroframe(LUISA_EXPR(frame), LUISA_EXPR(coro_id)); -} - -template - requires is_coroframe_struct_v> -[[nodiscard]] inline auto make_coroframe(Expr coro_id) noexcept { - Var f; - initialize_coroframe(f, coro_id); - return f; -} - [[nodiscard]] inline auto coro_id() noexcept { return def(detail::FunctionBuilder::current()->coro_id()); } diff --git a/include/luisa/dsl/func.h b/include/luisa/dsl/func.h index 693c78980..840332e64 100644 --- a/include/luisa/dsl/func.h +++ b/include/luisa/dsl/func.h @@ -15,7 +15,6 @@ #include #include #include -#include namespace luisa::compute { @@ -150,9 +149,6 @@ class FunctionBuilder; [[nodiscard]] LC_DSL_API luisa::shared_ptr transform_function(Function callable) noexcept; -[[nodiscard]] LC_DSL_API luisa::shared_ptr -transform_coroutine(Type *corotype, coro::CoroGraph &graph, luisa::unordered_map> &sub_builders, Function callable) noexcept; - }// namespace detail template @@ -335,18 +331,7 @@ class CallableInvoke { return luisa::span{_args.data(), _arg_count}; } }; -class CoroutineInvoke : public ShaderInvokeBase { - CoroutineInvoke(uint64_t handle, size_t arg_count, size_t uniform_size) noexcept - : ShaderInvokeBase{handle, arg_count, uniform_size} {} - [[nodiscard]] auto dispatch(uint size_x) && noexcept { - return this->_parallelize(uint3{size_x, 1u, 1u}); - } - [[nodiscard]] auto dispatch(const IndirectDispatchBuffer &indirect_buffer, - uint32_t offset = 0, - uint32_t max_dispatch_size = std::numeric_limits::max()) && noexcept { - return this->_parallelize(indirect_buffer, offset, max_dispatch_size).build(); - } -}; + }// namespace detail /// Callable class. Callable is not allowed, unless T is a function type. @@ -362,8 +347,6 @@ struct is_callable> : std::true_type {}; template class Callable { friend class CallableLibrary; - template - friend class Coroutine; static_assert( std::negation_v, @@ -519,96 +502,6 @@ class ExternalCallable { } }; -/// Coroutine class. Coroutine is not allowed, unless T is a function type. -template -class Coroutine { - static_assert(always_false_v); -}; - -template -struct is_callable> : std::true_type {}; -template -class Coroutine { - static_assert(std::negation_v...>>); - -private: - using CoroID = uint; - static_assert(std::is_lvalue_reference_v); - static_assert(is_coroframe_struct_v>>); - luisa::shared_ptr _builder; - luisa::unordered_map> _sub_callables; - uint _uniform_size; - luisa::unordered_map _coro_tokens; - coro::CoroGraph _coro_graph; -public: - template - requires std::negation_v>> && - std::negation_v>> - Coroutine(Def &&f) noexcept : _coro_graph{0, Type::of>()} { - auto ast = detail::FunctionBuilder::define_coroutine([&f] { - static_assert(std::is_invocable_v, detail::prototype_to_creation_t...>); - auto create = [](auto &&def, std::index_sequence) noexcept { - using arg_tuple = std::tuple; - using var_tuple = std::tuple>, - Var>...>; - using tag_tuple = std::tuple, detail::prototype_to_creation_tag_t...>; - - auto args = detail::create_argument_definitions(std::tuple<>{}); - static_assert(std::tuple_size_v == 1 + sizeof...(Args)); - return luisa::invoke(std::forward(def), - static_cast> &&>(std::get(args))...); - }; - //static_assert(std::is_same_v, "coroutine should return void"); - create(std::forward(f), std::index_sequence_for{}); - detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types - }); - _coro_tokens = ast->coro_tokens(); - Type *frame = const_cast(Type::of>()); - auto sub_builder = luisa::unordered_map>{}; - _builder = detail::transform_coroutine(frame, _coro_graph, sub_builder, ast->function()); - _sub_callables.clear(); - _uniform_size = ShaderDispatchCmdEncoder::compute_uniform_size(_builder->function().unbound_arguments()); - for (auto v : sub_builder) { - _sub_callables.insert(std::make_pair(v.first, Callable(v.second))); - } - } - - /// Get the underlying AST - [[nodiscard]] auto function() const noexcept { return Function{_builder.get()}; } - [[nodiscard]] auto const &function_builder() const & noexcept { return _builder; } - [[nodiscard]] auto &&function_builder() && noexcept { return std::move(_builder); } - [[nodiscard]] auto const suspend_count() noexcept { return _coro_tokens.size(); } - [[nodiscard]] auto const &coro_tokens() const & noexcept { return _coro_tokens; } - [[nodiscard]] auto const &graph() noexcept { return _coro_graph; } - //Call from start of coroutine - auto operator()(detail::prototype_to_callable_invocation_t type, - detail::prototype_to_callable_invocation_t... args) const noexcept { - - detail::CallableInvoke invoke; - static_cast((invoke << type)); - static_cast((invoke << ... << args)); - detail::FunctionBuilder::current()->call( - _builder->function(), invoke.args()); - } - auto operator()(detail::prototype_to_shader_invocation_t... args) const noexcept { - using invoke_type = detail::CoroutineInvoke; - auto arg_count = (0u + ... + detail::shader_argument_encode_count::value); - invoke_type invoke{arg_count, _uniform_size}; - static_cast((invoke << ... << args)); - return invoke; - } - //Get Callable from certain suspend point - auto operator[](CoroID index) const noexcept { - auto builder = _sub_callables.find(index); - LUISA_ASSERT(builder != _sub_callables.end(), "coroutine index out of range"); - return builder->second; - } - auto operator[](luisa::string_view index) const noexcept { - auto coro_token = function_builder()->coro_tokens().at(index); - return (*this)[coro_token]; - } -}; namespace detail { template @@ -702,11 +595,6 @@ struct dsl_function> { using type = T; }; -template -struct dsl_function> { - using type = T; -}; - template using dsl_function_t = typename dsl_function::type; @@ -724,9 +612,6 @@ Kernel3D(T &&) -> Kernel3D>>; template Callable(T &&) -> Callable>>; -template -Coroutine(T &&) -> Coroutine>>; - namespace detail { struct CallableOutliner { diff --git a/include/luisa/dsl/struct.h b/include/luisa/dsl/struct.h index e7f043b34..264064654 100644 --- a/include/luisa/dsl/struct.h +++ b/include/luisa/dsl/struct.h @@ -333,367 +333,3 @@ struct luisa_compute_extension {}; } \ template<> \ struct luisa_compute_extension final : luisa::compute::detail::Ref -#define LUISA_COROFRAME_STRUCT_EXT(S) \ - template<> \ - struct luisa_compute_extension; \ - namespace luisa::compute { \ - template<> \ - struct Expr { \ - private: \ - using this_type = S; \ - const Expression *_expression; \ - \ - public: \ - explicit Expr(const Expression *e) noexcept \ - : _expression{e} {} \ - [[nodiscard]] auto expression() const noexcept { return this->_expression; } \ - Expr(Expr &&another) noexcept = default; \ - Expr(const Expr &another) noexcept = default; \ - Expr &operator=(Expr) noexcept = delete; \ - template \ - [[nodiscard]] auto promise(luisa::string_view name) const noexcept { \ - return luisa::compute::dsl::def( \ - luisa::compute::detail::FunctionBuilder::current()->read_promise_( \ - luisa::compute::Type::of(), this->_expression, name)); \ - } \ - [[nodiscard]] auto target_token() const noexcept { \ - return this->promise("coro_token"); \ - } \ - [[nodiscard]] auto coro_id() const noexcept { \ - return this->promise("coro_id"); \ - } \ - }; \ - namespace detail { \ - template<> \ - struct Ref { \ - private: \ - using this_type = S; \ - const Expression *_expression; \ - \ - public: \ - explicit Ref(const Expression *e) noexcept \ - : _expression{e} {} \ - [[nodiscard]] auto expression() const noexcept { return this->_expression; } \ - Ref(Ref &&another) noexcept = default; \ - Ref(const Ref &another) noexcept = default; \ - [[nodiscard]] operator Expr() const noexcept { \ - return Expr{this->expression()}; \ - } \ - template \ - void operator=(Rhs &&rhs) & noexcept { dsl::assign(*this, std::forward(rhs)); } \ - void operator=(Ref rhs) & noexcept { (*this) = Expr{rhs}; } \ - [[nodiscard]] auto operator->() noexcept { \ - return reinterpret_cast *>(this); \ - } \ - [[nodiscard]] auto operator->() const noexcept { \ - return reinterpret_cast *>(this); \ - } \ - template \ - [[nodiscard]] auto promise(luisa::string_view name) const noexcept { \ - return luisa::compute::dsl::def( \ - luisa::compute::detail::FunctionBuilder::current()->read_promise_( \ - luisa::compute::Type::of(), this->_expression, name)); \ - } \ - [[nodiscard]] auto target_token() const noexcept { \ - return this->promise("coro_token"); \ - } \ - [[nodiscard]] auto coro_id() const noexcept { \ - return this->promise("coro_id"); \ - } \ - }; \ - } \ - } - -#define LUISA_DERIVE_COROFRAME_SOA(S) \ - namespace luisa::compute { \ - template<> \ - class SOA; \ - template<> \ - class SOAView; \ - \ - template<> \ - struct Expr> { \ - private: \ - Expr _buffer; \ - Expr _soa_offset; \ - Expr _soa_size; \ - Expr _element_offset; \ - luisa::vector> _member_offsets; \ - public: \ - Expr(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size, \ - Expr element_offset) noexcept \ - : _buffer{buffer}, \ - _soa_offset{soa_offset}, \ - _soa_size{soa_size}, \ - _element_offset{element_offset} { \ - auto type = Type::of() -> corotype(); \ - auto tot_size = dsl::def(soa_offset); \ - for (auto &mem : type->members()) { \ - _member_offsets.push_back(dsl::def(tot_size)); \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto member_size = soa_size * stride; \ - member_size = align_to_soa_cache_line(member_size); \ - tot_size += member_size; \ - } \ - } \ - [[nodiscard]] auto buffer() const noexcept { return _buffer; } \ - [[nodiscard]] auto soa_offset() const noexcept { return _soa_offset; } \ - [[nodiscard]] auto soa_size() const noexcept { return _soa_size; } \ - [[nodiscard]] auto element_offset() const noexcept { return _element_offset; } \ - private: \ - using this_type = S; \ - \ - public: \ - \ - Expr(SOAView soa) noexcept; \ - \ - Expr(const SOA &soa) noexcept; \ - \ - template \ - [[nodiscard]] auto read_coro_id(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - constexpr auto i = 0u; \ - auto type = Type::of(); \ - uint stride = ((type->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto id = dsl::def(std::forward(index)); \ - auto data = builder->call( \ - type, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - auto v = builder->local(type); \ - builder->assign(v, data); \ - return Var{v}; \ - } \ - \ - template \ - [[nodiscard]] auto read_coro_token(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - constexpr auto i = 1u; \ - auto type = Type::of(); \ - uint stride = ((type->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto id = dsl::def(std::forward(index)); \ - auto data = builder->call( \ - type, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - auto v = builder->local(type); \ - builder->assign(v, data); \ - return Var{v}; \ - } \ - \ - template \ - [[nodiscard]] auto read(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto ret = builder->local(Type::of()); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto i = 0u; i < type->members().size(); ++i) { \ - auto &mem = type->members()[i]; \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto data = builder->call( \ - mem, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - builder->assign(builder->member(mem, ret, i), data); \ - member_index += mem->size(); \ - } \ - return Var{ret}; \ - } \ - template \ - [[nodiscard]] auto read(I &&index, luisa::vector members) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto ret = builder->local(Type::of()); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto r_mem = 0u; r_mem < members.size(); ++r_mem) { \ - auto i = members[r_mem]; \ - auto &mem = type->members()[i]; \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto data = builder->call( \ - mem, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - builder->assign(builder->member(mem, ret, i), data); \ - member_index += mem->size(); \ - } \ - return Var{ret}; \ - } \ - template \ - void write(I &&index, Expr value) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto i = 0u; i < type->members().size(); ++i) { \ - auto &mem = type->members()[i]; \ - auto stride = (((uint)mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint)); \ - builder->call(CallOp::BYTE_BUFFER_WRITE, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint)), \ - builder->member(mem, value.expression(), i)}); \ - member_index += mem->size(); \ - } \ - } \ - template \ - void write(I &&index, Expr value, luisa::vector members) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto w_mem = 0u; w_mem < members.size(); ++w_mem) { \ - auto i = members[w_mem]; \ - auto &mem = type->members()[i]; \ - auto stride = (((uint)mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint)); \ - builder->call(CallOp::BYTE_BUFFER_WRITE, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint)), \ - builder->member(mem, value.expression(), i)}); \ - member_index += mem->size(); \ - } \ - } \ - [[nodiscard]] auto operator->() const noexcept { return this; } \ - }; \ - template<> \ - class SOAView { \ - protected: \ - ByteBuffer *_bufferview{nullptr}; \ - private: \ - using View = SOAView; \ - \ - private: \ - uint _soa_offset{}; \ - uint _soa_size{}; \ - uint _elem_offset{}; \ - uint _elem_size{}; \ - public: \ - static constexpr auto element_stride = 0; \ - SOAView() noexcept = default; \ - SOAView(ByteBuffer *buffer, \ - size_t soa_offset, \ - size_t soa_size, \ - size_t elem_offset, \ - size_t elem_size) noexcept \ - : _bufferview{buffer}, \ - _soa_offset{static_cast(soa_offset)}, \ - _soa_size{static_cast(soa_size)}, \ - _elem_offset{static_cast(elem_offset)}, \ - _elem_size{static_cast(elem_size)} {} \ - [[nodiscard]] auto &buffer() const noexcept { return *_bufferview; } \ - [[nodiscard]] auto bufferview() const noexcept { return _bufferview; } \ - [[nodiscard]] auto soa_offset() const noexcept { return _soa_offset; } \ - [[nodiscard]] auto soa_size() const noexcept { return _soa_size; } \ - [[nodiscard]] auto element_offset() const noexcept { return _elem_offset; } \ - [[nodiscard]] auto element_size() const noexcept { return _elem_size; } \ - [[nodiscard]] auto operator->() const noexcept { return Expr{*this}; } \ - [[nodiscard]] auto subview(size_t offset, size_t size) const noexcept { \ - if (!(offset + size <= this->element_size())) [[unlikely]] { \ - LUISA_ERROR_WITH_LOCATION("SOAView::subview out of range."); \ - } \ - return View{this->bufferview(), \ - this->soa_offset(), \ - this->soa_size(), \ - this->element_offset() + offset, \ - size}; \ - } \ - \ - public: \ - [[nodiscard]] static auto compute_soa_size(auto soa_size) noexcept { \ - auto type = Type::of() -> corotype(); \ - size_t tot_size = 0u; \ - for (auto mem : type->members()) { \ - tot_size += align_to_soa_cache_line(soa_size * \ - ((mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint))); \ - } \ - return tot_size; \ - } \ - }; \ - template<> \ - class SOA : public SOAView { \ - \ - private: \ - ByteBuffer _buffer; \ - \ - private: \ - SOA(ByteBuffer buffer, size_t size) noexcept \ - : _buffer{std::move(buffer)}, \ - SOAView{&buffer, 0u, size, 0u, size} { \ - this->_bufferview = &_buffer; \ - } \ - \ - public: \ - SOA(SOA &&x) noexcept { \ - *this = std::move(x); \ - this->_bufferview = &_buffer; \ - } \ - SOA &operator=(SOA &&x) noexcept { \ - *this = std::move(x); \ - this->_bufferview = &_buffer; \ - return *this; \ - } \ - SOA() noexcept = default; \ - SOA(Device &device, size_t elem_count) noexcept \ - : SOA{device.create_byte_buffer(SOAView::compute_soa_size(elem_count) * sizeof(uint)), \ - elem_count} {} \ - [[nodiscard]] auto view() const noexcept { return SOAView{*this}; } \ - }; \ - Expr>::Expr(SOAView soa) noexcept \ - : Expr{*soa.bufferview(), soa.soa_offset(), soa.soa_size(), soa.element_offset()} {} \ - \ - Expr>::Expr(const SOA &soa) noexcept \ - : Expr{soa.view()} {} \ - \ - template<> \ - struct Var> : public Expr> { \ - \ - private: \ - using Base = Expr>; \ - \ - Var(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size, \ - Expr elem_offset) noexcept \ - : Base{buffer, soa_offset, soa_size, elem_offset} {} \ - \ - Var(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size) noexcept \ - : Var{buffer, soa_offset, soa_size, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - Var(Expr buffer, \ - Expr soa_offset) noexcept \ - : Var{buffer, soa_offset, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - Var(Expr buffer) noexcept \ - : Var{buffer, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - public: \ - Var(detail::ArgumentCreation) noexcept \ - : Var{Var{detail::ArgumentCreation{}}} {} \ - }; \ - }// namespace luisa::compute - -#define LUISA_COROFRAME_STRUCT(S) \ - LUISA_COROFRAME_STRUCT_REFLECT(S, #S); \ - LUISA_COROFRAME_STRUCT_EXT(S) \ - LUISA_DERIVE_COROFRAME_SOA(S) \ - template<> \ - struct luisa_compute_extension final : luisa::compute::detail::Ref diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index d1eb7371f..1eb0a4745 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -49,23 +49,19 @@ #include #include -#include +#include +#include +#include +#include +#include +#include #include -#include -#include +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include #ifdef LUISA_ENABLE_DSL #include diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 69e82d146..68f34c61b 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -1,6 +1,5 @@ #pragma once -#include "luisa/coro/v2/coro_frame.h" #include namespace luisa::compute { @@ -14,6 +13,10 @@ class ByteBufferExprProxy; template class SOA; +namespace coroutine { +class CoroFrame; +}// namespace coroutine + class ByteBufferView; class LC_RUNTIME_API ByteBuffer final : public Resource { diff --git a/include/luisa/runtime/shader.h b/include/luisa/runtime/shader.h index 51cb0490a..1d3537c08 100644 --- a/include/luisa/runtime/shader.h +++ b/include/luisa/runtime/shader.h @@ -322,15 +322,6 @@ class Shader final : public ShaderBase { return invoke; } - template - [[nodiscard]] auto partial_invoke(Pre... args) const noexcept { - _check_is_valid(); - using invoke_type = detail::ShaderInvoke; - auto arg_count = (0u + ... + detail::shader_argument_encode_count::value); - invoke_type invoke{handle(), arg_count, _uniform_size}; - static_cast((invoke << ... << args)); - return invoke; - } [[nodiscard]] uint3 block_size() const noexcept { _check_is_valid(); return make_uint3(_block_size[0], _block_size[1], _block_size[2]); diff --git a/src/ast/ast2json.cpp b/src/ast/ast2json.cpp index cb3e34d17..c1128dc0d 100644 --- a/src/ast/ast2json.cpp +++ b/src/ast/ast2json.cpp @@ -588,9 +588,6 @@ class AST2JSON { private: [[nodiscard]] uint _type_index(const Type *type) noexcept { - if (type != nullptr && type->is_materialized_coroframe()) { - return _type_index(type->corotype()); - } if (auto iter = _type_to_index.find(type); iter != _type_to_index.end()) { return iter->second; diff --git a/src/ast/function_builder.cpp b/src/ast/function_builder.cpp index a42371253..085a5079c 100644 --- a/src/ast/function_builder.cpp +++ b/src/ast/function_builder.cpp @@ -133,24 +133,6 @@ void FunctionBuilder::bind_promise_(const Expression *expr, luisa::string name) _create_and_append_statement(expr, std::move(name)); } -const MemberExpr *FunctionBuilder::read_promise_(const Type *type, const Expression *expr, luisa::string_view name) noexcept { - LUISA_ASSERT(expr->type()->is_coroframe(), "Promise reading is only allowed for CoroFrame type"); - auto var = expr->type()->member(name); - if (var != -1) { - auto mem_type = expr->type()->corotype()->members()[var]; - LUISA_ASSERT(*mem_type == *type, - "Promise '{}' type mismatch: expected {}, got {}.", - name, type->description(), mem_type->description()); - return _create_expression(expr->type()->corotype()->members()[var], expr, var); - } - return nullptr; -} -void FunctionBuilder::initialize_coroframe(const luisa::compute::Expression *expr, const luisa::compute::Expression *coro_id) noexcept { - auto member_coro_id = member(Type::of(), expr, 0u); - auto member_coro_token = member(Type::of(), expr, 1u); - assign(member_coro_id, coro_id); - assign(member_coro_token, literal(Type::of(), 0u)); -} const CallExpr *FunctionBuilder::coro_id() noexcept { check_is_coroutine(); return call(Type::of(), CallOp::CORO_ID, {}); @@ -898,11 +880,6 @@ bool FunctionBuilder::requires_autodiff() const noexcept { return _propagated_builtin_callables.uses_autodiff(); } -void FunctionBuilder::coroframe_replace(const Type *type) noexcept { - LUISA_ASSERT(_arguments.size() > 0, "Lack of parameter for coroutine generated callables!"); - _arguments[0]._type = type; -} - bool FunctionBuilder::requires_printing() const noexcept { return _requires_printing; } diff --git a/src/ast/type.cpp b/src/ast/type.cpp index 52a6f29e4..34bce9897 100644 --- a/src/ast/type.cpp +++ b/src/ast/type.cpp @@ -207,43 +207,6 @@ const Type *TypeRegistry::coroframe_type(luisa::string_view name) noexcept { return _register(t); } -void _update(Type *dst, const Type *src) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto dst_inst = static_cast(dst); - auto src_inst = static_cast(src); - dst_inst->alignment = src_inst->alignment; - dst_inst->size = src_inst->size; - LUISA_ASSERT(dst_inst->members.empty(), "{} used as coroframe type " - "for second time!", - dst_inst->description); - dst_inst->members.push_back(src); - //dst_inst->description = src_inst->description; -} - -size_t _add_member(Type *type, const luisa::string &name) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto inst = static_cast(type); - size_t id = inst->member_names.size(); - auto ret = inst->member_names.insert(std::make_pair(name, id)); - return ret.second ? id : -1; -} - -void _set_member_name(Type *type, size_t index, luisa::string name) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto inst = static_cast(type); - auto old_iter = inst->member_names.find(name); - if (old_iter != inst->member_names.end() && - old_iter->second != index) [[unlikely]] { - LUISA_ERROR_WITH_LOCATION( - "Duplicate member name at index {}: {}.", - index, name); - } - LUISA_ASSERT(index < inst->members.size() || index < inst->corotype()->members().size(), - "Invalid member index: {} (total = {}).", - index, std::max(inst->members.size(), inst->corotype()->members().size())); - inst->member_names.insert_or_assign(std::move(name), index); -} - size_t TypeRegistry::type_count() const noexcept { std::lock_guard lock{_mutex}; return _types.size(); @@ -533,7 +496,7 @@ const TypeImpl *TypeRegistry::_decode(luisa::string_view desc) noexcept { }// namespace detail luisa::span Type::members() const noexcept { - LUISA_ASSERT(is_structure() || (is_coroframe()), + LUISA_ASSERT(is_structure(), "Calling members() on a non-structure type {}.", description()); return static_cast(this)->members; @@ -554,16 +517,6 @@ const Type *Type::element() const noexcept { return static_cast(this)->members.front(); } -const Type *Type::corotype() const noexcept { - LUISA_ASSERT(is_coroframe(), - "Calling corotype() on a non-coroframe type {}.", - description()); - LUISA_ASSERT(!static_cast(this)->members.empty(), - "Calling corotype() on a coroframe before analyze.\n" - "Define Coroutine with this coroframe to specify the backend type!"); - return static_cast(this)->members.front(); -} - const Type *Type::from(std::string_view description) noexcept { return detail::TypeRegistry::instance().decode_type(description); } @@ -610,10 +563,8 @@ uint64_t Type::hash() const noexcept { } size_t Type::size() const noexcept { - LUISA_ASSERT(!(is_coroframe() && members().empty()), - "Cannot find size of {}. " - "Usages of CoroFrame types should be " - "after the coroutine definition!", + LUISA_ASSERT(!is_custom(), + "Unknown size for custom type {}.", description()); return static_cast(this)->size; } @@ -689,11 +640,6 @@ bool Type::is_texture() const noexcept { return tag() == Tag::TEXTURE; } bool Type::is_bindless_array() const noexcept { return tag() == Tag::BINDLESS_ARRAY; } bool Type::is_accel() const noexcept { return tag() == Tag::ACCEL; } bool Type::is_custom() const noexcept { return tag() == Tag::CUSTOM; } -bool Type::is_coroframe() const noexcept { return tag() == Tag::COROFRAME; } -bool Type::is_materialized_coroframe() const noexcept { - return tag() == Tag::COROFRAME && - !reinterpret_cast(this)->members.empty(); -} const Type *Type::array(const Type *elem, size_t n) noexcept { return from(luisa::format("array<{},{}>", elem->description(), n)); @@ -801,41 +747,6 @@ const Type *Type::custom(luisa::string_view name) noexcept { return detail::TypeRegistry::instance().custom_type(name); } -const Type *Type::coroframe(luisa::string_view name) noexcept { - return detail::TypeRegistry::instance().coroframe_type(name); -} - -void Type::update_from(const Type *type) { - detail::_update(this, type); -} - -size_t Type::add_member(const luisa::string &name) noexcept { - LUISA_ERROR_WITH_LOCATION("Deprecated."); - LUISA_ASSERT(name != "coro_id" && name != "coro_token", - "{} is a reserved name for coroframe type.", name); - return detail::_add_member(this, name); -} - -void Type::set_member_name(size_t index, luisa::string name) noexcept { - LUISA_ERROR_WITH_LOCATION("Deprecated."); - LUISA_ASSERT(name != "coro_id" && name != "coro_token", - "{} is a reserved name for coroframe type.", name); - detail::_set_member_name(this, index, std::move(name)); -} - -size_t Type::member(luisa::string_view name) const noexcept { - LUISA_ERROR_WITH_LOCATION("Deprecated."); - if (name == "coro_id") return 0; - if (name == "coro_token") return 1u; - auto &map = static_cast(this)->member_names; - auto it = map.find(name); - if (it == map.end()) { - return -1; - } else { - return it->second; - } -} - bool Type::is_bool() const noexcept { return tag() == Tag::BOOL; } bool Type::is_int32() const noexcept { return tag() == Tag::INT32; } bool Type::is_uint32() const noexcept { return tag() == Tag::UINT32; } diff --git a/src/backends/common/hlsl/hlsl_codegen_util.cpp b/src/backends/common/hlsl/hlsl_codegen_util.cpp index 2a5368444..95757f6c7 100644 --- a/src/backends/common/hlsl/hlsl_codegen_util.cpp +++ b/src/backends/common/hlsl/hlsl_codegen_util.cpp @@ -381,10 +381,6 @@ void CodegenUtility::GetTypeName(Type const &type, vstd::StringBuilder &str, Usa case Type::Tag::CUSTOM: { str << '_' << type.description(); } break; - case Type::Tag::COROFRAME: { - auto customType = opt->CreateStruct(type.corotype()); - str << customType; - } break; default: LUISA_ERROR_WITH_LOCATION("Bad."); break; diff --git a/src/backends/common/shader_print_formatter.h b/src/backends/common/shader_print_formatter.h index 96b39508b..a849d0661 100644 --- a/src/backends/common/shader_print_formatter.h +++ b/src/backends/common/shader_print_formatter.h @@ -117,8 +117,7 @@ class ShaderPrintFormatter { } s.push_back('>'); commit_s(); - } else if (arg->is_structure() || arg->is_coroframe()) { - if (arg->is_coroframe()) { arg = arg->corotype(); } + } else if (arg->is_structure()) { s.push_back('{'); commit_s(); for (auto i = 0u; i < arg->members().size(); i++) { diff --git a/src/backends/cuda/cuda_codegen_ast.cpp b/src/backends/cuda/cuda_codegen_ast.cpp index bccfe92b9..0cfd52b7b 100644 --- a/src/backends/cuda/cuda_codegen_ast.cpp +++ b/src/backends/cuda/cuda_codegen_ast.cpp @@ -1528,8 +1528,6 @@ static void collect_types_in_function(Function f, for (auto m : t->members()) { self(self, m); } - } else if (t->is_coroframe()) { - self(self, t->corotype()); } } }; @@ -1763,10 +1761,6 @@ void CUDACodegenAST::_emit_type_name(const Type *type, bool hack_float_to_int) n } break; } - case Type::Tag::COROFRAME: { - _emit_type_name(type->corotype()); - break; - } default: break; } } diff --git a/src/backends/metal/metal_codegen_ast.cpp b/src/backends/metal/metal_codegen_ast.cpp index 6fc31145c..f3277d1e7 100644 --- a/src/backends/metal/metal_codegen_ast.cpp +++ b/src/backends/metal/metal_codegen_ast.cpp @@ -137,8 +137,6 @@ static void collect_types_in_function(Function f, for (auto m : t->members()) { self(self, m); } - } else if (t->is_coroframe()) { - self(self, t->corotype()); } } }; diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 3f8661dfd..3142f9d85 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -3,8 +3,9 @@ // #include -#include -#include +#include +#include +#include namespace luisa::compute::coroutine { diff --git a/src/coro/coro_frame_buffer.cpp b/src/coro/coro_frame_buffer.cpp index 83ea5d0f3..9dc4d0d7f 100644 --- a/src/coro/coro_frame_buffer.cpp +++ b/src/coro/coro_frame_buffer.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::detail { void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept { diff --git a/src/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp index 302f08911..b4b1d52f4 100644 --- a/src/coro/coro_frame_desc.cpp +++ b/src/coro/coro_frame_desc.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::coroutine { diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index 1386da21e..966719c01 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::coroutine::detail { diff --git a/src/coro/coro_graph.cpp b/src/coro/coro_graph.cpp index 9214e654e..5ab11eed6 100644 --- a/src/coro/coro_graph.cpp +++ b/src/coro/coro_graph.cpp @@ -5,8 +5,8 @@ #include #include #include -#include -#include +#include +#include #ifdef LUISA_ENABLE_IR #include diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 624e562f4..863287c32 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -3,10 +3,10 @@ // #include -#include -#include -#include -#include +#include +#include +#include +#include namespace luisa::compute::coroutine::detail { @@ -161,8 +161,6 @@ void persistent_threads_coro_scheduler_main_kernel_impl( launch_condition = (all_token[pid] == work_stat[1]); } $if (launch_condition) { - constexpr auto valid_token_mask = coro_token_terminal - 1u; - static_assert(valid_token_mask == 0x7fff'ffffu); $switch (all_token[pid]) { $case (0u) { $if (gen_st + thread_x() < workload[1]) { @@ -175,7 +173,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto index_y = index_xy / dispatch_shape.x; auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); call_subroutine(frame, coro_token_entry); - auto next = frame.target_token & valid_token_mask; + auto next = frame.target_token & coro_token_valid_mask; frames.write(pid, frame, graph->entry().output_fields()); all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); @@ -187,7 +185,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_counter.atomic(i).fetch_sub(1u); auto frame = frames.read(pid, graph->node(i).input_fields()); call_subroutine(frame, i); - auto next = frame.target_token & valid_token_mask; + auto next = frame.target_token & coro_token_valid_mask; frames.write(pid, frame, graph->node(i).output_fields()); all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); diff --git a/src/coro/schedulers/state_machine.cpp b/src/coro/schedulers/state_machine.cpp index 2c639f807..2a389c0c2 100644 --- a/src/coro/schedulers/state_machine.cpp +++ b/src/coro/schedulers/state_machine.cpp @@ -3,7 +3,7 @@ // #include -#include +#include namespace luisa::compute::coroutine::detail { diff --git a/src/dsl/func.cpp b/src/dsl/func.cpp index cf6f78c8a..21ab66c91 100644 --- a/src/dsl/func.cpp +++ b/src/dsl/func.cpp @@ -70,116 +70,4 @@ transform_function(Function function) noexcept { return function.shared_builder(); } -luisa::shared_ptr transform_coroutine( - Type *corotype, - coro::CoroGraph &graph, - luisa::unordered_map> &sub_builders, - Function function) noexcept { - if (true) { -#ifndef LUISA_ENABLE_IR - LUISA_ERROR_WITH_LOCATION( - "Coroutine requires IR support but " - "LuisaCompute is built without the IR module. " - "This might be caused by missing Rust. " - "Please install the Rust toolchain and " - "recompile LuisaCompute to get the IR module."); -#else - LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " - "on function with hash {:016x}.", - function.hash()); - - auto make_wrapper = [&](const FunctionBuilder *sub) noexcept { - return FunctionBuilder::define_callable([&] { - luisa::vector args; - args.reserve(function.arguments().size()); - LUISA_ASSERT(function.arguments().size() == function.bound_arguments().size(), - "Invalid capture list size (expected {}, got {}).", - function.arguments().size(), function.bound_arguments().size()); - auto fb = FunctionBuilder::current(); - for (auto arg_i = 0u; arg_i < function.arguments().size(); arg_i++) { - auto def_arg = function.arguments()[arg_i]; - auto internal_arg = luisa::visit( - [&](auto b) noexcept -> const Expression * { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); - } else if constexpr (std::is_same_v) { - return fb->texture_binding(def_arg.type(), b.handle, b.level); - } else if constexpr (std::is_same_v) { - return fb->bindless_array_binding(b.handle); - } else if constexpr (std::is_same_v) { - return fb->accel_binding(b.handle); - } else { - static_assert(std::is_same_v); - switch (def_arg.tag()) { - case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); - case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); - case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); - case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); - case Variable::Tag::ACCEL: return fb->accel(); - default: /* value argument */ return fb->argument(def_arg.type()); - } - } - }, - function.bound_arguments()[arg_i]); - args.emplace_back(internal_arg); - } - LUISA_ASSERT(sub->return_type() == nullptr, - "Coroutine subroutines should not have return type."); - fb->call(sub->function(), args); - }); - }; - - graph = coro::CoroGraph(0, corotype); - //idea: send in function-> module with .subroutine-> seperate transform to callable-> register to coroutine - luisa::shared_ptr converted; - auto m = AST2IR::build_coroutine(function); - perform_coroutine_transform(m->get()); - converted = IR2AST::build(m->get()); - auto subroutines = m->get()->subroutines; - auto subroutine_ids = m->get()->subroutine_ids; - auto coroframe = corotype; - auto coroframe_new = converted->arguments()[0].type(); - coroframe->update_from(coroframe_new); - for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, - m->get()->coro_frame_designated_fields.len}) { - auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; - if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } - graph.designate_state_member(luisa::string{name}, field.index); - coroframe->set_member_name(field.index, luisa::string{name}); - } - const_cast(converted.get())->coroframe_replace(corotype); - for (int i = 0; i < subroutines.len; ++i) { - auto sub = IR2AST::build(subroutines.ptr[i]._0.get()); - auto wrapper = make_wrapper(sub.get()); - sub_builders.insert(std::make_pair(subroutine_ids.ptr[i], wrapper)); - auto node = graph.add_node(subroutine_ids.ptr[i], wrapper); - for (int mem = 0u; mem < subroutines.ptr[i]._0->coro_frame_input_fields.len; ++mem) { - auto field = subroutines.ptr[i]._0->coro_frame_input_fields.ptr[mem]; - node->input_state_members.push_back(field); - } - for (int mem = 0u; mem < subroutines.ptr[i]._0->coro_frame_output_fields.len; ++mem) { - auto field = subroutines.ptr[i]._0->coro_frame_output_fields.ptr[mem]; - node->output_state_members.push_back(field); - } - } - auto node = graph.add_node(0, make_wrapper(converted.get())); - for (int mem = 0u; mem < m->get()->coro_frame_input_fields.len; ++mem) { - auto field = m->get()->coro_frame_input_fields.ptr[mem]; - node->input_state_members.push_back(field); - } - for (int mem = 0u; mem < m->get()->coro_frame_output_fields.len; ++mem) { - auto field = m->get()->coro_frame_output_fields.ptr[mem]; - node->output_state_members.push_back(field); - } - LUISA_VERBOSE_WITH_LOCATION("Converted IR to AST for " - "kernel with hash {:016x}. " - "Coroutine transform is done.", - function.hash()); - return make_wrapper(converted.get()); -#endif - } - return function.shared_builder(); -} - }// namespace luisa::compute::detail diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 4df8b0638..3c1a6ea98 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -43,7 +43,7 @@ if (LUISA_COMPUTE_ENABLE_RUST) luisa_compute_add_executable(test_ast2ir_ir2ast test_ast2ir_ir2ast.cpp) if (LUISA_COMPUTE_ENABLE_GUI) - luisa_compute_add_executable(test_skyline test_skyline.cpp) +# luisa_compute_add_executable(test_skyline test_skyline.cpp) luisa_compute_add_executable(test_kernel_ir test_kernel_ir.cpp) luisa_compute_add_executable(test_sdf_renderer_ir test_sdf_renderer_ir.cpp) luisa_compute_add_executable(test_path_tracing_ir test_path_tracing_ir.cpp) @@ -197,13 +197,13 @@ if (LUISA_COMPUTE_TEST_QT_INTEROP) endif () # coroutine tests -luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) +#luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) luisa_compute_add_executable(test_coro_sdf_renderer_v2 coro/sdf_renderer_v2.cpp) -luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) -luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) +#luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) +#luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) luisa_compute_add_executable(test_coro_path_tracing_wavefront_v2 coro/path_tracing_wavefront_v2.cpp) luisa_compute_add_executable(test_coro_path_tracing_persistent_threads_v2 coro/path_tracing_persistent_threads_v2.cpp) -luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) +#luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) -luisa_compute_add_executable(test_coro_playground coro/playground.cpp) +#luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/test_skyline.cpp b/src/tests/test_skyline.cpp index 72ca9dbd7..1eab730d9 100644 --- a/src/tests/test_skyline.cpp +++ b/src/tests/test_skyline.cpp @@ -8,13 +8,12 @@ #include #include #include -#include + using namespace luisa; using namespace luisa::compute; -struct alignas(4) Skyline { -}; + const bool SHOW = true; -LUISA_COROFRAME_STRUCT(Skyline) {}; + int main(int argc, char *argv[]) { Context context{argv[0]}; @@ -824,7 +823,7 @@ int main(int argc, char *argv[]) { image.write(xy, make_float4(sqrt(RayTrace(make_float2(xy), time)), 1.0f)); }; is_coroutine = true; - Coroutine coro = [&](Var &frame, Float time) noexcept { + Kernel2D coro = [&](Float time) noexcept { auto xy = make_uint2(coro_id().x % width, coro_id().x / width); device_image->write(xy, make_float4(sqrt(RayTrace(make_float2(xy), time)), 1.0f)); }; diff --git a/src/tests/test_sort.cpp b/src/tests/test_sort.cpp index 9a8a5382c..e43dadb7f 100644 --- a/src/tests/test_sort.cpp +++ b/src/tests/test_sort.cpp @@ -3,7 +3,6 @@ /// reference: [Onesweep: A Faster Least Significant Digit Radix Sort for GPUs] /// https://arxiv.org/abs/2206.01784 #include -#include #include using namespace luisa; using namespace luisa::compute; From f8b13a8b854625631e22fbc7f0610ec714a38edd Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 00:36:06 +0800 Subject: [PATCH 56/67] clean up --- .clang-format | 1 - include/luisa/ast/type.h | 2 - src/ast/ast2json.cpp | 1 - src/ast/type.cpp | 44 - src/backends/metal/metal_codegen_ast.cpp | 3 - src/coro/coro_graph.cpp | 7 +- src/rust/luisa_compute_ir/src/ast2ir.rs | 2 +- .../src/transform/materialize_coro.rs | 56 +- .../src/transform/materialize_coro_v2.rs | 1097 ----------------- .../luisa_compute_ir/src/transform/mod.rs | 5 - src/tests/CMakeLists.txt | 15 +- src/tests/coro/helloworld.cpp | 67 +- src/tests/coro/helloworld_v2.cpp | 53 - src/tests/coro/path_tracing.cpp | 380 ------ ...pp => path_tracing_persistent_threads.cpp} | 0 ..._v2.cpp => path_tracing_state_machine.cpp} | 0 ...ront_v2.cpp => path_tracing_wavefront.cpp} | 0 src/tests/coro/playground.cpp | 61 - src/tests/coro/sdf_renderer.cpp | 67 +- src/tests/coro/sdf_renderer_v2.cpp | 212 ---- src/tests/coro/sdf_renderer_wo_dispatcher.cpp | 229 ---- 21 files changed, 81 insertions(+), 2221 deletions(-) delete mode 100644 src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs delete mode 100644 src/tests/coro/helloworld_v2.cpp delete mode 100644 src/tests/coro/path_tracing.cpp rename src/tests/coro/{path_tracing_persistent_threads_v2.cpp => path_tracing_persistent_threads.cpp} (100%) rename src/tests/coro/{path_tracing_v2.cpp => path_tracing_state_machine.cpp} (100%) rename src/tests/coro/{path_tracing_wavefront_v2.cpp => path_tracing_wavefront.cpp} (100%) delete mode 100644 src/tests/coro/playground.cpp delete mode 100644 src/tests/coro/sdf_renderer_v2.cpp delete mode 100644 src/tests/coro/sdf_renderer_wo_dispatcher.cpp diff --git a/.clang-format b/.clang-format index 125fe989b..714315436 100644 --- a/.clang-format +++ b/.clang-format @@ -109,7 +109,6 @@ ForEachMacros: - LUISA_STRUCT - LUISA_BINDING_GROUP - LUISA_BINDING_GROUP_TEMPLATE - - LUISA_COROFRAME_STRUCT IfMacros: - $if - $elif diff --git a/include/luisa/ast/type.h b/include/luisa/ast/type.h index 2d26f8712..63be7bb80 100644 --- a/include/luisa/ast/type.h +++ b/include/luisa/ast/type.h @@ -310,8 +310,6 @@ class LC_AST_API Type { TEXTURE, BINDLESS_ARRAY, ACCEL, - - COROFRAME, CUSTOM }; diff --git a/src/ast/ast2json.cpp b/src/ast/ast2json.cpp index c1128dc0d..dbb0a3a85 100644 --- a/src/ast/ast2json.cpp +++ b/src/ast/ast2json.cpp @@ -629,7 +629,6 @@ class AST2JSON { t["element"] = _type_index(type->element()); break; } - case Type::Tag::COROFRAME: [[fallthrough]]; case Type::Tag::CUSTOM: { t["id"] = type->description(); break; diff --git a/src/ast/type.cpp b/src/ast/type.cpp index 34bce9897..222895081 100644 --- a/src/ast/type.cpp +++ b/src/ast/type.cpp @@ -39,7 +39,6 @@ struct TypeImpl final : public Type { uint index{}; luisa::string description; luisa::vector members; - luisa::unordered_map member_names; luisa::vector member_attributes; }; @@ -107,7 +106,6 @@ class LC_AST_API TypeRegistry { [[nodiscard]] const Type *decode_type(luisa::string_view desc) noexcept; /// Construct custom type [[nodiscard]] const Type *custom_type(luisa::string_view desc) noexcept; - [[nodiscard]] const Type *coroframe_type(luisa::string_view desc) noexcept; /// Return type count [[nodiscard]] size_t type_count() const noexcept; /// Traverse all types using visitor @@ -165,48 +163,6 @@ const Type *TypeRegistry::custom_type(luisa::string_view name) noexcept { return _register(t); } -const Type *TypeRegistry::coroframe_type(luisa::string_view name) noexcept { - // validate name - LUISA_ASSERT(!name.empty() && - name != "void" && - name != "int" && - name != "uint" && - name != "short" && - name != "ushort" && - name != "long" && - name != "ulong" && - name != "float" && - name != "half" && - name != "double" && - name != "bool" && - !name.starts_with("vector<") && - !name.starts_with("matrix<") && - !name.starts_with("array<") && - !name.starts_with("struct<") && - !name.starts_with("buffer<") && - !name.starts_with("texture<") && - name != "accel" && - name != "bindless_array" && - !isdigit(name.front() /* already checked not empty */), - "Invalid custom type name: {}", name); - LUISA_ASSERT(std::all_of(name.cbegin(), name.cend(), - [](char c) { return isalnum(c) || c == '_'; }), - "Invalid custom type name: {}", name); - std::lock_guard lock{_mutex}; - auto h = _compute_hash(name); - if (auto iter = _type_set.find(TypeDescAndHash{name, h}); - iter != _type_set.end()) { return *iter; } - - auto t = _type_pool.create(); - t->hash = h; - t->tag = Type::Tag::COROFRAME; - t->size = Type::custom_struct_size; - t->alignment = Type::custom_struct_alignment; - t->dimension = 1u; - t->description = name; - return _register(t); -} - size_t TypeRegistry::type_count() const noexcept { std::lock_guard lock{_mutex}; return _types.size(); diff --git a/src/backends/metal/metal_codegen_ast.cpp b/src/backends/metal/metal_codegen_ast.cpp index f3277d1e7..f3336369e 100644 --- a/src/backends/metal/metal_codegen_ast.cpp +++ b/src/backends/metal/metal_codegen_ast.cpp @@ -345,9 +345,6 @@ void MetalCodegenAST::_emit_type_name(const Type *type, Usage usage) noexcept { } break; } - case Type::Tag::COROFRAME: - _emit_type_name(type->corotype()); - break; } } diff --git a/src/coro/coro_graph.cpp b/src/coro/coro_graph.cpp index 5ab11eed6..72980541e 100644 --- a/src/coro/coro_graph.cpp +++ b/src/coro/coro_graph.cpp @@ -131,12 +131,7 @@ namespace detail { static void perform_coroutine_transform(ir::CallableModule *m) noexcept { auto coroutine_pipeline = ir::luisa_compute_ir_transform_pipeline_new(); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "canonicalize_control_flow"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "demote_locals"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "defer_load"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "extract_loop_cond"); - // ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "split_coro"); - ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro_v2"); + ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro"); auto converted_module = ir::luisa_compute_ir_transform_pipeline_transform_callable(coroutine_pipeline, *m); ir::luisa_compute_ir_transform_pipeline_destroy(coroutine_pipeline); *m = converted_module; diff --git a/src/rust/luisa_compute_ir/src/ast2ir.rs b/src/rust/luisa_compute_ir/src/ast2ir.rs index 937fa26ce..2569661e9 100644 --- a/src/rust/luisa_compute_ir/src/ast2ir.rs +++ b/src/rust/luisa_compute_ir/src/ast2ir.rs @@ -102,7 +102,7 @@ impl<'a> AST2IRType<'a> { } "BINDLESS_ARRAY" => Type::void(), "ACCEL" => Type::void(), - "COROFRAME" | "CUSTOM" => Type::opaque(j["id"].as_str().unwrap().into()), + "CUSTOM" => Type::opaque(j["id"].as_str().unwrap().into()), _ => panic!("Invalid type tag: {}", tag), }; self.types.insert(i, t.clone()); diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs index 743ece66a..9c9d41eb8 100644 --- a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs +++ b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs @@ -73,35 +73,33 @@ impl<'a> CoroScopeMaterializer<'a> { fn create_args(&self) -> Vec { let mut args = Vec::new(); - for (i, arg) in self.coro.args.iter().enumerate() { - if i == 0 { - // the coro frame - let node = new_node( - &self.coro.pools, - Node::new( - CArc::new(Instruction::Argument { by_value: false }), - self.frame.interface_type.clone(), - ), - ); - args.push(node); - } else { - // normal args - let instr = &arg.get().instruction; - match instr.as_ref() { - Instruction::Buffer - | Instruction::Bindless - | Instruction::Texture2D - | Instruction::Texture3D - | Instruction::Accel - | Instruction::Argument { .. } => { - let node = new_node( - &self.coro.pools, - Node::new(instr.clone(), arg.type_().clone()), - ); - args.push(node); - } - _ => unreachable!("Invalid argument type"), + args.reserve(1 /* frame */+ self.coro.args.len()); + // the coro frame + let node = new_node( + &self.coro.pools, + Node::new( + CArc::new(Instruction::Argument { by_value: false }), + self.frame.interface_type.clone(), + ), + ); + args.push(node); + for arg in self.coro.args.iter() { + // normal args + let instr = &arg.get().instruction; + match instr.as_ref() { + Instruction::Buffer + | Instruction::Bindless + | Instruction::Texture2D + | Instruction::Texture3D + | Instruction::Accel + | Instruction::Argument { .. } => { + let node = new_node( + &self.coro.pools, + Node::new(instr.clone(), arg.type_().clone()), + ); + args.push(node); } + _ => unreachable!("Invalid argument type"), } } args @@ -1003,7 +1001,7 @@ impl<'a> CoroScopeMaterializer<'a> { .args .iter() .cloned() - .zip(self.args.iter().cloned()) + .zip(self.args.iter().skip(1).cloned()) .collect(); let mut entry_builder = IrBuilder::new(self.coro.pools.clone()); let mut ctx = CoroScopeMaterializerCtx { diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs deleted file mode 100644 index 9c9d41eb8..000000000 --- a/src/rust/luisa_compute_ir/src/transform/materialize_coro_v2.rs +++ /dev/null @@ -1,1097 +0,0 @@ -// This file implements the materialization of subroutines in a coroutine. It analyzes the -// input coroutine module, generates the coroutine graph and transition graph, computes the -// coroutine frame layout, and finally materializes the subroutines into callable modules. -// Some corner cases to consider: -// - Some values might be promoted to values in the coroutine frame and should be loaded -// before use. -// - Some values might not dominate their uses any more as they might have been moved into -// a `SkipIfFirst` block. We need to promote them to locals. -// - Some "replayable" values are not included in the coroutine frame, nor defined in the -// subroutine body. We need to replay them. - -use crate::analysis::coro_frame::CoroFrame; -use crate::analysis::coro_graph::{ - CoroGraph, CoroInstrRef, CoroInstruction, CoroScope, CoroScopeRef, -}; -use crate::analysis::coro_transition_graph::CoroTransitionGraph; -use crate::analysis::coro_use_def::CoroUseDefAnalysis; -use crate::analysis::replayable_values::ReplayableValueAnalysis; -use crate::analysis::utility::{AccessChainIndex, AccessTree}; -use crate::ir::{ - collect_nodes, new_node, BasicBlock, CallableModule, CallableModuleRef, Const, - CoroFrameDesignatedField, CurveBasisSet, Func, Instruction, IrBuilder, Module, ModuleFlags, - ModuleKind, Node, NodeRef, Primitive, SwitchCase, Type, -}; -use crate::transform::canonicalize_control_flow::CanonicalizeControlFlow; -use crate::transform::defer_load::DeferLoad; -use crate::transform::demote_locals::DemoteLocals; -use crate::transform::mem2reg::Mem2Reg; -use crate::transform::reg2mem::Reg2Mem; -use crate::transform::Transform; -use crate::{CArc, CBox, CBoxedSlice, Pooled}; -use bitflags::Flags; -use std::collections::{HashMap, HashSet}; - -struct DuplicateNodeCollector<'a> { - frame: &'a CoroFrame<'a>, - scope: CoroScopeRef, -} - -// collect the nodes that violate the SSA form, which should be promoted to locals -impl<'a> DuplicateNodeCollector<'a> { - fn new(frame: &'a CoroFrame<'a>, scope: CoroScopeRef) -> Self { - Self { frame, scope } - } - - fn collect(&self) -> HashSet { - todo!() - } -} - -pub(crate) struct MaterializeCoro; - -struct CoroScopeMaterializer<'a> { - frame: &'a CoroFrame<'a>, - coro: &'a CallableModule, - token: Option, // None for entry - scope: CoroScopeRef, - args: Vec, -} - -impl<'a> CoroScopeMaterializer<'a> { - fn get_frame_node(&self) -> NodeRef { - self.args[0].clone() - } - - fn get_scope(&self) -> &CoroScope { - &self.frame.graph.get_scope(self.scope) - } - - fn get_instr(&self, instr: CoroInstrRef) -> &CoroInstruction { - self.frame.graph.get_instr(instr) - } - - fn create_args(&self) -> Vec { - let mut args = Vec::new(); - args.reserve(1 /* frame */+ self.coro.args.len()); - // the coro frame - let node = new_node( - &self.coro.pools, - Node::new( - CArc::new(Instruction::Argument { by_value: false }), - self.frame.interface_type.clone(), - ), - ); - args.push(node); - for arg in self.coro.args.iter() { - // normal args - let instr = &arg.get().instruction; - match instr.as_ref() { - Instruction::Buffer - | Instruction::Bindless - | Instruction::Texture2D - | Instruction::Texture3D - | Instruction::Accel - | Instruction::Argument { .. } => { - let node = new_node( - &self.coro.pools, - Node::new(instr.clone(), arg.type_().clone()), - ); - args.push(node); - } - _ => unreachable!("Invalid argument type"), - } - } - args - } - - fn new(frame: &'a CoroFrame<'a>, coro: &'a CallableModule, token: Option) -> Self { - let scope = if let Some(token) = token { - frame.graph.tokens[&token] - } else { - frame.graph.entry - }; - let mut m = Self { - frame, - coro, - token, - scope, - args: Vec::new(), - }; - m.args = m.create_args(); - m - } -} - -struct CoroScopeMaterializerCtx { - mappings: HashMap, // mapping from old nodes to new nodes - entry_builder: IrBuilder, // suitable for declaring locals - first_flag: Option, - uses_ray_tracing: bool, - uses_coro_id: bool, - replayable: ReplayableValueAnalysis, -} - -struct CoroScopeMaterializerState { - builder: IrBuilder, // current running build -} - -impl CoroScopeMaterializerState { - fn clone_for_branch_block(&self) -> Self { - Self { - builder: IrBuilder::new(self.builder.pools.clone()), - } - } -} - -impl<'a> CoroScopeMaterializer<'a> { - fn resume(&self, ctx: &mut CoroScopeMaterializerCtx) { - let mappings = self - .frame - .resume(self.scope, self.get_frame_node(), &mut ctx.entry_builder); - for (old_node, new_node) in mappings { - ctx.mappings.insert(old_node, new_node); - } - } - - fn suspend(&self, target: u32, ctx: &mut CoroScopeMaterializerCtx, b: &mut IrBuilder) { - self.frame.suspend( - self.scope, - target, - self.get_frame_node(), - b, - &mut ctx.mappings, - ); - } - - fn terminate(&self, b: &mut IrBuilder) { - self.frame.terminate(self.scope, self.get_frame_node(), b); - } - - fn ref_or_local( - &self, - old_node: NodeRef, - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) -> NodeRef { - if let Some(defined) = ctx.mappings.get(&old_node) { - defined.clone() - } else if old_node.is_gep() { - let (root, chain) = AccessTree::access_chain_from_gep_chain(old_node); - let chain: Vec<_> = chain - .iter() - .map(|node| self.value_or_load(node.clone(), ctx, state)) - .collect(); - let root = self.ref_or_local(root, ctx, state); - let gep = state - .builder - .gep(root, chain.as_slice(), old_node.type_().clone()); - ctx.mappings.insert(old_node, gep.clone()); - gep - } else { - // not defined yet, we'll define it now - let local = ctx.entry_builder.local_zero_init(old_node.type_().clone()); - ctx.mappings.insert(old_node.clone(), local.clone()); - local - } - } - - fn replay_value(&self, old_node: NodeRef, ctx: &mut CoroScopeMaterializerCtx) -> NodeRef { - if let Some(replayed) = ctx.mappings.get(&old_node) { - return replayed.clone(); - } - match old_node.get().instruction.as_ref() { - Instruction::Const(c) => ctx.entry_builder.const_(c.clone()), - Instruction::Call(func, args) => match func { - Func::Unreachable(_) | Func::ZeroInitializer | Func::WarpSize => ctx - .entry_builder - .call(func.clone(), &[], old_node.type_().clone()), - Func::ThreadId | Func::BlockId | Func::WarpLaneId | Func::DispatchSize => { - panic!("{:?} is not available in coroutines", func) - } - Func::CoroId | Func::DispatchId => { - ctx.uses_coro_id = true; - self.frame - .read_coro_id(self.get_frame_node(), &mut ctx.entry_builder) - } - Func::CoroToken => ctx - .entry_builder - .const_(Const::Uint32(self.token.unwrap_or(0))), - Func::Cast - | Func::Bitcast - | Func::Pack - | Func::Unpack - | Func::Add - | Func::Sub - | Func::Mul - | Func::Div - | Func::Rem - | Func::BitAnd - | Func::BitOr - | Func::BitXor - | Func::Shl - | Func::Shr - | Func::RotRight - | Func::RotLeft - | Func::Eq - | Func::Ne - | Func::Lt - | Func::Le - | Func::Gt - | Func::Ge - | Func::MatCompMul - | Func::Neg - | Func::Not - | Func::BitNot - | Func::All - | Func::Any - | Func::Select - | Func::Clamp - | Func::Lerp - | Func::Step - | Func::SmoothStep - | Func::Saturate - | Func::Abs - | Func::Min - | Func::Max - | Func::ReduceSum - | Func::ReduceProd - | Func::ReduceMin - | Func::ReduceMax - | Func::Clz - | Func::Ctz - | Func::PopCount - | Func::Reverse - | Func::IsInf - | Func::IsNan - | Func::Acos - | Func::Acosh - | Func::Asin - | Func::Asinh - | Func::Atan - | Func::Atan2 - | Func::Atanh - | Func::Cos - | Func::Cosh - | Func::Sin - | Func::Sinh - | Func::Tan - | Func::Tanh - | Func::Exp - | Func::Exp2 - | Func::Exp10 - | Func::Log - | Func::Log2 - | Func::Log10 - | Func::Powi - | Func::Powf - | Func::Sqrt - | Func::Rsqrt - | Func::Ceil - | Func::Floor - | Func::Fract - | Func::Trunc - | Func::Round - | Func::Fma - | Func::Copysign - | Func::Cross - | Func::Dot - | Func::OuterProduct - | Func::Length - | Func::LengthSquared - | Func::Normalize - | Func::Faceforward - | Func::Distance - | Func::Reflect - | Func::Determinant - | Func::Transpose - | Func::Inverse - | Func::Vec - | Func::Vec2 - | Func::Vec3 - | Func::Vec4 - | Func::Permute - | Func::InsertElement - | Func::ExtractElement - | Func::GetElementPtr - | Func::Struct - | Func::Array - | Func::Mat - | Func::Mat2 - | Func::Mat3 - | Func::Mat4 => { - let replayed_args: Vec<_> = args - .iter() - .map(|arg| self.replay_value(arg.clone(), ctx)) - .collect(); - ctx.entry_builder.call( - func.clone(), - replayed_args.as_slice(), - old_node.type_().clone(), - ) - } - _ => unreachable!("non-replayable value"), - }, - _ => unreachable!("non-replayable value"), - } - } - - fn try_replay(&self, old_node: NodeRef, ctx: &mut CoroScopeMaterializerCtx) -> Option { - if old_node.is_unreachable() { - ctx.mappings.get(&old_node).cloned() - } else { - if !ctx.replayable.detect(old_node) { - None - } else if let Some(defined) = ctx.mappings.get(&old_node) { - // if already defined, simply return it - Some(defined.clone()) - } else { - let replayed = self.replay_value(old_node, ctx); - ctx.mappings.insert(old_node, replayed.clone()); - Some(replayed) - } - } - } - - fn value_or_load( - &self, - old_node: NodeRef, - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) -> NodeRef { - if let Some(node) = self.try_replay(old_node, ctx) { - node - } else { - let node = self.ref_or_local(old_node, ctx, state); - if node.is_local() || node.is_gep() || node.is_reference_argument() { - state.builder.load(node) - } else { - node - } - } - } - - fn def_or_assign( - &self, - old_node: NodeRef, - new_value: NodeRef, - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) { - let var = self.ref_or_local(old_node, ctx, state); - assert!(var.is_gep() || var.is_local() || var.is_reference_argument()); - state.builder.update(var, new_value); - } - - fn materialize_branch_block( - &self, - block: &Vec, - ctx: &mut CoroScopeMaterializerCtx, - state: &CoroScopeMaterializerState, - ) -> Pooled { - let mut branch_state = state.clone_for_branch_block(); - self.materialize_instructions(block.as_slice(), ctx, &mut branch_state); - branch_state.builder.finish() - } - - fn make_first_flag(&self, ctx: &mut CoroScopeMaterializerCtx) { - assert_eq!(ctx.first_flag, None, "First flag already defined"); - let flag = { - let b = &mut ctx.entry_builder; - b.comment(CBoxedSlice::from("make first flag".as_bytes())); - let v = b.const_(Const::Bool(false)); - b.local(v) - }; - ctx.first_flag = Some(flag); - } - - fn materialize_call( - &self, - ret: NodeRef, - func: Func, - args: &[NodeRef], - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) { - macro_rules! process_return { - ($call: expr) => { - match $call.type_().as_ref() { - Type::Void => { /* nothing */ } - Type::UserData => todo!(), - Type::Opaque(_) => { - // as non-copyable reference - ctx.mappings.insert(ret.clone(), $call); - } - _ => self.def_or_assign(ret.clone(), $call, ctx, state), - } - }; - } - match func { - // callable - Func::Callable(c) => { - let args: Vec<_> = - c.0.args - .iter() - .zip(args.iter()) - .map(|(formal, &given)| { - if formal.is_reference_argument() || formal.type_().is_opaque("") { - self.ref_or_local(given, ctx, state) - } else { - self.value_or_load(given, ctx, state) - } - }) - .collect(); - let call = state.builder.call( - Func::Callable(c.clone()), - args.as_slice(), - ret.type_().clone(), - ); - process_return!(call) - } - Func::External(c) => { - let args: Vec<_> = args - .iter() - .map(|&a| self.value_or_load(a, ctx, state)) - .collect(); - let call = state.builder.call( - Func::External(c.clone()), - args.as_slice(), - ret.type_().clone(), - ); - process_return!(call) - } - // replayable but need special handling - Func::Unreachable(_) => { - let call = state.builder.call(func.clone(), &[], ret.type_().clone()); - match call.type_().as_ref() { - Type::Void => {} - _ => { - ctx.mappings.insert(ret, call); - } - } - } - // always replayable functions, should not appear here - Func::CoroId - | Func::CoroToken - | Func::ZeroInitializer - | Func::ThreadId - | Func::BlockId - | Func::WarpSize - | Func::WarpLaneId - | Func::DispatchId - | Func::DispatchSize => unreachable!(), - // local variable operations - Func::Load => { - let loaded = self.value_or_load(args[0].clone(), ctx, state); - self.def_or_assign(ret, loaded, ctx, state); - } - Func::AddressOf => { - // the first argument should be reference - let var = self.ref_or_local(args[0].clone(), ctx, state); - let addr = state - .builder - .call(func.clone(), &[var], ret.type_().clone()); - self.def_or_assign(ret, addr, ctx, state); - } - Func::GetElementPtr => { - let (root, chain) = AccessTree::access_chain_from_gep_chain(ret); - let root = self.ref_or_local(root, ctx, state); - let chain: Vec<_> = chain - .iter() - .map(|&i| self.value_or_load(i, ctx, state)) - .collect(); - let gep = state - .builder - .gep(root, chain.as_slice(), ret.type_().clone()); - ctx.mappings.insert(ret, gep); - } - // AD functions - Func::PropagateGrad => todo!(), - Func::OutputGrad => todo!(), - Func::RequiresGradient => todo!(), - Func::Backward => todo!(), - Func::Gradient => todo!(), - Func::GradientMarker => todo!(), - Func::AccGrad => todo!(), - Func::Detach => todo!(), - // resource functions, the first argument should always be a reference - Func::RayTracingQueryAll - | Func::RayTracingQueryAny - | Func::RayTracingInstanceTransform - | Func::RayTracingInstanceVisibilityMask - | Func::RayTracingInstanceUserId - | Func::RayTracingSetInstanceTransform - | Func::RayTracingSetInstanceOpacity - | Func::RayTracingSetInstanceVisibility - | Func::RayTracingSetInstanceUserId - | Func::RayTracingTraceClosest - | Func::RayTracingTraceAny - | Func::RayQueryWorldSpaceRay - | Func::RayQueryProceduralCandidateHit - | Func::RayQueryTriangleCandidateHit - | Func::RayQueryCommittedHit - | Func::RayQueryCommitTriangle - | Func::RayQueryCommitProcedural - | Func::RayQueryTerminate - | Func::IndirectDispatchSetCount - | Func::IndirectDispatchSetKernel - | Func::AtomicRef - | Func::AtomicExchange - | Func::AtomicCompareExchange - | Func::AtomicFetchAdd - | Func::AtomicFetchSub - | Func::AtomicFetchAnd - | Func::AtomicFetchOr - | Func::AtomicFetchXor - | Func::AtomicFetchMin - | Func::AtomicFetchMax - | Func::BufferRead - | Func::BufferWrite - | Func::BufferSize - | Func::BufferAddress - | Func::ByteBufferRead - | Func::ByteBufferWrite - | Func::ByteBufferSize - | Func::Texture2dRead - | Func::Texture2dWrite - | Func::Texture2dSize - | Func::Texture3dRead - | Func::Texture3dWrite - | Func::Texture3dSize - | Func::BindlessTexture2dSample - | Func::BindlessTexture2dSampleLevel - | Func::BindlessTexture2dSampleGrad - | Func::BindlessTexture2dSampleGradLevel - | Func::BindlessTexture3dSample - | Func::BindlessTexture3dSampleLevel - | Func::BindlessTexture3dSampleGrad - | Func::BindlessTexture3dSampleGradLevel - | Func::BindlessTexture2dRead - | Func::BindlessTexture3dRead - | Func::BindlessTexture2dReadLevel - | Func::BindlessTexture3dReadLevel - | Func::BindlessTexture2dSize - | Func::BindlessTexture3dSize - | Func::BindlessTexture2dSizeLevel - | Func::BindlessTexture3dSizeLevel - | Func::BindlessBufferRead - | Func::BindlessBufferWrite - | Func::BindlessBufferSize - | Func::BindlessBufferAddress - | Func::BindlessBufferType - | Func::BindlessByteBufferRead => { - let args: Vec<_> = args - .iter() - .enumerate() - .map(|(i, &a)| { - if i == 0 { - // resource - self.ref_or_local(a, ctx, state) - } else { - // value - self.value_or_load(a, ctx, state) - } - }) - .collect(); - let call = state - .builder - .call(func.clone(), args.as_slice(), ret.type_().clone()); - process_return!(call) - } - // functions with all value arguments - Func::Assume - | Func::Assert(_) - | Func::RasterDiscard - | Func::Cast - | Func::Bitcast - | Func::Pack - | Func::Unpack - | Func::Add - | Func::Sub - | Func::Mul - | Func::Div - | Func::Rem - | Func::BitAnd - | Func::BitOr - | Func::BitXor - | Func::Shl - | Func::Shr - | Func::RotRight - | Func::RotLeft - | Func::Eq - | Func::Ne - | Func::Lt - | Func::Le - | Func::Gt - | Func::Ge - | Func::MatCompMul - | Func::Neg - | Func::Not - | Func::BitNot - | Func::All - | Func::Any - | Func::Select - | Func::Clamp - | Func::Lerp - | Func::Step - | Func::SmoothStep - | Func::Saturate - | Func::Abs - | Func::Min - | Func::Max - | Func::ReduceSum - | Func::ReduceProd - | Func::ReduceMin - | Func::ReduceMax - | Func::Clz - | Func::Ctz - | Func::PopCount - | Func::Reverse - | Func::IsInf - | Func::IsNan - | Func::Acos - | Func::Acosh - | Func::Asin - | Func::Asinh - | Func::Atan - | Func::Atan2 - | Func::Atanh - | Func::Cos - | Func::Cosh - | Func::Sin - | Func::Sinh - | Func::Tan - | Func::Tanh - | Func::Exp - | Func::Exp2 - | Func::Exp10 - | Func::Log - | Func::Log2 - | Func::Log10 - | Func::Powi - | Func::Powf - | Func::Sqrt - | Func::Rsqrt - | Func::Ceil - | Func::Floor - | Func::Fract - | Func::Trunc - | Func::Round - | Func::Fma - | Func::Copysign - | Func::Cross - | Func::Dot - | Func::OuterProduct - | Func::Length - | Func::LengthSquared - | Func::Normalize - | Func::Faceforward - | Func::Distance - | Func::Reflect - | Func::Determinant - | Func::Transpose - | Func::Inverse - | Func::WarpIsFirstActiveLane - | Func::WarpFirstActiveLane - | Func::WarpActiveAllEqual - | Func::WarpActiveBitAnd - | Func::WarpActiveBitOr - | Func::WarpActiveBitXor - | Func::WarpActiveCountBits - | Func::WarpActiveMax - | Func::WarpActiveMin - | Func::WarpActiveProduct - | Func::WarpActiveSum - | Func::WarpActiveAll - | Func::WarpActiveAny - | Func::WarpActiveBitMask - | Func::WarpPrefixCountBits - | Func::WarpPrefixSum - | Func::WarpPrefixProduct - | Func::WarpReadLaneAt - | Func::WarpReadFirstLane - | Func::SynchronizeBlock - | Func::Vec - | Func::Vec2 - | Func::Vec3 - | Func::Vec4 - | Func::Permute - | Func::InsertElement - | Func::ExtractElement - | Func::Struct - | Func::Array - | Func::Mat - | Func::Mat2 - | Func::Mat3 - | Func::Mat4 - | Func::ShaderExecutionReorder - | Func::CpuCustomOp(_) => { - let args: Vec<_> = args - .iter() - .map(|&a| self.value_or_load(a, ctx, state)) - .collect(); - let call = state - .builder - .call(func.clone(), args.as_slice(), ret.type_().clone()); - process_return!(call) - } - // other, unused - Func::Unknown0 => todo!(), - Func::Unknown1 => todo!(), - } - } - - fn materialize_simple( - &self, - node: NodeRef, - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) { - if ctx.replayable.detect(node) && !node.is_unreachable() { - self.replay_value(node.clone(), ctx); - return; - } - match node.get().instruction.as_ref() { - Instruction::Local { init } => { - let init = self.value_or_load(init.clone(), ctx, state); - let this = self.ref_or_local(node, ctx, state); - state.builder.update(this, init); - } - Instruction::Update { var, value } => { - let value = self.value_or_load(value.clone(), ctx, state); - self.def_or_assign(var.clone(), value, ctx, state); - } - Instruction::Call(func, args) => { - self.materialize_call(node, func.clone(), args.iter().as_slice(), ctx, state); - } - Instruction::Loop { body, cond } => { - let mut body_state = state.clone_for_branch_block(); - for node in body.iter() { - self.materialize_simple(node, ctx, &mut body_state); - } - let cond = self.value_or_load(cond.clone(), ctx, &mut body_state); - let body = body_state.builder.finish(); - state.builder.loop_(body, cond); - } - Instruction::If { - cond, - true_branch, - false_branch, - } => { - let cond = self.value_or_load(cond.clone(), ctx, state); - let true_branch = self.materialize_branch_in_simple(true_branch, ctx, state); - let false_branch = self.materialize_branch_in_simple(false_branch, ctx, state); - state.builder.if_(cond, true_branch, false_branch); - } - Instruction::Switch { - value, - cases, - default, - } => { - let value = self.value_or_load(value.clone(), ctx, state); - let cases: Vec<_> = cases - .iter() - .map(|case| SwitchCase { - value: case.value, - block: self.materialize_branch_in_simple(&case.block, ctx, state), - }) - .collect(); - let default = self.materialize_branch_in_simple(default, ctx, state); - state.builder.switch(value, cases.as_slice(), default); - } - Instruction::AdScope { - body, - n_forward_grads, - forward, - } => { - let body = self.materialize_branch_in_simple(body, ctx, state); - if *forward { - state.builder.fwd_ad_scope(body, *n_forward_grads); - } else { - state.builder.ad_scope(body); - } - } - Instruction::RayQuery { - ray_query, - on_triangle_hit, - on_procedural_hit, - } => { - let ray_query = self.ref_or_local(ray_query.clone(), ctx, state); - let on_triangle_hit = - self.materialize_branch_in_simple(on_triangle_hit, ctx, state); - let on_procedural_hit = - self.materialize_branch_in_simple(on_procedural_hit, ctx, state); - ctx.uses_ray_tracing = true; - state.builder.ray_query( - ray_query, - on_triangle_hit, - on_procedural_hit, - node.type_().clone(), - ); - } - Instruction::Print { fmt, args } => { - let args: Vec<_> = args - .iter() - .map(|arg| self.value_or_load(arg.clone(), ctx, state)) - .collect(); - state.builder.print(fmt.clone(), args.as_slice()); - } - Instruction::AdDetach(body) => { - let body = self.materialize_branch_in_simple(body.as_ref(), ctx, state); - state.builder.ad_detach(body); - } - Instruction::Comment(msg) => { - state.builder.comment(msg.clone()); - } - Instruction::CoroRegister { .. } => { - // nothing to do - } - _ => unreachable!(), - } - } - - fn materialize_branch_in_simple( - &self, - block: &BasicBlock, - ctx: &mut CoroScopeMaterializerCtx, - state: &CoroScopeMaterializerState, - ) -> Pooled { - let mut branch_state = state.clone_for_branch_block(); - for node in block.iter() { - self.materialize_simple(node, ctx, &mut branch_state); - } - branch_state.builder.finish() - } - - fn materialize_instr( - &self, - instr: &CoroInstruction, - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) { - match instr { - CoroInstruction::Simple(node) => { - self.materialize_simple(node.clone(), ctx, state); - } - CoroInstruction::ConditionStackReplay { items } => { - state.builder.comment(CBoxedSlice::from( - "condition stack replay begin".to_string(), - )); - for item in items.iter() { - macro_rules! decode_value { - ($t:tt, $value: expr) => { - state.builder.const_(Const::$t($value)) - }; - } - let value = match item.node.type_().as_ref() { - Type::Primitive(p) => match p { - Primitive::Bool => decode_value!(Bool, item.value != 0), - Primitive::Int8 => decode_value!(Int8, item.value as i8), - Primitive::Uint8 => decode_value!(Uint8, item.value as u8), - Primitive::Int16 => decode_value!(Int16, item.value as i16), - Primitive::Uint16 => decode_value!(Uint16, item.value as u16), - Primitive::Int32 => decode_value!(Int32, item.value), - Primitive::Uint32 => decode_value!(Uint32, item.value as u32), - Primitive::Int64 => decode_value!(Int64, item.value as i64), - Primitive::Uint64 => decode_value!(Uint64, item.value as u64), - _ => unreachable!(), - }, - _ => unreachable!(), - }; - self.def_or_assign(item.node.clone(), value, ctx, state); - } - state - .builder - .comment(CBoxedSlice::from("condition stack replay end".to_string())); - } - CoroInstruction::MakeFirstFlag => { - self.make_first_flag(ctx); - } - CoroInstruction::SkipIfFirstFlag { body, .. } => { - state - .builder - .comment(CBoxedSlice::from("skip if first flag".to_string())); - let flag = state.builder.load(ctx.first_flag.unwrap().clone()); - let true_branch = self.materialize_branch_block(body, ctx, state); - let false_branch = IrBuilder::new(state.builder.pools.clone()).finish(); - state.builder.if_(flag, true_branch, false_branch); - state - .builder - .comment(CBoxedSlice::from("after skip if first flag".to_string())); - } - CoroInstruction::ClearFirstFlag(_) => { - state - .builder - .comment(CBoxedSlice::from("clear first flag".to_string())); - let v = state.builder.const_(Const::Bool(true)); - state.builder.update(ctx.first_flag.unwrap(), v); - } - CoroInstruction::Loop { body, cond } => { - // note: cond is inside the scope of body, so we have to convert it before pop - let mut body_state = state.clone_for_branch_block(); - self.materialize_instructions(body.as_slice(), ctx, &mut body_state); - let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { - self.value_or_load(*cond, ctx, &mut body_state) - } else { - unreachable!() - }; - // now we can pop the body and build the instruction - let body = body_state.builder.finish(); - state.builder.loop_(body, cond); - } - CoroInstruction::If { - cond, - true_branch, - false_branch, - } => { - let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { - self.value_or_load(cond.clone(), ctx, state) - } else { - unreachable!() - }; - let true_branch = self.materialize_branch_block(true_branch, ctx, state); - let false_branch = self.materialize_branch_block(false_branch, ctx, state); - state.builder.if_(cond, true_branch, false_branch); - } - CoroInstruction::Switch { - cond, - cases, - default, - } => { - let cond = if let CoroInstruction::Simple(cond) = self.get_instr(*cond) { - self.value_or_load(cond.clone(), ctx, state) - } else { - unreachable!() - }; - let cases: Vec<_> = cases - .iter() - .map(|case| SwitchCase { - value: case.value, - block: self.materialize_branch_block(&case.body, ctx, state), - }) - .collect(); - let default = self.materialize_branch_block(default, ctx, state); - state.builder.switch(cond, cases.as_slice(), default); - } - CoroInstruction::Suspend { token } => self.suspend(*token, ctx, &mut state.builder), - CoroInstruction::Terminate => self.terminate(&mut state.builder), - _ => unreachable!(), - } - } - - fn materialize_instructions( - &self, - instructions: &[CoroInstrRef], - ctx: &mut CoroScopeMaterializerCtx, - state: &mut CoroScopeMaterializerState, - ) { - for &instr in instructions { - self.materialize_instr(self.get_instr(instr), ctx, state); - } - } - - fn collect_target_tokens(&self, scope: CoroScopeRef) -> Vec { - let node = self.frame.transition_graph.nodes.get(&scope).unwrap(); - node.outlets.keys().cloned().collect::<_>() - } - - fn materialize(&self) -> CallableModule { - let mappings: HashMap<_, _> = self - .coro - .args - .iter() - .cloned() - .zip(self.args.iter().skip(1).cloned()) - .collect(); - let mut entry_builder = IrBuilder::new(self.coro.pools.clone()); - let mut ctx = CoroScopeMaterializerCtx { - mappings, - entry_builder, - first_flag: None, - uses_ray_tracing: false, - uses_coro_id: false, - replayable: ReplayableValueAnalysis::new(false), - }; - // resume states and generate first flag if not entry - if let Some(_) = self.token { - self.resume(&mut ctx); - } - // materialize the body - let mut b = IrBuilder::new_without_bb(self.coro.pools.clone()); - b.set_insert_point(ctx.entry_builder.get_insert_point()); - b.comment(CBoxedSlice::from(format!( - "coro body (token = {})", - self.token.unwrap_or(0) - ))); - let mut state = CoroScopeMaterializerState { builder: b }; - self.materialize_instructions(&self.get_scope().instructions, &mut ctx, &mut state); - let module = Module { - kind: ModuleKind::Function, - entry: ctx.entry_builder.finish(), - flags: ModuleFlags::empty(), - curve_basis_set: if ctx.uses_ray_tracing { - self.coro.module.curve_basis_set - } else { - CurveBasisSet::empty() - }, - pools: self.coro.pools.clone(), - }; - // compute the input/output coro frame fields so that the frontend scheduler can optimize the I/O - let (in_fields, out_fields) = self.frame.collect_io_fields(self.scope, ctx.uses_coro_id); - let designated_filed_offset = self.frame.get_designated_field_offset(); - let designated_fields: Vec<_> = self - .frame - .designated_field_names - .iter() - .enumerate() - .map(|(i, name)| CoroFrameDesignatedField { - name: CBoxedSlice::from(name.as_bytes()), - index: i as u32 + designated_filed_offset, - }) - .collect(); - let target_tokens = self.collect_target_tokens(self.scope); - // create the callable module - CallableModule { - module, - ret_type: Type::void(), - args: CBoxedSlice::new(self.args.clone()), - captures: CBoxedSlice::new(Vec::new()), - subroutines: CBoxedSlice::new(Vec::new()), - subroutine_ids: CBoxedSlice::new(Vec::new()), - coro_target_tokens: CBoxedSlice::new(target_tokens), - coro_frame_input_fields: CBoxedSlice::new(in_fields), - coro_frame_output_fields: CBoxedSlice::new(out_fields), - coro_frame_designated_fields: CBoxedSlice::new(designated_fields), - cpu_custom_ops: CBoxedSlice::new(Vec::new()), - pools: self.coro.pools.clone(), - } - } -} - -impl Transform for MaterializeCoro { - fn transform_callable(&self, callable: CallableModule) -> CallableModule { - let callable = CanonicalizeControlFlow.transform_callable(callable); - // let callable = Mem2Reg.transform_callable(callable); - let callable = DemoteLocals.transform_callable(callable); - let callable = DeferLoad.transform_callable(callable); - let coro_graph = CoroGraph::from(&callable.module); - let coro_use_def = CoroUseDefAnalysis::analyze(&coro_graph); - let coro_transition_graph = CoroTransitionGraph::build(&coro_graph, &coro_use_def); - let coro_frame = CoroFrame::build(&coro_graph, &coro_transition_graph); - coro_frame.dump(); - let mut entry = CoroScopeMaterializer::new(&coro_frame, &callable, None).materialize(); - let subroutines: Vec<_> = coro_graph - .tokens - .keys() - .map(|token| { - let r = - CoroScopeMaterializer::new(&coro_frame, &callable, Some(*token)).materialize(); - CallableModuleRef(CArc::new(r)) - }) - .collect(); - let subroutine_token: Vec<_> = coro_graph.tokens.keys().copied().collect(); - entry.subroutines = CBoxedSlice::new(subroutines); - entry.subroutine_ids = CBoxedSlice::new(subroutine_token); - entry - } -} diff --git a/src/rust/luisa_compute_ir/src/transform/mod.rs b/src/rust/luisa_compute_ir/src/transform/mod.rs index c29ec31e2..e6ed7f19a 100644 --- a/src/rust/luisa_compute_ir/src/transform/mod.rs +++ b/src/rust/luisa_compute_ir/src/transform/mod.rs @@ -20,7 +20,6 @@ pub mod copy_propagation; pub mod defer_load; pub mod inliner; pub mod materialize_coro; -pub mod materialize_coro_v2; pub mod remove_phi; use crate::ir::{self, CallableModule, KernelModule, Module, ModuleFlags}; @@ -137,10 +136,6 @@ pub extern "C" fn luisa_compute_ir_transform_pipeline_add_transform( let transform = materialize_coro::MaterializeCoro; unsafe { (*pipeline).add_transform(Box::new(transform)) }; } - "materialize_coro_v2" => { - let transform = materialize_coro_v2::MaterializeCoro; - unsafe { (*pipeline).add_transform(Box::new(transform)) }; - } "mem2reg" => { let transform = mem2reg::Mem2Reg; unsafe { (*pipeline).add_transform(Box::new(transform)) }; diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 3c1a6ea98..2cb0a8c8d 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -197,13 +197,8 @@ if (LUISA_COMPUTE_TEST_QT_INTEROP) endif () # coroutine tests -#luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) -luisa_compute_add_executable(test_coro_sdf_renderer_v2 coro/sdf_renderer_v2.cpp) -#luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) -#luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) -luisa_compute_add_executable(test_coro_path_tracing_v2 coro/path_tracing_v2.cpp) -luisa_compute_add_executable(test_coro_path_tracing_wavefront_v2 coro/path_tracing_wavefront_v2.cpp) -luisa_compute_add_executable(test_coro_path_tracing_persistent_threads_v2 coro/path_tracing_persistent_threads_v2.cpp) -#luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) -luisa_compute_add_executable(test_coro_helloworld_v2 coro/helloworld_v2.cpp) -#luisa_compute_add_executable(test_coro_playground coro/playground.cpp) +luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) +luisa_compute_add_executable(test_coro_path_tracing_state_machine coro/path_tracing_state_machine.cpp) +luisa_compute_add_executable(test_coro_path_tracing_wavefront coro/path_tracing_wavefront.cpp) +luisa_compute_add_executable(test_coro_path_tracing_persistent_threads coro/path_tracing_persistent_threads.cpp) +luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) diff --git a/src/tests/coro/helloworld.cpp b/src/tests/coro/helloworld.cpp index e99349e6d..ae3c70e4d 100644 --- a/src/tests/coro/helloworld.cpp +++ b/src/tests/coro/helloworld.cpp @@ -1,19 +1,14 @@ -#include -#include -#include -#include -#include -#include -#include +#include using namespace luisa; using namespace luisa::compute; -struct alignas(4) CoroFrame { -}; -LUISA_COROFRAME_STRUCT(CoroFrame){}; +namespace luisa::compute::coroutine { + +}// namespace luisa::compute::coroutine int main(int argc, char *argv[]) { + Context context{argv[0]}; if (argc <= 1) { exit(1); } Device device = context.create_device(argv[1]); @@ -22,31 +17,37 @@ int main(int argc, char *argv[]) { Image image{device.create_image(PixelStorage::BYTE4, resolution)}; luisa::vector host_image(image.view().size_bytes()); - Coroutine coro = [&](Var &frame) noexcept { - $suspend("1"); - Var coord = coro_id().xy(); - $suspend("2"); - Var uv = (make_float2(coord) + 0.5f) / make_float2(resolution); - $suspend("3"); - image->write(coord, make_float4(uv, 0.5f, 1.0f)); + Kernel1D test = [] { + coroutine::Generator range = [](UInt n) { + auto x = def(0u); + $while (x < n) { + $yield(x); + x += 1u; + }; + }; + for (auto x : range(100u)) { + device_log("x = {}", x); + } + }; + + coroutine::Coroutine nested2 = [](UInt n) { + $for (i, n) { + device_log("nested2: {} / {}", i, n); + $suspend(); + }; }; - auto type = Type::of(); - auto frame_buffer = device.create_buffer(resolution.x * resolution.y); - - Kernel2D mega_kernel = [&] { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame); - coro[1](frame); - coro[2](frame); - coro[3](frame); + + coroutine::Coroutine nested1 = [&](UInt n) { + $for (i, n) { + $await nested2(i); + device_log("nested1: {} / {}", i, n); + }; }; - auto shader = device.compile(mega_kernel); - stream << shader().dispatch(resolution) - << synchronize(); + coroutine::Coroutine top_level = [&]() { + $await nested1(10u); + }; - stream << image.copy_to(host_image.data()) - << synchronize(); - stbi_write_png("test_helloworld.png", resolution.x, resolution.y, 4, host_image.data(), 0); + coroutine::StateMachineCoroScheduler sched{device, top_level}; + stream << sched().dispatch(1u) << synchronize(); } diff --git a/src/tests/coro/helloworld_v2.cpp b/src/tests/coro/helloworld_v2.cpp deleted file mode 100644 index ae3c70e4d..000000000 --- a/src/tests/coro/helloworld_v2.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include - -using namespace luisa; -using namespace luisa::compute; - -namespace luisa::compute::coroutine { - -}// namespace luisa::compute::coroutine - -int main(int argc, char *argv[]) { - - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - Device device = context.create_device(argv[1]); - Stream stream = device.create_stream(); - constexpr uint2 resolution = make_uint2(1024, 1024); - Image image{device.create_image(PixelStorage::BYTE4, resolution)}; - luisa::vector host_image(image.view().size_bytes()); - - Kernel1D test = [] { - coroutine::Generator range = [](UInt n) { - auto x = def(0u); - $while (x < n) { - $yield(x); - x += 1u; - }; - }; - for (auto x : range(100u)) { - device_log("x = {}", x); - } - }; - - coroutine::Coroutine nested2 = [](UInt n) { - $for (i, n) { - device_log("nested2: {} / {}", i, n); - $suspend(); - }; - }; - - coroutine::Coroutine nested1 = [&](UInt n) { - $for (i, n) { - $await nested2(i); - device_log("nested1: {} / {}", i, n); - }; - }; - - coroutine::Coroutine top_level = [&]() { - $await nested1(10u); - }; - - coroutine::StateMachineCoroScheduler sched{device, top_level}; - stream << sched().dispatch(1u) << synchronize(); -} diff --git a/src/tests/coro/path_tracing.cpp b/src/tests/coro/path_tracing.cpp deleted file mode 100644 index 472498f4f..000000000 --- a/src/tests/coro/path_tracing.cpp +++ /dev/null @@ -1,380 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../common/cornell_box.h" - -#define TINYOBJLOADER_IMPLEMENTATION -#include "../common/tiny_obj_loader.h" - -using namespace luisa; -using namespace luisa::compute; - -struct Onb { - float3 tangent; - float3 binormal; - float3 normal; -}; - -LUISA_STRUCT(Onb, tangent, binormal, normal) { - [[nodiscard]] Float3 to_world(Expr v) const noexcept { - return v.x * tangent + v.y * binormal + v.z * normal; - } -}; - -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - -int main(int argc, char *argv[]) { - - log_level_verbose(); - - Context context{argv[0]}; - if (argc <= 1) { - LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); - exit(1); - } - Device device = context.create_device(argv[1]); - - // load the Cornell Box scene - tinyobj::ObjReaderConfig obj_reader_config; - obj_reader_config.triangulate = true; - obj_reader_config.vertex_color = false; - tinyobj::ObjReader obj_reader; - if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { - luisa::string_view error_message = "unknown error."; - if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } - LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); - } - if (auto &&e = obj_reader.Warning(); !e.empty()) { - LUISA_WARNING_WITH_LOCATION("{}", e); - } - - auto &&p = obj_reader.GetAttrib().vertices; - luisa::vector vertices; - vertices.reserve(p.size() / 3u); - for (uint i = 0u; i < p.size(); i += 3u) { - vertices.emplace_back(make_float3( - p[i + 0u], p[i + 1u], p[i + 2u])); - } - LUISA_INFO( - "Loaded mesh with {} shape(s) and {} vertices.", - obj_reader.GetShapes().size(), vertices.size()); - - BindlessArray heap = device.create_bindless_array(); - Stream stream = device.create_stream(StreamTag::GRAPHICS); - Buffer vertex_buffer = device.create_buffer(vertices.size()); - stream << vertex_buffer.copy_from(vertices.data()); - luisa::vector meshes; - luisa::vector> triangle_buffers; - for (auto &&shape : obj_reader.GetShapes()) { - uint index = static_cast(meshes.size()); - std::vector const &t = shape.mesh.indices; - uint triangle_count = t.size() / 3u; - LUISA_INFO( - "Processing shape '{}' at index {} with {} triangle(s).", - shape.name, index, triangle_count); - luisa::vector indices; - indices.reserve(t.size()); - for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } - Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); - Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); - heap.emplace_on_update(index, triangle_buffer); - stream << triangle_buffer.copy_from(indices.data()) - << mesh.build(); - } - - Accel accel = device.create_accel({}); - for (Mesh &m : meshes) { - accel.emplace_back(m, make_float4x4(1.0f)); - } - stream << heap.update() - << accel.build() - << synchronize(); - - Constant materials{ - make_float3(0.725f, 0.710f, 0.680f),// floor - make_float3(0.725f, 0.710f, 0.680f),// ceiling - make_float3(0.725f, 0.710f, 0.680f),// back wall - make_float3(0.140f, 0.450f, 0.091f),// right wall - make_float3(0.630f, 0.065f, 0.050f),// left wall - make_float3(0.725f, 0.710f, 0.680f),// short box - make_float3(0.725f, 0.710f, 0.680f),// tall box - make_float3(0.000f, 0.000f, 0.000f),// light - }; - - Callable linear_to_srgb = [&](Var x) noexcept { - return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, - 12.92f * x, - x <= 0.00031308f)); - }; - - Callable tea = [](UInt v0, UInt v1) noexcept { - UInt s0 = def(0u); - for (uint n = 0u; n < 4u; n++) { - s0 += 0x9e3779b9u; - v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); - v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); - } - return v0; - }; - - Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { - UInt2 p = dispatch_id().xy(); - UInt state = tea(p.x, p.y); - seed_image.write(p, make_uint4(state)); - }; - - Callable lcg = [](UInt &state) noexcept { - constexpr uint lcg_a = 1664525u; - constexpr uint lcg_c = 1013904223u; - state = lcg_a * state + lcg_c; - return cast(state & 0x00ffffffu) * - (1.0f / static_cast(0x01000000u)); - }; - - Callable make_onb = [](const Float3 &normal) noexcept { - Float3 binormal = normalize(ite( - abs(normal.x) > abs(normal.z), - make_float3(-normal.y, normal.x, 0.0f), - make_float3(0.0f, -normal.z, normal.y))); - Float3 tangent = normalize(cross(binormal, normal)); - return def(tangent, binormal, normal); - }; - - Callable generate_ray = [](Float2 p) noexcept { - static constexpr float fov = radians(27.8f); - static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); - Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); - Float3 direction = normalize(pixel - origin); - return make_ray(origin, direction); - }; - - Callable cosine_sample_hemisphere = [](Float2 u) noexcept { - Float r = sqrt(u.x); - Float phi = 2.0f * constants::pi * u.y; - return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); - }; - - Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { - return pdf_a / max(pdf_a + pdf_b, 1e-4f); - }; - - auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; - - Coroutine raytracing_coro = [&](Var &, ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - UInt2 coord = dispatch_id().xy(); - Float frame_size = min(resolution.x, resolution.y).cast(); - UInt state = seed_image.read(coord).x; - Float rx = lcg(state); - Float ry = lcg(state); - Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; - Float3 radiance = def(make_float3(0.0f)); - $suspend("per_spp"); - $for (i, spp_per_dispatch) { - Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); - Float3 beta = def(make_float3(1.0f)); - Float pdf_bsdf = def(0.0f); - constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); - constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; - constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; - constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); - Float light_area = length(cross(light_u, light_v)); - Float3 light_normal = normalize(cross(light_u, light_v)); - $suspend("per_depth"); - $for (depth, 10u) { - // trace - $suspend("before_tracing"); - Var hit = accel.intersect(ray, {}); - reorder_shader_execution(); - $if (hit->miss()) { $break; }; - Var triangle = heap->buffer(hit.inst).read(hit.prim); - Float3 p0 = vertex_buffer->read(triangle.i0); - Float3 p1 = vertex_buffer->read(triangle.i1); - Float3 p2 = vertex_buffer->read(triangle.i2); - Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); - Float3 n = normalize(cross(p1 - p0, p2 - p0)); - $suspend("after_tracing"); - - Float cos_wo = dot(-ray->direction(), n); - $if (cos_wo < 1e-4f) { $break; }; - - // hit light - $if (hit.inst == static_cast(meshes.size() - 1u)) { - $if (depth == 0u) { - radiance += light_emission; - } - $else { - Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); - Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); - radiance += mis_weight * beta * light_emission; - }; - $break; - }; - - // sample light - $suspend("sample_light"); - Float ux_light = lcg(state); - Float uy_light = lcg(state); - Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; - Float3 pp = offset_ray_origin(p, n); - Float3 pp_light = offset_ray_origin(p_light, light_normal); - Float d_light = distance(pp, pp_light); - Float3 wi_light = normalize(pp_light - pp); - Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); - Bool occluded = accel.intersect_any(shadow_ray, {}); - Float cos_wi_light = dot(wi_light, n); - Float cos_light = -dot(light_normal, wi_light); - Float3 albedo = materials.read(hit.inst); - $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { - Float pdf_light = (d_light * d_light) / (light_area * cos_light); - Float pdf_bsdf = cos_wi_light * inv_pi; - Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); - Float3 bsdf = albedo * inv_pi * cos_wi_light; - radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); - }; - - // sample BSDF - $suspend("sample_bsdf"); - Var onb = make_onb(n); - Float ux = lcg(state); - Float uy = lcg(state); - Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); - Float cos_wi = abs(wi_local.z); - Float3 new_direction = onb->to_world(wi_local); - ray = make_ray(pp, new_direction); - pdf_bsdf = cos_wi * inv_pi; - beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f - - // rr - $suspend("rr"); - Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); - $if (l == 0.0f) { $break; }; - Float q = max(l, 0.05f); - Float r = lcg(state); - $if (r >= q) { $break; }; - beta *= 1.0f / q; - }; - }; - $suspend("write_film"); - radiance /= static_cast(spp_per_dispatch); - seed_image.write(coord, make_uint4(state)); - $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; - image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); - }; - - Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - raytracing_coro(frame, image, seed_image, accel, resolution); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= raytracing_coro.suspend_count(); i++) { - $case (i) { - raytracing_coro[i](frame, image, seed_image, accel, resolution); - }; - } - $default { - $return(); - }; - }; - }; - }; - - Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { - UInt2 p = dispatch_id().xy(); - Float4 accum = accum_image.read(p); - Float3 curr = curr_image.read(p).xyz(); - accum_image.write(p, accum + make_float4(curr, 1.f)); - }; - - Callable aces_tonemapping = [](Float3 x) noexcept { - static constexpr float a = 2.51f; - static constexpr float b = 0.03f; - static constexpr float c = 2.43f; - static constexpr float d = 0.59f; - static constexpr float e = 0.14f; - return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); - }; - - Kernel2D clear_kernel = [](ImageFloat image) noexcept { - image.write(dispatch_id().xy(), make_float4(0.0f)); - }; - - Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { - UInt2 coord = dispatch_id().xy(); - Float4 hdr = hdr_image.read(coord); - Float3 ldr = hdr.xyz() / hdr.w * scale; - $if (!is_hdr) { - ldr = linear_to_srgb(ldr); - }; - ldr_image.write(coord, make_float4(ldr, 1.0f)); - }; - - ShaderOption o{.enable_debug_info = false}; - auto clear_shader = device.compile(clear_kernel, o); - auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); - auto accumulate_shader = device.compile(accumulate_kernel, o); - auto raytracing_shader = device.compile(mega_kernel, o); - auto make_sampler_shader = device.compile(make_sampler_kernel, o); - - static constexpr uint2 resolution = make_uint2(1024u); - Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); - Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); - luisa::vector> host_image(resolution.x * resolution.y); - CommandList cmd_list; - Image seed_image = device.create_image(PixelStorage::INT1, resolution); - cmd_list << clear_shader(accum_image).dispatch(resolution) - << make_sampler_shader(seed_image).dispatch(resolution); - - Window window{"path tracing", resolution}; - Swapchain swap_chain = device.create_swapchain( - stream, - SwapchainOption{ - .display = window.native_display(), - .window = window.native_handle(), - .size = make_uint2(resolution), - .wants_hdr = false, - .wants_vsync = false, - .back_buffer_count = 3, - }); - Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); - double last_time = 0.0; - uint frame_count = 0u; - Clock clock; - - while (!window.should_close()) { - cmd_list << raytracing_shader(framebuffer, seed_image, accel, resolution) - .dispatch(resolution) - << accumulate_shader(accum_image, framebuffer) - .dispatch(resolution); - cmd_list << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution); - stream << cmd_list.commit() - << swap_chain.present(ldr_image) << synchronize(); - window.poll_events(); - double dt = clock.toc() - last_time; - last_time = clock.toc(); - frame_count += spp_per_dispatch; - LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", - frame_count, dt, spp_per_dispatch / dt * 1000); - } - stream - << ldr_image.copy_to(host_image.data()) - << synchronize(); - - LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); - stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); -} diff --git a/src/tests/coro/path_tracing_persistent_threads_v2.cpp b/src/tests/coro/path_tracing_persistent_threads.cpp similarity index 100% rename from src/tests/coro/path_tracing_persistent_threads_v2.cpp rename to src/tests/coro/path_tracing_persistent_threads.cpp diff --git a/src/tests/coro/path_tracing_v2.cpp b/src/tests/coro/path_tracing_state_machine.cpp similarity index 100% rename from src/tests/coro/path_tracing_v2.cpp rename to src/tests/coro/path_tracing_state_machine.cpp diff --git a/src/tests/coro/path_tracing_wavefront_v2.cpp b/src/tests/coro/path_tracing_wavefront.cpp similarity index 100% rename from src/tests/coro/path_tracing_wavefront_v2.cpp rename to src/tests/coro/path_tracing_wavefront.cpp diff --git a/src/tests/coro/playground.cpp b/src/tests/coro/playground.cpp deleted file mode 100644 index 1b5dc1fc6..000000000 --- a/src/tests/coro/playground.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -using namespace luisa; -using namespace luisa::compute; - -struct alignas(4) CoroFrame {}; -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - -int main(int argc, char *argv[]) { - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - Device device = context.create_device(argv[1]); - Stream stream = device.create_stream(); - - Coroutine coro = [&](Var &, Int &a) noexcept { - $for (i, 10) { - device_log("before {}: {}", i, a); - auto token = $suspend(std::make_pair(a, "x")); - // device_log("after {}: {}", i, a); - a += 1; - $suspend(); - }; - }; - - Kernel1D mega_kernel = [&] { - device_log("scheduler: init"); - auto frame = make_coroframe(dispatch_id()); - device_log("coro_id = {}", frame->coro_id()); - device_log("scheduler: enter entry"); - auto a = def(0); - coro(frame, a); - device_log("scheduler: exit entry"); - $loop { - auto token = frame->target_token(); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - device_log("scheduler: enter coro {}, a = {}", i, a); - coro[i](frame, a); - device_log("promise x: {}", frame->promise("x")); - device_log("scheduler: exit coro {}, a = {}", i, a); - }; - } - $default { - device_log("scheduler: terminate"); - $return(); - }; - }; - }; - }; - - auto shader = device.compile(mega_kernel); - stream << shader().dispatch(1u) - << synchronize(); -} diff --git a/src/tests/coro/sdf_renderer.cpp b/src/tests/coro/sdf_renderer.cpp index edfb8f79e..cc8aba1f2 100644 --- a/src/tests/coro/sdf_renderer.cpp +++ b/src/tests/coro/sdf_renderer.cpp @@ -1,15 +1,10 @@ #include #include #include -#include using namespace luisa; using namespace luisa::compute; -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - int main(int argc, char *argv[]) { Context context{argv[0]}; @@ -124,9 +119,9 @@ int main(int argc, char *argv[]) { }; }; - Coroutine coro = [&](Var &, ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { + coroutine::Coroutine coro = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { Float2 resolution = make_float2(width, height); - UInt2 coord = make_uint2(coro_id().x / height, coro_id().x % height); + UInt2 coord = dispatch_id().xy(); $if (frame_index == 0u) { seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); @@ -184,26 +179,14 @@ int main(int argc, char *argv[]) { Image seed_image = device.create_image(PixelStorage::INT1, width, height); Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); - Kernel2D mega_kernel = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame, seed_image, accum_image, frame_index); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - coro[i](frame, seed_image, accum_image, frame_index); - }; - } - $default { - $return(); - }; - }; - }; - }; + coroutine::PersistentThreadsCoroSchedulerConfig config{ + .thread_count = 64_k, + .block_size = 64u, + .fetch_size = 3u, + .shared_memory_soa = false, + .global_memory_ext = false}; + coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; - auto shader = device.compile(mega_kernel); auto clear_shader = device.compile<2>([&] { auto coord = dispatch_id().xy(); accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); @@ -211,37 +194,13 @@ int main(int argc, char *argv[]) { }); luisa::vector host_image(accum_image.view().size_bytes()); - //coro::SimpleCoroDispatcher Wdispatcher{&coro, device, resolution.x * resolution.y}; - // coro::WavefrontCoroDispatcherConfig wf_config{ - // .max_instance_count = resolution.x * resolution.y, - // .debug = false, - // }; - // coro::WavefrontCoroDispatcher Wdispatcher{&coro, device, stream, wf_config}; - coro::PersistentCoroDispatcherConfig config{ - .max_thread_count = 256u * 256u, - .block_size = 128u, - .fetch_size = 3u, - .debug = false, - }; - coro::PersistentCoroDispatcher Wdispatcher{&coro, device, stream, config}; - stream << clear_shader().dispatch(resolution) - << synchronize(); - /*for (auto i = 0u; i < 100u; i++) { - stream << shader(seed_image, accum_image, i).dispatch(resolution); - }*/ + stream << clear_shader().dispatch(resolution) << synchronize(); Clock clk; - auto samples = 1000; - //Wdispatcher(seed_image, accum_image, 0u, resolution.x * resolution.y); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //return 0; + auto samples = 1024; for (auto i = 0u; i < samples; ++i) { LUISA_INFO("spp {}", i); - Wdispatcher(seed_image, accum_image, i, resolution.x * resolution.y); - stream << Wdispatcher.await_all(); - // << synchronize(); + stream << scheduler(seed_image, accum_image, i).dispatch(width, height) + << synchronize(); } stream << synchronize(); auto dt = clk.toc(); diff --git a/src/tests/coro/sdf_renderer_v2.cpp b/src/tests/coro/sdf_renderer_v2.cpp deleted file mode 100644 index cc8aba1f2..000000000 --- a/src/tests/coro/sdf_renderer_v2.cpp +++ /dev/null @@ -1,212 +0,0 @@ -#include -#include -#include - -using namespace luisa; -using namespace luisa::compute; - -int main(int argc, char *argv[]) { - - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - - static constexpr uint width = 1280u; - static constexpr uint height = 720u; - - static constexpr int max_ray_depth = 6; - static constexpr float eps = 1e-4f; - static constexpr float inf = 1e10f; - static constexpr float fov = 0.23f; - static constexpr float dist_limit = 100.0f; - static constexpr float3 camera_pos = make_float3(0.0f, 0.32f, 3.7f); - static constexpr float3 light_pos = make_float3(-1.5f, 0.6f, 0.3f); - static constexpr float3 light_normal = make_float3(1.0f, 0.0f, 0.0f); - static constexpr float light_radius = 2.0f; - - Clock clock; - - Callable intersect_light = [](Float3 pos, Float3 d) noexcept { - Float cos_w = dot(-d, light_normal); - Float dist = dot(d, light_pos - pos); - Float D = dist / cos_w; - Float dist_to_center = distance_squared(light_pos, pos + D * d); - Bool valid = cos_w > 0.0f & dist > 0.0f & dist_to_center < light_radius * light_radius; - return ite(valid, D, inf); - }; - - Callable tea = [](UInt v0, UInt v1) noexcept { - Var s0 = 0u; - for (uint n = 0u; n < 4u; n++) { - s0 += 0x9e3779b9u; - v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); - v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); - } - return v0; - }; - - Callable rand = [](UInt &state) noexcept { - constexpr uint lcg_a = 1664525u; - constexpr uint lcg_c = 1013904223u; - state = lcg_a * state + lcg_c; - return cast(state) / cast(std::numeric_limits::max()); - }; - - Callable out_dir = [&rand](Float3 n, UInt &seed) noexcept { - Float3 u = ite( - abs(n.y) < 1.0f - eps, - normalize(cross(n, make_float3(0.0f, 1.0f, 0.0f))), - make_float3(1.f, 0.f, 0.f)); - Float3 v = cross(n, u); - Float phi = 2.0f * pi * rand(seed); - Float ay = sqrt(rand(seed)); - Float ax = sqrt(1.0f - ay * ay); - return ax * (cos(phi) * u + sin(phi) * v) + ay * n; - }; - - Callable make_nested = [](Float f) noexcept { - static constexpr float freq = 40.0f; - f *= freq; - f = ite(f < 0.f, ite(cast(f) % 2 == 0, 1.f - fract(f), fract(f)), f); - return (f - 0.2f) * (1.0f / freq); - }; - - Callable sdf = [&make_nested](Float3 o) noexcept { - Float wall = min(o.y + 0.1f, o.z + 0.4f); - Float sphere = distance(o, make_float3(0.0f, 0.35f, 0.0f)) - 0.36f; - Float3 q = abs(o - make_float3(0.8f, 0.3f, 0.0f)) - 0.3f; - Float box = length(max(q, 0.0f)) + min(max(max(q.x, q.y), q.z), 0.0f); - Float3 O = o - make_float3(-0.8f, 0.3f, 0.0f); - Float2 d = make_float2(length(make_float2(O.x, O.z)) - 0.3f, abs(O.y) - 0.3f); - Float cylinder = min(max(d.x, d.y), 0.0f) + length(max(d, 0.0f)); - Float geometry = make_nested(min(min(sphere, box), cylinder)); - Float g = max(geometry, -(0.32f - (o.y * 0.6f + o.z * 0.8f))); - return min(wall, g); - }; - - Callable ray_march = [&sdf](Float3 p, Float3 d) noexcept { - Float dist = def(0.0f); - $for (j, 100) { - Float s = sdf(p + dist * d); - $if (s <= 1e-6f | dist >= inf) { $break; }; - dist += s; - }; - return min(dist, inf); - }; - - Callable sdf_normal = [&sdf](Float3 p) noexcept { - static constexpr float d = 1e-3f; - Float3 n = def(make_float3()); - Float sdf_center = sdf(p); - for (uint i = 0; i < 3; i++) { - Float3 inc = p; - inc[i] += d; - n[i] = (1.0f / d) * (sdf(inc) - sdf_center); - } - return normalize(n); - }; - - Callable next_hit = [&ray_march, &sdf_normal](Float &closest, Float3 &normal, Float3 &c, Float3 pos, Float3 d) noexcept { - closest = inf; - normal = make_float3(); - c = make_float3(); - Float ray_march_dist = ray_march(pos, d); - $if (ray_march_dist < min(dist_limit, closest)) { - closest = ray_march_dist; - Float3 hit_pos = pos + d * closest; - normal = sdf_normal(hit_pos); - Int t = cast((hit_pos.x + 10.0f) * 1.1f + 0.5f) % 3; - c = make_float3(0.4f) + make_float3(0.3f, 0.2f, 0.3f) * ite(t == make_int3(0, 1, 2), 1.0f, 0.0f); - }; - }; - - coroutine::Coroutine coro = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { - Float2 resolution = make_float2(width, height); - UInt2 coord = dispatch_id().xy(); - - $if (frame_index == 0u) { - seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); - accum_image.write(coord, make_float4(make_float3(0.0f), 1.0f)); - }; - - Float aspect_ratio = resolution.x / resolution.y; - Float3 pos = def(camera_pos); - UInt seed = seed_image.read(coord).x; - Float ux = rand(seed); - Float uy = rand(seed); - Float2 uv = make_float2(coord.x + ux, height - 1u - coord.y + uy); - Float3 d = make_float3( - 2.0f * fov * uv / resolution.y - fov * make_float2(aspect_ratio, 1.0f) - 1e-5f, -1.0f); - d = normalize(d); - Float3 throughput = def(make_float3(1.0f, 1.0f, 1.0f)); - Float hit_light = def(0.0f); - $for (depth, max_ray_depth) { - - $suspend("2"); - Float closest = def(0.0f); - Float3 normal = def(make_float3()); - Float3 c = def(make_float3()); - next_hit(closest, normal, c, pos, d); - Float dist_to_light = intersect_light(pos, d); - $if (dist_to_light < closest) { - // $if (depth == 0) { - // device_log("xxxxxcoord {} {}", coord.x, coord.y); - // }; - hit_light = 1.0f; - $break; - }; - $if (length_squared(normal) == 0.0f) { - // $if (depth == 0) { - // device_log("coord {} {}", coord.x, coord.y); - // }; - $break; - }; - Float3 hit_pos = pos + closest * d; - d = out_dir(normal, seed); - pos = hit_pos + 1e-4f * d; - throughput *= c; - }; - $suspend("3"); - Float3 accum_color = lerp(accum_image.read(coord).xyz(), throughput.xyz() * hit_light, 1.0f / (frame_index + 1.0f)); - accum_image.write(coord, make_float4(accum_color, 1.0f)); - //$suspend("4"); - seed_image.write(coord, make_uint4(seed)); - }; - - auto device = context.create_device(argv[1]); - auto stream = device.create_stream(); - constexpr auto resolution = make_uint2(width, height); - - Image seed_image = device.create_image(PixelStorage::INT1, width, height); - Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); - - coroutine::PersistentThreadsCoroSchedulerConfig config{ - .thread_count = 64_k, - .block_size = 64u, - .fetch_size = 3u, - .shared_memory_soa = false, - .global_memory_ext = false}; - coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; - - auto clear_shader = device.compile<2>([&] { - auto coord = dispatch_id().xy(); - accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); - seed_image->write(coord, make_uint4(coord.y * resolution.y + coord.x)); - }); - - luisa::vector host_image(accum_image.view().size_bytes()); - stream << clear_shader().dispatch(resolution) << synchronize(); - Clock clk; - auto samples = 1024; - for (auto i = 0u; i < samples; ++i) { - LUISA_INFO("spp {}", i); - stream << scheduler(seed_image, accum_image, i).dispatch(width, height) - << synchronize(); - } - stream << synchronize(); - auto dt = clk.toc(); - LUISA_INFO("Time: {} ms ({} spp/s)", dt, samples * 1e3 / dt); - - stream << accum_image.copy_to(host_image.data()) - << synchronize(); - stbi_write_hdr("test_sdf.hdr", resolution.x, resolution.y, 4, reinterpret_cast(host_image.data())); -} diff --git a/src/tests/coro/sdf_renderer_wo_dispatcher.cpp b/src/tests/coro/sdf_renderer_wo_dispatcher.cpp deleted file mode 100644 index 2462810a0..000000000 --- a/src/tests/coro/sdf_renderer_wo_dispatcher.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include -#include - -using namespace luisa; -using namespace luisa::compute; - -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame){}; - -int main(int argc, char *argv[]) { - - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - - static constexpr uint width = 1280u; - static constexpr uint height = 720u; - - static constexpr int max_ray_depth = 6; - static constexpr float eps = 1e-4f; - static constexpr float inf = 1e10f; - static constexpr float fov = 0.23f; - static constexpr float dist_limit = 100.0f; - static constexpr float3 camera_pos = make_float3(0.0f, 0.32f, 3.7f); - static constexpr float3 light_pos = make_float3(-1.5f, 0.6f, 0.3f); - static constexpr float3 light_normal = make_float3(1.0f, 0.0f, 0.0f); - static constexpr float light_radius = 2.0f; - - Clock clock; - - Callable intersect_light = [](Float3 pos, Float3 d) noexcept { - Float cos_w = dot(-d, light_normal); - Float dist = dot(d, light_pos - pos); - Float D = dist / cos_w; - Float dist_to_center = distance_squared(light_pos, pos + D * d); - Bool valid = cos_w > 0.0f & dist > 0.0f & dist_to_center < light_radius * light_radius; - return ite(valid, D, inf); - }; - - Callable tea = [](UInt v0, UInt v1) noexcept { - Var s0 = 0u; - for (uint n = 0u; n < 4u; n++) { - s0 += 0x9e3779b9u; - v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); - v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); - } - return v0; - }; - - Callable rand = [](UInt &state) noexcept { - constexpr uint lcg_a = 1664525u; - constexpr uint lcg_c = 1013904223u; - state = lcg_a * state + lcg_c; - return cast(state) / cast(std::numeric_limits::max()); - }; - - Callable out_dir = [&rand](Float3 n, UInt &seed) noexcept { - Float3 u = ite( - abs(n.y) < 1.0f - eps, - normalize(cross(n, make_float3(0.0f, 1.0f, 0.0f))), - make_float3(1.f, 0.f, 0.f)); - Float3 v = cross(n, u); - Float phi = 2.0f * pi * rand(seed); - Float ay = sqrt(rand(seed)); - Float ax = sqrt(1.0f - ay * ay); - return ax * (cos(phi) * u + sin(phi) * v) + ay * n; - }; - - Callable make_nested = [](Float f) noexcept { - static constexpr float freq = 40.0f; - f *= freq; - f = ite(f < 0.f, ite(cast(f) % 2 == 0, 1.f - fract(f), fract(f)), f); - return (f - 0.2f) * (1.0f / freq); - }; - - Callable sdf = [&make_nested](Float3 o) noexcept { - Float wall = min(o.y + 0.1f, o.z + 0.4f); - Float sphere = distance(o, make_float3(0.0f, 0.35f, 0.0f)) - 0.36f; - Float3 q = abs(o - make_float3(0.8f, 0.3f, 0.0f)) - 0.3f; - Float box = length(max(q, 0.0f)) + min(max(max(q.x, q.y), q.z), 0.0f); - Float3 O = o - make_float3(-0.8f, 0.3f, 0.0f); - Float2 d = make_float2(length(make_float2(O.x, O.z)) - 0.3f, abs(O.y) - 0.3f); - Float cylinder = min(max(d.x, d.y), 0.0f) + length(max(d, 0.0f)); - Float geometry = make_nested(min(min(sphere, box), cylinder)); - Float g = max(geometry, -(0.32f - (o.y * 0.6f + o.z * 0.8f))); - return min(wall, g); - }; - - Callable ray_march = [&sdf](Float3 p, Float3 d) noexcept { - Float dist = def(0.0f); - $for (j, 100) { - Float s = sdf(p + dist * d); - $if (s <= 1e-6f | dist >= inf) { $break; }; - dist += s; - }; - return min(dist, inf); - }; - - Callable sdf_normal = [&sdf](Float3 p) noexcept { - static constexpr float d = 1e-3f; - Float3 n = def(make_float3()); - Float sdf_center = sdf(p); - for (uint i = 0; i < 3; i++) { - Float3 inc = p; - inc[i] += d; - n[i] = (1.0f / d) * (sdf(inc) - sdf_center); - } - return normalize(n); - }; - - Callable next_hit = [&ray_march, &sdf_normal](Float &closest, Float3 &normal, Float3 &c, Float3 pos, Float3 d) noexcept { - closest = inf; - normal = make_float3(); - c = make_float3(); - Float ray_march_dist = ray_march(pos, d); - $if (ray_march_dist < min(dist_limit, closest)) { - closest = ray_march_dist; - Float3 hit_pos = pos + d * closest; - normal = sdf_normal(hit_pos); - Int t = cast((hit_pos.x + 10.0f) * 1.1f + 0.5f) % 3; - c = make_float3(0.4f) + make_float3(0.3f, 0.2f, 0.3f) * ite(t == make_int3(0, 1, 2), 1.0f, 0.0f); - }; - }; - - Coroutine coro = [&](Var &, ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { - Float2 resolution = make_float2(width, height); - UInt2 coord = make_uint2(coro_id().x / height, coro_id().x % height); - - $if (frame_index == 0u) { - seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); - accum_image.write(coord, make_float4(make_float3(0.0f), 1.0f)); - }; - - Float aspect_ratio = resolution.x / resolution.y; - Float3 pos = def(camera_pos); - UInt seed = seed_image.read(coord).x; - Float ux = rand(seed); - Float uy = rand(seed); - Float2 uv = make_float2(coord.x + ux, height - 1u - coord.y + uy); - Float3 d = make_float3( - 2.0f * fov * uv / resolution.y - fov * make_float2(aspect_ratio, 1.0f) - 1e-5f, -1.0f); - d = normalize(d); - Float3 throughput = def(make_float3(1.0f, 1.0f, 1.0f)); - Float hit_light = def(0.0f); - $for (depth, max_ray_depth) { - - $suspend("2"); - Float closest = def(0.0f); - Float3 normal = def(make_float3()); - Float3 c = def(make_float3()); - next_hit(closest, normal, c, pos, d); - Float dist_to_light = intersect_light(pos, d); - $if (dist_to_light < closest) { - // $if (depth == 0) { - // device_log("xxxxxcoord {} {}", coord.x, coord.y); - // }; - hit_light = 1.0f; - $break; - }; - $if (length_squared(normal) == 0.0f) { - // $if (depth == 0) { - // device_log("coord {} {}", coord.x, coord.y); - // }; - $break; - }; - Float3 hit_pos = pos + closest * d; - d = out_dir(normal, seed); - pos = hit_pos + 1e-4f * d; - throughput *= c; - }; - $suspend("3"); - - Float3 accum_color = lerp(accum_image.read(coord).xyz(), throughput.xyz() * hit_light, 1.0f / (frame_index + 1.0f)); - accum_image.write(coord, make_float4(accum_color, 1.0f)); - //$suspend("4"); - seed_image.write(coord, make_uint4(seed)); - }; - - auto device = context.create_device(argv[1]); - auto stream = device.create_stream(); - constexpr auto resolution = make_uint2(width, height); - constexpr auto spp = 1000u; - - Image seed_image = device.create_image(PixelStorage::INT1, width, height); - Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); - - Kernel1D mega_kernel = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame, seed_image, accum_image, frame_index); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - coro[i](frame, seed_image, accum_image, frame_index); - }; - } - $default { - $return(); - }; - }; - }; - }; - - auto shader = device.compile(mega_kernel); - auto clear_shader = device.compile<2>([&] { - auto coord = dispatch_id().xy(); - accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); - seed_image->write(coord, make_uint4(coord.y * resolution.y + coord.x)); - }); - - luisa::vector host_image(accum_image.view().size_bytes()); - stream << clear_shader().dispatch(resolution) - << synchronize(); - Clock clk; - for (auto i = 0u; i < spp; i++) { - LUISA_INFO("spp {}", i); - stream << shader(seed_image, accum_image, i).dispatch(resolution.x * resolution.y); - } - stream << synchronize(); - auto dt = clk.toc(); - LUISA_INFO("Time: {} ms ({} spp/s)", dt, 1e6 / dt); - - stream << accum_image.copy_to(host_image.data()) - << synchronize(); - stbi_write_hdr("test_sdf_wo_dispatcher.hdr", resolution.x, resolution.y, 4, reinterpret_cast(host_image.data())); -} From 870ae95564d9b2735670cbe4be734dde92c8ca99 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 00:41:40 +0800 Subject: [PATCH 57/67] fix coro wavefront test --- src/tests/coro/path_tracing_wavefront.cpp | 49 ++--------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/src/tests/coro/path_tracing_wavefront.cpp b/src/tests/coro/path_tracing_wavefront.cpp index 3eddcaf76..7203511a1 100644 --- a/src/tests/coro/path_tracing_wavefront.cpp +++ b/src/tests/coro/path_tracing_wavefront.cpp @@ -9,43 +9,6 @@ #define TINYOBJLOADER_IMPLEMENTATION #include "../common/tiny_obj_loader.h" -// namespace luisa::compute::coroutine { -// -// template -// class WavefrontCoroScheduler final : public CoroScheduler { -// -// private: -// Device &_device; -// Coroutine _coro; -// Buffer _frame_buffer; -// luisa>> _shaders; -// -// private: -// void _prepare(uint n) noexcept { -// _frame_buffer = _device.create_buffer(_coro.frame(), n); -// } -// -// private: -// void _dispatch(Stream &stream, uint3 dispatch_size, -// compute::detail::prototype_to_shader_invocation_t... args) noexcept override { -// auto s = luisa::make_ulong3(dispatch_size); -// auto n = s.x * s.y * s.z; -// LUISA_ASSERT(n < std::numeric_limits::max(), "Dispatch size is too large."); -// LUISA_ASSERT(n > 0u, "Dispatch size must be greater than zero."); -// if (!_frame_buffer || _frame_buffer.size() < n) { _prepare(n); } -// // generate -// stream << _shaders[0](args...).dispatch(dispatch_size); -// // loop over the subroutines until we found that all of them are done -// -// } -// -// public: -// explicit WavefrontCoroScheduler(Device &device, Coroutine coro) noexcept -// : _device{device}, _coro{std::move(coro)} {} -// }; -// -// }// namespace luisa::compute::coroutine - using namespace luisa; using namespace luisa::compute; @@ -196,14 +159,10 @@ int main(int argc, char *argv[]) { return pdf_a / max(pdf_a + pdf_b, 1e-4f); }; - auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 16u; + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { - UInt pixel_count = resolution.x * resolution.y; - UInt pixel_id = coro_id().x % pixel_count; - UInt2 coord = make_uint2(pixel_id % resolution.x, pixel_id / resolution.x); - - // UInt2 coord = dispatch_id().xy(); + UInt2 coord = dispatch_id().xy(); Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; @@ -311,7 +270,7 @@ int main(int argc, char *argv[]) { }; coroutine::WavefrontCoroSchedulerConfig config{ - .thread_count = 256_k, + .thread_count = 4_M, .soa = true, .sort = true, }; @@ -381,7 +340,7 @@ int main(int argc, char *argv[]) { while (!window.should_close()) { stream << scheduler(framebuffer, seed_image, accel, resolution) - .dispatch(resolution.x * resolution.y) + .dispatch(resolution.x, resolution.y, spp_per_dispatch) << accumulate_shader(accum_image, framebuffer) .dispatch(resolution) << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) From bdc3528c435e829dab7d513110dddbf4a3d39254 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 01:28:46 +0800 Subject: [PATCH 58/67] fix build --- include/luisa/coro/coro_frame_soa.h | 14 ++++++-------- include/luisa/dsl/builtin.h | 8 -------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/include/luisa/coro/coro_frame_soa.h b/include/luisa/coro/coro_frame_soa.h index c1bc45d68..69d86ce93 100644 --- a/include/luisa/coro/coro_frame_soa.h +++ b/include/luisa/coro/coro_frame_soa.h @@ -96,7 +96,7 @@ class SOA : public SOABase { auto fields = _desc->type()->members(); _field_offsets->reserve(fields.size()); for (const auto field_type : fields) { - auto alignment = std::max(field_type->alignment(), 4ull); + auto alignment = std::max(field_type->alignment(), 4u); size_bytes = (size_bytes + alignment - 1u) & ~(alignment - 1u); _field_offsets->emplace_back(size_bytes); auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); @@ -168,7 +168,7 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto alignment = std::max(field_type->alignment(), 4ull); + auto alignment = std::max(field_type->alignment(), 4u); auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto f = fb->local(field_type); @@ -186,7 +186,7 @@ struct Expr> : SOABase { auto fb = detail::FunctionBuilder::current(); auto field_type = _desc->type()->members()[field_index]; auto offset = _field_offsets->at(field_index); - auto alignment = std::max(field_type->alignment(), 4ull); + auto alignment = std::max(field_type->alignment(), 4u); auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; fb->call(CallOp::BYTE_BUFFER_WRITE, @@ -206,7 +206,7 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto alignment = std::max(field_type->alignment(), 4ull); + auto alignment = std::max(field_type->alignment(), 4u); auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto s = fb->call( @@ -228,14 +228,12 @@ struct Expr> : SOABase { if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } auto field_type = fields[i]; auto offset = _field_offsets->at(i); - auto alignment = std::max(field_type->alignment(), 4ull); + auto alignment = std::max(field_type->alignment(), 4u); auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; auto f = fb->member(field_type, frame.expression(), i); fb->call(CallOp::BYTE_BUFFER_WRITE, - {_expression, - detail::extract_expression(offset_var), - f}); + {_expression, detail::extract_expression(offset_var), f}); } } diff --git a/include/luisa/dsl/builtin.h b/include/luisa/dsl/builtin.h index a9382b6cb..622d762cb 100644 --- a/include/luisa/dsl/builtin.h +++ b/include/luisa/dsl/builtin.h @@ -1779,14 +1779,6 @@ template return def(detail::FunctionBuilder::current()->coro_token()); } -template -inline auto read_promise(T &&t, S &&name) noexcept { - return def(detail::FunctionBuilder::current()->read_promise_( - Type::of(), - detail::extract_expression(std::forward(t)), - std::forward(name))); -} - template [[nodiscard]] inline auto grad(const Local &x) noexcept { auto b = detail::FunctionBuilder::current(); From f9c6a8853fee0645fa69e5b35beaae0372937532 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 01:40:39 +0800 Subject: [PATCH 59/67] minor --- include/luisa/coro/schedulers/wavefront.h | 50 +++++++++++------------ src/tests/coro/path_tracing_wavefront.cpp | 10 ++--- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/include/luisa/coro/schedulers/wavefront.h b/include/luisa/coro/schedulers/wavefront.h index 5602b2012..bfca6c361 100644 --- a/include/luisa/coro/schedulers/wavefront.h +++ b/include/luisa/coro/schedulers/wavefront.h @@ -15,9 +15,9 @@ namespace luisa::compute::coroutine { struct WavefrontCoroSchedulerConfig { // uint3 block_size = make_uint3(8u, 8u, 1u); uint thread_count = 2_M; - bool soa = true; - bool sort = true;// use sort for coro token gathering - bool compact = true; + bool global_memory_soa = true; + bool gather_by_sorting = true;// use sort for coro token gathering + bool frame_buffer_compaction = true; uint hint_range = 0xffff'ffff; luisa::vector hint_fields; }; @@ -94,12 +94,12 @@ class WavefrontCoroScheduler : public CoroScheduler { void _create_shader(Device &device, const Coroutine &coroutine, const WavefrontCoroSchedulerConfig &config) noexcept { _config = config; - if (_config.soa) { + if (_config.global_memory_soa) { _frame_soa = device.create_soa(coroutine.shared_frame(), _config.thread_count); } else { _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), _config.thread_count); } - bool use_sort = _config.sort || !_config.hint_fields.empty(); + bool use_sort = _config.gather_by_sorting || !_config.hint_fields.empty(); _max_sub_coro = coroutine.subroutine_count(); _resume_index = device.create_buffer(_config.thread_count); if (use_sort) { @@ -138,7 +138,7 @@ class WavefrontCoroScheduler : public CoroScheduler { $if (index > _config.thread_count) { device_log("Index out of range {}/{}", index, _config.thread_count); }; - if (_config.soa) { + if (_config.global_memory_soa) { return _frame_soa->read_field(index, "target_token") & coro_token_valid_mask; } else { CoroFrame frame = _frame_buffer->read(index); @@ -155,7 +155,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Callable get_coro_hint = [&](UInt index, BufferUInt val) { if (!_config.hint_fields.empty()) { auto id = keep_index(index, val); - if (_config.soa) { + if (_config.global_memory_soa) { return _frame_soa->read_field(id, "coro_hint"); } else { CoroFrame frame = _frame_buffer->read(id); @@ -168,7 +168,7 @@ class WavefrontCoroScheduler : public CoroScheduler { _sort_temp_storage = radix_sort::temp_storage( device, _config.thread_count, std::max(std::min(_config.hint_range, 128u), _max_sub_coro)); } - if (_config.sort) { + if (_config.gather_by_sorting) { _sort_token = radix_sort::instance<>( device, _config.thread_count, _sort_temp_storage, &get_coro_token, &identical, &get_coro_token, 1, _max_sub_coro); @@ -194,12 +194,12 @@ class WavefrontCoroScheduler : public CoroScheduler { $return(); }; UInt frame_id; - if (!_config.compact) { + if (!_config.frame_buffer_compaction) { frame_id = index->read(x); } else { frame_id = offset + x; } - if (!_config.sort) { + if (!_config.gather_by_sorting) { count.atomic(0u).fetch_add(-1u); } auto global_id = st_task_id + x; @@ -210,12 +210,12 @@ class WavefrontCoroScheduler : public CoroScheduler { auto global_id_y = global_id_xy / dispatch_shape.x; CoroFrame frame = coroutine.instantiate(make_uint3(global_id_x, global_id_y, global_id_z)); coroutine.entry()(frame, args...); - if (_config.soa) { + if (_config.global_memory_soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); } else { _frame_buffer->write(frame_id, frame); } - if (!_config.sort) { + if (!_config.gather_by_sorting) { auto nxt = frame.target_token & coro_token_valid_mask; count.atomic(nxt).fetch_add(1u); } @@ -233,23 +233,23 @@ class WavefrontCoroScheduler : public CoroScheduler { }; auto frame_id = index.read(x); CoroFrame frame = coroutine.instantiate(); - if (_config.soa) { + if (_config.global_memory_soa) { //frame = frame_buffer.read(frame_id); frame = _frame_soa->read(frame_id, coroutine.graph()->node(i).input_fields()); } else { frame = _frame_buffer->read(frame_id); } - if (!_config.sort) { + if (!_config.gather_by_sorting) { count.atomic(i).fetch_sub(1u); } coroutine[i](frame, args...); - if (_config.soa) { + if (_config.global_memory_soa) { _frame_soa->write(frame_id, frame, coroutine.graph()->node(i).output_fields()); } else { _frame_buffer->write(frame_id, frame); } - if (!_config.sort) { + if (!_config.gather_by_sorting) { auto nxt = frame.target_token & coro_token_valid_mask; $if (nxt < _max_sub_coro) { count.atomic(nxt).fetch_add(1u); @@ -275,7 +275,7 @@ class WavefrontCoroScheduler : public CoroScheduler { Kernel1D _gather_kernel = [&](BufferUInt index, BufferUInt prefix, UInt n) { auto x = dispatch_x(); UInt r_id; - if (_config.soa) { + if (_config.global_memory_soa) { r_id = _frame_soa->read_field(x, "target_token") & coro_token_valid_mask; } else { auto frame = _frame_buffer->read(x); @@ -291,7 +291,7 @@ class WavefrontCoroScheduler : public CoroScheduler { auto x = dispatch_x(); $if (empty_offset + x < n) { UInt token; - if (_config.soa) { + if (_config.global_memory_soa) { token = _frame_soa->read_field(empty_offset + x, "target_token"); } else { CoroFrame frame = _frame_buffer->read(empty_offset + x); @@ -300,20 +300,20 @@ class WavefrontCoroScheduler : public CoroScheduler { $if ((token & coro_token_valid_mask) != 0u) { auto res = _global_buffer->atomic(0u).fetch_add(1u); auto slot = index.read(res); - if (!_config.sort) { + if (!_config.gather_by_sorting) { $while (slot >= empty_offset) { res = _global_buffer->atomic(0u).fetch_add(1u); slot = index.read(res); }; } - if (_config.soa) { + if (_config.global_memory_soa) { auto frame = _frame_soa->read(empty_offset + x); _frame_soa->write(slot, frame); } else { auto frame = _frame_buffer->read(empty_offset + x); _frame_buffer->write(slot, frame); } - if (_config.soa) { + if (_config.global_memory_soa) { _frame_soa->write_field(empty_offset + x, 0u, "target_token"); } else { CoroFrame empty_frame = coroutine.instantiate(); @@ -329,7 +329,7 @@ class WavefrontCoroScheduler : public CoroScheduler { auto x = dispatch_x(); $if (x < n) { CoroFrame frame = coroutine.instantiate(); - if (_config.soa) { + if (_config.global_memory_soa) { _frame_soa->write(x, frame, std::array{0u, 1u}); } else { _frame_buffer->write(x, frame); @@ -363,7 +363,7 @@ class WavefrontCoroScheduler : public CoroScheduler { } } void _await_step(Stream &stream) noexcept { - if (_config.sort) { + if (_config.gather_by_sorting) { auto host_update = [&] { _host_empty = true; for (uint i = 0u; i < _max_sub_coro; i++) { @@ -380,7 +380,7 @@ class WavefrontCoroScheduler : public CoroScheduler { if (_host_count[0] > _config.thread_count * 0.5f && !this->_all_dispatched()) { auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _config.thread_count && _config.compact) { + if (_host_count[0] != _config.thread_count && _config.frame_buffer_compaction) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index, _config.thread_count - _host_count[0], _config.thread_count).dispatch(_host_count[0]); } @@ -413,7 +413,7 @@ class WavefrontCoroScheduler : public CoroScheduler { stream << _gather_shader(_resume_index, _resume_offset, _config.thread_count).dispatch(_config.thread_count); if (_host_count[0] > _config.thread_count / 2 && !_all_dispatched()) { auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); - if (_host_count[0] != _config.thread_count && _config.compact) { + if (_host_count[0] != _config.thread_count && _config.frame_buffer_compaction) { stream << _clear_shader(_global_buffer, 1).dispatch(1u); stream << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), diff --git a/src/tests/coro/path_tracing_wavefront.cpp b/src/tests/coro/path_tracing_wavefront.cpp index 7203511a1..c52f68b9e 100644 --- a/src/tests/coro/path_tracing_wavefront.cpp +++ b/src/tests/coro/path_tracing_wavefront.cpp @@ -159,7 +159,7 @@ int main(int argc, char *argv[]) { return pdf_a / max(pdf_a + pdf_b, 1e-4f); }; - auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + auto spp_per_dispatch = 16u; coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { UInt2 coord = dispatch_id().xy(); @@ -171,7 +171,7 @@ int main(int argc, char *argv[]) { Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; Float3 radiance = def(make_float3(0.0f)); - $suspend("per_dispatch"); + // $suspend("per_dispatch"); // $if (all(coord == make_uint2(50u, 500u))) { // device_log("coord: {}", coord); @@ -253,7 +253,7 @@ int main(int argc, char *argv[]) { beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f // rr - $suspend("rr"); + // $suspend("rr"); Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); $if (l == 0.0f) { $break; }; Float q = max(l, 0.05f); @@ -271,8 +271,8 @@ int main(int argc, char *argv[]) { coroutine::WavefrontCoroSchedulerConfig config{ .thread_count = 4_M, - .soa = true, - .sort = true, + .global_memory_soa = true, + .gather_by_sorting = false, }; coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; // coroutine::PersistentThreadsCoroScheduler scheduler{device, coro}; From deb31e591aa277201446c7109f16d5169b5b649f Mon Sep 17 00:00:00 2001 From: chenxin Date: Wed, 15 May 2024 16:02:22 +0800 Subject: [PATCH 60/67] fix soa buffer type --- include/luisa/coro/coro_frame_soa.h | 2 +- src/runtime/byte_buffer.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/coro/coro_frame_soa.h b/include/luisa/coro/coro_frame_soa.h index 69d86ce93..1410f4927 100644 --- a/include/luisa/coro/coro_frame_soa.h +++ b/include/luisa/coro/coro_frame_soa.h @@ -104,7 +104,7 @@ class SOA : public SOABase { } size_bytes = (size_bytes + 3u) & ~3u; auto info = device->create_buffer( - Type::of(), + Type::of(), size_bytes, nullptr); _buffer = std::move(ByteBuffer{device, info}); diff --git a/src/runtime/byte_buffer.cpp b/src/runtime/byte_buffer.cpp index 178178f96..10b7911d7 100644 --- a/src/runtime/byte_buffer.cpp +++ b/src/runtime/byte_buffer.cpp @@ -25,7 +25,7 @@ ByteBuffer::ByteBuffer(DeviceInterface *device, size_t size_bytes) noexcept if ((size_bytes & 3) != 0) [[unlikely]] { detail::error_buffer_size_not_aligned(4); } - return device->create_buffer(Type::of(), size_bytes, nullptr); + return device->create_buffer(Type::of(), size_bytes, nullptr); }()} {} ByteBuffer::~ByteBuffer() noexcept { @@ -42,7 +42,7 @@ ByteBuffer Device::create_byte_buffer(size_t byte_size) noexcept { } ByteBuffer Device::import_external_byte_buffer(void *external_memory, size_t byte_size) noexcept { - auto info = impl()->create_buffer(Type::of(), byte_size, external_memory); + auto info = impl()->create_buffer(Type::of(), byte_size, external_memory); return ByteBuffer{impl(), info}; } From fcad6ac3cb94ba2b00d48776a6f28bc038d16114 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 21:52:43 +0800 Subject: [PATCH 61/67] refactor generator impl --- include/luisa/coro/coro_func.h | 116 +++++++++++++++++---------------- include/luisa/dsl/sugar.h | 8 +-- src/coro/coro_func.cpp | 23 +++---- 3 files changed, 73 insertions(+), 74 deletions(-) diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index 5adc60712..f7506cc61 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -15,9 +15,6 @@ namespace detail { LC_CORO_API void coroutine_chained_await_impl( CoroFrame &frame, uint node_count, luisa::move_only_function node) noexcept; -LC_CORO_API void coroutine_generator_step_impl( - CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept; }// namespace detail template @@ -145,94 +142,99 @@ class Generator { static_assert(luisa::always_false_v); }; -template -class Generator { +namespace detail { - static_assert(!std::is_same_v, - "Generator function must not return void."); +LC_CORO_API void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept; + +template +class GeneratorIter : public concepts::Noncopyable { private: - Coroutine _coro; + uint _n; + CoroFrame _frame; + using Resume = luisa::move_only_function; + Resume resume; -public: - template - requires std::negation_v>> && - std::negation_v>> - Generator(Def &&f) noexcept : _coro{std::forward(f)} {} +private: + template + friend class Generator; + GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept + : _n{n}, _frame{std::move(frame)}, resume{std::move(resume)} {} public: - [[nodiscard]] auto coroutine() const noexcept { return _coro; } + [[nodiscard]] auto set_id(Expr coro_id) && noexcept { + _frame.coro_id = coro_id; + return std::move(*this); + } + [[nodiscard]] Bool has_next() const noexcept { return !_frame.is_terminated(); } + [[nodiscard]] Var next() noexcept { + coroutine_generator_next_impl(_frame, _n, resume); + return _frame.get("__yielded_value"); + } private: - template - class Iterator { + class RangeForIterator { private: - luisa::unique_ptr _frame; - U _f; + GeneratorIter &_g; bool _invoked{false}; LoopStmt *_loop{nullptr}; private: - friend class Generator; - Iterator(luisa::unique_ptr frame, U f) noexcept - : _frame{std::move(frame)}, _f{std::move(f)} {} + friend class GeneratorIter; + explicit RangeForIterator(GeneratorIter &g) noexcept : _g{g} {} public: - Iterator &operator++() noexcept { + RangeForIterator &operator++() noexcept { _invoked = true; - _f(*_frame, false); compute::detail::FunctionBuilder::current()->pop_scope(_loop->body()); return *this; } [[nodiscard]] bool operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } - [[nodiscard]] Var> operator*() noexcept { - _f(*_frame, true); + [[nodiscard]] Var operator*() noexcept { auto fb = compute::detail::FunctionBuilder::current(); _loop = fb->loop_(); fb->push_scope(_loop->body()); - dsl::if_(_frame->is_terminated(), [] { dsl::break_(); }); - return _frame->get("yield_value"); + dsl::if_(!_g.has_next(), [] { dsl::break_(); }); + return _g.next(); } }; -private: - template - [[nodiscard]] auto _make_iterator(U u, luisa::optional> coro_id) const noexcept { - auto frame = luisa::make_unique(coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate()); - return Iterator{std::move(frame), std::move(u)}; - } +public: + [[nodiscard]] auto begin() noexcept { return RangeForIterator{*this}; } + [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } +}; + +}// namespace detail + +template +class Generator { + + static_assert(!std::is_same_v, + "Generator function must not return void."); private: - template - class Stepper : public concepts::Noncopyable { - private: - luisa::optional> _coro_id; - const Generator &_g; - U _f; + Coroutine _coro; - private: - friend class Generator; - Stepper(const Generator &g, U f) noexcept : _g{g}, _f{std::move(f)} {} +public: + template + requires std::negation_v>> && + std::negation_v>> + Generator(Def &&f) noexcept : _coro{std::forward(f)} {} - public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _coro_id.emplace(coro_id); - return std::move(*this); - } - [[nodiscard]] auto begin() noexcept { return _g._make_iterator(std::move(_f), std::move(_coro_id)); } - [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } - }; +public: + [[nodiscard]] auto coroutine() const noexcept { return _coro; } public: [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=, this](CoroFrame &frame, bool is_entry) noexcept { - detail::coroutine_generator_step_impl( - frame, _coro.subroutine_count(), is_entry, - [&](CoroToken token, CoroFrame &ff) noexcept { - _coro.subroutine(token)(ff, args...); - }); + return detail::GeneratorIter{ + _coro.subroutine_count(), + _coro.instantiate(), + [=, this](CoroFrame &frame, CoroToken token) noexcept { + _coro[token](frame, args...); + }, }; - return Stepper{*this, std::move(f)}; } }; diff --git a/include/luisa/dsl/sugar.h b/include/luisa/dsl/sugar.h index fb36b3a9c..c0efbf1f2 100644 --- a/include/luisa/dsl/sugar.h +++ b/include/luisa/dsl/sugar.h @@ -122,10 +122,10 @@ namespace luisa::compute::dsl_detail { #define $promise(...) ::luisa::compute::dsl::promise(__VA_ARGS__) -#define $yield(...) \ - do { \ - ::luisa::compute::dsl::promise("yield_value", __VA_ARGS__); \ - ::luisa::compute::dsl::suspend(); \ +#define $yield(...) \ + do { \ + ::luisa::compute::dsl::promise("__yielded_value", __VA_ARGS__); \ + ::luisa::compute::dsl::suspend(); \ } while (false) #define $await ::luisa::compute::coroutine::detail::CoroAwaitInvoker{} % diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index 966719c01..c8fa05541 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -23,18 +23,15 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, }; } -void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_entry, - luisa::move_only_function node) noexcept { - if (is_entry) { - node(coro_token_entry, frame); - } else { - $switch (frame.target_token) { - for (auto i = 1u; i < node_count; i++) { - $case (i) { node(i, frame); }; - } - $default { dsl::unreachable(); }; - }; - } +inline void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept { + $switch (frame.target_token) { + for (auto i = 0u; i < node_count; i++) { + $case (i) { resume(frame, i); }; + } + $default { dsl::unreachable(); }; + }; } -}// namespace luisa::compute::coro_v2::detail +}// namespace luisa::compute::coroutine::detail From f57ab51419702755fcdaf1b7f1a8f36a3068791a Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 15 May 2024 22:59:33 +0800 Subject: [PATCH 62/67] refactor --- include/luisa/coro/coro_func.h | 101 +++++++++++++++++++-------------- src/tests/coro/helloworld.cpp | 9 ++- 2 files changed, 66 insertions(+), 44 deletions(-) diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index f7506cc61..fbb820eb8 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -22,6 +22,26 @@ class Coroutine { static_assert(luisa::always_false_v); }; +namespace detail { + +class LC_CORO_API CoroAwaiter : public concepts::Noncopyable { + +private: + using Await = luisa::move_only_function; + Await _await; + +private: + template + friend class Coroutine; + explicit CoroAwaiter(Await await) noexcept + : _await{std::move(await)} {} + +public: + void await() && noexcept { _await(); } +}; + +}// namespace detail + template class Coroutine { @@ -95,41 +115,29 @@ class Coroutine { [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } private: - template - class Awaiter : public concepts::Noncopyable { - private: - U _f; - luisa::optional> _coro_id; - - private: - friend class Coroutine; - explicit Awaiter(U f) noexcept : _f{std::move(f)} {} - - public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _coro_id.emplace(coro_id); - return std::move(*this); - } - void await() && noexcept { return _f(std::move(_coro_id)); } - }; - -public: - [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { - auto f = [=, this](luisa::optional> coro_id) noexcept { + [[nodiscard]] auto _await(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return detail::CoroAwaiter{[=, coro_id = std::move(coro_id), this]() noexcept { auto frame = coro_id ? instantiate(*coro_id) : instantiate(); detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept { subroutine(token)(f, args...); }); - }; - return Awaiter{std::move(f)}; + }}; + } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::make_optional(coro_id), args...); } }; namespace detail { struct CoroAwaitInvoker { - template - void operator%(A &&awaiter) && noexcept { - std::forward(awaiter).await(); + void operator%(CoroAwaiter &&awaiter) const && noexcept { + std::move(awaiter).await(); } }; }// namespace detail @@ -155,24 +163,21 @@ class GeneratorIter : public concepts::Noncopyable { uint _n; CoroFrame _frame; using Resume = luisa::move_only_function; - Resume resume; + Resume _resume; private: template friend class Generator; GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept - : _n{n}, _frame{std::move(frame)}, resume{std::move(resume)} {} + : _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {} public: - [[nodiscard]] auto set_id(Expr coro_id) && noexcept { - _frame.coro_id = coro_id; - return std::move(*this); - } - [[nodiscard]] Bool has_next() const noexcept { return !_frame.is_terminated(); } - [[nodiscard]] Var next() noexcept { - coroutine_generator_next_impl(_frame, _n, resume); - return _frame.get("__yielded_value"); + auto &update() noexcept { + coroutine_generator_next_impl(_frame, _n, _resume); + return *this; } + [[nodiscard]] Bool is_terminated() const noexcept { return _frame.is_terminated(); } + [[nodiscard]] Var value() noexcept { return _frame.get("__yielded_value"); } private: class RangeForIterator { @@ -196,8 +201,9 @@ class GeneratorIter : public concepts::Noncopyable { auto fb = compute::detail::FunctionBuilder::current(); _loop = fb->loop_(); fb->push_scope(_loop->body()); - dsl::if_(!_g.has_next(), [] { dsl::break_(); }); - return _g.next(); + _g.update(); + dsl::if_(_g.is_terminated(), [] { dsl::break_(); }); + return _g.value(); } }; @@ -226,16 +232,25 @@ class Generator { public: [[nodiscard]] auto coroutine() const noexcept { return _coro; } -public: - [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { +private: + [[nodiscard]] auto _iter(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { return detail::GeneratorIter{ _coro.subroutine_count(), - _coro.instantiate(), - [=, this](CoroFrame &frame, CoroToken token) noexcept { - _coro[token](frame, args...); + coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate(), + [=, this](CoroFrame &f, CoroToken token) noexcept { + _coro[token](f, args...); }, }; } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::make_optional(coro_id), args...); + } }; }// namespace luisa::compute::coroutine diff --git a/src/tests/coro/helloworld.cpp b/src/tests/coro/helloworld.cpp index ae3c70e4d..3509a6e0d 100644 --- a/src/tests/coro/helloworld.cpp +++ b/src/tests/coro/helloworld.cpp @@ -25,10 +25,17 @@ int main(int argc, char *argv[]) { x += 1u; }; }; - for (auto x : range(100u)) { + auto iter = range(10u); + $while (!iter.update().is_terminated()) { + auto x = iter.value(); + device_log("x = {}", x); + }; + for (auto x : range(10u)) { device_log("x = {}", x); } }; + auto shader = device.compile(test); + stream << shader().dispatch(1) << synchronize(); coroutine::Coroutine nested2 = [](UInt n) { $for (i, n) { From 1161b00aff97d2540257e956336e16cd655573e2 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 01:01:07 +0800 Subject: [PATCH 63/67] support `Shared::operator[]` for non-SoA --- include/luisa/coro/coro_frame.h | 4 +- include/luisa/coro/coro_frame_desc.h | 1 + include/luisa/coro/coro_frame_smem.h | 57 +++++++++++++++++-- include/luisa/coro/coro_func.h | 13 ++--- .../coro/schedulers/persistent_threads.h | 2 +- src/coro/CMakeLists.txt | 1 + src/coro/coro_frame.cpp | 2 +- src/coro/coro_frame_desc.cpp | 6 ++ src/coro/coro_frame_smem.cpp | 7 +++ src/coro/coro_func.cpp | 2 +- src/coro/schedulers/persistent_threads.cpp | 48 ++++++++++++---- 11 files changed, 111 insertions(+), 32 deletions(-) create mode 100644 src/coro/coro_frame_smem.cpp diff --git a/include/luisa/coro/coro_frame.h b/include/luisa/coro/coro_frame.h index bac5580f8..943622ec6 100644 --- a/include/luisa/coro/coro_frame.h +++ b/include/luisa/coro/coro_frame.h @@ -14,14 +14,14 @@ class LC_CORO_API CoroFrame { private: luisa::shared_ptr _desc; - const RefExpr *_expression; + const Expression *_expression; public: Var &coro_id; Var &target_token; public: - CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept; + CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept; CoroFrame(CoroFrame &&another) noexcept; CoroFrame(const CoroFrame &another) noexcept; CoroFrame &operator=(const CoroFrame &rhs) noexcept; diff --git a/include/luisa/coro/coro_frame_desc.h b/include/luisa/coro/coro_frame_desc.h index 0c81cbcf8..82a873e93 100644 --- a/include/luisa/coro/coro_frame_desc.h +++ b/include/luisa/coro/coro_frame_desc.h @@ -26,6 +26,7 @@ class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this create(const Type *type, DesignatedFieldDict m) noexcept; [[nodiscard]] auto type() const noexcept { return _type; } + [[nodiscard]] const Type *field(uint index) const noexcept; [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; [[nodiscard]] luisa::string dump() const noexcept; diff --git a/include/luisa/coro/coro_frame_smem.h b/include/luisa/coro/coro_frame_smem.h index 42159b6a4..c23057014 100644 --- a/include/luisa/coro/coro_frame_smem.h +++ b/include/luisa/coro/coro_frame_smem.h @@ -4,17 +4,24 @@ #pragma once +#include #include #include namespace luisa::compute { +namespace detail { +[[noreturn]] LC_CORO_API void error_coro_frame_smem_subscript_on_soa() noexcept; +}// namespace detail + template<> class Shared { private: luisa::shared_ptr _desc; luisa::vector _expressions; + // we need to return a `CoroFrame &` for operator[] with proper lifetime + luisa::vector> _temp_frames; size_t _size; private: @@ -39,8 +46,9 @@ class Shared { } public: - Shared(luisa::shared_ptr desc, size_t n, - bool soa = true, luisa::span soa_excluded_fields = {}) noexcept + Shared(luisa::shared_ptr desc, + size_t n, bool soa = false, + luisa::span soa_excluded_fields = {}) noexcept : _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); } Shared(Shared &&) noexcept = default; @@ -52,10 +60,11 @@ class Shared { [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } [[nodiscard]] auto size() const noexcept { return _size; } -public: +private: /// Read index with active fields template - [[nodiscard]] auto read(I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + requires is_integral_expr_v + [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); auto frame = fb->local(_desc->type()); @@ -71,7 +80,7 @@ class Shared { if (m == 0u) { auto t = Type::of(); auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); - std::array elems; + std::array elems{}; elems[0] = fb->access(t, s, fb->literal(t, 0u)); elems[1] = fb->access(t, s, fb->literal(t, 1u)); elems[2] = fb->access(t, s, fb->literal(t, 2u)); @@ -88,7 +97,8 @@ class Shared { /// Write index with active fields template - void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields = luisa::nullopt) const noexcept { + requires is_integral_expr_v + void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { auto i = def(std::forward(index)); auto fb = detail::FunctionBuilder::current(); if (!is_soa()) { @@ -115,6 +125,41 @@ class Shared { } } } + +public: + /// Reference to the i-th element + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame &operator[](I &&index) noexcept { + if (is_soa()) { detail::error_coro_frame_smem_subscript_on_soa(); } + auto fb = detail::FunctionBuilder::current(); + auto ref = fb->access( + _desc->type(), _expressions[0], + detail::extract_expression(std::forward(index))); + auto temp_frame = luisa::make_unique(_desc, ref); + return *_temp_frames.emplace_back(std::move(temp_frame)); + } + + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index) const noexcept { + return _read(std::forward(index), luisa::nullopt); + } + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index, luisa::span active_fields) const noexcept { + return _read(std::forward(index), luisa::make_optional(active_fields)); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + _write(std::forward(index), frame, luisa::nullopt); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { + _write(std::forward(index), frame, luisa::make_optional(active_fields)); + } }; }// namespace luisa::compute diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h index fbb820eb8..e3d22d4b0 100644 --- a/include/luisa/coro/coro_func.h +++ b/include/luisa/coro/coro_func.h @@ -30,13 +30,9 @@ class LC_CORO_API CoroAwaiter : public concepts::Noncopyable { using Await = luisa::move_only_function; Await _await; -private: - template - friend class Coroutine; +public: explicit CoroAwaiter(Await await) noexcept : _await{std::move(await)} {} - -public: void await() && noexcept { _await(); } }; @@ -165,11 +161,10 @@ class GeneratorIter : public concepts::Noncopyable { using Resume = luisa::move_only_function; Resume _resume; -private: - template - friend class Generator; +public: GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept - : _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {} + : _n{n}, _frame{std::move(frame)}, + _resume{std::move(resume)} {} public: auto &update() noexcept { diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h index 3972b77fc..e0c647980 100644 --- a/include/luisa/coro/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -15,7 +15,7 @@ struct PersistentThreadsCoroSchedulerConfig { uint thread_count = 64_k; uint block_size = 128; uint fetch_size = 16; - bool shared_memory_soa = true; + bool shared_memory_soa = false; bool global_memory_ext = false; }; diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index 86f1c14e6..5743d6d23 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -2,6 +2,7 @@ set(LUISA_COMPUTE_CORO_SOURCES coro_frame.cpp coro_frame_buffer.cpp coro_frame_desc.cpp + coro_frame_smem.cpp coro_func.cpp coro_graph.cpp schedulers/persistent_threads.cpp diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp index 3142f9d85..d23aeca5a 100644 --- a/src/coro/coro_frame.cpp +++ b/src/coro/coro_frame.cpp @@ -9,7 +9,7 @@ namespace luisa::compute::coroutine { -CoroFrame::CoroFrame(luisa::shared_ptr desc, const RefExpr *expr) noexcept +CoroFrame::CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept : _desc{std::move(desc)}, _expression{expr}, coro_id{this->get(0u)}, diff --git a/src/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp index b4b1d52f4..c9275a47f 100644 --- a/src/coro/coro_frame_desc.cpp +++ b/src/coro/coro_frame_desc.cpp @@ -28,6 +28,12 @@ luisa::shared_ptr CoroFrameDesc::create(const Type *type, Designa return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); } +const Type *CoroFrameDesc::field(uint index) const noexcept { + auto fields = _type->members(); + LUISA_ASSERT(index < fields.size(), "CoroFrame field index out of range."); + return fields[index]; +} + uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { if (name == "coro_id") { return 0u; } if (name == "target_token") { return 1u; } diff --git a/src/coro/coro_frame_smem.cpp b/src/coro/coro_frame_smem.cpp new file mode 100644 index 000000000..3a390d0e3 --- /dev/null +++ b/src/coro/coro_frame_smem.cpp @@ -0,0 +1,7 @@ +#include + +namespace luisa::compute::detail { +void error_coro_frame_smem_subscript_on_soa() noexcept { + LUISA_ERROR_WITH_LOCATION("Subscripting on SOA shared frame is not allowed."); +} +}// namespace luisa::compute::detail diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index c8fa05541..fe05e74c0 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -23,7 +23,7 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, }; } -inline void coroutine_generator_next_impl( +void coroutine_generator_next_impl( CoroFrame &frame, uint node_count, const luisa::move_only_function &resume) noexcept { $switch (frame.target_token) { diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 863287c32..9b7ac37c0 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -29,7 +29,11 @@ void persistent_threads_coro_scheduler_main_kernel_impl( for (auto index : dsl::dynamic_range(q_fac)) { auto s = index * config.block_size + thread_x(); all_token[s] = 0u; - // frames.write(s, coro.instantiate(), std::array{0u, 1u}); + if (config.shared_memory_soa) { + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u}); + } else { + frames[s].target_token = 0u; + } } for (auto index : dsl::dynamic_range(g_fac)) { auto s = index * config.block_size + thread_x(); @@ -120,23 +124,37 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto frame_token = all_token[dst]; $if (coro_token != 0u) { $if (frame_token != 0u) { - auto g_state = global_frames->read(global_id); - global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); + if (config.shared_memory_soa) { + auto g_state = global_frames->read(global_id); + global_frames->write(global_id, frames.read(dst)); + frames.write(dst, g_state); + } else { + auto g_state = global_frames->read(global_id); + global_frames->write(global_id, frames[dst]); + frames[dst] = g_state; + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; } $else { - auto g_state = global_frames->read(global_id); - frames.write(dst, g_state); + if (config.shared_memory_soa) { + auto g_state = global_frames->read(global_id); + frames.write(dst, g_state); + } else { + frames[dst] = global_frames->read(global_id); + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; }; } $else { $if (frame_token != 0u) { - auto frame = frames.read(dst); - global_frames->write(global_id, frame); + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } all_token[shared_queue_size + g_queue_id] = frame_token; all_token[dst] = coro_token; }; @@ -183,10 +201,16 @@ void persistent_threads_coro_scheduler_main_kernel_impl( for (auto i = 1u; i < subroutine_count; i++) { $case (i) { work_counter.atomic(i).fetch_sub(1u); - auto frame = frames.read(pid, graph->node(i).input_fields()); - call_subroutine(frame, i); - auto next = frame.target_token & coro_token_valid_mask; - frames.write(pid, frame, graph->node(i).output_fields()); + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = frames.read(pid, graph->node(i).input_fields()); + call_subroutine(frame, i); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->node(i).output_fields()); + } else { + call_subroutine(frames[pid], i); + next = frames[pid].target_token & coro_token_valid_mask; + } all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); }; From b2c36209a91e14a990532343b74b16b2df2f4a05 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 14:40:25 +0800 Subject: [PATCH 64/67] remove --- scripts/print_hlsl_builtin.lua | 2 +- src/backends/common/hlsl/builtin/coroutine | 9 --------- src/backends/common/hlsl/builtin/coroutine.c | 4 ---- src/backends/common/hlsl/hlsl_codegen_util.cpp | 4 ---- 4 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 src/backends/common/hlsl/builtin/coroutine delete mode 100644 src/backends/common/hlsl/builtin/coroutine.c diff --git a/scripts/print_hlsl_builtin.lua b/scripts/print_hlsl_builtin.lua index d6f6b0932..6688ef65e 100644 --- a/scripts/print_hlsl_builtin.lua +++ b/scripts/print_hlsl_builtin.lua @@ -5,7 +5,7 @@ local files_list = {'accel_process', 'bindless_upload', 'bc6_encode_block', 'bc6 'bc6_trymode_le10cs', 'bc7_encode_block', 'bc7_header', 'bc7_trymode_02cs', 'bc7_trymode_137cs', 'bc7_trymode_456cs', 'hlsl_header', 'raytracing_header', 'tex2d_bindless', 'tex3d_bindless', 'compute_quad', 'determinant', 'inverse', 'indirect', 'resource_size', 'accel_header', 'copy_sign', - 'bindless_common', 'auto_diff', "reduce", "coroutine"} + 'bindless_common', 'auto_diff', "reduce"} local lib = import("lib") local hlsl_builtin_path = path.join(os.projectdir(), "src/backends/common/hlsl/builtin") diff --git a/src/backends/common/hlsl/builtin/coroutine b/src/backends/common/hlsl/builtin/coroutine deleted file mode 100644 index e9cc25f62..000000000 --- a/src/backends/common/hlsl/builtin/coroutine +++ /dev/null @@ -1,9 +0,0 @@ -template -uint3 CoroId(in T frame){ -return frame.v0.v; -} - -template -uint CoroToken(in T frame){ -return frame.v1; -} \ No newline at end of file diff --git a/src/backends/common/hlsl/builtin/coroutine.c b/src/backends/common/hlsl/builtin/coroutine.c deleted file mode 100644 index 6f4494609..000000000 --- a/src/backends/common/hlsl/builtin/coroutine.c +++ /dev/null @@ -1,4 +0,0 @@ -#include "hlsl_config.h" -LC_HLSL_EXTERN char coroutine[]={116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,51,32,67,111,114,111,73,100,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,48,46,118,59,10,125,10,10,116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,32,67,111,114,111,84,111,107,101,110,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,49,59,10,125}; -LC_HLSL_EXTERN char *get_coroutine(){return coroutine;} -LC_HLSL_EXTERN int get_coroutine_size(){return 136;} diff --git a/src/backends/common/hlsl/hlsl_codegen_util.cpp b/src/backends/common/hlsl/hlsl_codegen_util.cpp index 95757f6c7..5cba86d17 100644 --- a/src/backends/common/hlsl/hlsl_codegen_util.cpp +++ b/src/backends/common/hlsl/hlsl_codegen_util.cpp @@ -117,10 +117,6 @@ static size_t AddHeader(CallOpSet const &ops, vstd::StringBuilder &builder, bool ops.test(CallOp::MATRIX_COMPONENT_WISE_MULTIPLICATION)) { builder << CodegenUtility::ReadInternalHLSLFile("reduce"); } - if (ops.test(CallOp::CORO_ID) || - ops.test(CallOp::CORO_TOKEN)) { - builder << CodegenUtility::ReadInternalHLSLFile("coroutine"); - } return immutable_size; } }// namespace detail From b0d90b8e078d6ff51fb4e4ed97b55612852bf0e4 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 16 May 2024 16:41:47 +0800 Subject: [PATCH 65/67] trying to optimize dx --- .../coro/schedulers/persistent_threads.h | 21 ++-- src/coro/schedulers/persistent_threads.cpp | 109 ++++++++---------- 2 files changed, 62 insertions(+), 68 deletions(-) diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h index e0c647980..7d777e2b4 100644 --- a/include/luisa/coro/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -14,7 +14,7 @@ namespace luisa::compute::coroutine { struct PersistentThreadsCoroSchedulerConfig { uint thread_count = 64_k; uint block_size = 128; - uint fetch_size = 16; + uint fetch_size = 4; bool shared_memory_soa = false; bool global_memory_ext = false; }; @@ -23,7 +23,7 @@ namespace detail { LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept; }// namespace detail @@ -52,7 +52,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { + Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_size_prefix_product, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); auto global_queue_size = _config.block_size * g_fac; auto shared_queue_size = _config.block_size * q_fac; @@ -60,7 +60,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; detail::persistent_threads_coro_scheduler_main_kernel_impl( _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, - frames, dispatch_shape, global, _global_frames, call_subroutine); + frames, dispatch_size_prefix_product, global, _global_frames, call_subroutine); }; _pt_shader = device.compile(main_kernel); _clear_shader = device.compile<1>([](BufferUInt global) { @@ -78,12 +78,19 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - stream << _clear_shader(_global).dispatch(1u); + auto dispatch_size_prefix_product = make_uint3( + dispatch_size.x, + dispatch_size.x * dispatch_size.y, + dispatch_size.x * dispatch_size.y * dispatch_size.z); if (_config.global_memory_ext) { auto n = static_cast(_global_frames.size()); - stream << _initialize_shader(n).dispatch(n); + stream << _clear_shader(_global).dispatch(1u) + << _initialize_shader(n).dispatch(n) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); + } else { + stream << _clear_shader(_global).dispatch(1u) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); } - stream << _pt_shader(_global, dispatch_size, args...).dispatch(_config.thread_count); } public: diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 9b7ac37c0..81302e9b1 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -13,7 +13,7 @@ namespace luisa::compute::coroutine::detail { void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept { @@ -26,16 +26,17 @@ void persistent_threads_coro_scheduler_main_kernel_impl( shared_queue_size}; Shared workload{2}; Shared work_stat{2};// work_state[0] for max_count, [1] for max_id - for (auto index : dsl::dynamic_range(q_fac)) { + for (auto index = 0u; index < q_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[s] = 0u; if (config.shared_memory_soa) { - frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u}); + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{0u, 1u}); } else { + frames[s].coro_id = make_uint3(); frames[s].target_token = 0u; } } - for (auto index : dsl::dynamic_range(g_fac)) { + for (auto index = 0u; index < g_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[shared_queue_size + s] = 0u; } @@ -58,8 +59,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( rem_local[0] = 0u; sync_block(); auto count = def(0u); - auto count_limit = def(-1); - auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; + auto count_limit = def(std::numeric_limits::max()); $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { sync_block();//very important, synchronize for condition rem_local[0] = 0u; @@ -70,8 +70,8 @@ void persistent_threads_coro_scheduler_main_kernel_impl( $if (thread_x() == config.block_size - 1) { $if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); - workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); - $if (workload[0] >= dispatch_size) { + workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size_prefix_product.z); + $if (workload[0] >= dispatch_size_prefix_product.z) { rem_global[0] = 0u; }; }; @@ -96,7 +96,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_offset[1] = 0; sync_block(); if (!config.global_memory_ext) { - for (auto index : dsl::dynamic_range(q_fac)) {//collect indices + for (auto index = 0u; index < q_fac; index++) {//collect indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token == work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -104,7 +104,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( }; } } else { - for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices + for (auto index = 0u; index < q_fac; index++) {//collect switch out indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token != work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -113,7 +113,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( } sync_block(); $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work - for (auto index : dsl::dynamic_range(g_fac)) { //swap frames + $for (index, 0u, g_fac) { //swap frames auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); auto g_queue_id = index * config.block_size + thread_x(); auto coro_token = all_token[shared_queue_size + g_queue_id]; @@ -123,76 +123,62 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto dst = path_id[id]; auto frame_token = all_token[dst]; $if (coro_token != 0u) { + auto g_state = global_frames->read(global_id); $if (frame_token != 0u) { if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); } else { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames[dst]); - frames[dst] = g_state; } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - } - $else { - if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); - frames.write(dst, g_state); - } else { - frames[dst] = global_frames->read(global_id); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; }; + if (config.shared_memory_soa) { + frames.write(dst, g_state); + } else { + frames[dst] = g_state; + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; } - $else { - $if (frame_token != 0u) { - if (config.shared_memory_soa) { - auto frame = frames.read(dst); - global_frames->write(global_id, frame); - } else { - global_frames->write(global_id, frames[dst]); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; + $elif (frame_token != 0u) { + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; }; }; }; - } + }; }; } auto gen_st = workload[0]; sync_block(); - auto pid = def(0u); - if (config.global_memory_ext) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!config.global_memory_ext) { - launch_condition = (thread_x() < work_offset[0]); - } else { - launch_condition = (all_token[pid] == work_stat[1]); - } - $if (launch_condition) { + auto pid = config.global_memory_ext ? thread_x() : path_id[thread_x()]; + $if (config.global_memory_ext ? (all_token[pid] == work_stat[1]) : (thread_x() < work_offset[0])) { $switch (all_token[pid]) { $case (0u) { $if (gen_st + thread_x() < workload[1]) { work_counter.atomic(0u).fetch_sub(1u); auto global_index = gen_st + thread_x(); - auto image_size = dispatch_shape.x * dispatch_shape.y; - auto index_z = global_index / image_size; - auto index_xy = global_index % image_size; - auto index_x = index_xy % dispatch_shape.x; - auto index_y = index_xy / dispatch_shape.x; - auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); - call_subroutine(frame, coro_token_entry); - auto next = frame.target_token & coro_token_valid_mask; - frames.write(pid, frame, graph->entry().output_fields()); + auto index_z = global_index / dispatch_size_prefix_product.y; + auto index_xy = global_index % dispatch_size_prefix_product.y; + auto index_x = index_xy % dispatch_size_prefix_product.x; + auto index_y = index_xy / dispatch_size_prefix_product.x; + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); + call_subroutine(frame, coro_token_entry); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->entry().output_fields()); + } else { + frames[pid].coro_id = make_uint3(index_x, index_y, index_z); + frames[pid].target_token = coro_token_entry; + call_subroutine(frames[pid], coro_token_entry); + next = frames[pid].target_token & coro_token_valid_mask; + } all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); workload.atomic(0).fetch_add(1u); @@ -215,6 +201,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_counter.atomic(next).fetch_add(1u); }; } + $default { dsl::unreachable(); }; }; }; sync_block(); From 91b47bad9e4896e2ffdf7d8ec7fe43cfe9943d7f Mon Sep 17 00:00:00 2001 From: chenxin Date: Fri, 17 May 2024 16:01:46 +0800 Subject: [PATCH 66/67] fix namespace --- include/luisa/coro/coro_frame.h | 2 +- include/luisa/coro/coro_token.h | 2 +- src/coro/coro_func.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/luisa/coro/coro_frame.h b/include/luisa/coro/coro_frame.h index bac5580f8..d83d5c9d6 100644 --- a/include/luisa/coro/coro_frame.h +++ b/include/luisa/coro/coro_frame.h @@ -61,4 +61,4 @@ class LC_CORO_API CoroFrame { [[nodiscard]] Var is_terminated() const noexcept; }; -}// namespace luisa::compute::coro_v2 +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_token.h b/include/luisa/coro/coro_token.h index 15032d4e1..be5883064 100644 --- a/include/luisa/coro/coro_token.h +++ b/include/luisa/coro/coro_token.h @@ -9,4 +9,4 @@ using CoroToken = unsigned int; constexpr CoroToken coro_token_entry = 0u; constexpr CoroToken coro_token_terminal = 0x8000'0000u; constexpr CoroToken coro_token_valid_mask = 0x7fff'ffffu; -}// namespace luisa::compute::coro_v2 +}// namespace luisa::compute::coroutine diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp index 966719c01..68d5a496d 100644 --- a/src/coro/coro_func.cpp +++ b/src/coro/coro_func.cpp @@ -37,4 +37,4 @@ void coroutine_generator_step_impl(CoroFrame &frame, uint node_count, bool is_en } } -}// namespace luisa::compute::coro_v2::detail +}// namespace luisa::compute::coroutine::detail From 66883df60238fa64c4bf65eb1cb815bf1fd612b9 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sat, 18 May 2024 15:57:30 +0800 Subject: [PATCH 67/67] minor --- include/luisa/coro/coro_graph.h | 2 +- include/luisa/coro/radix_sort.h | 4 ++-- src/coro/coro_frame_desc.cpp | 2 +- src/coro/schedulers/persistent_threads.cpp | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/luisa/coro/coro_graph.h b/include/luisa/coro/coro_graph.h index 99a26f066..0c5a37fa2 100644 --- a/include/luisa/coro/coro_graph.h +++ b/include/luisa/coro/coro_graph.h @@ -69,4 +69,4 @@ class LC_CORO_API CoroGraph { [[nodiscard]] luisa::string dump() const noexcept; }; -}// namespace luisa::compute::co +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/radix_sort.h b/include/luisa/coro/radix_sort.h index 0072123ac..f9ea72057 100644 --- a/include/luisa/coro/radix_sort.h +++ b/include/luisa/coro/radix_sort.h @@ -88,7 +88,7 @@ class instance { ///@param high_bit: highest bit of radix sort instance(Device device, uint maxn, temp_storage &temp, Callable *get_key, Callable *get_val, - Callable *get_key_set = nullptr, + Callable *get_key_set = nullptr, uint mode = 0, uint digit = 128, uint low_bit = 0, uint high_bit = 31) : DIGIT{digit}, low_bit{low_bit}, high_bit{high_bit}, MAXN{maxn}, _temp{temp} { LUISA_ASSERT(mode == 0 || mode == 1, "mode should be 0 for radix sort and 1 for bucket sort!"); BIT = 0; @@ -361,4 +361,4 @@ class instance { } }; } -} \ No newline at end of file +}// namespace luisa::compute::radix_sort \ No newline at end of file diff --git a/src/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp index c9275a47f..97fa9fcc4 100644 --- a/src/coro/coro_frame_desc.cpp +++ b/src/coro/coro_frame_desc.cpp @@ -56,4 +56,4 @@ luisa::string CoroFrameDesc::dump() const noexcept { return s; } -}// namespace luisa::compute::co +}// namespace luisa::compute::coroutine diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 81302e9b1..5c4f42a8d 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -113,7 +113,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( } sync_block(); $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work - $for (index, 0u, g_fac) { //swap frames + $for (index, 0u, g_fac) { //swap frames auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); auto g_queue_id = index * config.block_size + thread_x(); auto coro_token = all_token[shared_queue_size + g_queue_id];