diff --git a/.clang-format b/.clang-format index 125fe989b..714315436 100644 --- a/.clang-format +++ b/.clang-format @@ -109,7 +109,6 @@ ForEachMacros: - LUISA_STRUCT - LUISA_BINDING_GROUP - LUISA_BINDING_GROUP_TEMPLATE - - LUISA_COROFRAME_STRUCT IfMacros: - $if - $elif diff --git a/include/luisa/ast/function_builder.h b/include/luisa/ast/function_builder.h index a1e89e1b8..a0fd25297 100644 --- a/include/luisa/ast/function_builder.h +++ b/include/luisa/ast/function_builder.h @@ -242,8 +242,6 @@ class LC_AST_API FunctionBuilder : public luisa::enable_shared_from_this &coro_tokens() const noexcept { return _coro_tokens; } - void coroframe_replace(const Type *type) noexcept; - // build primitives /// Define a kernel function with given definition template static auto define_kernel(Def &&def) { @@ -407,12 +405,8 @@ class LC_AST_API FunctionBuilder : public luisa::enable_shared_from_this::value; template struct is_custom_struct : std::false_type {}; -template -struct is_coroframe_struct : std::false_type {}; - template constexpr auto is_custom_struct_v = is_custom_struct::value; -template -constexpr auto is_coroframe_struct_v = is_coroframe_struct::value; - namespace detail { template @@ -316,8 +310,6 @@ class LC_AST_API Type { TEXTURE, BINDLESS_ARRAY, ACCEL, - - COROFRAME, CUSTOM }; @@ -379,9 +371,6 @@ class LC_AST_API Type { /// Return custom type with the specified name [[nodiscard]] static const Type *custom(luisa::string_view name) noexcept; - /// Return custom type for coroframe - [[nodiscard]] static const Type *coroframe(luisa::string_view name) noexcept; - /// Construct Type object from description /// @param description Type description in the following syntax: \n /// TYPE := DATA | RESOURCE | CUSTOM \n @@ -427,14 +416,6 @@ class LC_AST_API Type { [[nodiscard]] luisa::span members() const noexcept; [[nodiscard]] luisa::span member_attributes() const noexcept; [[nodiscard]] const Type *element() const noexcept; - [[nodiscard]] const Type *corotype() const noexcept; - ///change the corresponding Type to another Type in registry - void update_from(const Type *type); - /// add a member with name and type - size_t add_member(const luisa::string &name) noexcept; - void set_member_name(size_t index, luisa::string name) noexcept; - /// get member index with string - [[nodiscard]] size_t member(luisa::string_view name) const noexcept; /// Scalar = bool || float || int || uint [[nodiscard]] bool is_scalar() const noexcept; [[nodiscard]] bool is_bool() const noexcept; @@ -470,8 +451,6 @@ class LC_AST_API Type { [[nodiscard]] bool is_accel() const noexcept; [[nodiscard]] bool is_resource() const noexcept; [[nodiscard]] bool is_custom() const noexcept; - [[nodiscard]] bool is_coroframe() const noexcept; - [[nodiscard]] bool is_materialized_coroframe() const noexcept; }; }// namespace luisa::compute diff --git a/include/luisa/ast/type_registry.h b/include/luisa/ast/type_registry.h index d6f820859..b5f7976b8 100644 --- a/include/luisa/ast/type_registry.h +++ b/include/luisa/ast/type_registry.h @@ -256,9 +256,6 @@ const Type *Type::of() noexcept { if constexpr (is_custom_struct_v) { static thread_local auto t = Type::custom(desc); return t; - } else if constexpr (is_coroframe_struct_v) { - static thread_local auto t = Type::coroframe(desc); - return t; } else { static thread_local auto t = Type::from(desc); return t; @@ -275,8 +272,6 @@ template struct is_valid_reflection, std::integer_sequence> { static_assert(alignof(S) >= 4u, "Structs must be aligned to at least 4 bytes."); - static_assert(std::negation_v...>>, - "Structs cannot contain CoroFrame Type"); private: [[nodiscard]] constexpr static auto _check() noexcept { @@ -369,17 +364,3 @@ constexpr auto is_valid_reflection_v = is_valid_reflection::value; return name; \ } \ }; - -#define LUISA_COROFRAME_STRUCT_REFLECT(S, name) \ - template<> \ - struct canonical_layout { \ - using type = std::tuple; \ - }; \ - template<> \ - struct luisa::compute::is_coroframe_struct : std::true_type {}; \ - template<> \ - struct luisa::compute::detail::TypeDesc { \ - static constexpr luisa::string_view description() noexcept { \ - return name; \ - } \ - }; diff --git a/include/luisa/coro/coro_dispatcher.h b/include/luisa/coro/coro_dispatcher.h deleted file mode 100644 index 89789a8bf..000000000 --- a/include/luisa/coro/coro_dispatcher.h +++ /dev/null @@ -1,979 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -namespace luisa::compute { - -class Stream; -inline namespace coro { -template -struct prototype_to_coro_dispatcher { - using type = T; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = BufferView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = ImageView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = VolumeView; -}; - -template -struct prototype_to_coro_dispatcher> { - using type = SOAView; -}; - -template -using prototype_to_coro_dispatcher_t = typename prototype_to_coro_dispatcher::type; - -template -class CoroAwait; -template -class CoroDispatcherBase { - static_assert(always_false_v); -}; -const uint token_mask = 0x7fffffff; -template -class CoroDispatcherBase : public concepts::Noncopyable { - using FuncType = void(FrameRef, Args...); - friend class CoroAwait; -public: - - luisa::queue> _dispatcher; -protected: - std::tuple...> _args; - Coroutine *_coro; - virtual void _await_step(Stream &stream) noexcept = 0; - virtual void _await_all(Stream &stream) noexcept; - uint _dispatch_size; - Device _device; - ///helper function for calling a shader by omitting suffix coroutine args - template - detail::ShaderInvoke call_shader(Shader &shader, prototype_to_coro_dispatcher_t... prefix_args); - -protected: -public: - CoroDispatcherBase(Coroutine *coro_ptr, Device &device) noexcept - : _coro{std::move(coro_ptr)}, _device{device}, _dispatch_size{} {} - virtual ~CoroDispatcherBase() noexcept = default; - [[nodiscard]] virtual bool all_dispatched() const noexcept = 0; - [[nodiscard]] virtual bool all_done() const noexcept = 0; - [[nodiscard]] CoroAwait await_step() noexcept; - [[nodiscard]] CoroAwait await_all() noexcept; - virtual void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept {//initialize the dispatcher - _args = std::make_tuple(std::forward>(args)...); - _dispatch_size = dispatch_size; - } -}; - -template -class SimpleCoroDispatcher : public CoroDispatcherBase { -private: - using FrameType = std::remove_reference_t; - Shader1D _shader; - bool _done; - void _await_step(Stream &stream) noexcept override { - stream << this->template call_shader<1, uint>(_shader, this->_dispatch_size).dispatch(this->_dispatch_size); - _done = true; - stream << synchronize(); - } -public: - SimpleCoroDispatcher(Coroutine *coroutine, Device &device, - uint max_frame_count) : CoroDispatcherBase{coroutine, device}, - _done{false} { - uint max_sub_coro = coroutine->suspend_count() + 1; - Kernel1D run_kernel = [&](UInt n, Var... args) { - set_block_size(128, 1, 1); - auto x = dispatch_x(); - $if (x < n) { - Var frame; - initialize_coroframe(frame, def(x, 0, 0)); - (*coroutine)(frame, args...); - $loop { - auto token = read_promise(frame, "coro_token"); - $if (token == 0x8000'0000) { - $break; - }; - $switch (token) { - for (auto i = 1u; i < max_sub_coro; ++i) { - $case (i) { - (*coroutine)[i](frame, args...); - }; - } - }; - }; - }; - }; - _shader = device.compile(run_kernel); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept { - CoroDispatcherBase::operator()(args..., dispatch_size); - _done = false; - } - bool all_dispatched() const noexcept override { - return _done; - } - bool all_done() const noexcept override { - return _done; - } -}; - -struct WavefrontCoroDispatcherConfig { - uint max_instance_count = 2_M; - bool soa = true; - bool sort = true;//use sort for coro token gathering - bool compact = true; - bool debug = false; - uint hint_range=0xffff'ffff; - luisa::vector hint_fields; -}; - -template -class WavefrontCoroDispatcher : public CoroDispatcherBase { - -public: - using Config = WavefrontCoroDispatcherConfig; - -private: - using FrameType = std::remove_reference_t; - using Container = compute::SOA; - bool is_soa = true; - bool sort_base_gather = false; - bool use_compact = true; - Shader1D, Buffer, uint, Container, uint, uint, Args...> _gen_shader; - luisa::vector, Buffer, Container, uint, Args...>> _resume_shaders; - Shader1D, Buffer, uint> _count_prefix_shader; - Shader1D, Buffer, Container, uint> _gather_shader; - Shader1D, Container, uint> _initialize_shader; - Shader1D, Container, uint, uint> _compact_shader; - Shader1D, uint> _clear_shader; - Container _frame; - compute::Buffer _resume_index; - compute::Buffer _resume_count; - ///offset calculate from count, will be end after gathering - compute::Buffer _resume_offset; - compute::Buffer _global_buffer; - compute::Buffer _debug_buffer; - luisa::vector _host_count; - luisa::vector _host_offset; - bool _host_empty; - uint _dispatch_counter; - uint _max_sub_coro; - uint _max_frame_count; - bool _debug; - void _await_step(Stream &stream) noexcept override; - radix_sort::temp_storage _sort_temp_storage; - radix_sort::instance<> _sort_token; - radix_sort::instance> _sort_hint; - luisa::vector _have_hint; - compute::Buffer _temp_key[2]; - compute::Buffer _temp_index; - Stream &_stream; -public: - bool all_dispatched() const noexcept; - bool all_done() const noexcept; - - WavefrontCoroDispatcher(Coroutine *coroutine, - Device &device, Stream &stream, - const WavefrontCoroDispatcherConfig &config) noexcept - : CoroDispatcherBase{coroutine, device}, - is_soa{config.soa}, - sort_base_gather{config.sort}, - use_compact{config.compact}, - _max_frame_count{config.max_instance_count}, - _stream{stream}, _debug{config.debug}, - _frame{device.create_soa(config.max_instance_count)} { - /*if (device.backend_name() != "cuda") {//only cuda can sort - sort_base_gather = false; - hint_token = {}; - LUISA_INFO("Using wavefront dispatcher without cuda, the sorting will be disabled!"); - }*/ - bool use_sort = sort_base_gather||!config.hint_fields.empty(); - uint max_sub_coro = coroutine->suspend_count() + 1; - _max_sub_coro = max_sub_coro; - _resume_index = device.create_buffer(_max_frame_count); - if (use_sort) { - _temp_index = device.create_buffer(_max_frame_count); - _temp_key[0] = device.create_buffer(_max_frame_count); - _temp_key[1] = device.create_buffer(_max_frame_count); - } - _resume_count = device.create_buffer(max_sub_coro); - _resume_offset = device.create_buffer(max_sub_coro); - _global_buffer = device.create_buffer(1); - if (_debug) { - _debug_buffer = device.create_buffer(max_sub_coro); - } - _host_empty = true; - _dispatch_counter = 0; - _host_offset.resize(max_sub_coro); - _host_count.resize(max_sub_coro); - _have_hint.resize(max_sub_coro, false); - for (auto &token : config.hint_fields) { - auto id = coroutine->coro_tokens().find(token); - if (id != coroutine->coro_tokens().end()) { - LUISA_ASSERT(id->secondsecond,max_sub_coro); - _have_hint[id->second] = true; - } - else - LUISA_WARNING("coroutine token {} not found, hint disabled", token); - } - for (auto i = 0u; i < max_sub_coro; i++) { - if (i) { - _host_count[i] = 0; - _host_offset[i] = _max_frame_count; - } else { - _host_count[i] = _max_frame_count; - _host_offset[i] = 0; - } - } - Callable get_coro_token = [&](UInt index) { - $if (index > _max_frame_count) { - device_log("index {} out of range {}", index, _max_frame_count); - }; - auto frame = _frame->read(index); - return read_promise(frame, "coro_token") & token_mask; - }; - Callable identical = [&](UInt index) { - return index; - }; - - Callable keep_index = [&](UInt index,BufferUInt val) { - return val.read(index); - }; - Callable get_coro_hint = [&](UInt index,BufferUInt val) { - if (!config.hint_fields.empty()) { - auto id = keep_index(index,val); - auto frame = _frame->read(id); - auto x=read_promise(frame, "coro_hint"); - return x; - } else { - return def(0u); - } - }; - if (use_sort) { - _sort_temp_storage = radix_sort::temp_storage(device, _max_frame_count, std::max(std::min(config.hint_range,128u), max_sub_coro)); - } - if (sort_base_gather) { - _sort_token = radix_sort::instance<>(device, _max_frame_count, _sort_temp_storage, - &get_coro_token, &identical, &get_coro_token, 1, max_sub_coro); - } - if (!config.hint_fields.empty()) { - if(config.hint_range<=128){ - _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, - &get_coro_hint, &keep_index, &get_coro_hint, 1, config.hint_range); - } - else{ - auto highbit=0; - while((config.hint_range>>highbit)!=1){ - highbit++; - } - _sort_hint = radix_sort::instance>(device, _max_frame_count, _sort_temp_storage, - &get_coro_hint, &keep_index, &get_coro_hint, 0, 128, 0, highbit); - } - } - - Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt offset, Var frame_buffer, UInt st_task_id, UInt n, Var... args) { - auto x = dispatch_x(); - $if (x >= n) { - $return(); - }; - UInt frame_id; - if (!use_compact) { - frame_id = index->read(x); - } else { - frame_id = offset + x; - } - Var frame; - if (_debug) { - frame = frame_buffer.read(frame_id); - $if ((read_promise(frame, "coro_token") & token_mask) != 0u) { - device_log("wrong gen for frame {} at kernel {}when dispatch {}", frame_id, read_promise(frame, "coro_token"), dispatch_x()); - }; - } - initialize_coroframe(frame, def(st_task_id + x, 0, 0)); - if (!sort_base_gather) { - count.atomic(0).fetch_add(-1u); - } - - (*coroutine)(frame, args...); - if (is_soa) { - frame_buffer.write(frame_id, frame, coroutine->graph().node(0u)->output_state_members); - } else { - frame_buffer.write(frame_id, frame); - } - if (!sort_base_gather) { - auto nxt = read_promise(frame, "coro_token") & token_mask; - count.atomic(nxt).fetch_add(1u); - } - }; - ShaderOption o{}; - _gen_shader = device.compile(gen_kernel, o); - _gen_shader.set_name("gen"); - _resume_shaders.resize(max_sub_coro); - for (int i = 1; i < max_sub_coro; ++i) { - Kernel1D resume_kernel = [&](BufferUInt index, BufferUInt count, Var frame_buffer, UInt n, Var... args) { - auto x = dispatch_x(); - $if (x >= n) { - $return(); - }; - auto frame_id = index.read(x); - Var frame; - if (is_soa) { - //frame = frame_buffer.read(frame_id); - frame = frame_buffer.read(frame_id, coroutine->graph().node(i)->input_state_members); - } else { - frame = frame_buffer.read(frame_id); - } - if (!sort_base_gather) { - count.atomic(i).fetch_add(-1u); - } - if (_debug) { - auto token = frame_buffer.read(frame_id); - auto check = read_promise(token, "coro_token") & token_mask; - /*$if (check != i) { - device_log("wrong launch for frame {} at kernel {} as kernel {} when dispatch {}", frame_id, check, i, dispatch_x()); - };*/ - } - /*if(_have_hint[i]) { - auto token = frame_buffer.read(frame_id); - device_log("dispatch:{}, index:{}, hint:{}",dispatch_x(),frame_id, read_promise(token,"coro_hint")&token_mask); - }*/ - (*coroutine)[i](frame, args...); - if (is_soa) { - frame_buffer.write(frame_id, frame, coroutine->graph().node(i)->output_state_members); - } else { - frame_buffer.write(frame_id, frame); - } - //if (debug) - // device_log("resume kernel {} : id {} goto kernel {}", i, frame_id, nxt); - - if (!sort_base_gather) { - auto nxt = read_promise(frame, "coro_token") & token_mask; - $switch (nxt) { - for (int i = 0; i < max_sub_coro; ++i) { - $case (i) { - count.atomic(i).fetch_add(1u); - }; - } - }; - //count.atomic(nxt).fetch_add(1u); - } - }; - if (_debug) o.name = "resume" + std::to_string(i); - _resume_shaders[i] = device.compile(resume_kernel, o); - _resume_shaders[i].set_name("resume" + std::to_string(i)); - } - Kernel1D _prefix_kernel = [&](BufferUInt count, BufferUInt prefix, UInt n) { - $if (dispatch_x() == 0) { - auto pre = def(0u); - $for (i, 0u, _max_sub_coro) { - auto val = count.read(i); - prefix.write(i, pre); - pre = pre + val; - if (_debug) { - _debug_buffer->write(i, pre); - } - }; - }; - }; - _count_prefix_shader = device.compile(_prefix_kernel); - Kernel1D _collect_kernel = [&](BufferUInt index, BufferUInt prefix, Var frame_buffer, UInt n) { - auto x = dispatch_x(); - auto frame = frame_buffer.read(x); - auto r_id = read_promise(frame, "coro_token") & token_mask; - auto q_id = prefix.atomic(r_id).fetch_add(1u); - if (_debug) { - $if (q_id >= _debug_buffer->read(r_id)) { - device_log("collect: indices overflow!!!! frame_id:{}>buffer_offset:{}", q_id, _debug_buffer->read(r_id)); - }; - /* - $if (q_id == _debug_buffer->read(r_id) - 1) { - device_log("finish gather: kernel {}, tot :{}", r_id, _debug_buffer->read(r_id)); - }; - */ - } - index.write(q_id, x); - }; - _gather_shader = device.compile(_collect_kernel); - Kernel1D _compact_kernel_2 = [&](BufferUInt index, Var frame_buffer, UInt empty_offset, UInt n) { - //_global_buffer->write(0u, 0u); - auto x = dispatch_x(); - $if (empty_offset + x < n) { - auto token = frame_buffer.read_coro_token(empty_offset + x); - $if ((token & token_mask) != 0u) { - - auto res = _global_buffer->atomic(0).fetch_add(1u); - auto slot = index.read(res); - if (!sort_base_gather) { - $while (slot >= empty_offset) { - res = _global_buffer->atomic(0).fetch_add(1u); - slot = index.read(res); - }; - } - /*if (_debug) { - $if (slot >= empty_offset) { - device_log("compact: new slot is in empty set!!!! slot:{}>empty_offset:{}", slot, empty_offset); - }; - }*/ - if (_debug) { - auto empty = frame_buffer.read(slot); - $if ((read_promise(empty, "coro_token") & token_mask) != 0u) { - device_log("wrong compact for frame {} at kernel {} when dispatch {}", slot, (read_promise(empty, "coro_token") & token_mask), dispatch_x()); - }; - } - auto frame = frame_buffer.read(empty_offset + x); - frame_buffer.write(slot, frame); - Var empty_frame; - initialize_coroframe(empty_frame, def(0, 0, 0)); - if (is_soa) { - frame_buffer.write(empty_offset + x, empty_frame, {1}); - } else { - frame_buffer.write(empty_offset + x, empty_frame); - } - }; - }; - }; - _compact_shader = device.compile(_compact_kernel_2); - _compact_shader.set_name("compact"); - - Kernel1D _initialize_kernel = [&](BufferUInt count, Var frame_buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - auto frame = frame_buffer.read(x); - initialize_coroframe(frame, def(0, 0, 0)); - }; - $if (x < max_sub_coro) { - count.write(x, ite(x == 0, _max_frame_count, 0u)); - }; - }; - Kernel1D clear = [&](BufferUInt buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - buffer.write(x, 0u); - }; - }; - _clear_shader = device.compile(clear); - _initialize_shader = device.compile(_initialize_kernel); - stream << _initialize_shader(_resume_count, _frame, _max_frame_count).dispatch(_max_frame_count); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept override { - CoroDispatcherBase::operator()(args..., dispatch_size); - _dispatch_counter = 0; - _host_empty = true; - for (auto i = 0u; i < _max_sub_coro; i++) { - if (i) { - _host_count[i] = 0; - _host_offset[i] = _max_frame_count; - } else { - _host_count[i] = _max_frame_count; - _host_offset[i] = 0; - } - } - _stream << _initialize_shader(_resume_count, _frame, _max_frame_count).dispatch(_max_frame_count); - } -}; - -struct PersistentCoroDispatcherConfig { - uint max_thread_count = 64_k; - uint block_size = 128; - uint fetch_size = 16; - bool global = false; - bool debug = false; -}; - -template -class PersistentCoroDispatcher : public CoroDispatcherBase { - -public: - using Config = PersistentCoroDispatcherConfig; - -private: - using FrameType = std::remove_reference_t; - Shader1D, uint, Args...> _pt_shader; - Shader1D> _clear_shader; - Buffer _global; - Buffer _global_frame; - Shader1D, uint> _initialize_shader; - uint _max_sub_coro; - uint _max_thread_count; - uint _block_size; - uint _global_size; - bool _dispatched; - bool _done; - void _await_step(Stream &stream) noexcept; - void _await_all(Stream &stream) noexcept; - bool _debug; - Stream &_stream; -public: - bool all_dispatched() const noexcept; - bool all_done() const noexcept; - PersistentCoroDispatcher(Coroutine *coroutine, - Device &device, Stream &stream, - const PersistentCoroDispatcherConfig &config) noexcept - : CoroDispatcherBase{coroutine, device}, - _max_thread_count{(config.max_thread_count+config.block_size-1)/config.block_size*config.block_size}, - _block_size{config.block_size}, - _debug{config.debug}, _stream{stream} { - auto use_global=config.global; - _global = device.create_buffer(1); - auto q_fac = 1u; - uint max_sub_coro = coroutine->suspend_count() + 1; - auto g_fac=(uint)std::max((int)(max_sub_coro-q_fac),0); - auto global_queue_size= config.block_size * g_fac; - _global_size=0; - if(use_global) { - _global_frame = device.create_buffer(_max_thread_count*g_fac); - _global_size = _max_thread_count*g_fac; - } - _max_sub_coro = max_sub_coro; - _dispatched = false; - _done = false; - Kernel1D main_kernel = [&](BufferUInt global, UInt dispatch_size, Var... args) { - set_block_size(config.block_size, 1, 1); - auto shared_queue_size = config.block_size * q_fac; - Shared frames{shared_queue_size}; - Shared path_id{shared_queue_size}; - Shared work_counter{max_sub_coro}; - Shared work_offset{2u}; - Shared all_token{use_global?(shared_queue_size+global_queue_size):shared_queue_size}; - Shared workload{2}; - Shared work_stat{2};//0 max_count,1 max_id - //Shared tag_counter{use_tag_sort ? pipeline().surfaces().size() : 0}; - //Shared tag_offset{pipeline().surfaces().size()}; - $for (index, 0u, q_fac) { - all_token[index * config.block_size + thread_x()] = 0u; - initialize_coroframe(frames[index * config.block_size + thread_x()], def(0, 0, 0)); - }; - $for (index, 0u, g_fac) { - all_token[shared_queue_size+index * config.block_size + thread_x()] = 0u; - }; - $if (thread_x() < max_sub_coro) { - $if (thread_x() == 0) { - work_counter[thread_x()] = use_global?(shared_queue_size+global_queue_size):shared_queue_size; - } - $else { - work_counter[thread_x()] = 0u; - }; - }; - workload[0] = 0; - workload[1] = 0; - Shared rem_global{1}; - Shared rem_local{1}; - rem_global[0] = 1u; - rem_local[0] = 0u; - sync_block(); - auto count = def(0u); - auto count_limit = def(-1); - $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { - sync_block();//very important, synchronize for condition - - rem_local[0] = 0u; - count += 1; - work_stat[0] = 0; - work_stat[1] = -1; - /* $if(thread_x() < (uint)KERNEL_COUNT) {//clear counter - work_counter[thread_x()] = 0u; - }; - sync_block(); - $for(index, 0u, q_factor) { - for (auto i = 0u; i < KERNEL_COUNT; ++i) {//count the kernels - auto state = path_state[index*block_size+thread_x()]; - $if(state.kernel_index == i) { - if (i != (uint)INVALID) { - rem_local[0] = 1u; - } else { - $if(workload[0] < workload[1]) { - rem_local[0] = 1u; - }; - } - work_counter.atomic(i).fetch_add(1u); - }; - } - };*/ - sync_block(); - $if (thread_x() == config.block_size - 1) { - $if ((workload[0] >= workload[1]) & (rem_global[0] == 1u)) {//fetch new workload - workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); - if (_debug) - device_log("block {}, fetch workload: {}", block_x(), workload[0]); - workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); - $if (workload[0] >= dispatch_size) { - rem_global[0] = 0u; - }; - }; - }; - sync_block(); - $if (thread_x() < max_sub_coro) {//get max - $if ((workload[0] < workload[1]) | (thread_x() != 0u)) { - $if (work_counter[thread_x()] != 0) { - rem_local[0] = 1u; - work_stat.atomic(0).fetch_max(work_counter[thread_x()]); - }; - }; - }; - sync_block(); - $if (thread_x() < max_sub_coro) {//get argmax - $if ((work_stat[0] == work_counter[thread_x()]) & ((workload[0] < workload[1]) | (thread_x() != 0u))) { - work_stat[1] = thread_x(); - }; - }; - sync_block(); - work_offset[0] = 0; - work_offset[1] = 0; - sync_block(); - if(!use_global) { - $for (index, 0u, q_fac) {//collect indices - auto frame_token = all_token[index * config.block_size + thread_x()]; - $if (frame_token == work_stat[1]) { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * config.block_size + thread_x(); - }; - }; - }else{ - $for(index, 0u, q_fac) {//collect switch out indices - auto frame_token = all_token[index * config.block_size + thread_x()]; - $if (frame_token != work_stat[1]) { - auto id = work_offset.atomic(0).fetch_add(1u); - path_id[id] = index * config.block_size + thread_x(); - }; - }; - sync_block(); - $if(shared_queue_size - work_offset[0] < config.block_size) {//no enough work - $for(index, 0u, g_fac) { //swap frames - auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); - auto g_queue_id=index * config.block_size + thread_x(); - auto coro_token=all_token[shared_queue_size+g_queue_id]; - $if(coro_token == work_stat[1]) { - auto id = work_offset.atomic(1).fetch_add(1u); - $if(id < work_offset[0]) { - auto dst = path_id[id]; - auto frame_token = all_token[dst]; - $if(coro_token != 0u) { - $if(frame_token != 0u) { - auto g_state = _global_frame->read(global_id); - _global_frame->write(global_id,frames[dst]); - frames[dst]=g_state; - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; - - } - $else { - auto g_state = _global_frame->read(global_id); - frames[dst]=g_state; - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; - }; - } - $else { - $if(frame_token != 0u) { - _global_frame->write(global_id,frames[dst]); - all_token[shared_queue_size+g_queue_id]=frame_token; - all_token[dst]=coro_token; - }; - }; - }; - }; - }; - }; - } - auto gen_st = workload[0]; - sync_block(); - auto pid = def(0u); - if (use_global) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!use_global) { - launch_condition = (thread_x() < work_offset[0]); - } - else{ - launch_condition = (all_token[pid]==work_stat[1]); - } - $if (launch_condition) { - $switch (all_token[pid]) { - $case (0u) { - $if (gen_st + thread_x() < workload[1]) { - work_counter.atomic(0u).fetch_sub(1u); - auto work_id = gen_st + thread_x(); - initialize_coroframe(frames[pid], def(gen_st + thread_x(), 0, 0)); - (*coroutine)(frames[pid], args...);//only work when kernel 0s are continue - auto nxt = read_promise(frames[pid], "coro_token") & token_mask; - all_token[pid]=nxt; - work_counter.atomic(nxt).fetch_add(1u); - workload.atomic(0).fetch_add(1u); - }; - }; - for (auto i = 1u; i < max_sub_coro; ++i) { - $case (i) { - work_counter.atomic(i).fetch_sub(1u); - (*coroutine)[i](frames[pid], args...); - auto nxt = read_promise(frames[pid], "coro_token") & token_mask; - all_token[pid]=nxt; - work_counter.atomic(nxt).fetch_add(1u); - }; - } - }; - }; - sync_block(); - }; - $if (count >= count_limit) { - device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); - $if (thread_x() < max_sub_coro) { - device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); - }; - }; - }; - _pt_shader = device.compile(main_kernel); - Kernel1D clear = [&](BufferUInt global) { - global->write(dispatch_x(), 0u); - }; - Kernel1D initialize_frame = [&](Var> frame_buffer, UInt n) { - auto x = dispatch_x(); - $if (x < n) { - auto frame = def(); - initialize_coroframe(frame, def(0, 0, 0)); - frame_buffer.write(x,frame); - }; - }; - _clear_shader = device.compile(clear); - _initialize_shader = device.compile(initialize_frame); - stream << _clear_shader(_global).dispatch(1u); - } - void operator()(prototype_to_coro_dispatcher_t... args, uint dispatch_size) noexcept override { - CoroDispatcherBase::operator()(args..., dispatch_size); - _dispatched = false; - _done = false; - _stream << _clear_shader(_global).dispatch(1u); - if(_global_size) { - - _stream << _initialize_shader(_global_frame, _global_size).dispatch(_global_size); - } - } -}; - -// a simple wrap that helps submit a coroutine dispatch to the stream -enum struct CmdTag { - AWAIT_STEP, - AWAIT_ALL -}; -template -class CoroAwait { - - friend class CoroDispatcherBase; - -public: - - -private: - CmdTag _tag; - CoroDispatcherBase *_dispatcher; - -private: - CoroAwait(CmdTag tag, - CoroDispatcherBase *dispatcher) noexcept - : _tag{tag}, _dispatcher{dispatcher} {} -public: - void operator()(Stream &stream) && noexcept; -}; - -}// namespace coro - -template -struct luisa::compute::detail::is_stream_event_impl> : std::true_type {}; -}// namespace luisa::compute - -namespace luisa::compute::inline coro { -template -void CoroDispatcherBase::_await_all(Stream &stream) noexcept { - while (!this->all_done()) { - this->_await_step(stream); - } -} -template -CoroAwait CoroDispatcherBase::await_step() noexcept { - return CoroAwait{CmdTag::AWAIT_STEP, this}; -} -template -CoroAwait CoroDispatcherBase::await_all() noexcept { - return CoroAwait{CmdTag::AWAIT_ALL, this}; -} -template -template -detail::ShaderInvoke CoroDispatcherBase::call_shader(Shader &shader, - prototype_to_coro_dispatcher_t... prefix_args) { - auto invoke = shader.partial_invoke(prefix_args...); - std::apply([&invoke](prototype_to_coro_dispatcher_t... args) { - static_cast((invoke << ... << args)); - }, - _args); - return invoke; -} -template -void CoroAwait::operator()(Stream &stream) && noexcept { - switch (_tag) { - case CmdTag::AWAIT_STEP: this->_dispatcher->_await_step(stream); break; - case CmdTag::AWAIT_ALL: this->_dispatcher->_await_all(stream); break; - } -} -/*template -void WavefrontCoroDispatcher::_await_step(Stream &stream) noexcept { - -}*/ -template -void WavefrontCoroDispatcher::_await_step(Stream &stream) noexcept { - if (sort_base_gather) { - auto host_update = [&] { - _host_empty = true; - auto sum = 0u; - for (uint i = 0; i < _max_sub_coro; i++) { - _host_count[i] = (i + 1 == _max_sub_coro ? _max_frame_count : _host_offset[i + 1]) - _host_offset[i]; - _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); - } - }; - _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], _resume_index, _max_frame_count); - - stream << _sort_temp_storage.hist_buffer.view(0, _max_sub_coro).copy_to(_host_offset.data()) - << host_update - << synchronize(); - if (_debug) - for (int i = 0; i < _max_sub_coro; ++i) { - LUISA_INFO("kernel {}: total {}", i, _host_count[i]); - } - if (_host_count[0] > _max_frame_count * (0.5) && !all_dispatched()) { - auto gen_count = std::min(this->_dispatch_size - this->_dispatch_counter, _host_count[0]); - if (_debug) { - LUISA_INFO("Gen {} new frame", gen_count); - } - if (_host_count[0] != _max_frame_count && use_compact) { - stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream << _compact_shader(_resume_index, _frame, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); - } - stream << this->template call_shader<1, Buffer, Buffer, uint, Container, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, _max_frame_count - _host_count[0], _frame, _dispatch_counter, _max_frame_count) - .dispatch(gen_count); - _dispatch_counter += gen_count; - _host_empty = false; - } else { - for (uint i = 1; i < _max_sub_coro; i++) { - if (_debug) - LUISA_INFO("launch kernel {} for count {}", i, _host_count[i]); - if (_host_count[i] > 0) { - if (_have_hint[i]) { - BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; - BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i],_resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } else { - - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } - } - } - } - stream << synchronize(); - } else { - if (_debug) - for (int i = 0; i < _max_sub_coro; ++i) { - LUISA_INFO("kernel {}: total {}", i, _host_count[i]); - } - stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); - stream << _gather_shader(_resume_index, _resume_offset, _frame, _max_frame_count).dispatch(_max_frame_count); - if (_host_count[0] > _max_frame_count / 2 && !all_dispatched()) { - auto gen_count = std::min(this->_dispatch_size - this->_dispatch_counter, _host_count[0]); - if (_debug) { - LUISA_INFO("Gen {} new frame", gen_count); - } - if (_host_count[0] != _max_frame_count && use_compact) { - stream << _clear_shader(_global_buffer, 1).dispatch(1u); - stream - << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), _frame, _max_frame_count - _host_count[0], _max_frame_count).dispatch(_host_count[0]); - } - stream << this->template call_shader<1, Buffer, Buffer, uint, Container, uint, uint>( - _gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), _resume_count, _max_frame_count - _host_count[0], _frame, _dispatch_counter, _max_frame_count) - .dispatch(gen_count); - _dispatch_counter += gen_count; - _host_empty = false; - } else { - for (uint i = 1; i < _max_sub_coro; i++) { - if (_debug) - LUISA_INFO("launch kernel {} for count {}", i, _host_count[i]); - if (_host_count[i] > 0) { - if (_have_hint[i]) { - BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; - BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; - uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i],_resume_index.view(_host_offset[i], _host_count[i])); - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _index[out], - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } else { - stream << this->template call_shader<1, Buffer, Buffer, Container, uint>(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), - _resume_count, _frame, _max_frame_count) - .dispatch(_host_count[i]); - } - } - } - } - auto host_update = [&] { - _host_empty = true; - auto sum = 0u; - for (uint i = 0; i < _max_sub_coro; i++) { - _host_offset[i] = sum; - sum += _host_count[i]; - _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); - } - }; - stream << _resume_count.view(0, _max_sub_coro).copy_to(_host_count.data()) - << host_update; - stream << synchronize(); - } -} -template -bool WavefrontCoroDispatcher::all_dispatched() const noexcept { - return this->_dispatch_size == this->_dispatch_counter; -}; -template -bool WavefrontCoroDispatcher::all_done() const noexcept { - return all_dispatched() && _host_empty; -}; - -template -void PersistentCoroDispatcher::_await_step(Stream &stream) noexcept { - LUISA_ERROR("PersistentCoroDispatcher can only be used with await_all!"); -} -template -void PersistentCoroDispatcher::_await_all(Stream &stream) noexcept { - _dispatched = true; - stream << this->template call_shader<1, Buffer, uint>(_pt_shader, _global, this->_dispatch_size).dispatch(_max_thread_count) - << [&] { _done = true; }; - stream << synchronize(); -} - -template -bool PersistentCoroDispatcher::all_dispatched() const noexcept { - return _dispatched; -} -template -bool PersistentCoroDispatcher::all_done() const noexcept { - return _done; -} -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/coro_frame.h b/include/luisa/coro/coro_frame.h new file mode 100644 index 000000000..c1439fbde --- /dev/null +++ b/include/luisa/coro/coro_frame.h @@ -0,0 +1,64 @@ +// +// Created by mike on 5/7/24. +// + +#pragma once + +#include +#include +#include + +namespace luisa::compute::coroutine { + +class LC_CORO_API CoroFrame { + +private: + luisa::shared_ptr _desc; + const Expression *_expression; + +public: + Var &coro_id; + Var &target_token; + +public: + CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept; + CoroFrame(CoroFrame &&another) noexcept; + CoroFrame(const CoroFrame &another) noexcept; + CoroFrame &operator=(const CoroFrame &rhs) noexcept; + CoroFrame &operator=(CoroFrame &&rhs) noexcept; + +public: + [[nodiscard]] static CoroFrame create(luisa::shared_ptr desc) noexcept; + [[nodiscard]] static CoroFrame create(luisa::shared_ptr desc, Expr coro_id) noexcept; + +private: + void _check_member_index(uint index) const noexcept; + +public: + [[nodiscard]] auto description() const noexcept { return _desc.get(); } + [[nodiscard]] auto expression() const noexcept { return _expression; } + +public: + template + [[nodiscard]] Var &get(uint index) noexcept { + _check_member_index(index); + auto fb = luisa::compute::detail::FunctionBuilder::current(); + auto member = fb->member(_desc->type()->members()[index], _expression, index); + return *fb->create_temporary>(member); + } + template + [[nodiscard]] const Var &get(uint index) const noexcept { + return const_cast(this)->get(index); + } + template + [[nodiscard]] Var &get(luisa::string_view name) noexcept { + return get(_desc->designated_field(name)); + } + template + [[nodiscard]] const Var &get(luisa::string_view name) const noexcept { + return const_cast(this)->get(name); + } + [[nodiscard]] Var is_terminated() const noexcept; +}; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_frame_buffer.h b/include/luisa/coro/coro_frame_buffer.h new file mode 100644 index 000000000..d2d15c9ac --- /dev/null +++ b/include/luisa/coro/coro_frame_buffer.h @@ -0,0 +1,307 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include + +namespace luisa::compute { + +namespace detail { + +[[noreturn]] LC_CORO_API void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept; + +template +class BufferExprProxy; + +}// namespace detail + +template<> +class BufferView { + + friend class lc::validation::Stream; + +private: + luisa::shared_ptr _desc; + void *_native_handle; + uint64_t _handle; + size_t _offset_bytes; + size_t _size; + size_t _total_size; + +private: + friend class Buffer; + friend class SparseBuffer; + + template + friend class BufferView; + +public: + BufferView(luisa::shared_ptr desc, + void *native_handle, uint64_t handle, + size_t offset_bytes, size_t size, size_t total_size) noexcept + : _desc{std::move(desc)}, _native_handle{native_handle}, _handle{handle}, + _offset_bytes{offset_bytes}, _size{size}, _total_size{total_size} { + if (auto a = _desc->type()->alignment(); _offset_bytes % a != 0u) [[unlikely]] { + detail::error_buffer_invalid_alignment(_offset_bytes, a); + } + } + + template typename B> + requires(is_buffer_v>) + BufferView(const B &buffer) noexcept : BufferView{buffer.view()} {} + + BufferView() noexcept : BufferView{nullptr, nullptr, invalid_resource_handle, 0, 0, 0} {} + [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } + + // properties + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _handle; } + [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } + [[nodiscard]] auto stride() const noexcept { return _desc->type()->size(); } + [[nodiscard]] auto size() const noexcept { return _size; } + [[nodiscard]] auto offset() const noexcept { return _offset_bytes / stride(); } + [[nodiscard]] auto offset_bytes() const noexcept { return _offset_bytes; } + [[nodiscard]] auto size_bytes() const noexcept { return _size * stride(); } + + [[nodiscard]] auto original() const noexcept { + return BufferView{_desc, _native_handle, _handle, 0u, _total_size, _total_size}; + } + [[nodiscard]] auto subview(size_t offset_elements, size_t size_elements) const noexcept { + if (size_elements + offset_elements > _size) [[unlikely]] { + detail::error_buffer_subview_overflow(offset_elements, size_elements, _size); + } + return BufferView{_desc, _native_handle, _handle, + _offset_bytes + offset_elements * stride(), + size_elements, _total_size}; + } + // reinterpret cast buffer to another type U + template + requires(!is_custom_struct_v) + [[nodiscard]] auto as() const noexcept { + if (this->size_bytes() < sizeof(U)) [[unlikely]] { + detail::error_buffer_reinterpret_size_too_small(sizeof(U), this->size_bytes()); + } + auto total_size_bytes = _total_size * stride(); + return BufferView{_native_handle, _handle, sizeof(U), _offset_bytes, + this->size_bytes() / sizeof(U), total_size_bytes / sizeof(U)}; + } + // commands + // copy buffer's data to pointer + [[nodiscard]] auto copy_to(void *data) const noexcept { + return luisa::make_unique(_handle, offset_bytes(), size_bytes(), data); + } + // copy pointer's data to buffer + [[nodiscard]] auto copy_from(const void *data) noexcept { + return luisa::make_unique(this->handle(), this->offset_bytes(), this->size_bytes(), data); + } + // copy source buffer's data to buffer + [[nodiscard]] auto copy_from(BufferView source) noexcept { + if (source.size() != this->size()) [[unlikely]] { + detail::error_buffer_copy_sizes_mismatch(source.size(), this->size()); + } + return luisa::make_unique( + source.handle(), this->handle(), + source.offset_bytes(), this->offset_bytes(), + this->size_bytes()); + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +template<> +class Buffer final : public Resource { + +private: + luisa::shared_ptr _desc; + size_t _size{}; + +private: + friend class Device; + friend class ResourceGenerator; + friend class DxCudaInterop; + friend class PinnedMemoryExt; + + Buffer(DeviceInterface *device, + const BufferCreationInfo &info, + luisa::shared_ptr desc) noexcept + : Resource{device, Tag::BUFFER, info}, + _desc{std::move(desc)}, + _size{info.total_size_bytes / info.element_stride} { + if (info.element_stride != _desc->type()->size()) [[unlikely]] { + detail::error_coro_frame_buffer_invalid_element_size( + info.element_stride, _desc->type()->size()); + } + } + + Buffer(DeviceInterface *device, + const luisa::shared_ptr &desc, + size_t size) noexcept + : Buffer{device, + [&] { + if (size == 0) [[unlikely]] { + detail::error_buffer_size_is_zero(); + } + return device->create_buffer(desc->type(), size, nullptr); + }(), + desc} {} + + Buffer(DeviceInterface *device, + const coroutine::CoroFrameDesc *desc, + size_t size) noexcept + : Buffer{device, desc->shared_from_this(), size} {} + +public: + Buffer() noexcept = default; + ~Buffer() noexcept override { + if (*this) { device()->destroy_buffer(handle()); } + } + Buffer(Buffer &&) noexcept = default; + Buffer(Buffer const &) noexcept = delete; + Buffer &operator=(Buffer &&rhs) noexcept { + _move_from(std::move(rhs)); + return *this; + } + Buffer &operator=(Buffer const &) noexcept = delete; + using Resource::operator bool; + // properties + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto size() const noexcept { + _check_is_valid(); + return _size; + } + [[nodiscard]] auto stride() const noexcept { + _check_is_valid(); + return _desc->type()->size(); + } + [[nodiscard]] auto size_bytes() const noexcept { + _check_is_valid(); + return _size * stride(); + } + [[nodiscard]] auto view() const noexcept { + _check_is_valid(); + return BufferView{_desc, this->native_handle(), this->handle(), 0u, _size, _size}; + } + [[nodiscard]] auto view(size_t offset, size_t count) const noexcept { + return view().subview(offset, count); + } + // commands + // copy buffer's data to pointer + [[nodiscard]] auto copy_to(void *data) const noexcept { + return this->view().copy_to(data); + } + // copy pointer's data to buffer + [[nodiscard]] auto copy_from(const void *data) noexcept { + return this->view().copy_from(data); + } + // copy source buffer's data to buffer + [[nodiscard]] auto copy_from(BufferView source) noexcept { + return this->view().copy_from(source); + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + _check_is_valid(); + return reinterpret_cast> *>(this); + } +}; + +template<> +struct Expr> { + +private: + luisa::shared_ptr _desc; + const RefExpr *_expression{nullptr}; + +public: + explicit Expr(luisa::shared_ptr desc, + const RefExpr *expr) noexcept + : _desc{std::move(desc)}, _expression{expr} {} + + Expr(BufferView buffer) noexcept + : _desc{buffer.desc()->shared_from_this()}, + _expression{detail::FunctionBuilder::current()->buffer_binding( + Type::buffer(buffer.desc()->type()), buffer.handle(), + buffer.offset_bytes(), buffer.size_bytes())} {} + + /// Contruct from Buffer. Will call buffer_binding() to bind buffer + Expr(const Buffer &buffer) noexcept + : Expr{BufferView{buffer}} {} + + /// Construct from Var>. + Expr(const Var> &buffer) noexcept + : Expr{buffer.desc()->shared_from_this(), buffer.expression()} {} + + /// Construct from Var>. + Expr(const Var> &buffer) noexcept + : Expr{buffer.desc()->shared_from_this(), buffer.expression()} {} + + [[nodiscard]] const coroutine::CoroFrameDesc *desc() const noexcept { return _desc.get(); } + [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } + + /// Read buffer at index + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + auto f = detail::FunctionBuilder::current(); + auto expr = f->call(_desc->type(), CallOp::BUFFER_READ, + {_expression, detail::extract_expression(std::forward(index))}); + auto frame = f->local(_desc->type()); + f->assign(frame, expr); + return coroutine::CoroFrame{_desc, frame}; + } + + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &value) const noexcept { + detail::FunctionBuilder::current()->call( + CallOp::BUFFER_WRITE, + {_expression, + detail::extract_expression(std::forward(index)), + value.expression()}); + } + + [[nodiscard]] Expr device_address() const noexcept { + return def(detail::FunctionBuilder::current()->call( + Type::of(), CallOp::BUFFER_ADDRESS, {_expression})); + } + +public: + [[nodiscard]] auto operator->() const noexcept { return this; } +}; + +namespace detail { + +template typename B> +class BufferExprProxy> { + +private: + using CoroFrameBuffer = B; + CoroFrameBuffer _buffer; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(BufferExprProxy) + +public: + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_buffer}.read(std::forward(index)); + } + template + requires is_integral_expr_v + void write(I &&index, V &&value) const noexcept { + Expr{_buffer}.write(std::forward(index), std::forward(value)); + } + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_buffer}.device_address(); + } +}; + +}// namespace detail + +}// namespace luisa::compute diff --git a/include/luisa/coro/coro_frame_desc.h b/include/luisa/coro/coro_frame_desc.h new file mode 100644 index 000000000..82a873e93 --- /dev/null +++ b/include/luisa/coro/coro_frame_desc.h @@ -0,0 +1,35 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this { + +public: + using DesignatedFieldDict = luisa::unordered_map; + +private: + const Type *_type{nullptr}; + DesignatedFieldDict _designated_fields; + +private: + CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept; + +public: + [[nodiscard]] static luisa::shared_ptr create(const Type *type, DesignatedFieldDict m) noexcept; + [[nodiscard]] auto type() const noexcept { return _type; } + [[nodiscard]] const Type *field(uint index) const noexcept; + [[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; } + [[nodiscard]] uint designated_field(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; +}; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_frame_smem.h b/include/luisa/coro/coro_frame_smem.h new file mode 100644 index 000000000..c23057014 --- /dev/null +++ b/include/luisa/coro/coro_frame_smem.h @@ -0,0 +1,165 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include + +namespace luisa::compute { + +namespace detail { +[[noreturn]] LC_CORO_API void error_coro_frame_smem_subscript_on_soa() noexcept; +}// namespace detail + +template<> +class Shared { + +private: + luisa::shared_ptr _desc; + luisa::vector _expressions; + // we need to return a `CoroFrame &` for operator[] with proper lifetime + luisa::vector> _temp_frames; + size_t _size; + +private: + void _create(bool soa, luisa::span soa_excluded) noexcept { + auto fb = detail::FunctionBuilder::current(); + if (!soa) { + auto s = fb->shared(Type::array(_desc->type(), _size)); + _expressions.emplace_back(s); + } else { + auto fields = _desc->type()->members(); + _expressions.reserve(fields.size()); + for (auto i = 0u; i < fields.size(); i++) { + if (std::find(soa_excluded.begin(), soa_excluded.end(), i) != soa_excluded.end()) { + _expressions.emplace_back(nullptr); + } else { + auto type = i == 0u ? Type::array(Type::of(), 3u) : fields[i]; + auto s = fb->shared(Type::array(type, _size)); + _expressions.emplace_back(s); + } + } + } + } + +public: + Shared(luisa::shared_ptr desc, + size_t n, bool soa = false, + luisa::span soa_excluded_fields = {}) noexcept + : _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); } + + Shared(Shared &&) noexcept = default; + Shared(const Shared &) noexcept = delete; + Shared &operator=(Shared &&) noexcept = delete; + Shared &operator=(const Shared &) noexcept = delete; + + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; } + [[nodiscard]] auto size() const noexcept { return _size; } + +private: + /// Read index with active fields + template + requires is_integral_expr_v + [[nodiscard]] auto _read(I &&index, luisa::optional> active_fields) const noexcept { + auto i = def(std::forward(index)); + auto fb = detail::FunctionBuilder::current(); + auto frame = fb->local(_desc->type()); + if (!is_soa()) { + auto expr = fb->access(_desc->type(), _expressions[0], i.expression()); + fb->assign(frame, expr); + } else { + auto fields = _desc->type()->members(); + for (auto m = 0u; m < fields.size(); m++) { + if (_expressions[m] == nullptr) { continue; } + if (active_fields && std::find(active_fields->begin(), active_fields->end(), m) == active_fields->end()) { continue; } + auto f = fb->member(fields[m], frame, m); + if (m == 0u) { + auto t = Type::of(); + auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); + std::array elems{}; + elems[0] = fb->access(t, s, fb->literal(t, 0u)); + elems[1] = fb->access(t, s, fb->literal(t, 1u)); + elems[2] = fb->access(t, s, fb->literal(t, 2u)); + auto v = fb->call(Type::of(), CallOp::MAKE_UINT3, elems); + fb->assign(f, v); + } else { + auto s = fb->access(fields[m], _expressions[m], i.expression()); + fb->assign(f, s); + } + } + } + return coroutine::CoroFrame{_desc, frame}; + } + + /// Write index with active fields + template + requires is_integral_expr_v + void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional> active_fields) const noexcept { + auto i = def(std::forward(index)); + auto fb = detail::FunctionBuilder::current(); + if (!is_soa()) { + auto ref = fb->access(_desc->type(), _expressions[0], i.expression()); + fb->assign(ref, frame.expression()); + } else { + auto fields = _desc->type()->members(); + for (auto m = 0u; m < fields.size(); m++) { + if (_expressions[m] == nullptr) { continue; } + if (active_fields && std::find(active_fields->begin(), active_fields->end(), m) == active_fields->end()) { continue; } + auto f = fb->member(fields[m], frame.expression(), m); + if (m == 0u) { + auto t = Type::of(); + auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression()); + for (auto j = 0u; j < 3u; j++) { + auto fj = fb->swizzle(t, f, 1u, j); + auto sj = fb->access(t, s, fb->literal(t, j)); + fb->assign(sj, fj); + } + } else { + auto s = fb->access(fields[m], _expressions[m], i.expression()); + fb->assign(s, f); + } + } + } + } + +public: + /// Reference to the i-th element + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame &operator[](I &&index) noexcept { + if (is_soa()) { detail::error_coro_frame_smem_subscript_on_soa(); } + auto fb = detail::FunctionBuilder::current(); + auto ref = fb->access( + _desc->type(), _expressions[0], + detail::extract_expression(std::forward(index))); + auto temp_frame = luisa::make_unique(_desc, ref); + return *_temp_frames.emplace_back(std::move(temp_frame)); + } + + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index) const noexcept { + return _read(std::forward(index), luisa::nullopt); + } + template + requires is_integral_expr_v + [[nodiscard]] coroutine::CoroFrame read(I &&index, luisa::span active_fields) const noexcept { + return _read(std::forward(index), luisa::make_optional(active_fields)); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame) const noexcept { + _write(std::forward(index), frame, luisa::nullopt); + } + template + requires is_integral_expr_v + void write(I &&index, const coroutine::CoroFrame &frame, luisa::span active_fields) const noexcept { + _write(std::forward(index), frame, luisa::make_optional(active_fields)); + } +}; + +}// namespace luisa::compute diff --git a/include/luisa/coro/coro_frame_soa.h b/include/luisa/coro/coro_frame_soa.h new file mode 100644 index 000000000..1410f4927 --- /dev/null +++ b/include/luisa/coro/coro_frame_soa.h @@ -0,0 +1,334 @@ +// +// Created by ChenXin on 2024/5/12. +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace luisa::compute { + +namespace detail { +template +class SOAExprProxy; +}// namespace detail + +template +class SOA; + +struct SOABase { +protected: + luisa::shared_ptr _desc; + luisa::shared_ptr> _field_offsets; + size_t _offset_elements{0u}, _size_elements{0u}; + +public: + SOABase() noexcept = default; + SOABase(SOABase &&) noexcept = default; + SOABase(const SOABase &) noexcept = default; + SOABase(luisa::shared_ptr desc, + luisa::shared_ptr> field_offsets, + size_t offset_elements, size_t size_elements) noexcept + : _desc{std::move(desc)}, + _field_offsets{std::move(field_offsets)}, + _offset_elements{offset_elements}, _size_elements{size_elements} {} +}; + +template<> +class SOAView : public SOABase { +private: + ByteBufferView _buffer_view; + +public: + SOAView() noexcept = default; + SOAView(luisa::shared_ptr desc, ByteBufferView buffer_view, + luisa::shared_ptr> field_offsets, + size_t offset_elements, size_t size_elements) noexcept + : SOABase{std::move(desc), std::move(field_offsets), offset_elements, size_elements}, + _buffer_view{buffer_view} { + LUISA_ASSERT((_buffer_view.offset() == 0u) && + (_buffer_view.size_bytes() == _buffer_view.total_size()), + "Invalid buffer view for SOA."); + } + template + requires std::same_as> + SOAView(const T &soa) noexcept + : SOAView{soa.view()} {} + SOAView(const SOAView &) noexcept = default; + SOAView(SOAView &&) noexcept = default; + ~SOAView() noexcept = default; + + [[nodiscard]] auto subview(uint offset_elements, uint size_elements) noexcept { + return SOAView{_desc, _buffer_view.subview(offset_elements, size_elements), + _field_offsets, _offset_elements, _size_elements}; + } + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto handle() const noexcept { return _buffer_view.handle(); } + [[nodiscard]] auto offset_elements() const noexcept { return _offset_elements; } + [[nodiscard]] auto size_elements() const noexcept { return _size_elements; } + [[nodiscard]] auto buffer_offset() const noexcept { return _buffer_view.offset(); } + [[nodiscard]] auto size_bytes() const noexcept { return _buffer_view.size_bytes(); } + [[nodiscard]] auto field_offsets() const noexcept { return _field_offsets; } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +template<> +class SOA : public SOABase { + +private: + ByteBuffer _buffer; + +public: + SOA(DeviceInterface *device, luisa::shared_ptr desc, size_t n) noexcept + : SOABase{std::move(desc), luisa::make_shared>(), + 0u, n} { + size_t size_bytes = 0u; + auto fields = _desc->type()->members(); + _field_offsets->reserve(fields.size()); + for (const auto field_type : fields) { + auto alignment = std::max(field_type->alignment(), 4u); + size_bytes = (size_bytes + alignment - 1u) & ~(alignment - 1u); + _field_offsets->emplace_back(size_bytes); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + size_bytes += aligned_size * _size_elements; + } + size_bytes = (size_bytes + 3u) & ~3u; + auto info = device->create_buffer( + Type::of(), + size_bytes, + nullptr); + _buffer = std::move(ByteBuffer{device, info}); + } + SOA() noexcept = default; + SOA(const SOA &) = delete; + SOA(SOA &&) noexcept = default; + SOA &operator=(const SOA &) = delete; + SOA &operator=(SOA &&x) noexcept { + _desc = std::move(x._desc); + _size_elements = x._size_elements; + _field_offsets = std::move(x._field_offsets); + _buffer = std::move(x._buffer); + return *this; + } + ~SOA() noexcept = default; + + // properties + [[nodiscard]] auto desc() const noexcept { return _desc.get(); } + [[nodiscard]] auto view() const noexcept { + return SOAView{ + _desc, + _buffer.view(), + _field_offsets, + 0u, + _size_elements}; + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast> *>(this); + } +}; + +/// Class of Expr> +template<> +struct Expr> : SOABase { + +private: + const RefExpr *_expression{nullptr}; + +public: + /// Construct from SOAView. Will call buffer_binding() to bind buffer + Expr(const SOAView &soa_view) noexcept + : SOABase{soa_view.desc()->shared_from_this(), soa_view.field_offsets(), + soa_view.offset_elements(), soa_view.size_elements()}, + _expression{detail::FunctionBuilder::current()->buffer_binding( + Type::of(), soa_view.handle(), + soa_view.buffer_offset(), soa_view.size_bytes())} {} + + /// Construct from SOA. Will call buffer_binding() to bind buffer + Expr(const SOA &soa) noexcept + : Expr{SOAView{soa}} {} + + /// Return RefExpr + [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } + + /// Read field with field_index at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto field_type = _desc->type()->members()[field_index]; + auto offset = _field_offsets->at(field_index); + auto alignment = std::max(field_type->alignment(), 4u); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; + auto f = fb->local(field_type); + auto s = fb->call( + field_type, CallOp::BYTE_BUFFER_READ, + {_expression, detail::extract_expression(offset_var)}); + fb->assign(f, s); + return Var(f); + } + + /// Write field with field_index at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, uint field_index) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto field_type = _desc->type()->members()[field_index]; + auto offset = _field_offsets->at(field_index); + auto alignment = std::max(field_type->alignment(), 4u); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; + fb->call(CallOp::BYTE_BUFFER_WRITE, + {_expression, + detail::extract_expression(offset_var), + detail::extract_expression(value)}); + } + + /// Read index with active fields + template + [[nodiscard]] auto read( + I &&index, luisa::optional> active_fields = luisa::nullopt) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto frame = fb->local(_desc->type()); + auto fields = _desc->type()->members(); + for (auto i = 0u; i < fields.size(); i++) { + if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } + auto field_type = fields[i]; + auto offset = _field_offsets->at(i); + auto alignment = std::max(field_type->alignment(), 4u); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; + auto s = fb->call( + field_type, CallOp::BYTE_BUFFER_READ, + {_expression, detail::extract_expression(offset_var)}); + auto f = fb->member(field_type, frame, i); + fb->assign(f, s); + } + return coroutine::CoroFrame{_desc, frame}; + } + + /// Write index with active fields + template + void write(I &&index, const coroutine::CoroFrame &frame, + luisa::optional> active_fields = luisa::nullopt) const noexcept { + auto fb = detail::FunctionBuilder::current(); + auto fields = _desc->type()->members(); + for (auto i = 0u; i < fields.size(); i++) { + if (active_fields && std::find(active_fields->begin(), active_fields->end(), i) == active_fields->end()) { continue; } + auto field_type = fields[i]; + auto offset = _field_offsets->at(i); + auto alignment = std::max(field_type->alignment(), 4u); + auto aligned_size = (field_type->size() + alignment - 1u) & ~(alignment - 1u); + auto offset_var = offset + (_offset_elements + ULong(index)) * aligned_size; + auto f = fb->member(field_type, frame.expression(), i); + fb->call(CallOp::BYTE_BUFFER_WRITE, + {_expression, detail::extract_expression(offset_var), f}); + } + } + + [[nodiscard]] Expr device_address() const noexcept { + return def(detail::FunctionBuilder::current()->call( + Type::of(), CallOp::BUFFER_ADDRESS, {_expression})); + } + + /// Self-pointer to unify the interfaces of the captured Buffer and Expr> + [[nodiscard]] auto operator->() const noexcept { return this; } +}; + +/// Class of Expr> +template<> +struct Expr> : public Expr> { + using Expr>::Expr; +}; + +namespace detail { + +template typename B> +class SOAExprProxy> { +private: + using SOAOrView = B; + SOAOrView _soa; + +public: + LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(SOAExprProxy) + +public: + /// Read field with field_index at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, uint field_index) const noexcept { + return Expr{_soa}.template read_field(std::forward(index), field_index); + } + /// Read field named with "name" at index + template + requires is_integral_expr_v + [[nodiscard]] Var read_field(I &&index, luisa::string_view name) const noexcept { + return read_field(std::forward(index), _soa.desc()->designated_field(name)); + } + + /// Write field with field_index at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, uint field_index) const noexcept { + Expr{_soa}.write_field(std::forward(index), + std::forward(value), + field_index); + } + /// Write field named with "name" at index + template + requires is_integral_expr_v + void write_field(I &&index, V &&value, luisa::string_view name) const noexcept { + write_field(std::forward(index), + std::forward(value), + _soa.desc()->designated_field(name)); + } + + /// Read index + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::nullopt); + } + /// Read index with active fields + template + requires is_integral_expr_v + [[nodiscard]] auto read(I &&index, luisa::span active_fields) const noexcept { + return Expr{_soa}.read(std::forward(index), luisa::make_optional(active_fields)); + } + + /// Write index + template + requires is_integral_expr_v + void write(I &&index, V &&value) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + luisa::nullopt); + } + /// Write index with active fields + template + requires is_integral_expr_v + void write(I &&index, V &&value, luisa::span active_fields) const noexcept { + Expr{_soa}.write(std::forward(index), + std::forward(value), + luisa::make_optional(active_fields)); + } + + [[nodiscard]] Expr device_address() const noexcept { + return Expr{_soa}.device_address(); + } +}; + +}// namespace detail + +}// namespace luisa::compute \ No newline at end of file diff --git a/include/luisa/coro/coro_func.h b/include/luisa/coro/coro_func.h new file mode 100644 index 000000000..e3d22d4b0 --- /dev/null +++ b/include/luisa/coro/coro_func.h @@ -0,0 +1,251 @@ +// +// Created by Mike on 2024/5/8. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +namespace detail { +LC_CORO_API void coroutine_chained_await_impl( + CoroFrame &frame, uint node_count, + luisa::move_only_function node) noexcept; +}// namespace detail + +template +class Coroutine { + static_assert(luisa::always_false_v); +}; + +namespace detail { + +class LC_CORO_API CoroAwaiter : public concepts::Noncopyable { + +private: + using Await = luisa::move_only_function; + Await _await; + +public: + explicit CoroAwaiter(Await await) noexcept + : _await{std::move(await)} {} + void await() && noexcept { _await(); } +}; + +}// namespace detail + +template +class Coroutine { + +public: + static_assert(std::is_same_v, + "Coroutine function must return void."); + + class Subroutine { + + private: + Function f; + + private: + friend class Coroutine; + explicit Subroutine(Function function) noexcept : f{function} {} + + public: + void operator()(CoroFrame &frame, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + compute::detail::CallableInvoke invoke; + invoke << frame.expression(); + static_cast((invoke << ... << args)); + compute::detail::FunctionBuilder::current()->call(f, invoke.args()); + } + }; + +private: + luisa::shared_ptr _graph; + +public: + explicit Coroutine(luisa::shared_ptr graph) noexcept + : _graph{std::move(graph)} { + // TODO: check arguments + } + + template + requires std::negation_v>> && + std::negation_v>> + Coroutine(Def &&f) noexcept { + auto coro = compute::detail::FunctionBuilder::define_coroutine([&f] { + static_assert(std::is_invocable_r_v...>); + auto create = [](auto &&def, std::index_sequence) noexcept { + using arg_tuple = std::tuple; + using var_tuple = std::tuple>...>; + using tag_tuple = std::tuple...>; + auto args = compute::detail::create_argument_definitions(std::tuple<>{}); + static_assert(std::tuple_size_v == sizeof...(Args)); + return luisa::invoke(std::forward(def), + static_cast> &&>(std::get(args))...); + }; + create(std::forward(f), std::index_sequence_for{}); + compute::detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types + }); + _graph = CoroGraph::create(coro->function()); + } + +public: + [[nodiscard]] auto graph() const noexcept { return _graph.get(); } + [[nodiscard]] auto &shared_graph() const noexcept { return _graph; } + [[nodiscard]] auto frame() const noexcept { return _graph->frame(); } + [[nodiscard]] auto &shared_frame() const noexcept { return _graph->shared_frame(); } + +public: + [[nodiscard]] auto instantiate() const noexcept { return CoroFrame::create(_graph->shared_frame()); } + [[nodiscard]] auto instantiate(Expr coro_id) const noexcept { return CoroFrame::create(_graph->shared_frame(), coro_id); } + [[nodiscard]] auto subroutine_count() const noexcept { return static_cast(_graph->nodes().size()); } + [[nodiscard]] auto operator[](CoroToken token) const noexcept { return Subroutine{_graph->node(token).cc()}; } + [[nodiscard]] auto operator[](luisa::string_view name) const noexcept { return Subroutine{_graph->node(name).cc()}; } + [[nodiscard]] auto entry() const noexcept { return (*this)[coro_token_entry]; } + [[nodiscard]] auto subroutine(CoroToken token) const noexcept { return (*this)[token]; } + [[nodiscard]] auto subroutine(luisa::string_view name) const noexcept { return (*this)[name]; } + +private: + [[nodiscard]] auto _await(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return detail::CoroAwaiter{[=, coro_id = std::move(coro_id), this]() noexcept { + auto frame = coro_id ? instantiate(*coro_id) : instantiate(); + detail::coroutine_chained_await_impl(frame, subroutine_count(), [&](CoroToken token, CoroFrame &f) noexcept { + subroutine(token)(f, args...); + }); + }}; + } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _await(luisa::make_optional(coro_id), args...); + } +}; + +namespace detail { +struct CoroAwaitInvoker { + void operator%(CoroAwaiter &&awaiter) const && noexcept { + std::move(awaiter).await(); + } +}; +}// namespace detail + +template +Coroutine(T &&) -> Coroutine>>; + +template +class Generator { + static_assert(luisa::always_false_v); +}; + +namespace detail { + +LC_CORO_API void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept; + +template +class GeneratorIter : public concepts::Noncopyable { + +private: + uint _n; + CoroFrame _frame; + using Resume = luisa::move_only_function; + Resume _resume; + +public: + GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept + : _n{n}, _frame{std::move(frame)}, + _resume{std::move(resume)} {} + +public: + auto &update() noexcept { + coroutine_generator_next_impl(_frame, _n, _resume); + return *this; + } + [[nodiscard]] Bool is_terminated() const noexcept { return _frame.is_terminated(); } + [[nodiscard]] Var value() noexcept { return _frame.get("__yielded_value"); } + +private: + class RangeForIterator { + private: + GeneratorIter &_g; + bool _invoked{false}; + LoopStmt *_loop{nullptr}; + + private: + friend class GeneratorIter; + explicit RangeForIterator(GeneratorIter &g) noexcept : _g{g} {} + + public: + RangeForIterator &operator++() noexcept { + _invoked = true; + compute::detail::FunctionBuilder::current()->pop_scope(_loop->body()); + return *this; + } + [[nodiscard]] bool operator==(luisa::default_sentinel_t) const noexcept { return _invoked; } + [[nodiscard]] Var operator*() noexcept { + auto fb = compute::detail::FunctionBuilder::current(); + _loop = fb->loop_(); + fb->push_scope(_loop->body()); + _g.update(); + dsl::if_(_g.is_terminated(), [] { dsl::break_(); }); + return _g.value(); + } + }; + +public: + [[nodiscard]] auto begin() noexcept { return RangeForIterator{*this}; } + [[nodiscard]] auto end() const noexcept { return luisa::default_sentinel; } +}; + +}// namespace detail + +template +class Generator { + + static_assert(!std::is_same_v, + "Generator function must not return void."); + +private: + Coroutine _coro; + +public: + template + requires std::negation_v>> && + std::negation_v>> + Generator(Def &&f) noexcept : _coro{std::forward(f)} {} + +public: + [[nodiscard]] auto coroutine() const noexcept { return _coro; } + +private: + [[nodiscard]] auto _iter(luisa::optional> coro_id, + compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return detail::GeneratorIter{ + _coro.subroutine_count(), + coro_id ? _coro.instantiate(*coro_id) : _coro.instantiate(), + [=, this](CoroFrame &f, CoroToken token) noexcept { + _coro[token](f, args...); + }, + }; + } + +public: + [[nodiscard]] auto operator()(compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::nullopt, args...); + } + [[nodiscard]] auto operator()(Expr coro_id, compute::detail::prototype_to_callable_invocation_t... args) const noexcept { + return _iter(luisa::make_optional(coro_id), args...); + } +}; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_graph.h b/include/luisa/coro/coro_graph.h index 07b2ed697..0c5a37fa2 100644 --- a/include/luisa/coro/coro_graph.h +++ b/include/luisa/coro/coro_graph.h @@ -1,32 +1,72 @@ +// +// Created by Mike on 2024/5/8. +// + #pragma once #include +#include #include -#include +#include +#include + +namespace luisa::compute::coroutine { -namespace luisa::compute::inline coro { +class CoroFrameDesc; class LC_CORO_API CoroGraph { +public: + using CC = luisa::shared_ptr;// current continuation function + +public: + class LC_CORO_API Node { + + private: + luisa::vector _input_fields; + luisa::vector _output_fields; + luisa::vector _targets; + CC _cc; + + public: + Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept; + ~Node() noexcept; + + public: + [[nodiscard]] auto input_fields() const noexcept { return luisa::span{_input_fields}; } + [[nodiscard]] auto output_fields() const noexcept { return luisa::span{_output_fields}; } + [[nodiscard]] auto targets() const noexcept { return luisa::span{_targets}; } + [[nodiscard]] Function cc() const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; + }; + private: - luisa::unordered_map _nodes; - uint _entry; - const Type *_state_type; - luisa::unordered_map _designated_state_members; + luisa::shared_ptr _frame; + luisa::unordered_map _nodes; + luisa::unordered_map _named_tokens; + +public: + CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept; + ~CoroGraph() noexcept; public: - // for construction only - CoroGraph(uint entry, const Type *state_type) noexcept; - [[nodiscard]] CoroNode *add_node(uint token, CoroNode::Func f) noexcept; - void designate_state_member(luisa::string name, uint index) noexcept; + // create a coroutine graph from a coroutine function definition + [[nodiscard]] static luisa::shared_ptr create(Function coroutine) noexcept; public: - [[nodiscard]] const CoroNode *entry() const noexcept; - [[nodiscard]] const CoroNode *node(uint token) const noexcept; + [[nodiscard]] auto frame() const noexcept { return _frame.get(); } + [[nodiscard]] auto &shared_frame() const noexcept { return _frame; } [[nodiscard]] auto &nodes() const noexcept { return _nodes; } - [[nodiscard]] auto state_type() const noexcept { return _state_type; } - [[nodiscard]] auto &designated_state_members() const noexcept { return _designated_state_members; } - [[nodiscard]] uint designated_state_member(luisa::string_view name) const noexcept; + [[nodiscard]] auto &named_tokens() const noexcept { return _named_tokens; } + [[nodiscard]] const Node &entry() const noexcept; + [[nodiscard]] const Node &node(CoroToken index) const noexcept; + [[nodiscard]] const Node &node(luisa::string_view name) const noexcept; + [[nodiscard]] luisa::string dump() const noexcept; }; -}// namespace luisa::compute::inline coro +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_node.h b/include/luisa/coro/coro_node.h deleted file mode 100644 index cb76522c9..000000000 --- a/include/luisa/coro/coro_node.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace luisa::compute::inline coro { - -class CoroGraph; - -class LC_CORO_API CoroNode { - - friend class CoroGraph; - -public: - using Func = luisa::shared_ptr; - -private: - const CoroGraph *_graph; - Func _function; - -protected: - CoroNode(const CoroGraph *graph, Func function) noexcept; - -public: - luisa::vector input_state_members; - luisa::vector output_state_members; - -public: - [[nodiscard]] auto graph() const noexcept { return _graph; } - [[nodiscard]] auto function() const noexcept { return _function->function(); } -}; - -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/coro_scheduler.h b/include/luisa/coro/coro_scheduler.h new file mode 100644 index 000000000..82b9444d0 --- /dev/null +++ b/include/luisa/coro/coro_scheduler.h @@ -0,0 +1,90 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include + +namespace luisa::compute::coroutine { + +template +class CoroScheduler; + +namespace detail { + +template +class CoroSchedulerInvoke; + +class CoroSchedulerDispatch : public concepts::Noncopyable { + +private: + luisa::move_only_function _impl; + +private: + template + friend class CoroSchedulerInvoke; + explicit CoroSchedulerDispatch(luisa::move_only_function impl) noexcept + : _impl{std::move(impl)} {} + +public: + void operator()(Stream &stream) && noexcept { _impl(stream); } +}; + +template +class CoroSchedulerInvoke : public concepts::Noncopyable { + +private: + using Scheduler = CoroScheduler; + Scheduler *_scheduler; + std::tuple...> _args; + +private: + friend Scheduler; + explicit CoroSchedulerInvoke(Scheduler *scheduler, compute::detail::prototype_to_shader_invocation_t... args) noexcept + : _scheduler{scheduler}, _args{args...} {} + +public: + [[nodiscard]] auto dispatch(uint3 size) && noexcept { + return CoroSchedulerDispatch{[s = _scheduler, args = std::move(_args), size](Stream &stream) noexcept { + std::apply( + [s, size, &stream](A &&...a) noexcept { + s->_dispatch(stream, size, std::forward(a)...); + }, + args); + }}; + } + [[nodiscard]] auto dispatch(uint nx, uint ny, uint nz) && noexcept { + return std::move(*this).dispatch(luisa::make_uint3(nx, ny, nz)); + } + [[nodiscard]] auto dispatch(uint nx, uint ny) && noexcept { + return std::move(*this).dispatch(luisa::make_uint3(nx, ny, 1u)); + } + [[nodiscard]] auto dispatch(uint2 size) && noexcept { + return std::move(*this).dispatch(luisa::make_uint3(size, 1u)); + } + [[nodiscard]] auto dispatch(uint nx) && noexcept { + return std::move(*this).dispatch(luisa::make_uint3(nx, 1u, 1u)); + } +}; + +}// namespace detail + +template +class CoroScheduler { + +private: + friend class detail::CoroSchedulerInvoke; + virtual void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept = 0; + +public: + virtual ~CoroScheduler() noexcept = default; + [[nodiscard]] auto operator()(compute::detail::prototype_to_shader_invocation_t... args) noexcept { + return detail::CoroSchedulerInvoke{this, args...}; + } +}; + +}// namespace luisa::compute::coroutine + +LUISA_MARK_STREAM_EVENT_TYPE(luisa::compute::coroutine::detail::CoroSchedulerDispatch) diff --git a/include/luisa/coro/coro_token.h b/include/luisa/coro/coro_token.h new file mode 100644 index 000000000..be5883064 --- /dev/null +++ b/include/luisa/coro/coro_token.h @@ -0,0 +1,12 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +namespace luisa::compute::coroutine { +using CoroToken = unsigned int; +constexpr CoroToken coro_token_entry = 0u; +constexpr CoroToken coro_token_terminal = 0x8000'0000u; +constexpr CoroToken coro_token_valid_mask = 0x7fff'ffffu; +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/coro_transition.h b/include/luisa/coro/coro_transition.h deleted file mode 100644 index f63815461..000000000 --- a/include/luisa/coro/coro_transition.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -namespace luisa::compute::inline coro { - -struct CoroTransition { - uint destination; -}; - -}// namespace luisa::compute::inline coro diff --git a/include/luisa/coro/radix_sort.h b/include/luisa/coro/radix_sort.h index 96d8ed046..f9ea72057 100644 --- a/include/luisa/coro/radix_sort.h +++ b/include/luisa/coro/radix_sort.h @@ -6,7 +6,7 @@ #include #include #include -#include +#include namespace luisa::compute { namespace radix_sort { const uint HIST_BLOCK_SIZE = 128; @@ -88,7 +88,7 @@ class instance { ///@param high_bit: highest bit of radix sort instance(Device device, uint maxn, temp_storage &temp, Callable *get_key, Callable *get_val, - Callable *get_key_set = nullptr, + Callable *get_key_set = nullptr, uint mode = 0, uint digit = 128, uint low_bit = 0, uint high_bit = 31) : DIGIT{digit}, low_bit{low_bit}, high_bit{high_bit}, MAXN{maxn}, _temp{temp} { LUISA_ASSERT(mode == 0 || mode == 1, "mode should be 0 for radix sort and 1 for bucket sort!"); BIT = 0; @@ -361,4 +361,4 @@ class instance { } }; } -} \ No newline at end of file +}// namespace luisa::compute::radix_sort \ No newline at end of file diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h new file mode 100644 index 000000000..7d777e2b4 --- /dev/null +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -0,0 +1,116 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +struct PersistentThreadsCoroSchedulerConfig { + uint thread_count = 64_k; + uint block_size = 128; + uint fetch_size = 4; + bool shared_memory_soa = false; + bool global_memory_ext = false; +}; + +namespace detail { +LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl( + const PersistentThreadsCoroSchedulerConfig &config, + uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, + Expr> global, const Buffer &global_frames, + luisa::move_only_function call_subroutine) noexcept; +}// namespace detail + +template +class PersistentThreadsCoroScheduler : public CoroScheduler { + +public: + using Coro = Coroutine; + using Config = PersistentThreadsCoroSchedulerConfig; + +private: + Config _config; + Shader1D, uint3, Args...> _pt_shader; + Shader1D> _clear_shader; + Buffer _global; + Buffer _global_frames; + Shader1D _initialize_shader; + +private: + void _prepare(Device &device, const Coro &coro) noexcept { + _global = device.create_buffer(1); + auto q_fac = 1u; + auto g_fac = coro.subroutine_count() - q_fac; + if (_config.global_memory_ext) { + auto global_ext_size = _config.thread_count * g_fac; + _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); + } + Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_size_prefix_product, Var... args) noexcept { + set_block_size(_config.block_size, 1u, 1u); + auto global_queue_size = _config.block_size * g_fac; + auto shared_queue_size = _config.block_size * q_fac; + auto call_subroutine = [&](CoroFrame &frame, CoroToken token) noexcept { coro[token](frame, args...); }; + Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; + detail::persistent_threads_coro_scheduler_main_kernel_impl( + _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, + frames, dispatch_size_prefix_product, global, _global_frames, call_subroutine); + }; + _pt_shader = device.compile(main_kernel); + _clear_shader = device.compile<1>([](BufferUInt global) { + global->write(dispatch_x(), 0u); + }); + if (_config.global_memory_ext) { + _initialize_shader = device.compile<1>([&](UInt n) noexcept { + auto x = dispatch_x(); + $if (x < n) { + _global_frames->write(x, coro.instantiate()); + }; + }); + } + } + + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + auto dispatch_size_prefix_product = make_uint3( + dispatch_size.x, + dispatch_size.x * dispatch_size.y, + dispatch_size.x * dispatch_size.y * dispatch_size.z); + if (_config.global_memory_ext) { + auto n = static_cast(_global_frames.size()); + stream << _clear_shader(_global).dispatch(1u) + << _initialize_shader(n).dispatch(n) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); + } else { + stream << _clear_shader(_global).dispatch(1u) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); + } + } + +public: + PersistentThreadsCoroScheduler(Device &device, const Coro &coro, const Config &config) noexcept + : _config{config} { + _config.thread_count = luisa::align(_config.thread_count, _config.block_size); + _prepare(device, coro); + } + PersistentThreadsCoroScheduler(Device &device, const Coro &coro) noexcept + : PersistentThreadsCoroScheduler{device, coro, Config{}} {} +}; + +// User-defined CTAD guides +template +PersistentThreadsCoroScheduler(Device &device, const Coroutine &coro, + const PersistentThreadsCoroSchedulerConfig &config) noexcept + -> PersistentThreadsCoroScheduler; + +template +PersistentThreadsCoroScheduler(Device &device, const Coroutine &coro) noexcept + -> PersistentThreadsCoroScheduler; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/schedulers/state_machine.h b/include/luisa/coro/schedulers/state_machine.h new file mode 100644 index 000000000..a2045b69f --- /dev/null +++ b/include/luisa/coro/schedulers/state_machine.h @@ -0,0 +1,86 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +namespace detail { +LC_CORO_API void coro_scheduler_state_machine_impl( + CoroFrame &frame, uint state_count, + luisa::move_only_function node) noexcept; +LC_CORO_API void coro_scheduler_state_machine_smem_impl( + Shared &smem, const CoroGraph *graph, + luisa::move_only_function node) noexcept; +}// namespace detail + +struct StateMachineCoroSchedulerConfig { + uint3 block_size = luisa::make_uint3(128, 1, 1); + bool shared_memory = false; + bool shared_memory_soa = true; +}; + +template +class StateMachineCoroScheduler : public CoroScheduler { + +public: + using Coro = Coroutine; + using Config = StateMachineCoroSchedulerConfig; + +private: + Shader3D _shader; + +private: + void _create_shader(Device &device, const Coro &coroutine, const Config &config) noexcept { + Kernel3D kernel = [&coroutine, &config](Var... args) noexcept { + set_block_size(config.block_size); + if (config.shared_memory) { + auto n = config.block_size.x * config.block_size.y * config.block_size.z; + Shared sm{coroutine.shared_frame(), n, config.shared_memory_soa, std::array{0u, 1u}}; + detail::coro_scheduler_state_machine_smem_impl( + sm, coroutine.graph(), + [&](CoroToken token, CoroFrame &frame) noexcept { + coroutine.subroutine(token)(frame, args...); + }); + } else { + auto frame = coroutine.instantiate(dispatch_id()); + detail::coro_scheduler_state_machine_impl( + frame, coroutine.subroutine_count(), + [&](CoroToken token) noexcept { + coroutine.subroutine(token)(frame, args...); + }); + } + }; + _shader = device.compile(kernel); + } + + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + stream << _shader(args...).dispatch(dispatch_size); + } + +public: + StateMachineCoroScheduler(Device &device, const Coro &coro, const Config &config) noexcept { + _create_shader(device, coro, config); + } + StateMachineCoroScheduler(Device &device, const Coro &coro) noexcept + : StateMachineCoroScheduler{device, coro, Config{}} {} +}; + +// User-defined CTAD guides +template +StateMachineCoroScheduler(Device &, const Coroutine &) + -> StateMachineCoroScheduler; + +template +StateMachineCoroScheduler(Device &, const Coroutine &, + const StateMachineCoroSchedulerConfig &) + -> StateMachineCoroScheduler; + +}// namespace luisa::compute::coroutine diff --git a/include/luisa/coro/schedulers/wavefront.h b/include/luisa/coro/schedulers/wavefront.h new file mode 100644 index 000000000..bfca6c361 --- /dev/null +++ b/include/luisa/coro/schedulers/wavefront.h @@ -0,0 +1,478 @@ +// +// Created by Mike on 2024/5/10. +// + +#pragma once + +#include +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +struct WavefrontCoroSchedulerConfig { + // uint3 block_size = make_uint3(8u, 8u, 1u); + uint thread_count = 2_M; + bool global_memory_soa = true; + bool gather_by_sorting = true;// use sort for coro token gathering + bool frame_buffer_compaction = true; + uint hint_range = 0xffff'ffff; + luisa::vector hint_fields; +}; + +template +class WavefrontCoroScheduler : public CoroScheduler { + +public: + using Config = WavefrontCoroSchedulerConfig; + +private: + Config _config; + using ArgPack = std::tuple...>; + luisa::optional _args; + SOA _frame_soa; + Buffer _frame_buffer; + Shader1D, Buffer, uint3, uint, uint, uint, Args...> _gen_shader; + luisa::vector, Buffer, uint, Args...>> _resume_shaders; + Shader1D, Buffer, uint> _count_prefix_shader; + Shader1D, Buffer, uint> _gather_shader; + Shader1D, uint> _initialize_shader; + Shader1D, uint, uint> _compact_shader; + Shader1D, uint> _clear_shader; + Buffer _resume_index; + Buffer _resume_count; + ///offset calculate from count, will be end after gathering + Buffer _resume_offset; + Buffer _global_buffer; + luisa::vector _host_count; + luisa::vector _host_offset; + uint3 _dispatch_shape; + bool _host_empty; + uint _dispatch_counter; + uint _max_sub_coro; + uint _dispatch_size; + radix_sort::temp_storage _sort_temp_storage; + radix_sort::instance<> _sort_token; + radix_sort::instance> _sort_hint; + luisa::vector _have_hint; + Buffer _temp_key[2]; + Buffer _temp_index; + +private: + void _dispatch(Stream &stream, uint3 dispatch_size, + compute::detail::prototype_to_shader_invocation_t... args) noexcept override { + _args.emplace(std::forward>(args)...); + _dispatch_shape = dispatch_size; + _dispatch_size = dispatch_size.x * dispatch_size.y * dispatch_size.z;// TODO + _dispatch_counter = 0; + _host_empty = true; + for (auto i = 0u; i < _max_sub_coro; i++) { + if (i) { + _host_count[i] = 0; + _host_offset[i] = _config.thread_count; + } else { + _host_count[i] = _config.thread_count; + _host_offset[i] = 0; + } + } + stream << _initialize_shader(_resume_count, _config.thread_count).dispatch(_config.thread_count); + this->_await_all(stream); + } + + template + [[nodiscard]] auto _invoke(ShaderTytpe &shader, PrefixArgsType &&...prefix_args) const noexcept { + return std::apply( + [&](PostfixArgsType &&...postfix_args) { + return shader(std::forward(prefix_args)..., + std::forward(postfix_args)...); + }, + _args.value()); + } + + void _create_shader(Device &device, const Coroutine &coroutine, + const WavefrontCoroSchedulerConfig &config) noexcept { + _config = config; + if (_config.global_memory_soa) { + _frame_soa = device.create_soa(coroutine.shared_frame(), _config.thread_count); + } else { + _frame_buffer = device.create_coro_frame_buffer(coroutine.shared_frame(), _config.thread_count); + } + bool use_sort = _config.gather_by_sorting || !_config.hint_fields.empty(); + _max_sub_coro = coroutine.subroutine_count(); + _resume_index = device.create_buffer(_config.thread_count); + if (use_sort) { + _temp_index = device.create_buffer(_config.thread_count); + _temp_key[0] = device.create_buffer(_config.thread_count); + _temp_key[1] = device.create_buffer(_config.thread_count); + } + _resume_count = device.create_buffer(_max_sub_coro); + _resume_offset = device.create_buffer(_max_sub_coro); + _global_buffer = device.create_buffer(1); + _host_empty = true; + _dispatch_counter = 0; + _host_offset.resize(_max_sub_coro); + _host_count.resize(_max_sub_coro); + _have_hint.resize(_max_sub_coro, false); + for (auto &token : _config.hint_fields) { + auto id = coroutine.graph()->named_tokens().find(token); + if (id != coroutine.graph()->named_tokens().end()) { + LUISA_ASSERT(id->second < _max_sub_coro, + "coroutine token {} of id {} out of range {}", token, id->second, _max_sub_coro); + _have_hint[id->second] = true; + } else { + LUISA_WARNING("coroutine token {} not found, hint disabled", token); + } + } + for (auto i = 0u; i < _max_sub_coro; i++) { + if (i) { + _host_count[i] = 0; + _host_offset[i] = _config.thread_count; + } else { + _host_count[i] = _config.thread_count; + _host_offset[i] = 0; + } + } + Callable get_coro_token = [&](UInt index) { + $if (index > _config.thread_count) { + device_log("Index out of range {}/{}", index, _config.thread_count); + }; + if (_config.global_memory_soa) { + return _frame_soa->read_field(index, "target_token") & coro_token_valid_mask; + } else { + CoroFrame frame = _frame_buffer->read(index); + return frame.target_token & coro_token_valid_mask; + } + }; + Callable identical = [](UInt index) { + return index; + }; + + Callable keep_index = [](UInt index, BufferUInt val) { + return val.read(index); + }; + Callable get_coro_hint = [&](UInt index, BufferUInt val) { + if (!_config.hint_fields.empty()) { + auto id = keep_index(index, val); + if (_config.global_memory_soa) { + return _frame_soa->read_field(id, "coro_hint"); + } else { + CoroFrame frame = _frame_buffer->read(id); + return frame.get("coro_hint"); + } + } + return def(0u); + }; + if (use_sort) { + _sort_temp_storage = radix_sort::temp_storage( + device, _config.thread_count, std::max(std::min(_config.hint_range, 128u), _max_sub_coro)); + } + if (_config.gather_by_sorting) { + _sort_token = radix_sort::instance<>( + device, _config.thread_count, _sort_temp_storage, &get_coro_token, &identical, + &get_coro_token, 1, _max_sub_coro); + } + if (!_config.hint_fields.empty()) { + if (_config.hint_range <= 128) { + _sort_hint = radix_sort::instance>( + device, _config.thread_count, _sort_temp_storage, &get_coro_hint, &keep_index, + &get_coro_hint, 1, _config.hint_range); + } else { + auto highbit = 0; + while ((_config.hint_range >> highbit) != 1) { + highbit++; + } + _sort_hint = radix_sort::instance>( + device, _config.thread_count, _sort_temp_storage, &get_coro_hint, &keep_index, + &get_coro_hint, 0, 128, 0, highbit); + } + } + Kernel1D gen_kernel = [&](BufferUInt index, BufferUInt count, UInt3 dispatch_shape, UInt offset, UInt st_task_id, UInt n, Var... args) { + auto x = dispatch_x(); + $if (x >= n) { + $return(); + }; + UInt frame_id; + if (!_config.frame_buffer_compaction) { + frame_id = index->read(x); + } else { + frame_id = offset + x; + } + if (!_config.gather_by_sorting) { + count.atomic(0u).fetch_add(-1u); + } + auto global_id = st_task_id + x; + auto image_size = dispatch_shape.x * dispatch_shape.y; + auto global_id_z = global_id / image_size; + auto global_id_xy = global_id % image_size; + auto global_id_x = global_id_xy % dispatch_shape.x; + auto global_id_y = global_id_xy / dispatch_shape.x; + CoroFrame frame = coroutine.instantiate(make_uint3(global_id_x, global_id_y, global_id_z)); + coroutine.entry()(frame, args...); + if (_config.global_memory_soa) { + _frame_soa->write(frame_id, frame, coroutine.graph()->node(0u).output_fields()); + } else { + _frame_buffer->write(frame_id, frame); + } + if (!_config.gather_by_sorting) { + auto nxt = frame.target_token & coro_token_valid_mask; + count.atomic(nxt).fetch_add(1u); + } + }; + ShaderOption o{}; + _gen_shader = device.compile(gen_kernel, o); + _gen_shader.set_name("gen"); + _resume_shaders.resize(_max_sub_coro); + + for (auto i = 1u; i < _max_sub_coro; ++i) { + Kernel1D resume_kernel = [&](BufferUInt index, BufferUInt count, UInt n, Var... args) { + auto x = dispatch_x(); + $if (x >= n) { + $return(); + }; + auto frame_id = index.read(x); + CoroFrame frame = coroutine.instantiate(); + if (_config.global_memory_soa) { + //frame = frame_buffer.read(frame_id); + frame = _frame_soa->read(frame_id, coroutine.graph()->node(i).input_fields()); + } else { + frame = _frame_buffer->read(frame_id); + } + if (!_config.gather_by_sorting) { + count.atomic(i).fetch_sub(1u); + } + coroutine[i](frame, args...); + if (_config.global_memory_soa) { + _frame_soa->write(frame_id, frame, coroutine.graph()->node(i).output_fields()); + } else { + _frame_buffer->write(frame_id, frame); + } + + if (!_config.gather_by_sorting) { + auto nxt = frame.target_token & coro_token_valid_mask; + $if (nxt < _max_sub_coro) { + count.atomic(nxt).fetch_add(1u); + }; + } + }; + _resume_shaders[i] = device.compile(resume_kernel, o); + _resume_shaders[i].set_name("resume" + std::to_string(i)); + } + + Kernel1D _prefix_kernel = [&](BufferUInt count, BufferUInt prefix, UInt n) { + $if (dispatch_x() == 0) { + auto pre = def(0u); + for (auto i = 0u; i < _max_sub_coro; ++i) { + auto val = count.read(i); + prefix.write(i, pre); + pre = pre + val; + } + }; + }; + _count_prefix_shader = device.compile(_prefix_kernel); + + Kernel1D _gather_kernel = [&](BufferUInt index, BufferUInt prefix, UInt n) { + auto x = dispatch_x(); + UInt r_id; + if (_config.global_memory_soa) { + r_id = _frame_soa->read_field(x, "target_token") & coro_token_valid_mask; + } else { + auto frame = _frame_buffer->read(x); + r_id = frame.target_token & coro_token_valid_mask; + } + auto q_id = prefix.atomic(r_id).fetch_add(1u); + index.write(q_id, x); + }; + _gather_shader = device.compile(_gather_kernel); + + Kernel1D _compact_kernel_2 = [&](BufferUInt index, UInt empty_offset, UInt n) { + //_global_buffer->write(0u, 0u); + auto x = dispatch_x(); + $if (empty_offset + x < n) { + UInt token; + if (_config.global_memory_soa) { + token = _frame_soa->read_field(empty_offset + x, "target_token"); + } else { + CoroFrame frame = _frame_buffer->read(empty_offset + x); + token = frame.target_token; + } + $if ((token & coro_token_valid_mask) != 0u) { + auto res = _global_buffer->atomic(0u).fetch_add(1u); + auto slot = index.read(res); + if (!_config.gather_by_sorting) { + $while (slot >= empty_offset) { + res = _global_buffer->atomic(0u).fetch_add(1u); + slot = index.read(res); + }; + } + if (_config.global_memory_soa) { + auto frame = _frame_soa->read(empty_offset + x); + _frame_soa->write(slot, frame); + } else { + auto frame = _frame_buffer->read(empty_offset + x); + _frame_buffer->write(slot, frame); + } + if (_config.global_memory_soa) { + _frame_soa->write_field(empty_offset + x, 0u, "target_token"); + } else { + CoroFrame empty_frame = coroutine.instantiate(); + _frame_buffer->write(empty_offset + x, empty_frame); + } + }; + }; + }; + _compact_shader = device.compile(_compact_kernel_2); + _compact_shader.set_name("compact"); + + Kernel1D _initialize_kernel = [&](BufferUInt count, UInt n) { + auto x = dispatch_x(); + $if (x < n) { + CoroFrame frame = coroutine.instantiate(); + if (_config.global_memory_soa) { + _frame_soa->write(x, frame, std::array{0u, 1u}); + } else { + _frame_buffer->write(x, frame); + } + }; + $if (x < _max_sub_coro) { + count.write(x, ite(x == 0u, _config.thread_count, 0u)); + }; + }; + _initialize_shader = device.compile(_initialize_kernel); + + Kernel1D clear = [&](BufferUInt buffer, UInt n) { + auto x = dispatch_x(); + $if (x < n) { + buffer.write(x, 0u); + }; + }; + _clear_shader = device.compile(clear); + } + + [[nodiscard]] bool _all_dispatched() const noexcept { + return _dispatch_counter == _dispatch_size; + } + [[nodiscard]] bool _all_done() const noexcept { + return this->_all_dispatched() && _host_empty; + } + + void _await_all(Stream &stream) noexcept { + while (!this->_all_done()) { + this->_await_step(stream); + } + } + void _await_step(Stream &stream) noexcept { + if (_config.gather_by_sorting) { + auto host_update = [&] { + _host_empty = true; + for (uint i = 0u; i < _max_sub_coro; i++) { + _host_count[i] = (i + 1u == _max_sub_coro ? _config.thread_count : _host_offset[i + 1u]) - _host_offset[i]; + _host_empty = _host_empty && (i == 0u || _host_count[i] == 0u); + } + }; + _sort_token.sort(stream, _temp_key[0], _resume_index, _temp_key[1], + _resume_index, _config.thread_count); + + stream << _sort_temp_storage.hist_buffer.view(0u, _max_sub_coro).copy_to(_host_offset.data()) + << host_update + << synchronize(); + + if (_host_count[0] > _config.thread_count * 0.5f && !this->_all_dispatched()) { + auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); + if (_host_count[0] != _config.thread_count && _config.frame_buffer_compaction) { + stream << _clear_shader(_global_buffer, 1).dispatch(1u); + stream << _compact_shader(_resume_index, _config.thread_count - _host_count[0], _config.thread_count).dispatch(_host_count[0]); + } + stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _dispatch_shape, _config.thread_count - _host_count[0], _dispatch_counter, + _config.thread_count) + .dispatch(gen_count); + _dispatch_counter += gen_count; + _host_empty = false; + } else { + for (uint i = 1; i < _max_sub_coro; i++) { + if (_host_count[i] > 0) { + if (_have_hint[i]) { + BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; + BufferView _key[2] = {_temp_key[1].view(_host_offset[i], _host_count[i]), _temp_key[0].view(_host_offset[i], _host_count[i])}; + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.thread_count) + .dispatch(_host_count[i]); + } else { + stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _config.thread_count) + .dispatch(_host_count[i]); + } + } + } + } + stream << synchronize(); + } else { + stream << _count_prefix_shader(_resume_count, _resume_offset, _max_sub_coro).dispatch(1u); + stream << _gather_shader(_resume_index, _resume_offset, _config.thread_count).dispatch(_config.thread_count); + if (_host_count[0] > _config.thread_count / 2 && !_all_dispatched()) { + auto gen_count = std::min(_dispatch_size - _dispatch_counter, _host_count[0]); + if (_host_count[0] != _config.thread_count && _config.frame_buffer_compaction) { + stream << _clear_shader(_global_buffer, 1).dispatch(1u); + stream + << _compact_shader(_resume_index.view(_host_offset[0], _host_count[0]), + _config.thread_count - _host_count[0], _config.thread_count) + .dispatch(_host_count[0]); + } + stream << _invoke(_gen_shader, _resume_index.view(_host_offset[0], _host_count[0]), + _resume_count, _dispatch_shape, _config.thread_count - _host_count[0], _dispatch_counter, _config.thread_count) + .dispatch(gen_count); + _dispatch_counter += gen_count; + _host_empty = false; + } else { + for (uint i = 1; i < _max_sub_coro; i++) { + if (_host_count[i] > 0) { + if (_have_hint[i]) { + BufferView _index[2] = {_resume_index.view(_host_offset[i], _host_count[i]), _temp_index.view(_host_offset[i], _host_count[i])}; + BufferView _key[2] = {_temp_key[0].view(_host_offset[i], _host_count[i]), _temp_key[1].view(_host_offset[i], _host_count[i])}; + uint out = _sort_hint.sort_switch(stream, _key, _index, _host_count[i], _resume_index.view(_host_offset[i], _host_count[i])); + stream << _invoke(_resume_shaders[i], _index[out], _resume_count, _config.thread_count) + .dispatch(_host_count[i]); + } else { + stream << _invoke(_resume_shaders[i], _resume_index.view(_host_offset[i], _host_count[i]), + _resume_count, _config.thread_count) + .dispatch(_host_count[i]); + } + } + } + } + auto host_update = [&] { + _host_empty = true; + auto sum = 0u; + for (uint i = 0; i < _max_sub_coro; i++) { + _host_offset[i] = sum; + sum += _host_count[i]; + _host_empty = _host_empty && (i == 0 || _host_count[i] == 0); + } + }; + stream << _resume_count.view(0, _max_sub_coro).copy_to(_host_count.data()) + << host_update; + stream << synchronize(); + } + } + +public: + WavefrontCoroScheduler(Device &device, const Coroutine &coro, + const WavefrontCoroSchedulerConfig &config) noexcept { + _create_shader(device, coro, config); + } + WavefrontCoroScheduler(Device &device, const Coroutine &coro) noexcept + : WavefrontCoroScheduler{device, coro, WavefrontCoroSchedulerConfig{}} {} +}; + +template +WavefrontCoroScheduler(Device &, const Coroutine &) + -> WavefrontCoroScheduler; + +template +WavefrontCoroScheduler(Device &, const Coroutine &, + const WavefrontCoroSchedulerConfig &) + -> WavefrontCoroScheduler; + +}// namespace luisa::compute::coroutine \ No newline at end of file diff --git a/include/luisa/coro/shader_scheduler.h b/include/luisa/coro/shader_scheduler.h deleted file mode 100644 index e69de29bb..000000000 diff --git a/include/luisa/dsl/builtin.h b/include/luisa/dsl/builtin.h index 00b5e4560..622d762cb 100644 --- a/include/luisa/dsl/builtin.h +++ b/include/luisa/dsl/builtin.h @@ -1771,24 +1771,6 @@ template {LUISA_EXPR(x)})); } -/// Coroutines -/// coroutine initialization -template - requires is_coroframe_struct_v> && - is_dsl_v && - is_same_expr_v -inline void initialize_coroframe(F &&frame, V &&coro_id) noexcept { - detail::FunctionBuilder::current()->initialize_coroframe(LUISA_EXPR(frame), LUISA_EXPR(coro_id)); -} - -template - requires is_coroframe_struct_v> -[[nodiscard]] inline auto make_coroframe(Expr coro_id) noexcept { - Var f; - initialize_coroframe(f, coro_id); - return f; -} - [[nodiscard]] inline auto coro_id() noexcept { return def(detail::FunctionBuilder::current()->coro_id()); } @@ -1797,14 +1779,6 @@ template return def(detail::FunctionBuilder::current()->coro_token()); } -template -inline auto read_promise(T &&t, S &&name) noexcept { - return def(detail::FunctionBuilder::current()->read_promise_( - Type::of(), - detail::extract_expression(std::forward(t)), - std::forward(name))); -} - template [[nodiscard]] inline auto grad(const Local &x) noexcept { auto b = detail::FunctionBuilder::current(); diff --git a/include/luisa/dsl/func.h b/include/luisa/dsl/func.h index bd51738be..840332e64 100644 --- a/include/luisa/dsl/func.h +++ b/include/luisa/dsl/func.h @@ -15,7 +15,6 @@ #include #include #include -#include namespace luisa::compute { @@ -150,9 +149,6 @@ class FunctionBuilder; [[nodiscard]] LC_DSL_API luisa::shared_ptr transform_function(Function callable) noexcept; -[[nodiscard]] LC_DSL_API luisa::shared_ptr -transform_coroutine(Type *corotype, coro::CoroGraph &graph, luisa::unordered_map> &sub_builders, Function callable) noexcept; - }// namespace detail template @@ -306,6 +302,14 @@ class CallableInvoke { public: CallableInvoke() noexcept = default; + /// Add an argument + CallableInvoke &operator<<(const Expression *expr) noexcept { + if (_arg_count == max_argument_count) [[unlikely]] { + _error_too_many_arguments(); + } + _args[_arg_count++] = expr; + return *this; + } /// Add an argument. template CallableInvoke &operator<<(Expr arg) noexcept { @@ -314,10 +318,7 @@ class CallableInvoke { } else if constexpr (is_soa_expr_v) { callable_encode_soa(*this, arg); } else { - if (_arg_count == max_argument_count) [[unlikely]] { - _error_too_many_arguments(); - } - _args[_arg_count++] = arg.expression(); + *this << arg.expression(); } return *this; } @@ -330,18 +331,7 @@ class CallableInvoke { return luisa::span{_args.data(), _arg_count}; } }; -class CoroutineInvoke : public ShaderInvokeBase { - CoroutineInvoke(uint64_t handle, size_t arg_count, size_t uniform_size) noexcept - : ShaderInvokeBase{handle, arg_count, uniform_size} {} - [[nodiscard]] auto dispatch(uint size_x) && noexcept { - return this->_parallelize(uint3{size_x, 1u, 1u}); - } - [[nodiscard]] auto dispatch(const IndirectDispatchBuffer &indirect_buffer, - uint32_t offset = 0, - uint32_t max_dispatch_size = std::numeric_limits::max()) && noexcept { - return this->_parallelize(indirect_buffer, offset, max_dispatch_size).build(); - } -}; + }// namespace detail /// Callable class. Callable is not allowed, unless T is a function type. @@ -357,8 +347,6 @@ struct is_callable> : std::true_type {}; template class Callable { friend class CallableLibrary; - template - friend class Coroutine; static_assert( std::negation_v, @@ -372,7 +360,9 @@ class Callable { public: explicit Callable(luisa::shared_ptr builder) noexcept - : _builder{std::move(builder)} {} + : _builder{std::move(builder)} { + // TODO: check arguments + } /** * @brief Construct a Callable object. * @@ -512,96 +502,6 @@ class ExternalCallable { } }; -/// Coroutine class. Coroutine is not allowed, unless T is a function type. -template -class Coroutine { - static_assert(always_false_v); -}; - -template -struct is_callable> : std::true_type {}; -template -class Coroutine { - static_assert(std::negation_v...>>); - -private: - using CoroID = uint; - static_assert(std::is_lvalue_reference_v); - static_assert(is_coroframe_struct_v>>); - luisa::shared_ptr _builder; - luisa::unordered_map> _sub_callables; - uint _uniform_size; - luisa::unordered_map _coro_tokens; - coro::CoroGraph _coro_graph; -public: - template - requires std::negation_v>> && - std::negation_v>> - Coroutine(Def &&f) noexcept : _coro_graph{0, Type::of>()} { - auto ast = detail::FunctionBuilder::define_coroutine([&f] { - static_assert(std::is_invocable_v, detail::prototype_to_creation_t...>); - auto create = [](auto &&def, std::index_sequence) noexcept { - using arg_tuple = std::tuple; - using var_tuple = std::tuple>, - Var>...>; - using tag_tuple = std::tuple, detail::prototype_to_creation_tag_t...>; - - auto args = detail::create_argument_definitions(std::tuple<>{}); - static_assert(std::tuple_size_v == 1 + sizeof...(Args)); - return luisa::invoke(std::forward(def), - static_cast> &&>(std::get(args))...); - }; - //static_assert(std::is_same_v, "coroutine should return void"); - create(std::forward(f), std::index_sequence_for{}); - detail::FunctionBuilder::current()->return_(nullptr);// to check if any previous $return called with non-void types - }); - _coro_tokens = ast->coro_tokens(); - Type *frame = const_cast(Type::of>()); - auto sub_builder = luisa::unordered_map>{}; - _builder = detail::transform_coroutine(frame, _coro_graph, sub_builder, ast->function()); - _sub_callables.clear(); - _uniform_size = ShaderDispatchCmdEncoder::compute_uniform_size(_builder->function().unbound_arguments()); - for (auto v : sub_builder) { - _sub_callables.insert(std::make_pair(v.first, Callable(v.second))); - } - } - - /// Get the underlying AST - [[nodiscard]] auto function() const noexcept { return Function{_builder.get()}; } - [[nodiscard]] auto const &function_builder() const & noexcept { return _builder; } - [[nodiscard]] auto &&function_builder() && noexcept { return std::move(_builder); } - [[nodiscard]] auto const suspend_count() noexcept { return _coro_tokens.size(); } - [[nodiscard]] auto const & coro_tokens() const & noexcept { return _coro_tokens; } - [[nodiscard]] auto const &graph() noexcept { return _coro_graph; } - //Call from start of coroutine - auto operator()(detail::prototype_to_callable_invocation_t type, - detail::prototype_to_callable_invocation_t... args) const noexcept { - - detail::CallableInvoke invoke; - static_cast((invoke << type)); - static_cast((invoke << ... << args)); - detail::FunctionBuilder::current()->call( - _builder->function(), invoke.args()); - } - auto operator()(detail::prototype_to_shader_invocation_t... args) const noexcept { - using invoke_type = detail::CoroutineInvoke; - auto arg_count = (0u + ... + detail::shader_argument_encode_count::value); - invoke_type invoke{arg_count, _uniform_size}; - static_cast((invoke << ... << args)); - return invoke; - } - //Get Callable from certain suspend point - auto operator[](CoroID index) const noexcept { - auto builder = _sub_callables.find(index); - LUISA_ASSERT(builder != _sub_callables.end(), "coroutine index out of range"); - return builder->second; - } - auto operator[](luisa::string_view index) const noexcept { - auto coro_token = function_builder()->coro_tokens().at(index); - return (*this)[coro_token]; - } -}; namespace detail { template @@ -695,11 +595,6 @@ struct dsl_function> { using type = T; }; -template -struct dsl_function> { - using type = T; -}; - template using dsl_function_t = typename dsl_function::type; @@ -717,9 +612,6 @@ Kernel3D(T &&) -> Kernel3D>>; template Callable(T &&) -> Callable>>; -template -Coroutine(T &&) -> Coroutine>>; - namespace detail { struct CallableOutliner { diff --git a/include/luisa/dsl/resource.h b/include/luisa/dsl/resource.h index 7dae9aecc..6efa9cafa 100644 --- a/include/luisa/dsl/resource.h +++ b/include/luisa/dsl/resource.h @@ -96,20 +96,43 @@ struct Expr> [[nodiscard]] auto operator->() const noexcept { return this; } }; +/// Same as Expr> +template +struct Expr> : public Expr> { + using Expr>::Expr; +}; + template<> struct Expr { private: const RefExpr *_expression{nullptr}; + public: /// Construct from RefExpr explicit Expr(const RefExpr *expr) noexcept : _expression{expr} {} - /// Construct from BufferView. Will call buffer_binding() to bind buffer - Expr(const ByteBuffer &buffer) noexcept + /// Construct from ByteBufferView. Will call buffer_binding() to bind buffer + Expr(const ByteBufferView &buffer) noexcept : _expression{detail::FunctionBuilder::current()->buffer_binding( Type::of(), buffer.handle(), - 0u, buffer.size_bytes())} {} + buffer.offset(), buffer.size_bytes())} {} + + /// Contruct from ByteBuffer. Will call buffer_binding() to bind buffer + Expr(const ByteBuffer &buffer) noexcept + : Expr{ByteBufferView{buffer}} {} + + /// Construct from Var. + template + requires std::same_as + Expr(const Var &buffer) noexcept + : Expr{buffer.expression()} {} + + /// Construct from Var. + template + requires std::same_as + Expr(const Var &buffer) noexcept + : Expr{buffer.expression()} {} /// Return RefExpr [[nodiscard]] const RefExpr *expression() const noexcept { return _expression; } @@ -139,12 +162,6 @@ struct Expr { } }; -/// Same as Expr> -template -struct Expr> : public Expr> { - using Expr>::Expr; -}; - /// Class of Expr> template struct Expr> { @@ -531,28 +548,29 @@ class BufferExprProxy { } }; +template class ByteBufferExprProxy { private: - ByteBuffer _buffer; + BufferOrView _buffer; public: LUISA_RESOURCE_PROXY_AVOID_CONSTRUCTION(ByteBufferExprProxy) public: - template + template requires is_integral_expr_v [[nodiscard]] auto read(I &&index) const noexcept { - return Expr{_buffer}.template read(std::forward(index)); + return Expr{_buffer}.template read(std::forward(index)); } template requires is_integral_expr_v void write(I &&index, V &&value) const noexcept { - Expr{_buffer}.write(std::forward(index), - std::forward(value)); + Expr{_buffer}.write(std::forward(index), + std::forward(value)); } [[nodiscard]] Expr device_address() const noexcept { - return Expr{_buffer}.device_address(); + return Expr{_buffer}.device_address(); } }; @@ -644,6 +662,14 @@ struct Var> : public Expr> { Var(const Var &) noexcept = delete; }; +template +struct Var> : public Expr> { + explicit Var(detail::ArgumentCreation) noexcept + : Expr>{detail::FunctionBuilder::current()->buffer(Type::of>())} {} + Var(Var &&) noexcept = default; + Var(const Var &) noexcept = delete; +}; + template<> struct Var : public Expr { explicit Var(detail::ArgumentCreation) noexcept @@ -652,10 +678,10 @@ struct Var : public Expr { Var(const Var &) noexcept = delete; }; -template -struct Var> : public Expr> { +template<> +struct Var : public Expr { explicit Var(detail::ArgumentCreation) noexcept - : Expr>{detail::FunctionBuilder::current()->buffer(Type::of>())} {} + : Expr{detail::FunctionBuilder::current()->buffer(Type::of())} {} Var(Var &&) noexcept = default; Var(const Var &) noexcept = delete; }; @@ -704,6 +730,8 @@ struct Var : public Expr { template using BufferVar = Var>; +using ByteBufferVar = Var; + template using ImageVar = Var>; diff --git a/include/luisa/dsl/stmt.h b/include/luisa/dsl/stmt.h index d24e6dd5b..0b166ad9d 100644 --- a/include/luisa/dsl/stmt.h +++ b/include/luisa/dsl/stmt.h @@ -494,6 +494,13 @@ inline auto suspend(S &&first, Args &&...args) noexcept { return compute::dsl::suspend(""sv, std::forward(first), std::forward(args)...); } +template +inline void promise(luisa::string name, T &&value) noexcept { + detail::FunctionBuilder::current()->bind_promise_( + detail::extract_expression(std::forward(value)), + std::move(name)); +} + inline void return_() noexcept { detail::FunctionBuilder::current()->return_(); } diff --git a/include/luisa/dsl/struct.h b/include/luisa/dsl/struct.h index 54bb27922..264064654 100644 --- a/include/luisa/dsl/struct.h +++ b/include/luisa/dsl/struct.h @@ -333,366 +333,3 @@ struct luisa_compute_extension {}; } \ template<> \ struct luisa_compute_extension final : luisa::compute::detail::Ref -#define LUISA_COROFRAME_STRUCT_EXT(S) \ - template<> \ - struct luisa_compute_extension; \ - namespace luisa::compute { \ - template<> \ - struct Expr { \ - private: \ - using this_type = S; \ - const Expression *_expression; \ - \ - public: \ - explicit Expr(const Expression *e) noexcept \ - : _expression{e} {} \ - [[nodiscard]] auto expression() const noexcept { return this->_expression; } \ - Expr(Expr &&another) noexcept = default; \ - Expr(const Expr &another) noexcept = default; \ - Expr &operator=(Expr) noexcept = delete; \ - template \ - [[nodiscard]] auto promise(luisa::string_view name) const noexcept { \ - return luisa::compute::dsl::def( \ - luisa::compute::detail::FunctionBuilder::current()->read_promise_( \ - luisa::compute::Type::of(), this->_expression, name)); \ - } \ - [[nodiscard]] auto target_token() const noexcept { \ - return this->promise("coro_token"); \ - } \ - [[nodiscard]] auto coro_id() const noexcept { \ - return this->promise("coro_id"); \ - } \ - }; \ - namespace detail { \ - template<> \ - struct Ref { \ - private: \ - using this_type = S; \ - const Expression *_expression; \ - \ - public: \ - explicit Ref(const Expression *e) noexcept \ - : _expression{e} {} \ - [[nodiscard]] auto expression() const noexcept { return this->_expression; } \ - Ref(Ref &&another) noexcept = default; \ - Ref(const Ref &another) noexcept = default; \ - [[nodiscard]] operator Expr() const noexcept { \ - return Expr{this->expression()}; \ - } \ - template \ - void operator=(Rhs &&rhs) & noexcept { dsl::assign(*this, std::forward(rhs)); } \ - void operator=(Ref rhs) & noexcept { (*this) = Expr{rhs}; } \ - [[nodiscard]] auto operator->() noexcept { \ - return reinterpret_cast *>(this); \ - } \ - [[nodiscard]] auto operator->() const noexcept { \ - return reinterpret_cast *>(this); \ - } \ - template \ - [[nodiscard]] auto promise(luisa::string_view name) const noexcept { \ - return luisa::compute::dsl::def( \ - luisa::compute::detail::FunctionBuilder::current()->read_promise_( \ - luisa::compute::Type::of(), this->_expression, name)); \ - } \ - [[nodiscard]] auto target_token() const noexcept { \ - return this->promise("coro_token"); \ - } \ - [[nodiscard]] auto coro_id() const noexcept { \ - return this->promise("coro_id"); \ - } \ - }; \ - } \ - } - -#define LUISA_DERIVE_COROFRAME_SOA(S) \ - namespace luisa::compute { \ - template<> \ - class SOA; \ - template<> \ - class SOAView; \ - \ - template<> \ - struct Expr> { \ - private: \ - Expr _buffer; \ - Expr _soa_offset; \ - Expr _soa_size; \ - Expr _element_offset; \ - luisa::vector> _member_offsets; \ - public: \ - Expr(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size, \ - Expr element_offset) noexcept \ - : _buffer{buffer}, \ - _soa_offset{soa_offset}, \ - _soa_size{soa_size}, \ - _element_offset{element_offset} { \ - auto type = Type::of() -> corotype(); \ - auto tot_size = dsl::def(soa_offset); \ - for (auto &mem : type->members()) { \ - _member_offsets.push_back(dsl::def(tot_size)); \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto member_size = soa_size * stride; \ - member_size = align_to_soa_cache_line(member_size); \ - tot_size += member_size; \ - } \ - } \ - [[nodiscard]] auto buffer() const noexcept { return _buffer; } \ - [[nodiscard]] auto soa_offset() const noexcept { return _soa_offset; } \ - [[nodiscard]] auto soa_size() const noexcept { return _soa_size; } \ - [[nodiscard]] auto element_offset() const noexcept { return _element_offset; } \ - private: \ - using this_type = S; \ - \ - public: \ - \ - Expr(SOAView soa) noexcept; \ - \ - Expr(const SOA &soa) noexcept; \ - \ - template \ - [[nodiscard]] auto read_coro_id(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - constexpr auto i = 0u; \ - auto type = Type::of(); \ - uint stride = ((type->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto id = dsl::def(std::forward(index)); \ - auto data = builder->call( \ - type, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - auto v = builder->local(type); \ - builder->assign(v, data); \ - return Var{v}; \ - } \ - \ - template \ - [[nodiscard]] auto read_coro_token(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - constexpr auto i = 1u; \ - auto type = Type::of(); \ - uint stride = ((type->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto id = dsl::def(std::forward(index)); \ - auto data = builder->call( \ - type, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - auto v = builder->local(type); \ - builder->assign(v, data); \ - return Var{v}; \ - } \ - \ - template \ - [[nodiscard]] auto read(I &&index) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto ret = builder->local(Type::of()); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto i = 0u; i < type->members().size(); ++i) { \ - auto &mem = type->members()[i]; \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto data = builder->call( \ - mem, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - builder->assign(builder->member(mem, ret, i), data); \ - member_index += mem->size(); \ - } \ - return Var{ret}; \ - } \ - template \ - [[nodiscard]] auto read(I &&index, luisa::vector members) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto ret = builder->local(Type::of()); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto r_mem = 0u; r_mem < members.size(); ++r_mem) { \ - auto i = members[r_mem]; \ - auto &mem = type->members()[i]; \ - uint stride = ((mem->size() + sizeof(uint) - 1u) / sizeof(uint)); \ - auto data = builder->call( \ - mem, CallOp::BYTE_BUFFER_READ, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint))}); \ - builder->assign(builder->member(mem, ret, i), data); \ - member_index += mem->size(); \ - } \ - return Var{ret}; \ - } \ - template \ - void write(I &&index, Expr value) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto i = 0u; i < type->members().size(); ++i) { \ - auto &mem = type->members()[i]; \ - auto stride = (((uint)mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint)); \ - builder->call(CallOp::BYTE_BUFFER_WRITE, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint)), \ - builder->member(mem, value.expression(), i)}); \ - member_index += mem->size(); \ - } \ - } \ - template \ - void write(I &&index, Expr value, luisa::vector members) const noexcept { \ - auto builder = detail::FunctionBuilder::current(); \ - auto type = Type::of() -> corotype(); \ - auto member_index = 0u; \ - auto id = dsl::def(std::forward(index)); \ - for (auto w_mem = 0u; w_mem < members.size(); ++w_mem) { \ - auto i = members[w_mem]; \ - auto &mem = type->members()[i]; \ - auto stride = (((uint)mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint)); \ - builder->call(CallOp::BYTE_BUFFER_WRITE, \ - {_buffer.expression(), \ - detail::extract_expression((_member_offsets[i] + \ - (id + _element_offset) * stride) * \ - (uint)sizeof(uint)), \ - builder->member(mem, value.expression(), i)}); \ - member_index += mem->size(); \ - } \ - } \ - [[nodiscard]] auto operator->() const noexcept { return this; } \ - }; \ - template<> \ - class SOAView { \ - protected: \ - ByteBuffer *_bufferview{nullptr}; \ - private: \ - using View = SOAView; \ - \ - private: \ - uint _soa_offset{}; \ - uint _soa_size{}; \ - uint _elem_offset{}; \ - uint _elem_size{}; \ - public: \ - static constexpr auto element_stride = 0; \ - SOAView() noexcept = default; \ - SOAView(ByteBuffer *buffer, \ - size_t soa_offset, \ - size_t soa_size, \ - size_t elem_offset, \ - size_t elem_size) noexcept \ - : _bufferview{buffer}, \ - _soa_offset{static_cast(soa_offset)}, \ - _soa_size{static_cast(soa_size)}, \ - _elem_offset{static_cast(elem_offset)}, \ - _elem_size{static_cast(elem_size)} {} \ - [[nodiscard]] auto &buffer() const noexcept { return *_bufferview; } \ - [[nodiscard]] auto bufferview() const noexcept { return _bufferview; } \ - [[nodiscard]] auto soa_offset() const noexcept { return _soa_offset; } \ - [[nodiscard]] auto soa_size() const noexcept { return _soa_size; } \ - [[nodiscard]] auto element_offset() const noexcept { return _elem_offset; } \ - [[nodiscard]] auto element_size() const noexcept { return _elem_size; } \ - [[nodiscard]] auto operator->() const noexcept { return Expr{*this}; } \ - [[nodiscard]] auto subview(size_t offset, size_t size) const noexcept { \ - if (!(offset + size <= this->element_size())) [[unlikely]] { \ - LUISA_ERROR_WITH_LOCATION("SOAView::subview out of range."); \ - } \ - return View{this->bufferview(), \ - this->soa_offset(), \ - this->soa_size(), \ - this->element_offset() + offset, \ - size}; \ - } \ - \ - public: \ - [[nodiscard]] static auto compute_soa_size(auto soa_size) noexcept { \ - auto type = Type::of() -> corotype(); \ - size_t tot_size = 0u; \ - for (auto mem : type->members()) { \ - tot_size += align_to_soa_cache_line(soa_size * \ - ((mem->size() + (uint)sizeof(uint) - 1u) / (uint)sizeof(uint))); \ - } \ - return tot_size; \ - } \ - }; \ - template<> \ - class SOA : public SOAView { \ - \ - private: \ - ByteBuffer _buffer; \ - \ - private: \ - SOA(ByteBuffer buffer, size_t size) noexcept \ - : _buffer{std::move(buffer)}, \ - SOAView{&buffer, 0u, size, 0u, size} { \ - this->_bufferview = &_buffer; \ - } \ - \ - public: \ - SOA(SOA &&x) noexcept { \ - *this = std::move(x); \ - this->_bufferview = &_buffer; \ - } \ - SOA &operator=(SOA &&x) noexcept { \ - *this = std::move(x); \ - this->_bufferview = &_buffer; \ - return *this; \ - } \ - SOA() noexcept = default; \ - SOA(Device &device, size_t elem_count) noexcept \ - : SOA{device.create_byte_buffer(SOAView::compute_soa_size(elem_count) * sizeof(uint)), \ - elem_count} {} \ - [[nodiscard]] auto view() const noexcept { return SOAView{*this}; } \ - }; \ - Expr>::Expr(SOAView soa) noexcept \ - : Expr{*soa.bufferview(), soa.soa_offset(), soa.soa_size(), soa.element_offset()} {} \ - \ - Expr>::Expr(const SOA &soa) noexcept \ - : Expr{soa.view()} {} \ - \ - template<> \ - struct Var> : public Expr> { \ - \ - private: \ - using Base = Expr>; \ - \ - Var(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size, \ - Expr elem_offset) noexcept \ - : Base{buffer, soa_offset, soa_size, elem_offset} {} \ - \ - Var(Expr buffer, \ - Expr soa_offset, \ - Expr soa_size) noexcept \ - : Var{buffer, soa_offset, soa_size, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - Var(Expr buffer, \ - Expr soa_offset) noexcept \ - : Var{buffer, soa_offset, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - Var(Expr buffer) noexcept \ - : Var{buffer, \ - Var{detail::ArgumentCreation{}}} {} \ - \ - public: \ - Var(detail::ArgumentCreation) noexcept \ - : Var{Var{detail::ArgumentCreation{}}} {} \ - }; \ - }// namespace luisa::compute -#define LUISA_COROFRAME_STRUCT(S) \ - LUISA_COROFRAME_STRUCT_REFLECT(S, #S); \ - LUISA_COROFRAME_STRUCT_EXT(S) \ - LUISA_DERIVE_COROFRAME_SOA(S) \ - template<> \ - struct luisa_compute_extension final : luisa::compute::detail::Ref \ No newline at end of file diff --git a/include/luisa/dsl/sugar.h b/include/luisa/dsl/sugar.h index 45bce5d26..c0efbf1f2 100644 --- a/include/luisa/dsl/sugar.h +++ b/include/luisa/dsl/sugar.h @@ -117,9 +117,19 @@ namespace luisa::compute::dsl_detail { ::luisa::compute::detail::comment( \ ::luisa::compute::dsl_detail::format_source_location( \ __FILE__, __LINE__)); \ - return ::luisa::compute::suspend(__VA_ARGS__); \ + return ::luisa::compute::dsl::suspend(__VA_ARGS__); \ }()) +#define $promise(...) ::luisa::compute::dsl::promise(__VA_ARGS__) + +#define $yield(...) \ + do { \ + ::luisa::compute::dsl::promise("__yielded_value", __VA_ARGS__); \ + ::luisa::compute::dsl::suspend(); \ + } while (false) + +#define $await ::luisa::compute::coroutine::detail::CoroAwaitInvoker{} % + #define $loop \ ::luisa::compute::detail::LoopStmtBuilder::create_with_comment( \ ::luisa::compute::dsl_detail::format_source_location(__FILE__, __LINE__)) % \ diff --git a/include/luisa/luisa-compute.h b/include/luisa/luisa-compute.h index 975781c26..1eb0a4745 100644 --- a/include/luisa/luisa-compute.h +++ b/include/luisa/luisa-compute.h @@ -46,8 +46,23 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + #ifdef LUISA_ENABLE_DSL #include #include @@ -108,6 +123,7 @@ #include #include #include +#include #include #include #include diff --git a/include/luisa/runtime/byte_buffer.h b/include/luisa/runtime/byte_buffer.h index 146f70009..68f34c61b 100644 --- a/include/luisa/runtime/byte_buffer.h +++ b/include/luisa/runtime/byte_buffer.h @@ -5,9 +5,20 @@ namespace luisa::compute { namespace detail { +LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; +template class ByteBufferExprProxy; }// namespace detail +template +class SOA; + +namespace coroutine { +class CoroFrame; +}// namespace coroutine + +class ByteBufferView; + class LC_RUNTIME_API ByteBuffer final : public Resource { private: @@ -16,6 +27,7 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { private: friend class Device; friend class ResourceGenerator; + friend class SOA; ByteBuffer(DeviceInterface *device, const BufferCreationInfo &info) noexcept; ByteBuffer(DeviceInterface *device, size_t size_bytes) noexcept; @@ -30,6 +42,7 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { return *this; } ByteBuffer &operator=(ByteBuffer const &) noexcept = delete; + [[nodiscard]] ByteBufferView view() const noexcept; using Resource::operator bool; [[nodiscard]] auto copy_to(void *data) const noexcept { _check_is_valid(); @@ -70,14 +83,71 @@ class LC_RUNTIME_API ByteBuffer final : public Resource { // DSL interface [[nodiscard]] auto operator->() const noexcept { _check_is_valid(); - return reinterpret_cast(this); + return reinterpret_cast *>(this); + } +}; + +class ByteBufferView { + +private: + void *_native_handle; + uint64_t _handle; + size_t _offset_bytes; + size_t _size; + size_t _total_size; + +private: + friend class ByteBuffer; + +public: + ByteBufferView(void *native_handle, uint64_t handle, + size_t offset_bytes, + size_t size, size_t total_size) noexcept + : _native_handle{native_handle}, _handle{handle}, _offset_bytes{offset_bytes}, + _size{size}, _total_size{total_size} {} + + ByteBufferView(const ByteBuffer &buffer) noexcept : ByteBufferView{buffer.view()} {} + ByteBufferView(const ByteBufferView &) noexcept = default; + ByteBufferView(ByteBufferView &&) noexcept = default; + ByteBufferView() noexcept : ByteBufferView{nullptr, invalid_resource_handle, 0u, 0u, 0u} {} + [[nodiscard]] explicit operator bool() const noexcept { return _handle != invalid_resource_handle; } + + [[nodiscard]] auto handle() const noexcept { return _handle; } + [[nodiscard]] auto native_handle() const noexcept { return _native_handle; } + [[nodiscard]] auto offset() const noexcept { return _offset_bytes; } + [[nodiscard]] auto size_bytes() const noexcept { return _size; } + [[nodiscard]] auto total_size() const noexcept { return _total_size; } + [[nodiscard]] auto original() const noexcept { + return ByteBufferView{_native_handle, _handle, 0u, _total_size, _total_size}; + } + [[nodiscard]] auto subview(size_t offset_elements, size_t size_elements) const noexcept { + if (size_elements + offset_elements > _size) [[unlikely]] { + detail::error_buffer_subview_overflow(offset_elements, size_elements, _size); + } + return ByteBufferView{_native_handle, _handle, _offset_bytes + offset_elements, size_elements, _total_size}; + } + // reinterpret cast buffer to another type U + template + requires(!is_custom_struct_v) + [[nodiscard]] auto as() const noexcept { + if (this->size_bytes() < sizeof(U)) [[unlikely]] { + detail::error_buffer_reinterpret_size_too_small(sizeof(U), this->size_bytes()); + } + auto total_size_bytes = _total_size; + return BufferView{_native_handle, _handle, sizeof(U), _offset_bytes, + this->size_bytes() / sizeof(U), total_size_bytes / sizeof(U)}; + } + // DSL interface + [[nodiscard]] auto operator->() const noexcept { + return reinterpret_cast *>(this); } }; namespace detail { -LC_RUNTIME_API void error_buffer_size_not_aligned(size_t align) noexcept; + template<> struct is_buffer_impl : std::true_type {}; + }// namespace detail }// namespace luisa::compute diff --git a/include/luisa/runtime/device.h b/include/luisa/runtime/device.h index 4c966f5d9..1d24f1861 100644 --- a/include/luisa/runtime/device.h +++ b/include/luisa/runtime/device.h @@ -7,6 +7,8 @@ #include #include +#include + namespace luisa { class BinaryIO; }// namespace luisa @@ -29,6 +31,11 @@ class SparseBufferHeap; class SparseTextureHeap; class ByteBuffer; +namespace coroutine { +class CoroFrameDesc; +class CoroFrame; +}// namespace coroutine + template class SOA; @@ -222,11 +229,25 @@ class LC_RUNTIME_API Device { [[nodiscard]] ByteBuffer import_external_byte_buffer(void *external_memory, size_t byte_size) noexcept; template - requires(!is_custom_struct_v)//backend-specific type not allowed + requires(!is_custom_struct_v) /* backend-specific types not allowed */ && + (!std::same_as, coroutine::CoroFrame>) [[nodiscard]] auto create_buffer(size_t size) noexcept { return _create>(size); } + template + requires std::same_as, luisa::shared_ptr> || + std::same_as, const coroutine::CoroFrameDesc *> + [[nodiscard]] auto create_coro_frame_buffer(Desc &&desc, size_t size) noexcept { + return _create>(std::forward(desc), size); + } + + template + requires std::same_as, coroutine::CoroFrame> + [[nodiscard]] auto create_buffer(Desc &&desc, size_t size) noexcept { + return create_coro_frame_buffer(std::forward(desc), size); + } + template requires(!is_custom_struct_v)//backend-specific type not allowed [[nodiscard]] auto import_external_buffer(void *external_memory, size_t elem_count) noexcept { @@ -239,6 +260,14 @@ class LC_RUNTIME_API Device { return SOA{*this, size}; } + template + requires(std::same_as, luisa::shared_ptr> || + std::same_as, const coroutine::CoroFrameDesc *>) && + std::same_as + [[nodiscard]] auto create_soa(Desc &&desc, size_t size) noexcept { + return SOA{impl(), std::forward(desc), size}; + } + template requires(!is_custom_struct_v)//backend-specific type not allowed [[nodiscard]] auto create_sparse_buffer(size_t size) noexcept { diff --git a/include/luisa/runtime/shader.h b/include/luisa/runtime/shader.h index 51cb0490a..1d3537c08 100644 --- a/include/luisa/runtime/shader.h +++ b/include/luisa/runtime/shader.h @@ -322,15 +322,6 @@ class Shader final : public ShaderBase { return invoke; } - template - [[nodiscard]] auto partial_invoke(Pre... args) const noexcept { - _check_is_valid(); - using invoke_type = detail::ShaderInvoke; - auto arg_count = (0u + ... + detail::shader_argument_encode_count::value); - invoke_type invoke{handle(), arg_count, _uniform_size}; - static_cast((invoke << ... << args)); - return invoke; - } [[nodiscard]] uint3 block_size() const noexcept { _check_is_valid(); return make_uint3(_block_size[0], _block_size[1], _block_size[2]); diff --git a/include/luisa/rust/ir.hpp b/include/luisa/rust/ir.hpp index 82fb07636..557654e7d 100644 --- a/include/luisa/rust/ir.hpp +++ b/include/luisa/rust/ir.hpp @@ -314,6 +314,7 @@ struct CallableModule { CBoxedSlice captures; CBoxedSlice subroutines; CBoxedSlice subroutine_ids; + CBoxedSlice coro_target_tokens; CBoxedSlice coro_frame_input_fields; CBoxedSlice coro_frame_output_fields; CBoxedSlice coro_frame_designated_fields; diff --git a/scripts/print_hlsl_builtin.lua b/scripts/print_hlsl_builtin.lua index d6f6b0932..6688ef65e 100644 --- a/scripts/print_hlsl_builtin.lua +++ b/scripts/print_hlsl_builtin.lua @@ -5,7 +5,7 @@ local files_list = {'accel_process', 'bindless_upload', 'bc6_encode_block', 'bc6 'bc6_trymode_le10cs', 'bc7_encode_block', 'bc7_header', 'bc7_trymode_02cs', 'bc7_trymode_137cs', 'bc7_trymode_456cs', 'hlsl_header', 'raytracing_header', 'tex2d_bindless', 'tex3d_bindless', 'compute_quad', 'determinant', 'inverse', 'indirect', 'resource_size', 'accel_header', 'copy_sign', - 'bindless_common', 'auto_diff', "reduce", "coroutine"} + 'bindless_common', 'auto_diff', "reduce"} local lib = import("lib") local hlsl_builtin_path = path.join(os.projectdir(), "src/backends/common/hlsl/builtin") diff --git a/src/ast/ast2json.cpp b/src/ast/ast2json.cpp index cb3e34d17..dbb0a3a85 100644 --- a/src/ast/ast2json.cpp +++ b/src/ast/ast2json.cpp @@ -588,9 +588,6 @@ class AST2JSON { private: [[nodiscard]] uint _type_index(const Type *type) noexcept { - if (type != nullptr && type->is_materialized_coroframe()) { - return _type_index(type->corotype()); - } if (auto iter = _type_to_index.find(type); iter != _type_to_index.end()) { return iter->second; @@ -632,7 +629,6 @@ class AST2JSON { t["element"] = _type_index(type->element()); break; } - case Type::Tag::COROFRAME: [[fallthrough]]; case Type::Tag::CUSTOM: { t["id"] = type->description(); break; diff --git a/src/ast/function_builder.cpp b/src/ast/function_builder.cpp index a42371253..085a5079c 100644 --- a/src/ast/function_builder.cpp +++ b/src/ast/function_builder.cpp @@ -133,24 +133,6 @@ void FunctionBuilder::bind_promise_(const Expression *expr, luisa::string name) _create_and_append_statement(expr, std::move(name)); } -const MemberExpr *FunctionBuilder::read_promise_(const Type *type, const Expression *expr, luisa::string_view name) noexcept { - LUISA_ASSERT(expr->type()->is_coroframe(), "Promise reading is only allowed for CoroFrame type"); - auto var = expr->type()->member(name); - if (var != -1) { - auto mem_type = expr->type()->corotype()->members()[var]; - LUISA_ASSERT(*mem_type == *type, - "Promise '{}' type mismatch: expected {}, got {}.", - name, type->description(), mem_type->description()); - return _create_expression(expr->type()->corotype()->members()[var], expr, var); - } - return nullptr; -} -void FunctionBuilder::initialize_coroframe(const luisa::compute::Expression *expr, const luisa::compute::Expression *coro_id) noexcept { - auto member_coro_id = member(Type::of(), expr, 0u); - auto member_coro_token = member(Type::of(), expr, 1u); - assign(member_coro_id, coro_id); - assign(member_coro_token, literal(Type::of(), 0u)); -} const CallExpr *FunctionBuilder::coro_id() noexcept { check_is_coroutine(); return call(Type::of(), CallOp::CORO_ID, {}); @@ -898,11 +880,6 @@ bool FunctionBuilder::requires_autodiff() const noexcept { return _propagated_builtin_callables.uses_autodiff(); } -void FunctionBuilder::coroframe_replace(const Type *type) noexcept { - LUISA_ASSERT(_arguments.size() > 0, "Lack of parameter for coroutine generated callables!"); - _arguments[0]._type = type; -} - bool FunctionBuilder::requires_printing() const noexcept { return _requires_printing; } diff --git a/src/ast/type.cpp b/src/ast/type.cpp index ceed43f72..222895081 100644 --- a/src/ast/type.cpp +++ b/src/ast/type.cpp @@ -39,7 +39,6 @@ struct TypeImpl final : public Type { uint index{}; luisa::string description; luisa::vector members; - luisa::unordered_map member_names; luisa::vector member_attributes; }; @@ -107,7 +106,6 @@ class LC_AST_API TypeRegistry { [[nodiscard]] const Type *decode_type(luisa::string_view desc) noexcept; /// Construct custom type [[nodiscard]] const Type *custom_type(luisa::string_view desc) noexcept; - [[nodiscard]] const Type *coroframe_type(luisa::string_view desc) noexcept; /// Return type count [[nodiscard]] size_t type_count() const noexcept; /// Traverse all types using visitor @@ -165,85 +163,6 @@ const Type *TypeRegistry::custom_type(luisa::string_view name) noexcept { return _register(t); } -const Type *TypeRegistry::coroframe_type(luisa::string_view name) noexcept { - // validate name - LUISA_ASSERT(!name.empty() && - name != "void" && - name != "int" && - name != "uint" && - name != "short" && - name != "ushort" && - name != "long" && - name != "ulong" && - name != "float" && - name != "half" && - name != "double" && - name != "bool" && - !name.starts_with("vector<") && - !name.starts_with("matrix<") && - !name.starts_with("array<") && - !name.starts_with("struct<") && - !name.starts_with("buffer<") && - !name.starts_with("texture<") && - name != "accel" && - name != "bindless_array" && - !isdigit(name.front() /* already checked not empty */), - "Invalid custom type name: {}", name); - LUISA_ASSERT(std::all_of(name.cbegin(), name.cend(), - [](char c) { return isalnum(c) || c == '_'; }), - "Invalid custom type name: {}", name); - std::lock_guard lock{_mutex}; - auto h = _compute_hash(name); - if (auto iter = _type_set.find(TypeDescAndHash{name, h}); - iter != _type_set.end()) { return *iter; } - - auto t = _type_pool.create(); - t->hash = h; - t->tag = Type::Tag::COROFRAME; - t->size = Type::custom_struct_size; - t->alignment = Type::custom_struct_alignment; - t->dimension = 1u; - t->description = name; - return _register(t); -} - -void _update(Type *dst, const Type *src) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto dst_inst = static_cast(dst); - auto src_inst = static_cast(src); - dst_inst->alignment = src_inst->alignment; - dst_inst->size = src_inst->size; - LUISA_ASSERT(dst_inst->members.empty(), "{} used as coroframe type " - "for second time!", - dst_inst->description); - dst_inst->members.push_back(src); - //dst_inst->description = src_inst->description; -} - -size_t _add_member(Type *type, const luisa::string &name) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto inst = static_cast(type); - size_t id = inst->member_names.size(); - auto ret = inst->member_names.insert(std::make_pair(name, id)); - return ret.second ? id : -1; -} - -void _set_member_name(Type *type, size_t index, luisa::string name) { - std::scoped_lock lock{TypeRegistry::instance().mutex()}; - auto inst = static_cast(type); - auto old_iter = inst->member_names.find(name); - if (old_iter != inst->member_names.end() && - old_iter->second != index) [[unlikely]] { - LUISA_ERROR_WITH_LOCATION( - "Duplicate member name at index {}: {}.", - index, name); - } - LUISA_ASSERT(index < inst->members.size() || index < inst->corotype()->members().size(), - "Invalid member index: {} (total = {}).", - index, std::max(inst->members.size(), inst->corotype()->members().size())); - inst->member_names.insert_or_assign(std::move(name), index); -} - size_t TypeRegistry::type_count() const noexcept { std::lock_guard lock{_mutex}; return _types.size(); @@ -533,7 +452,7 @@ const TypeImpl *TypeRegistry::_decode(luisa::string_view desc) noexcept { }// namespace detail luisa::span Type::members() const noexcept { - LUISA_ASSERT(is_structure() || (is_coroframe()), + LUISA_ASSERT(is_structure(), "Calling members() on a non-structure type {}.", description()); return static_cast(this)->members; @@ -554,16 +473,6 @@ const Type *Type::element() const noexcept { return static_cast(this)->members.front(); } -const Type *Type::corotype() const noexcept { - LUISA_ASSERT(is_coroframe(), - "Calling corotype() on a non-coroframe type {}.", - description()); - LUISA_ASSERT(!static_cast(this)->members.empty(), - "Calling corotype() on a coroframe before analyze.\n" - "Define Coroutine with this coroframe to specify the backend type!"); - return static_cast(this)->members.front(); -} - const Type *Type::from(std::string_view description) noexcept { return detail::TypeRegistry::instance().decode_type(description); } @@ -610,10 +519,8 @@ uint64_t Type::hash() const noexcept { } size_t Type::size() const noexcept { - LUISA_ASSERT(!(is_coroframe() && members().empty()), - "Cannot find size of {}. " - "Usages of CoroFrame types should be " - "after the coroutine definition!", + LUISA_ASSERT(!is_custom(), + "Unknown size for custom type {}.", description()); return static_cast(this)->size; } @@ -689,11 +596,6 @@ bool Type::is_texture() const noexcept { return tag() == Tag::TEXTURE; } bool Type::is_bindless_array() const noexcept { return tag() == Tag::BINDLESS_ARRAY; } bool Type::is_accel() const noexcept { return tag() == Tag::ACCEL; } bool Type::is_custom() const noexcept { return tag() == Tag::CUSTOM; } -bool Type::is_coroframe() const noexcept { return tag() == Tag::COROFRAME; } -bool Type::is_materialized_coroframe() const noexcept { - return tag() == Tag::COROFRAME && - !reinterpret_cast(this)->members.empty(); -} const Type *Type::array(const Type *elem, size_t n) noexcept { return from(luisa::format("array<{},{}>", elem->description(), n)); @@ -801,38 +703,6 @@ const Type *Type::custom(luisa::string_view name) noexcept { return detail::TypeRegistry::instance().custom_type(name); } -const Type *Type::coroframe(luisa::string_view name) noexcept { - return detail::TypeRegistry::instance().coroframe_type(name); -} - -void Type::update_from(const Type *type) { - detail::_update(this, type); -} - -size_t Type::add_member(const luisa::string &name) noexcept { - LUISA_ASSERT(name != "coro_id" && name != "coro_token", - "{} is a reserved name for coroframe type.", name); - return detail::_add_member(this, name); -} - -void Type::set_member_name(size_t index, luisa::string name) noexcept { - LUISA_ASSERT(name != "coro_id" && name != "coro_token", - "{} is a reserved name for coroframe type.", name); - detail::_set_member_name(this, index, std::move(name)); -} - -size_t Type::member(luisa::string_view name) const noexcept { - if (name == "coro_id") return 0; - if (name == "coro_token") return 1u; - auto &map = static_cast(this)->member_names; - auto it = map.find(name); - if (it == map.end()) { - return -1; - } else { - return it->second; - } -} - bool Type::is_bool() const noexcept { return tag() == Tag::BOOL; } bool Type::is_int32() const noexcept { return tag() == Tag::INT32; } bool Type::is_uint32() const noexcept { return tag() == Tag::UINT32; } diff --git a/src/backends/common/hlsl/builtin/coroutine b/src/backends/common/hlsl/builtin/coroutine deleted file mode 100644 index e9cc25f62..000000000 --- a/src/backends/common/hlsl/builtin/coroutine +++ /dev/null @@ -1,9 +0,0 @@ -template -uint3 CoroId(in T frame){ -return frame.v0.v; -} - -template -uint CoroToken(in T frame){ -return frame.v1; -} \ No newline at end of file diff --git a/src/backends/common/hlsl/builtin/coroutine.c b/src/backends/common/hlsl/builtin/coroutine.c deleted file mode 100644 index 6f4494609..000000000 --- a/src/backends/common/hlsl/builtin/coroutine.c +++ /dev/null @@ -1,4 +0,0 @@ -#include "hlsl_config.h" -LC_HLSL_EXTERN char coroutine[]={116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,51,32,67,111,114,111,73,100,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,48,46,118,59,10,125,10,10,116,101,109,112,108,97,116,101,60,116,121,112,101,110,97,109,101,32,84,62,10,117,105,110,116,32,67,111,114,111,84,111,107,101,110,40,105,110,32,84,32,102,114,97,109,101,41,123,10,114,101,116,117,114,110,32,102,114,97,109,101,46,118,49,59,10,125}; -LC_HLSL_EXTERN char *get_coroutine(){return coroutine;} -LC_HLSL_EXTERN int get_coroutine_size(){return 136;} diff --git a/src/backends/common/hlsl/hlsl_codegen_util.cpp b/src/backends/common/hlsl/hlsl_codegen_util.cpp index 624c06f1d..5cba86d17 100644 --- a/src/backends/common/hlsl/hlsl_codegen_util.cpp +++ b/src/backends/common/hlsl/hlsl_codegen_util.cpp @@ -117,10 +117,6 @@ static size_t AddHeader(CallOpSet const &ops, vstd::StringBuilder &builder, bool ops.test(CallOp::MATRIX_COMPONENT_WISE_MULTIPLICATION)) { builder << CodegenUtility::ReadInternalHLSLFile("reduce"); } - if (ops.test(CallOp::CORO_ID) || - ops.test(CallOp::CORO_TOKEN)) { - builder << CodegenUtility::ReadInternalHLSLFile("coroutine"); - } return immutable_size; } }// namespace detail @@ -381,10 +377,6 @@ void CodegenUtility::GetTypeName(Type const &type, vstd::StringBuilder &str, Usa case Type::Tag::CUSTOM: { str << '_' << type.description(); } break; - case Type::Tag::COROFRAME: { - auto customType = opt->CreateStruct(type.corotype()); - str << customType; - } break; default: LUISA_ERROR_WITH_LOCATION("Bad."); break; @@ -2106,7 +2098,7 @@ uint4 dsp_c; LUISA_ERROR("Arguments binding size: {} exceeds 64 32-bit units not supported by hardware device. Try to use bindless instead.", bind_count); } else if (bind_count > 16) [[unlikely]] { if (!rootsig_exceed_warned.exchange(true)) { - LUISA_WARNING("Arguments binding size exceeds 16 32-bit unit (max 64 allowed). This may cause extra performance cost, try to use bindless instead."); + LUISA_WARNING("Arguments binding size: {} exceeds 16 32-bit unit (max 64 allowed). This may cause extra performance cost, try to use bindless instead.", bind_count); } } return { diff --git a/src/backends/common/shader_print_formatter.h b/src/backends/common/shader_print_formatter.h index 96b39508b..a849d0661 100644 --- a/src/backends/common/shader_print_formatter.h +++ b/src/backends/common/shader_print_formatter.h @@ -117,8 +117,7 @@ class ShaderPrintFormatter { } s.push_back('>'); commit_s(); - } else if (arg->is_structure() || arg->is_coroframe()) { - if (arg->is_coroframe()) { arg = arg->corotype(); } + } else if (arg->is_structure()) { s.push_back('{'); commit_s(); for (auto i = 0u; i < arg->members().size(); i++) { diff --git a/src/backends/cuda/cuda_codegen_ast.cpp b/src/backends/cuda/cuda_codegen_ast.cpp index bccfe92b9..0cfd52b7b 100644 --- a/src/backends/cuda/cuda_codegen_ast.cpp +++ b/src/backends/cuda/cuda_codegen_ast.cpp @@ -1528,8 +1528,6 @@ static void collect_types_in_function(Function f, for (auto m : t->members()) { self(self, m); } - } else if (t->is_coroframe()) { - self(self, t->corotype()); } } }; @@ -1763,10 +1761,6 @@ void CUDACodegenAST::_emit_type_name(const Type *type, bool hack_float_to_int) n } break; } - case Type::Tag::COROFRAME: { - _emit_type_name(type->corotype()); - break; - } default: break; } } diff --git a/src/backends/metal/metal_codegen_ast.cpp b/src/backends/metal/metal_codegen_ast.cpp index 6fc31145c..f3336369e 100644 --- a/src/backends/metal/metal_codegen_ast.cpp +++ b/src/backends/metal/metal_codegen_ast.cpp @@ -137,8 +137,6 @@ static void collect_types_in_function(Function f, for (auto m : t->members()) { self(self, m); } - } else if (t->is_coroframe()) { - self(self, t->corotype()); } } }; @@ -347,9 +345,6 @@ void MetalCodegenAST::_emit_type_name(const Type *type, Usage usage) noexcept { } break; } - case Type::Tag::COROFRAME: - _emit_type_name(type->corotype()); - break; } } diff --git a/src/coro/CMakeLists.txt b/src/coro/CMakeLists.txt index 5c6645170..5743d6d23 100644 --- a/src/coro/CMakeLists.txt +++ b/src/coro/CMakeLists.txt @@ -1,11 +1,20 @@ set(LUISA_COMPUTE_CORO_SOURCES - shader_scheduler.cpp - coro_dispatcher.cpp + coro_frame.cpp + coro_frame_buffer.cpp + coro_frame_desc.cpp + coro_frame_smem.cpp + coro_func.cpp coro_graph.cpp - coro_node.cpp) + schedulers/persistent_threads.cpp + schedulers/state_machine.cpp + schedulers/wavefront.cpp) add_library(luisa-compute-coro SHARED ${LUISA_COMPUTE_CORO_SOURCES}) -target_link_libraries(luisa-compute-coro PUBLIC luisa-compute-ast luisa-compute-ir) +target_link_libraries(luisa-compute-coro PUBLIC + luisa-compute-ast + luisa-compute-ir + luisa-compute-dsl + luisa-compute-runtime) target_precompile_headers(luisa-compute-coro PRIVATE pch.h) target_compile_definitions(luisa-compute-coro PRIVATE LC_CORO_EXPORT_DLL=1) diff --git a/src/coro/coro_dispatcher.cpp b/src/coro/coro_dispatcher.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/coro/coro_frame.cpp b/src/coro/coro_frame.cpp new file mode 100644 index 000000000..d23aeca5a --- /dev/null +++ b/src/coro/coro_frame.cpp @@ -0,0 +1,71 @@ +// +// Created by mike on 5/7/24. +// + +#include +#include +#include +#include + +namespace luisa::compute::coroutine { + +CoroFrame::CoroFrame(luisa::shared_ptr desc, const Expression *expr) noexcept + : _desc{std::move(desc)}, + _expression{expr}, + coro_id{this->get(0u)}, + target_token{this->get(1u)} { + LUISA_ASSERT(expr != nullptr, "CoroFrame expression must not be null."); + LUISA_ASSERT(expr->type() == _desc->type(), "CoroFrame expression type mismatch."); +} + +CoroFrame::CoroFrame(CoroFrame &&another) noexcept + : CoroFrame{std::move(another._desc), another._expression} { + another._expression = nullptr; +} + +CoroFrame::CoroFrame(const CoroFrame &another) noexcept + : CoroFrame{another._desc, [e = another.expression()]() noexcept { + auto fb = detail::FunctionBuilder::current(); + auto copy = fb->local(e->type()); + fb->assign(copy, e); + return copy; + }()} {} + +CoroFrame &CoroFrame::operator=(const CoroFrame &rhs) noexcept { + if (this == std::addressof(rhs)) { return *this; } + LUISA_ASSERT(this->description() == rhs.description(), "CoroFrame description mismatch."); + auto fb = detail::FunctionBuilder::current(); + fb->assign(_expression, rhs.expression()); + return *this; +} + +CoroFrame &CoroFrame::operator=(CoroFrame &&rhs) noexcept { + return *this = static_cast(rhs); +} + +CoroFrame CoroFrame::create(luisa::shared_ptr desc) noexcept { + auto fb = detail::FunctionBuilder::current(); + // create an variable for the coro frame + auto expr = fb->local(desc->type()); + // initialize the coro frame members + auto zero_init = fb->call(desc->type(), CallOp::ZERO, {}); + fb->assign(expr, zero_init); + return CoroFrame{std::move(desc), expr}; +} + +CoroFrame CoroFrame::create(luisa::shared_ptr desc, Expr coro_id) noexcept { + auto frame = create(std::move(desc)); + frame.coro_id = coro_id; + return frame; +} + +void CoroFrame::_check_member_index(uint index) const noexcept { + LUISA_ASSERT(index < _desc->type()->members().size(), + "CoroFrame member index out of range."); +} + +Var CoroFrame::is_terminated() const noexcept { + return (target_token & coro_token_terminal) != 0u; +} + +}// namespace luisa::compute::coroutine diff --git a/src/coro/coro_frame_buffer.cpp b/src/coro/coro_frame_buffer.cpp new file mode 100644 index 000000000..9dc4d0d7f --- /dev/null +++ b/src/coro/coro_frame_buffer.cpp @@ -0,0 +1,14 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::detail { +void error_coro_frame_buffer_invalid_element_size(size_t stride, size_t expected) noexcept { + LUISA_ERROR( + "Invalid coroutine frame buffer view element size {} (expected {}).", + stride, expected); +} +}// namespace luisa::compute::detail diff --git a/src/coro/coro_frame_desc.cpp b/src/coro/coro_frame_desc.cpp new file mode 100644 index 000000000..97fa9fcc4 --- /dev/null +++ b/src/coro/coro_frame_desc.cpp @@ -0,0 +1,59 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::coroutine { + +CoroFrameDesc::CoroFrameDesc(const Type *type, DesignatedFieldDict m) noexcept + : _type{type}, _designated_fields{std::move(m)} { + LUISA_ASSERT(_type != nullptr, "CoroFrame underlying type must not be null."); + LUISA_ASSERT(_type->is_structure(), "CoroFrame underlying type must be a structure."); + LUISA_ASSERT(_type->members().size() >= 2u, "CoroFrame underlying type must have at least 2 members (coro_id and target_token)."); + LUISA_ASSERT(_type->members()[0]->is_uint32_vector() && _type->members()[0]->dimension() == 3u, "CoroFrame member 0 (coro_id) must be uint3."); + LUISA_ASSERT(_type->members()[1]->is_uint32(), "CoroFrame member 1 (target_token) must be uint."); + auto member_count = _type->members().size(); + for (auto &&[name, index] : _designated_fields) { + LUISA_ASSERT(name != "coro_id", "CoroFrame designated member name 'coro_id' is reserved."); + LUISA_ASSERT(name != "target_token", "CoroFrame designated member name 'target_token' is reserved."); + LUISA_ASSERT(index != 0, "CoroFrame designated member index 0 is reserved for coro_id."); + LUISA_ASSERT(index != 1, "CoroFrame designated member index 1 is reserved for target_token."); + LUISA_ASSERT(index < member_count, "CoroFrame designated member index out of range."); + } +} + +luisa::shared_ptr CoroFrameDesc::create(const Type *type, DesignatedFieldDict m) noexcept { + return luisa::make_shared(CoroFrameDesc{type, std::move(m)}); +} + +const Type *CoroFrameDesc::field(uint index) const noexcept { + auto fields = _type->members(); + LUISA_ASSERT(index < fields.size(), "CoroFrame field index out of range."); + return fields[index]; +} + +uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept { + if (name == "coro_id") { return 0u; } + if (name == "target_token") { return 1u; } + auto iter = _designated_fields.find(name); + LUISA_ASSERT(iter != _designated_fields.end(), "CoroFrame designated member not found."); + return iter->second; +} + +luisa::string CoroFrameDesc::dump() const noexcept { + luisa::string s; + for (auto i = 0u; i < _type->members().size(); i++) { + s.append(luisa::format(" Field {}: {}\n", i, _type->members()[i]->description())); + } + if (!_designated_fields.empty()) { + s.append("Designated Fields:\n"); + for (auto &&[name, index] : _designated_fields) { + s.append(luisa::format(" {} -> \"{}\"\n", index, name)); + } + } + return s; +} + +}// namespace luisa::compute::coroutine diff --git a/src/coro/coro_frame_smem.cpp b/src/coro/coro_frame_smem.cpp new file mode 100644 index 000000000..3a390d0e3 --- /dev/null +++ b/src/coro/coro_frame_smem.cpp @@ -0,0 +1,7 @@ +#include + +namespace luisa::compute::detail { +void error_coro_frame_smem_subscript_on_soa() noexcept { + LUISA_ERROR_WITH_LOCATION("Subscripting on SOA shared frame is not allowed."); +} +}// namespace luisa::compute::detail diff --git a/src/coro/coro_func.cpp b/src/coro/coro_func.cpp new file mode 100644 index 000000000..fe05e74c0 --- /dev/null +++ b/src/coro/coro_func.cpp @@ -0,0 +1,37 @@ +// +// Created by Mike on 2024/5/8. +// + +#include +#include + +namespace luisa::compute::coroutine::detail { + +void coroutine_chained_await_impl(CoroFrame &frame, uint node_count, + luisa::move_only_function node) noexcept { + node(coro_token_entry, frame); + $while (!frame.is_terminated()) { + $suspend(); + $switch (frame.target_token) { + for (auto i = 1u; i < node_count; i++) { + $case (i) { + node(i, frame); + }; + } + $default { dsl::unreachable(); }; + }; + }; +} + +void coroutine_generator_next_impl( + CoroFrame &frame, uint node_count, + const luisa::move_only_function &resume) noexcept { + $switch (frame.target_token) { + for (auto i = 0u; i < node_count; i++) { + $case (i) { resume(frame, i); }; + } + $default { dsl::unreachable(); }; + }; +} + +}// namespace luisa::compute::coroutine::detail diff --git a/src/coro/coro_graph.cpp b/src/coro/coro_graph.cpp index e57371f17..72980541e 100644 --- a/src/coro/coro_graph.cpp +++ b/src/coro/coro_graph.cpp @@ -1,39 +1,250 @@ +// +// Created by Mike on 2024/5/8. +// + #include +#include +#include +#include #include -namespace luisa::compute::inline coro { +#ifdef LUISA_ENABLE_IR +#include +#endif + +namespace luisa::compute::coroutine { + +CoroGraph::Node::Node(luisa::vector input_fields, + luisa::vector output_fields, + luisa::vector targets, + CC current_continuation) noexcept + : _input_fields{std::move(input_fields)}, + _output_fields{std::move(output_fields)}, + _targets{std::move(targets)}, + _cc{std::move(current_continuation)} {} -const CoroNode *CoroGraph::entry() const noexcept { - return this->node(_entry); +CoroGraph::Node::~Node() noexcept = default; + +Function CoroGraph::Node::cc() const noexcept { return _cc->function(); } + +luisa::string CoroGraph::Node::dump() const noexcept { + luisa::string s; + s.append(" Input Fields: ["); + for (auto i : _input_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_input_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Output Fields: ["); + for (auto i : _output_fields) { + s.append(luisa::format("{}, ", i)); + } + if (!_output_fields.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + s.append(" Transition Targets: ["); + for (auto i : _targets) { + s.append(luisa::format("{}, ", i)); + } + if (!_targets.empty()) { + s.pop_back(); + s.pop_back(); + } + s.append("]\n"); + return s; } -const CoroNode *CoroGraph::node(uint token) const noexcept { +CoroGraph::CoroGraph(luisa::shared_ptr frame_desc, + luisa::unordered_map nodes, + luisa::unordered_map named_tokens) noexcept + : _frame{std::move(frame_desc)}, + _nodes{std::move(nodes)}, + _named_tokens{std::move(named_tokens)} {} + +CoroGraph::~CoroGraph() noexcept = default; + +const CoroGraph::Node &CoroGraph::entry() const noexcept { + return node(coro_token_entry); +} + +const CoroGraph::Node &CoroGraph::node(CoroToken token) const noexcept { auto iter = _nodes.find(token); - LUISA_ASSERT(iter != _nodes.cend(), - "Coroutine node (token = {}) not found.", token); - return &(iter->second); + LUISA_ASSERT(iter != _nodes.end(), + "Coroutine node with token {} not found.", + token); + return iter->second; } -CoroNode *CoroGraph::add_node(uint token, CoroNode::Func f) noexcept { - auto node = CoroNode{this, std::move(f)}; - auto [iter, success] = _nodes.emplace(token, std::move(node)); - LUISA_ASSERT(success, "Coroutine node (token = {}) already exists.", token); - return &(iter->second); +const CoroGraph::Node &CoroGraph::node(luisa::string_view name) const noexcept { + auto iter = _named_tokens.find(name); + LUISA_ASSERT(iter != _named_tokens.end(), + "Coroutine node with name '{}' not found.", + name); + return node(iter->second); } -uint CoroGraph::designated_state_member(luisa::string_view name) const noexcept { - auto iter = _designated_state_members.find(name); - LUISA_ASSERT(iter != _designated_state_members.cend(), - "State member '{}' not designated.", name); - return iter->second; +luisa::string CoroGraph::dump() const noexcept { + luisa::string s; + s.append("Arguments:\n"); + auto args = entry().cc().arguments(); + for (auto i = 0u; i < args.size(); i++) { + s.append(luisa::format(" Argument {}: ", i)); + s.append(args[i].type()->description()); + if (args[i].is_reference()) { s.append(" &"); } + s.append("\n"); + } + s.append("Frame:\n").append(_frame->dump()); + for (auto &&[token, node] : _nodes) { + if (token == coro_token_entry) { + s.append("Entry:\n"); + } else { + s.append(luisa::format("Node {}:\n", token)); + } + s.append(node.dump()); + } + if (!_named_tokens.empty()) { + s.append("Named Tokens:\n"); + for (auto &&[name, token] : _named_tokens) { + s.append(luisa::format(" {} -> \"{}\"\n", token, name)); + } + } + return s; } -CoroGraph::CoroGraph(uint entry, const Type *state_type) noexcept - : _entry{entry}, _state_type{state_type} {} +#ifndef LUISA_ENABLE_IR +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_ERROR_WITH_LOCATION( + "Coroutine requires IR support but " + "LuisaCompute is built without the IR module. " + "This might be caused by missing Rust. " + "Please install the Rust toolchain and " + "recompile LuisaCompute to get the IR module."); +} +#else + +namespace detail { + +static void perform_coroutine_transform(ir::CallableModule *m) noexcept { + auto coroutine_pipeline = ir::luisa_compute_ir_transform_pipeline_new(); + ir::luisa_compute_ir_transform_pipeline_add_transform(coroutine_pipeline, "materialize_coro"); + auto converted_module = ir::luisa_compute_ir_transform_pipeline_transform_callable(coroutine_pipeline, *m); + ir::luisa_compute_ir_transform_pipeline_destroy(coroutine_pipeline); + *m = converted_module; +} -void CoroGraph::designate_state_member(luisa::string name, uint index) noexcept { - auto [iter, success] = _designated_state_members.emplace(name, index); - LUISA_ASSERT(success, "State member '{}' already designated.", name); +[[nodiscard]] static auto make_subroutine_wrapper(Function coroutine, Function cc) noexcept { + using FB = luisa::compute::detail::FunctionBuilder; + return FB::define_callable([&] { + luisa::vector args; + args.reserve(1u /* frame */ + coroutine.arguments().size()); + LUISA_ASSERT(coroutine.arguments().size() == coroutine.bound_arguments().size(), + "Invalid capture list size (expected {}, got {}).", + coroutine.arguments().size(), coroutine.bound_arguments().size()); + auto fb = FB::current(); + args.emplace_back(fb->reference(cc.arguments().front().type())); + for (auto arg_i = 0u; arg_i < coroutine.arguments().size(); arg_i++) { + auto def_arg = coroutine.arguments()[arg_i]; + auto internal_arg = luisa::visit( + [&](T b) noexcept -> const Expression * { + if constexpr (std::is_same_v) { + return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); + } else if constexpr (std::is_same_v) { + return fb->texture_binding(def_arg.type(), b.handle, b.level); + } else if constexpr (std::is_same_v) { + return fb->bindless_array_binding(b.handle); + } else if constexpr (std::is_same_v) { + return fb->accel_binding(b.handle); + } else { + static_assert(std::is_same_v); + switch (def_arg.tag()) { + case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); + case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); + case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); + case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); + case Variable::Tag::ACCEL: return fb->accel(); + default: /* value argument */ return fb->argument(def_arg.type()); + } + } + }, + coroutine.bound_arguments()[arg_i]); + args.emplace_back(internal_arg); + } + LUISA_ASSERT(cc.return_type() == nullptr, + "Coroutine subroutines should not have return type."); + fb->call(cc, args); + }); } -}// namespace luisa::compute::inline coro +}// namespace detail + +luisa::shared_ptr CoroGraph::create(Function coroutine) noexcept { + LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " + "on function with hash {:016x}.", + coroutine.hash()); + + // convert the coroutine function to IR, transform it, and then convert back + auto m = AST2IR::build_coroutine(coroutine); + detail::perform_coroutine_transform(m->get()); + auto entry = IR2AST::build(m->get()); + + // create the coroutine frame descriptor + auto frame = [m, entry] { + auto underlying = entry->arguments().front().type(); + CoroFrameDesc::DesignatedFieldDict members; + for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, + m->get()->coro_frame_designated_fields.len}) { + auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; + if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } + auto [_, success] = members.try_emplace(name, field.index); + LUISA_ASSERT(success, "Duplicated designated field name '{}' at field {}.", name, field.index); + } + return CoroFrameDesc::create(underlying, std::move(members)); + }(); + + // extract the subroutines + auto subroutines = m->get()->subroutines; + auto subroutine_ids = m->get()->subroutine_ids; + LUISA_ASSERT(subroutines.len == subroutine_ids.len, + "Subroutine count mismatch: {} vs {}.", + subroutines.len, subroutine_ids.len); + luisa::unordered_map nodes; + nodes.reserve(subroutines.len + 1u); + auto convert_fields = [](ir::CBoxedSlice slice) noexcept { + luisa::vector fields; + fields.reserve(slice.len); + for (auto i = 0u; i < slice.len; i++) { fields.emplace_back(slice.ptr[i]); } + return fields; + }; + // add the entry node + nodes.emplace( + coro_token_entry, + Node{convert_fields(m->get()->coro_frame_input_fields), + convert_fields(m->get()->coro_frame_output_fields), + convert_fields(m->get()->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, entry->function())}); + // add subroutine nodes + for (auto i = 0u; i < subroutines.len; i++) { + auto s = subroutines.ptr[i]._0.get(); + auto subroutine = IR2AST::build(s); + auto [_, success] = nodes.try_emplace( + subroutine_ids.ptr[i], + Node{convert_fields(s->coro_frame_input_fields), + convert_fields(s->coro_frame_output_fields), + convert_fields(s->coro_target_tokens), + detail::make_subroutine_wrapper(coroutine, subroutine->function())}); + LUISA_ASSERT(success, "Duplicated subroutine token {}.", subroutine_ids.ptr[i]); + } + // create the graph + return luisa::make_shared( + std::move(frame), std::move(nodes), + coroutine.builder()->coro_tokens()); +} + +#endif + +}// namespace luisa::compute::coroutine diff --git a/src/coro/coro_node.cpp b/src/coro/coro_node.cpp deleted file mode 100644 index 499c55d0a..000000000 --- a/src/coro/coro_node.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -namespace luisa::compute::inline coro { - -CoroNode::CoroNode(const CoroGraph *graph, - CoroNode::Func function) noexcept - : _graph{graph}, _function{std::move(function)} {} -}// namespace luisa::compute::inline coro diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp new file mode 100644 index 000000000..5c4f42a8d --- /dev/null +++ b/src/coro/schedulers/persistent_threads.cpp @@ -0,0 +1,219 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include +#include +#include +#include + +namespace luisa::compute::coroutine::detail { + +void persistent_threads_coro_scheduler_main_kernel_impl( + const PersistentThreadsCoroSchedulerConfig &config, + uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, + Expr> global, const Buffer &global_frames, + luisa::move_only_function call_subroutine) noexcept { + + auto subroutine_count = static_cast(graph->nodes().size()); + Shared path_id{shared_queue_size}; + Shared work_counter{subroutine_count}; + Shared work_offset{2u}; + Shared all_token{config.global_memory_ext ? + shared_queue_size + global_queue_size : + shared_queue_size}; + Shared workload{2}; + Shared work_stat{2};// work_state[0] for max_count, [1] for max_id + for (auto index = 0u; index < q_fac; index++) { + auto s = index * config.block_size + thread_x(); + all_token[s] = 0u; + if (config.shared_memory_soa) { + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{0u, 1u}); + } else { + frames[s].coro_id = make_uint3(); + frames[s].target_token = 0u; + } + } + for (auto index = 0u; index < g_fac; index++) { + auto s = index * config.block_size + thread_x(); + all_token[shared_queue_size + s] = 0u; + } + $if (thread_x() < subroutine_count) { + $if (thread_x() == 0u) { + work_counter[thread_x()] = + config.global_memory_ext ? + shared_queue_size + global_queue_size : + shared_queue_size; + } + $else { + work_counter[thread_x()] = 0u; + }; + }; + workload[0] = 0u; + workload[1] = 0u; + Shared rem_global{1}; + Shared rem_local{1}; + rem_global[0] = 1u; + rem_local[0] = 0u; + sync_block(); + auto count = def(0u); + auto count_limit = def(std::numeric_limits::max()); + $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { + sync_block();//very important, synchronize for condition + rem_local[0] = 0u; + count += 1; + work_stat[0] = 0; + work_stat[1] = -1; + sync_block(); + $if (thread_x() == config.block_size - 1) { + $if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload + workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); + workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size_prefix_product.z); + $if (workload[0] >= dispatch_size_prefix_product.z) { + rem_global[0] = 0u; + }; + }; + }; + sync_block(); + $if (thread_x() < subroutine_count) {//get max + $if (workload[0] < workload[1] | thread_x() != 0u) { + $if (work_counter[thread_x()] != 0) { + rem_local[0] = 1u; + work_stat.atomic(0).fetch_max(work_counter[thread_x()]); + }; + }; + }; + sync_block(); + $if (thread_x() < subroutine_count) {//get argmax + $if (work_stat[0] == work_counter[thread_x()] & (workload[0] < workload[1] | thread_x() != 0u)) { + work_stat[1] = thread_x(); + }; + }; + sync_block(); + work_offset[0] = 0; + work_offset[1] = 0; + sync_block(); + if (!config.global_memory_ext) { + for (auto index = 0u; index < q_fac; index++) {//collect indices + auto frame_token = all_token[index * config.block_size + thread_x()]; + $if (frame_token == work_stat[1]) { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * config.block_size + thread_x(); + }; + } + } else { + for (auto index = 0u; index < q_fac; index++) {//collect switch out indices + auto frame_token = all_token[index * config.block_size + thread_x()]; + $if (frame_token != work_stat[1]) { + auto id = work_offset.atomic(0).fetch_add(1u); + path_id[id] = index * config.block_size + thread_x(); + }; + } + sync_block(); + $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work + $for (index, 0u, g_fac) { //swap frames + auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); + auto g_queue_id = index * config.block_size + thread_x(); + auto coro_token = all_token[shared_queue_size + g_queue_id]; + $if (coro_token == work_stat[1]) { + auto id = work_offset.atomic(1).fetch_add(1u); + $if (id < work_offset[0]) { + auto dst = path_id[id]; + auto frame_token = all_token[dst]; + $if (coro_token != 0u) { + auto g_state = global_frames->read(global_id); + $if (frame_token != 0u) { + if (config.shared_memory_soa) { + global_frames->write(global_id, frames.read(dst)); + } else { + global_frames->write(global_id, frames[dst]); + } + }; + if (config.shared_memory_soa) { + frames.write(dst, g_state); + } else { + frames[dst] = g_state; + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + } + $elif (frame_token != 0u) { + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; + }; + }; + }; + }; + }; + } + auto gen_st = workload[0]; + sync_block(); + auto pid = config.global_memory_ext ? thread_x() : path_id[thread_x()]; + $if (config.global_memory_ext ? (all_token[pid] == work_stat[1]) : (thread_x() < work_offset[0])) { + $switch (all_token[pid]) { + $case (0u) { + $if (gen_st + thread_x() < workload[1]) { + work_counter.atomic(0u).fetch_sub(1u); + auto global_index = gen_st + thread_x(); + auto index_z = global_index / dispatch_size_prefix_product.y; + auto index_xy = global_index % dispatch_size_prefix_product.y; + auto index_x = index_xy % dispatch_size_prefix_product.x; + auto index_y = index_xy / dispatch_size_prefix_product.x; + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); + call_subroutine(frame, coro_token_entry); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->entry().output_fields()); + } else { + frames[pid].coro_id = make_uint3(index_x, index_y, index_z); + frames[pid].target_token = coro_token_entry; + call_subroutine(frames[pid], coro_token_entry); + next = frames[pid].target_token & coro_token_valid_mask; + } + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + workload.atomic(0).fetch_add(1u); + }; + }; + for (auto i = 1u; i < subroutine_count; i++) { + $case (i) { + work_counter.atomic(i).fetch_sub(1u); + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = frames.read(pid, graph->node(i).input_fields()); + call_subroutine(frame, i); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->node(i).output_fields()); + } else { + call_subroutine(frames[pid], i); + next = frames[pid].target_token & coro_token_valid_mask; + } + all_token[pid] = next; + work_counter.atomic(next).fetch_add(1u); + }; + } + $default { dsl::unreachable(); }; + }; + }; + sync_block(); + }; +#ifndef NDEBUG + $if (count >= count_limit) { + device_log("block_id{},thread_id {}, loop not break! local:{}, global:{}", block_x(), thread_x(), rem_local[0], rem_global[0]); + $if (thread_x() < subroutine_count) { + device_log("work rem: id {}, size {}", thread_x(), work_counter[thread_x()]); + }; + }; +#endif +} + +}// namespace luisa::compute::coroutine::detail diff --git a/src/coro/schedulers/state_machine.cpp b/src/coro/schedulers/state_machine.cpp new file mode 100644 index 000000000..2a389c0c2 --- /dev/null +++ b/src/coro/schedulers/state_machine.cpp @@ -0,0 +1,47 @@ +// +// Created by Mike on 2024/5/10. +// + +#include +#include + +namespace luisa::compute::coroutine::detail { + +void coro_scheduler_state_machine_impl(CoroFrame &frame, uint state_count, + luisa::move_only_function node) noexcept { + node(coro_token_entry); + $loop { + $switch (frame.target_token) { + for (auto i = 1u; i < state_count; i++) { + $case (i) { node(i); }; + } + $default { $return(); }; + }; + }; +} + +void coro_scheduler_state_machine_smem_impl(Shared &smem, const CoroGraph *graph, + luisa::move_only_function node) noexcept { + auto target_token = def(0u); + auto frame = CoroFrame::create(graph->shared_frame(), dispatch_id()); + node(coro_token_entry, frame); + target_token = frame.target_token; + auto tid = thread_z() * block_size().x * block_size().y + thread_y() * block_size().x + thread_x(); + smem.write(tid, frame, graph->entry().output_fields()); + $loop { + $switch (target_token) { + for (auto i = 1u; i < graph->nodes().size(); i++) { + $case (i) { + auto f = smem.read(tid, graph->node(i).input_fields()); + f.coro_id = dispatch_id(); + node(i, f); + target_token = f.target_token; + smem.write(tid, f, graph->node(i).output_fields()); + }; + } + $default { $return(); }; + }; + }; +} + +}// namespace luisa::compute::coroutine::detail diff --git a/src/coro/schedulers/wavefront.cpp b/src/coro/schedulers/wavefront.cpp new file mode 100644 index 000000000..cc81a9278 --- /dev/null +++ b/src/coro/schedulers/wavefront.cpp @@ -0,0 +1,3 @@ +// +// Created by Mike on 2024/5/10. +// diff --git a/src/coro/shader_scheduler.cpp b/src/coro/shader_scheduler.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dsl/CMakeLists.txt b/src/dsl/CMakeLists.txt index e5bf95b0d..cffae679a 100644 --- a/src/dsl/CMakeLists.txt +++ b/src/dsl/CMakeLists.txt @@ -23,7 +23,7 @@ if (LUISA_COMPUTE_ENABLE_DSL) ${LUISA_COMPUTE_DSL_RASTER_SOURCES}) add_library(luisa-compute-dsl SHARED ${LUISA_COMPUTE_DSL_SOURCES}) - target_link_libraries(luisa-compute-dsl PUBLIC luisa-compute-ast luisa-compute-runtime luisa-compute-coro) + target_link_libraries(luisa-compute-dsl PUBLIC luisa-compute-ast luisa-compute-runtime) target_compile_definitions(luisa-compute-dsl PRIVATE LC_DSL_EXPORT_DLL=1 PUBLIC LUISA_ENABLE_DSL=1) diff --git a/src/dsl/func.cpp b/src/dsl/func.cpp index 1ac3a99ec..21ab66c91 100644 --- a/src/dsl/func.cpp +++ b/src/dsl/func.cpp @@ -70,115 +70,4 @@ transform_function(Function function) noexcept { return function.shared_builder(); } -luisa::shared_ptr transform_coroutine( - Type *corotype, - coro::CoroGraph &graph, - luisa::unordered_map> &sub_builders, - Function function) noexcept { - if (true) { -#ifndef LUISA_ENABLE_IR - LUISA_ERROR_WITH_LOCATION( - "Coroutine requires IR support but " - "LuisaCompute is built without the IR module. " - "This might be caused by missing Rust. " - "Please install the Rust toolchain and " - "recompile LuisaCompute to get the IR module."); -#else - LUISA_VERBOSE_WITH_LOCATION("Performing Coroutine transform " - "on function with hash {:016x}.", - function.hash()); - - auto make_wrapper = [&](const FunctionBuilder *sub) noexcept { - return FunctionBuilder::define_callable([&] { - luisa::vector args; - args.reserve(function.arguments().size()); - LUISA_ASSERT(function.arguments().size() == function.bound_arguments().size(), - "Invalid capture list size (expected {}, got {}).", - function.arguments().size(), function.bound_arguments().size()); - auto fb = FunctionBuilder::current(); - for (auto arg_i = 0u; arg_i < function.arguments().size(); arg_i++) { - auto def_arg = function.arguments()[arg_i]; - auto internal_arg = luisa::visit( - [&](auto b) noexcept -> const Expression * { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return fb->buffer_binding(def_arg.type(), b.handle, b.offset, b.size); - } else if constexpr (std::is_same_v) { - return fb->texture_binding(def_arg.type(), b.handle, b.level); - } else if constexpr (std::is_same_v) { - return fb->bindless_array_binding(b.handle); - } else if constexpr (std::is_same_v) { - return fb->accel_binding(b.handle); - } else { - static_assert(std::is_same_v); - switch (def_arg.tag()) { - case Variable::Tag::REFERENCE: return fb->reference(def_arg.type()); - case Variable::Tag::BUFFER: return fb->buffer(def_arg.type()); - case Variable::Tag::TEXTURE: return fb->texture(def_arg.type()); - case Variable::Tag::BINDLESS_ARRAY: return fb->bindless_array(); - case Variable::Tag::ACCEL: return fb->accel(); - default: /* value argument */ return fb->argument(def_arg.type()); - } - } - }, - function.bound_arguments()[arg_i]); - args.emplace_back(internal_arg); - } - LUISA_ASSERT(sub->return_type() == nullptr, - "Coroutine subroutines should not have return type."); - fb->call(sub->function(), args); - }); - }; - - graph = coro::CoroGraph(0, corotype); - //idea: send in function-> module with .subroutine-> seperate transform to callable-> register to coroutine - luisa::shared_ptr converted; - auto m = AST2IR::build_coroutine(function); - perform_coroutine_transform(m->get()); - converted = IR2AST::build(m->get()); - auto subroutines = m->get()->subroutines; - auto subroutine_ids = m->get()->subroutine_ids; - auto coroframe = corotype; - auto coroframe_new = converted->arguments()[0].type(); - coroframe->update_from(coroframe_new); - for (auto &&field : luisa::span{m->get()->coro_frame_designated_fields.ptr, - m->get()->coro_frame_designated_fields.len}) { - auto name = luisa::string_view{reinterpret_cast(field.name.ptr), field.name.len}; - if (!name.empty() && name.back() == '\0') { name = name.substr(0, name.size() - 1); } - graph.designate_state_member(luisa::string{name}, field.index); - coroframe->set_member_name(field.index, luisa::string{name}); - } - const_cast(converted.get())->coroframe_replace(corotype); - for (int i = 0; i < subroutines.len; ++i) { - auto sub = IR2AST::build(subroutines.ptr[i]._0.get()); - auto wrapper = make_wrapper(sub.get()); - sub_builders.insert(std::make_pair(subroutine_ids.ptr[i], wrapper)); - auto node = graph.add_node(subroutine_ids.ptr[i], wrapper); - for (int mem = 0u; mem < subroutines.ptr[i]._0->coro_frame_input_fields.len; ++mem) { - auto field = subroutines.ptr[i]._0->coro_frame_input_fields.ptr[mem]; - node->input_state_members.push_back(field); - } - for (int mem = 0u; mem < subroutines.ptr[i]._0->coro_frame_output_fields.len; ++mem) { - auto field = subroutines.ptr[i]._0->coro_frame_output_fields.ptr[mem]; - node->output_state_members.push_back(field); - } - } - auto node = graph.add_node(0, make_wrapper(converted.get())); - for (int mem = 0u; mem < m->get()->coro_frame_input_fields.len; ++mem) { - auto field = m->get()->coro_frame_input_fields.ptr[mem]; - node->input_state_members.push_back(field); - } - for (int mem = 0u; mem < m->get()->coro_frame_output_fields.len; ++mem) { - auto field = m->get()->coro_frame_output_fields.ptr[mem]; - node->output_state_members.push_back(field); - } - LUISA_VERBOSE_WITH_LOCATION("Converted IR to AST for " - "kernel with hash {:016x}. " - "Coroutine transform is done.", - function.hash()); - return make_wrapper(converted.get()); -#endif - } - return function.shared_builder(); -} }// namespace luisa::compute::detail diff --git a/src/ext/imgui b/src/ext/imgui index e391fe2e6..8b2c6dd42 160000 --- a/src/ext/imgui +++ b/src/ext/imgui @@ -1 +1 @@ -Subproject commit e391fe2e66eb1c96b1624ae8444dc64c23146ef4 +Subproject commit 8b2c6dd42fb02dc95a581bb8c7db0199064cfced diff --git a/src/ext/magic_enum b/src/ext/magic_enum index 7afc57b19..f34f967c4 160000 --- a/src/ext/magic_enum +++ b/src/ext/magic_enum @@ -1 +1 @@ -Subproject commit 7afc57b194dd08631d5e96e42b217bb52933828f +Subproject commit f34f967c4e70ec60d1c561fbd2fcabaeedce5957 diff --git a/src/ir/ast2ir.cpp b/src/ir/ast2ir.cpp index 77651fb59..e76d8c573 100644 --- a/src/ir/ast2ir.cpp +++ b/src/ir/ast2ir.cpp @@ -33,7 +33,7 @@ namespace luisa::compute { [](ir::CArc *p) noexcept { p->release(); luisa::delete_with_allocator(p); - }}; + }}; } [[nodiscard]] luisa::shared_ptr> AST2IR::build_coroutine(Function function) noexcept { diff --git a/src/runtime/byte_buffer.cpp b/src/runtime/byte_buffer.cpp index 33307129e..10b7911d7 100644 --- a/src/runtime/byte_buffer.cpp +++ b/src/runtime/byte_buffer.cpp @@ -25,24 +25,24 @@ ByteBuffer::ByteBuffer(DeviceInterface *device, size_t size_bytes) noexcept if ((size_bytes & 3) != 0) [[unlikely]] { detail::error_buffer_size_not_aligned(4); } - return device->create_buffer( - Type::of(), - (size_bytes + sizeof(uint) - 1u) / sizeof(uint), - nullptr); + return device->create_buffer(Type::of(), size_bytes, nullptr); }()} {} ByteBuffer::~ByteBuffer() noexcept { if (*this) { device()->destroy_buffer(handle()); } } +ByteBufferView ByteBuffer::view() const noexcept { + _check_is_valid(); + return ByteBufferView{this->native_handle(), this->handle(), 0u, size_bytes(), size_bytes()}; +} + ByteBuffer Device::create_byte_buffer(size_t byte_size) noexcept { return ByteBuffer{impl(), byte_size}; } ByteBuffer Device::import_external_byte_buffer(void *external_memory, size_t byte_size) noexcept { - auto info = impl()->create_buffer(Type::of(), - (byte_size + sizeof(uint) - 1u) / sizeof(uint), - external_memory); + auto info = impl()->create_buffer(Type::of(), byte_size, external_memory); return ByteBuffer{impl(), info}; } diff --git a/src/rust/luisa_compute_ir/src/ast2ir.rs b/src/rust/luisa_compute_ir/src/ast2ir.rs index bc486f294..2569661e9 100644 --- a/src/rust/luisa_compute_ir/src/ast2ir.rs +++ b/src/rust/luisa_compute_ir/src/ast2ir.rs @@ -102,7 +102,7 @@ impl<'a> AST2IRType<'a> { } "BINDLESS_ARRAY" => Type::void(), "ACCEL" => Type::void(), - "COROFRAME" | "CUSTOM" => Type::opaque(j["id"].as_str().unwrap().into()), + "CUSTOM" => Type::opaque(j["id"].as_str().unwrap().into()), _ => panic!("Invalid type tag: {}", tag), }; self.types.insert(i, t.clone()); @@ -2322,6 +2322,7 @@ impl<'a: 'b, 'b> AST2IR<'a, 'b> { captures: CBoxedSlice::new(Vec::new()), subroutine_ids: CBoxedSlice::new(Vec::new()), subroutines: CBoxedSlice::new(Vec::new()), + coro_target_tokens: CBoxedSlice::new(Vec::new()), coro_frame_input_fields: CBoxedSlice::new(Vec::new()), coro_frame_output_fields: CBoxedSlice::new(Vec::new()), coro_frame_designated_fields: CBoxedSlice::new(Vec::new()), diff --git a/src/rust/luisa_compute_ir/src/ir.rs b/src/rust/luisa_compute_ir/src/ir.rs index 38cf9e788..f18f0ed1f 100644 --- a/src/rust/luisa_compute_ir/src/ir.rs +++ b/src/rust/luisa_compute_ir/src/ir.rs @@ -2012,6 +2012,7 @@ pub struct CallableModule { pub captures: CBoxedSlice, pub subroutines: CBoxedSlice, pub subroutine_ids: CBoxedSlice, + pub coro_target_tokens: CBoxedSlice, pub coro_frame_input_fields: CBoxedSlice, pub coro_frame_output_fields: CBoxedSlice, pub coro_frame_designated_fields: CBoxedSlice, @@ -2266,6 +2267,7 @@ impl ModuleDuplicator { captures: dup_captures, subroutines: callable.subroutines.clone(), subroutine_ids: callable.subroutine_ids.clone(), + coro_target_tokens: callable.coro_target_tokens.clone(), coro_frame_input_fields: callable.coro_frame_input_fields.clone(), coro_frame_output_fields: callable.coro_frame_output_fields.clone(), coro_frame_designated_fields: callable.coro_frame_designated_fields.clone(), diff --git a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs index b5c758804..9c9d41eb8 100644 --- a/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs +++ b/src/rust/luisa_compute_ir/src/transform/materialize_coro.rs @@ -73,35 +73,33 @@ impl<'a> CoroScopeMaterializer<'a> { fn create_args(&self) -> Vec { let mut args = Vec::new(); - for (i, arg) in self.coro.args.iter().enumerate() { - if i == 0 { - // the coro frame - let node = new_node( - &self.coro.pools, - Node::new( - CArc::new(Instruction::Argument { by_value: false }), - self.frame.interface_type.clone(), - ), - ); - args.push(node); - } else { - // normal args - let instr = &arg.get().instruction; - match instr.as_ref() { - Instruction::Buffer - | Instruction::Bindless - | Instruction::Texture2D - | Instruction::Texture3D - | Instruction::Accel - | Instruction::Argument { .. } => { - let node = new_node( - &self.coro.pools, - Node::new(instr.clone(), arg.type_().clone()), - ); - args.push(node); - } - _ => unreachable!("Invalid argument type"), + args.reserve(1 /* frame */+ self.coro.args.len()); + // the coro frame + let node = new_node( + &self.coro.pools, + Node::new( + CArc::new(Instruction::Argument { by_value: false }), + self.frame.interface_type.clone(), + ), + ); + args.push(node); + for arg in self.coro.args.iter() { + // normal args + let instr = &arg.get().instruction; + match instr.as_ref() { + Instruction::Buffer + | Instruction::Bindless + | Instruction::Texture2D + | Instruction::Texture3D + | Instruction::Accel + | Instruction::Argument { .. } => { + let node = new_node( + &self.coro.pools, + Node::new(instr.clone(), arg.type_().clone()), + ); + args.push(node); } + _ => unreachable!("Invalid argument type"), } } args @@ -455,9 +453,11 @@ impl<'a> CoroScopeMaterializer<'a> { .iter() .map(|&a| self.value_or_load(a, ctx, state)) .collect(); - let call = state - .builder - .call(Func::External(c.clone()), args.as_slice(), ret.type_().clone()); + let call = state.builder.call( + Func::External(c.clone()), + args.as_slice(), + ret.type_().clone(), + ); process_return!(call) } // replayable but need special handling @@ -990,13 +990,18 @@ impl<'a> CoroScopeMaterializer<'a> { } } + fn collect_target_tokens(&self, scope: CoroScopeRef) -> Vec { + let node = self.frame.transition_graph.nodes.get(&scope).unwrap(); + node.outlets.keys().cloned().collect::<_>() + } + fn materialize(&self) -> CallableModule { let mappings: HashMap<_, _> = self .coro .args .iter() .cloned() - .zip(self.args.iter().cloned()) + .zip(self.args.iter().skip(1).cloned()) .collect(); let mut entry_builder = IrBuilder::new(self.coro.pools.clone()); let mut ctx = CoroScopeMaterializerCtx { @@ -1044,6 +1049,7 @@ impl<'a> CoroScopeMaterializer<'a> { index: i as u32 + designated_filed_offset, }) .collect(); + let target_tokens = self.collect_target_tokens(self.scope); // create the callable module CallableModule { module, @@ -1052,6 +1058,7 @@ impl<'a> CoroScopeMaterializer<'a> { captures: CBoxedSlice::new(Vec::new()), subroutines: CBoxedSlice::new(Vec::new()), subroutine_ids: CBoxedSlice::new(Vec::new()), + coro_target_tokens: CBoxedSlice::new(target_tokens), coro_frame_input_fields: CBoxedSlice::new(in_fields), coro_frame_output_fields: CBoxedSlice::new(out_fields), coro_frame_designated_fields: CBoxedSlice::new(designated_fields), diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 9ece0d5b7..2cb0a8c8d 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -43,7 +43,7 @@ if (LUISA_COMPUTE_ENABLE_RUST) luisa_compute_add_executable(test_ast2ir_ir2ast test_ast2ir_ir2ast.cpp) if (LUISA_COMPUTE_ENABLE_GUI) - luisa_compute_add_executable(test_skyline test_skyline.cpp) +# luisa_compute_add_executable(test_skyline test_skyline.cpp) luisa_compute_add_executable(test_kernel_ir test_kernel_ir.cpp) luisa_compute_add_executable(test_sdf_renderer_ir test_sdf_renderer_ir.cpp) luisa_compute_add_executable(test_path_tracing_ir test_path_tracing_ir.cpp) @@ -198,7 +198,7 @@ endif () # coroutine tests luisa_compute_add_executable(test_coro_sdf_renderer coro/sdf_renderer.cpp) -luisa_compute_add_executable(test_coro_sdf_renderer_wo_dispatcher coro/sdf_renderer_wo_dispatcher.cpp) -luisa_compute_add_executable(test_coro_path_tracing coro/path_tracing.cpp) +luisa_compute_add_executable(test_coro_path_tracing_state_machine coro/path_tracing_state_machine.cpp) +luisa_compute_add_executable(test_coro_path_tracing_wavefront coro/path_tracing_wavefront.cpp) +luisa_compute_add_executable(test_coro_path_tracing_persistent_threads coro/path_tracing_persistent_threads.cpp) luisa_compute_add_executable(test_coro_helloworld coro/helloworld.cpp) -luisa_compute_add_executable(test_coro_playground coro/playground.cpp) diff --git a/src/tests/coro/helloworld.cpp b/src/tests/coro/helloworld.cpp index e99349e6d..3509a6e0d 100644 --- a/src/tests/coro/helloworld.cpp +++ b/src/tests/coro/helloworld.cpp @@ -1,19 +1,14 @@ -#include -#include -#include -#include -#include -#include -#include +#include using namespace luisa; using namespace luisa::compute; -struct alignas(4) CoroFrame { -}; -LUISA_COROFRAME_STRUCT(CoroFrame){}; +namespace luisa::compute::coroutine { + +}// namespace luisa::compute::coroutine int main(int argc, char *argv[]) { + Context context{argv[0]}; if (argc <= 1) { exit(1); } Device device = context.create_device(argv[1]); @@ -22,31 +17,44 @@ int main(int argc, char *argv[]) { Image image{device.create_image(PixelStorage::BYTE4, resolution)}; luisa::vector host_image(image.view().size_bytes()); - Coroutine coro = [&](Var &frame) noexcept { - $suspend("1"); - Var coord = coro_id().xy(); - $suspend("2"); - Var uv = (make_float2(coord) + 0.5f) / make_float2(resolution); - $suspend("3"); - image->write(coord, make_float4(uv, 0.5f, 1.0f)); + Kernel1D test = [] { + coroutine::Generator range = [](UInt n) { + auto x = def(0u); + $while (x < n) { + $yield(x); + x += 1u; + }; + }; + auto iter = range(10u); + $while (!iter.update().is_terminated()) { + auto x = iter.value(); + device_log("x = {}", x); + }; + for (auto x : range(10u)) { + device_log("x = {}", x); + } + }; + auto shader = device.compile(test); + stream << shader().dispatch(1) << synchronize(); + + coroutine::Coroutine nested2 = [](UInt n) { + $for (i, n) { + device_log("nested2: {} / {}", i, n); + $suspend(); + }; }; - auto type = Type::of(); - auto frame_buffer = device.create_buffer(resolution.x * resolution.y); - - Kernel2D mega_kernel = [&] { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame); - coro[1](frame); - coro[2](frame); - coro[3](frame); + + coroutine::Coroutine nested1 = [&](UInt n) { + $for (i, n) { + $await nested2(i); + device_log("nested1: {} / {}", i, n); + }; }; - auto shader = device.compile(mega_kernel); - stream << shader().dispatch(resolution) - << synchronize(); + coroutine::Coroutine top_level = [&]() { + $await nested1(10u); + }; - stream << image.copy_to(host_image.data()) - << synchronize(); - stbi_write_png("test_helloworld.png", resolution.x, resolution.y, 4, host_image.data(), 0); + coroutine::StateMachineCoroScheduler sched{device, top_level}; + stream << sched().dispatch(1u) << synchronize(); } diff --git a/src/tests/coro/path_tracing.cpp b/src/tests/coro/path_tracing_persistent_threads.cpp similarity index 87% rename from src/tests/coro/path_tracing.cpp rename to src/tests/coro/path_tracing_persistent_threads.cpp index 472498f4f..d60356a9f 100644 --- a/src/tests/coro/path_tracing.cpp +++ b/src/tests/coro/path_tracing_persistent_threads.cpp @@ -1,18 +1,9 @@ #include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include +#include #include "../common/cornell_box.h" #define TINYOBJLOADER_IMPLEMENTATION @@ -33,10 +24,6 @@ LUISA_STRUCT(Onb, tangent, binormal, normal) { } }; -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - int main(int argc, char *argv[]) { log_level_verbose(); @@ -174,7 +161,7 @@ int main(int argc, char *argv[]) { auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; - Coroutine raytracing_coro = [&](Var &, ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { UInt2 coord = dispatch_id().xy(); Float frame_size = min(resolution.x, resolution.y).cast(); UInt state = seed_image.read(coord).x; @@ -275,24 +262,8 @@ int main(int argc, char *argv[]) { image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); }; - Kernel2D mega_kernel = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - raytracing_coro(frame, image, seed_image, accel, resolution); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= raytracing_coro.suspend_count(); i++) { - $case (i) { - raytracing_coro[i](frame, image, seed_image, accel, resolution); - }; - } - $default { - $return(); - }; - }; - }; - }; + coroutine::PersistentThreadsCoroSchedulerConfig config{}; + coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { UInt2 p = dispatch_id().xy(); @@ -328,17 +299,16 @@ int main(int argc, char *argv[]) { auto clear_shader = device.compile(clear_kernel, o); auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); auto accumulate_shader = device.compile(accumulate_kernel, o); - auto raytracing_shader = device.compile(mega_kernel, o); auto make_sampler_shader = device.compile(make_sampler_kernel, o); static constexpr uint2 resolution = make_uint2(1024u); Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); luisa::vector> host_image(resolution.x * resolution.y); - CommandList cmd_list; + Image seed_image = device.create_image(PixelStorage::INT1, resolution); - cmd_list << clear_shader(accum_image).dispatch(resolution) - << make_sampler_shader(seed_image).dispatch(resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); Window window{"path tracing", resolution}; Swapchain swap_chain = device.create_swapchain( @@ -357,13 +327,13 @@ int main(int argc, char *argv[]) { Clock clock; while (!window.should_close()) { - cmd_list << raytracing_shader(framebuffer, seed_image, accel, resolution) - .dispatch(resolution) - << accumulate_shader(accum_image, framebuffer) - .dispatch(resolution); - cmd_list << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution); - stream << cmd_list.commit() - << swap_chain.present(ldr_image) << synchronize(); + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); window.poll_events(); double dt = clock.toc() - last_time; last_time = clock.toc(); diff --git a/src/tests/coro/path_tracing_state_machine.cpp b/src/tests/coro/path_tracing_state_machine.cpp new file mode 100644 index 000000000..a5821c266 --- /dev/null +++ b/src/tests/coro/path_tracing_state_machine.cpp @@ -0,0 +1,352 @@ +#include + +#include +#include + +#include +#include "../common/cornell_box.h" + +#define TINYOBJLOADER_IMPLEMENTATION +#include "../common/tiny_obj_loader.h" + +using namespace luisa; +using namespace luisa::compute; + +struct Onb { + float3 tangent; + float3 binormal; + float3 normal; +}; + +LUISA_STRUCT(Onb, tangent, binormal, normal) { + [[nodiscard]] Float3 to_world(Expr v) const noexcept { + return v.x * tangent + v.y * binormal + v.z * normal; + } +}; + +int main(int argc, char *argv[]) { + + log_level_verbose(); + + Context context{argv[0]}; + if (argc <= 1) { + LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); + exit(1); + } + Device device = context.create_device(argv[1]); + + // load the Cornell Box scene + tinyobj::ObjReaderConfig obj_reader_config; + obj_reader_config.triangulate = true; + obj_reader_config.vertex_color = false; + tinyobj::ObjReader obj_reader; + if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { + luisa::string_view error_message = "unknown error."; + if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } + LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); + } + if (auto &&e = obj_reader.Warning(); !e.empty()) { + LUISA_WARNING_WITH_LOCATION("{}", e); + } + + auto &&p = obj_reader.GetAttrib().vertices; + luisa::vector vertices; + vertices.reserve(p.size() / 3u); + for (uint i = 0u; i < p.size(); i += 3u) { + vertices.emplace_back(make_float3( + p[i + 0u], p[i + 1u], p[i + 2u])); + } + LUISA_INFO( + "Loaded mesh with {} shape(s) and {} vertices.", + obj_reader.GetShapes().size(), vertices.size()); + + BindlessArray heap = device.create_bindless_array(); + Stream stream = device.create_stream(StreamTag::GRAPHICS); + Buffer vertex_buffer = device.create_buffer(vertices.size()); + stream << vertex_buffer.copy_from(vertices.data()); + luisa::vector meshes; + luisa::vector> triangle_buffers; + for (auto &&shape : obj_reader.GetShapes()) { + uint index = static_cast(meshes.size()); + std::vector const &t = shape.mesh.indices; + uint triangle_count = t.size() / 3u; + LUISA_INFO( + "Processing shape '{}' at index {} with {} triangle(s).", + shape.name, index, triangle_count); + luisa::vector indices; + indices.reserve(t.size()); + for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } + Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); + Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); + heap.emplace_on_update(index, triangle_buffer); + stream << triangle_buffer.copy_from(indices.data()) + << mesh.build(); + } + + Accel accel = device.create_accel({}); + for (Mesh &m : meshes) { + accel.emplace_back(m, make_float4x4(1.0f)); + } + stream << heap.update() + << accel.build() + << synchronize(); + + Constant materials{ + make_float3(0.725f, 0.710f, 0.680f),// floor + make_float3(0.725f, 0.710f, 0.680f),// ceiling + make_float3(0.725f, 0.710f, 0.680f),// back wall + make_float3(0.140f, 0.450f, 0.091f),// right wall + make_float3(0.630f, 0.065f, 0.050f),// left wall + make_float3(0.725f, 0.710f, 0.680f),// short box + make_float3(0.725f, 0.710f, 0.680f),// tall box + make_float3(0.000f, 0.000f, 0.000f),// light + }; + + Callable linear_to_srgb = [&](Var x) noexcept { + return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, + 12.92f * x, + x <= 0.00031308f)); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + UInt s0 = def(0u); + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { + UInt2 p = dispatch_id().xy(); + UInt state = tea(p.x, p.y); + seed_image.write(p, make_uint4(state)); + }; + + Callable lcg = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state & 0x00ffffffu) * + (1.0f / static_cast(0x01000000u)); + }; + + Callable make_onb = [](const Float3 &normal) noexcept { + Float3 binormal = normalize(ite( + abs(normal.x) > abs(normal.z), + make_float3(-normal.y, normal.x, 0.0f), + make_float3(0.0f, -normal.z, normal.y))); + Float3 tangent = normalize(cross(binormal, normal)); + return def(tangent, binormal, normal); + }; + + Callable generate_ray = [](Float2 p) noexcept { + static constexpr float fov = radians(27.8f); + static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); + Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); + Float3 direction = normalize(pixel - origin); + return make_ray(origin, direction); + }; + + Callable cosine_sample_hemisphere = [](Float2 u) noexcept { + Float r = sqrt(u.x); + Float phi = 2.0f * constants::pi * u.y; + return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); + }; + + Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { + return pdf_a / max(pdf_a + pdf_b, 1e-4f); + }; + + auto spp_per_dispatch = device.backend_name() == "metal" || device.backend_name() == "cpu" ? 1u : 64u; + + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + UInt2 coord = dispatch_id().xy(); + Float frame_size = min(resolution.x, resolution.y).cast(); + UInt state = seed_image.read(coord).x; + Float rx = lcg(state); + Float ry = lcg(state); + Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; + Float3 radiance = def(make_float3(0.0f)); + $suspend("per_spp"); + $for (i, spp_per_dispatch) { + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + $suspend("per_depth"); + $for (depth, 10u) { + // trace + $suspend("before_tracing"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + $suspend("after_tracing"); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; + }; + $break; + }; + + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + }; + + // sample BSDF + $suspend("sample_bsdf"); + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; + }; + }; + $suspend("write_film"); + radiance /= static_cast(spp_per_dispatch); + seed_image.write(coord, make_uint4(state)); + $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; + image.write(dispatch_id().xy(), make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f)); + }; + + coroutine::StateMachineCoroSchedulerConfig config{.block_size = make_uint3(8u, 8u, 1u), + .shared_memory = false, + .shared_memory_soa = true}; + coroutine::StateMachineCoroScheduler scheduler{device, coro, config}; + + Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { + UInt2 p = dispatch_id().xy(); + Float4 accum = accum_image.read(p); + Float3 curr = curr_image.read(p).xyz(); + accum_image.write(p, accum + make_float4(curr, 1.f)); + }; + + Callable aces_tonemapping = [](Float3 x) noexcept { + static constexpr float a = 2.51f; + static constexpr float b = 0.03f; + static constexpr float c = 2.43f; + static constexpr float d = 0.59f; + static constexpr float e = 0.14f; + return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); + }; + + Kernel2D clear_kernel = [](ImageFloat image) noexcept { + image.write(dispatch_id().xy(), make_float4(0.0f)); + }; + + Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { + UInt2 coord = dispatch_id().xy(); + Float4 hdr = hdr_image.read(coord); + Float3 ldr = hdr.xyz() / hdr.w * scale; + $if (!is_hdr) { + ldr = linear_to_srgb(ldr); + }; + ldr_image.write(coord, make_float4(ldr, 1.0f)); + }; + + ShaderOption o{.enable_debug_info = false}; + auto clear_shader = device.compile(clear_kernel, o); + auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); + auto accumulate_shader = device.compile(accumulate_kernel, o); + auto make_sampler_shader = device.compile(make_sampler_kernel, o); + + static constexpr uint2 resolution = make_uint2(1024u); + Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); + Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); + luisa::vector> host_image(resolution.x * resolution.y); + + Image seed_image = device.create_image(PixelStorage::INT1, resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); + + Window window{"path tracing", resolution}; + Swapchain swap_chain = device.create_swapchain( + stream, + SwapchainOption{ + .display = window.native_display(), + .window = window.native_handle(), + .size = make_uint2(resolution), + .wants_hdr = false, + .wants_vsync = false, + .back_buffer_count = 3, + }); + Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); + double last_time = 0.0; + uint frame_count = 0u; + Clock clock; + + while (!window.should_close()) { + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); + window.poll_events(); + double dt = clock.toc() - last_time; + last_time = clock.toc(); + frame_count += spp_per_dispatch; + LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", + frame_count, dt, spp_per_dispatch / dt * 1000); + } + stream + << ldr_image.copy_to(host_image.data()) + << synchronize(); + + LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); + stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); +} diff --git a/src/tests/coro/path_tracing_wavefront.cpp b/src/tests/coro/path_tracing_wavefront.cpp new file mode 100644 index 000000000..c52f68b9e --- /dev/null +++ b/src/tests/coro/path_tracing_wavefront.cpp @@ -0,0 +1,362 @@ +#include + +#include +#include + +#include +#include "../common/cornell_box.h" + +#define TINYOBJLOADER_IMPLEMENTATION +#include "../common/tiny_obj_loader.h" + +using namespace luisa; +using namespace luisa::compute; + +struct Onb { + float3 tangent; + float3 binormal; + float3 normal; +}; + +LUISA_STRUCT(Onb, tangent, binormal, normal) { + [[nodiscard]] Float3 to_world(Expr v) const noexcept { + return v.x * tangent + v.y * binormal + v.z * normal; + } +}; + +int main(int argc, char *argv[]) { + + log_level_verbose(); + + Context context{argv[0]}; + if (argc <= 1) { + LUISA_INFO("Usage: {} . : cuda, dx, cpu, metal", argv[0]); + exit(1); + } + Device device = context.create_device(argv[1]); + + // load the Cornell Box scene + tinyobj::ObjReaderConfig obj_reader_config; + obj_reader_config.triangulate = true; + obj_reader_config.vertex_color = false; + tinyobj::ObjReader obj_reader; + if (!obj_reader.ParseFromString(obj_string, "", obj_reader_config)) { + luisa::string_view error_message = "unknown error."; + if (auto &&e = obj_reader.Error(); !e.empty()) { error_message = e; } + LUISA_ERROR_WITH_LOCATION("Failed to load OBJ file: {}", error_message); + } + if (auto &&e = obj_reader.Warning(); !e.empty()) { + LUISA_WARNING_WITH_LOCATION("{}", e); + } + + auto &&p = obj_reader.GetAttrib().vertices; + luisa::vector vertices; + vertices.reserve(p.size() / 3u); + for (uint i = 0u; i < p.size(); i += 3u) { + vertices.emplace_back(make_float3( + p[i + 0u], p[i + 1u], p[i + 2u])); + } + LUISA_INFO( + "Loaded mesh with {} shape(s) and {} vertices.", + obj_reader.GetShapes().size(), vertices.size()); + + BindlessArray heap = device.create_bindless_array(); + Stream stream = device.create_stream(StreamTag::GRAPHICS); + Buffer vertex_buffer = device.create_buffer(vertices.size()); + stream << vertex_buffer.copy_from(vertices.data()); + luisa::vector meshes; + luisa::vector> triangle_buffers; + for (auto &&shape : obj_reader.GetShapes()) { + uint index = static_cast(meshes.size()); + std::vector const &t = shape.mesh.indices; + uint triangle_count = t.size() / 3u; + LUISA_INFO( + "Processing shape '{}' at index {} with {} triangle(s).", + shape.name, index, triangle_count); + luisa::vector indices; + indices.reserve(t.size()); + for (tinyobj::index_t i : t) { indices.emplace_back(i.vertex_index); } + Buffer &triangle_buffer = triangle_buffers.emplace_back(device.create_buffer(triangle_count)); + Mesh &mesh = meshes.emplace_back(device.create_mesh(vertex_buffer, triangle_buffer)); + heap.emplace_on_update(index, triangle_buffer); + stream << triangle_buffer.copy_from(indices.data()) + << mesh.build(); + } + + Accel accel = device.create_accel({}); + for (Mesh &m : meshes) { + accel.emplace_back(m, make_float4x4(1.0f)); + } + stream << heap.update() + << accel.build() + << synchronize(); + + Constant materials{ + make_float3(0.725f, 0.710f, 0.680f),// floor + make_float3(0.725f, 0.710f, 0.680f),// ceiling + make_float3(0.725f, 0.710f, 0.680f),// back wall + make_float3(0.140f, 0.450f, 0.091f),// right wall + make_float3(0.630f, 0.065f, 0.050f),// left wall + make_float3(0.725f, 0.710f, 0.680f),// short box + make_float3(0.725f, 0.710f, 0.680f),// tall box + make_float3(0.000f, 0.000f, 0.000f),// light + }; + + Callable linear_to_srgb = [&](Var x) noexcept { + return saturate(select(1.055f * pow(x, 1.0f / 2.4f) - 0.055f, + 12.92f * x, + x <= 0.00031308f)); + }; + + Callable tea = [](UInt v0, UInt v1) noexcept { + UInt s0 = def(0u); + for (uint n = 0u; n < 4u; n++) { + s0 += 0x9e3779b9u; + v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); + v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); + } + return v0; + }; + + Kernel2D make_sampler_kernel = [&](ImageUInt seed_image) noexcept { + UInt2 p = dispatch_id().xy(); + UInt state = tea(p.x, p.y); + seed_image.write(p, make_uint4(state)); + }; + + Callable lcg = [](UInt &state) noexcept { + constexpr uint lcg_a = 1664525u; + constexpr uint lcg_c = 1013904223u; + state = lcg_a * state + lcg_c; + return cast(state & 0x00ffffffu) * + (1.0f / static_cast(0x01000000u)); + }; + + Callable make_onb = [](const Float3 &normal) noexcept { + Float3 binormal = normalize(ite( + abs(normal.x) > abs(normal.z), + make_float3(-normal.y, normal.x, 0.0f), + make_float3(0.0f, -normal.z, normal.y))); + Float3 tangent = normalize(cross(binormal, normal)); + return def(tangent, binormal, normal); + }; + + Callable generate_ray = [](Float2 p) noexcept { + static constexpr float fov = radians(27.8f); + static constexpr float3 origin = make_float3(-0.01f, 0.995f, 5.0f); + Float3 pixel = origin + make_float3(p * tan(0.5f * fov), -1.0f); + Float3 direction = normalize(pixel - origin); + return make_ray(origin, direction); + }; + + Callable cosine_sample_hemisphere = [](Float2 u) noexcept { + Float r = sqrt(u.x); + Float phi = 2.0f * constants::pi * u.y; + return make_float3(r * cos(phi), r * sin(phi), sqrt(1.0f - u.x)); + }; + + Callable balanced_heuristic = [](Float pdf_a, Float pdf_b) noexcept { + return pdf_a / max(pdf_a + pdf_b, 1e-4f); + }; + + auto spp_per_dispatch = 16u; + + coroutine::Coroutine coro = [&](ImageFloat image, ImageUInt seed_image, AccelVar accel, UInt2 resolution) noexcept { + UInt2 coord = dispatch_id().xy(); + + Float frame_size = min(resolution.x, resolution.y).cast(); + UInt state = seed_image.read(coord).x; + Float rx = lcg(state); + Float ry = lcg(state); + Float2 pixel = (make_float2(coord) + make_float2(rx, ry)) / frame_size * 2.0f - 1.0f; + Float3 radiance = def(make_float3(0.0f)); + + // $suspend("per_dispatch"); + + // $if (all(coord == make_uint2(50u, 500u))) { + // device_log("coord: {}", coord); + // }; + + Var ray = generate_ray(pixel * make_float2(1.0f, -1.0f)); + Float3 beta = def(make_float3(1.0f)); + Float pdf_bsdf = def(0.0f); + constexpr float3 light_position = make_float3(-0.24f, 1.98f, 0.16f); + constexpr float3 light_u = make_float3(-0.24f, 1.98f, -0.22f) - light_position; + constexpr float3 light_v = make_float3(0.23f, 1.98f, 0.16f) - light_position; + constexpr float3 light_emission = make_float3(17.0f, 12.0f, 4.0f); + Float light_area = length(cross(light_u, light_v)); + Float3 light_normal = normalize(cross(light_u, light_v)); + + $for (depth, 10u) { + // trace + $suspend("intersect"); + Var hit = accel.intersect(ray, {}); + reorder_shader_execution(); + $if (hit->miss()) { $break; }; + Var triangle = heap->buffer(hit.inst).read(hit.prim); + Float3 p0 = vertex_buffer->read(triangle.i0); + Float3 p1 = vertex_buffer->read(triangle.i1); + Float3 p2 = vertex_buffer->read(triangle.i2); + Float3 p = triangle_interpolate(hit.bary, p0, p1, p2); + Float3 n = normalize(cross(p1 - p0, p2 - p0)); + + Float cos_wo = dot(-ray->direction(), n); + $if (cos_wo < 1e-4f) { $break; }; + + // hit light + $if (hit.inst == static_cast(meshes.size() - 1u)) { + $if (depth == 0u) { + radiance += light_emission; + } + $else { + Float pdf_light = length_squared(p - ray->origin()) / (light_area * cos_wo); + Float mis_weight = balanced_heuristic(pdf_bsdf, pdf_light); + radiance += mis_weight * beta * light_emission; + }; + $break; + }; + + // sample light + $suspend("sample_light"); + Float ux_light = lcg(state); + Float uy_light = lcg(state); + Float3 p_light = light_position + ux_light * light_u + uy_light * light_v; + Float3 pp = offset_ray_origin(p, n); + Float3 pp_light = offset_ray_origin(p_light, light_normal); + Float d_light = distance(pp, pp_light); + Float3 wi_light = normalize(pp_light - pp); + + Var shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.f, d_light); + Bool occluded = accel.intersect_any(shadow_ray, {}); + Float cos_wi_light = dot(wi_light, n); + Float cos_light = -dot(light_normal, wi_light); + + $suspend("evaluate_surface"); + Float3 albedo = materials.read(hit.inst); + $if (!occluded & cos_wi_light > 1e-4f & cos_light > 1e-4f) { + Float pdf_light = (d_light * d_light) / (light_area * cos_light); + Float pdf_bsdf = cos_wi_light * inv_pi; + Float mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + Float3 bsdf = albedo * inv_pi * cos_wi_light; + radiance += beta * bsdf * mis_weight * light_emission / max(pdf_light, 1e-4f); + }; + + // sample BSDF + Var onb = make_onb(n); + Float ux = lcg(state); + Float uy = lcg(state); + Float3 wi_local = cosine_sample_hemisphere(make_float2(ux, uy)); + Float cos_wi = abs(wi_local.z); + Float3 new_direction = onb->to_world(wi_local); + ray = make_ray(pp, new_direction); + pdf_bsdf = cos_wi * inv_pi; + beta *= albedo;// * cos_wi * inv_pi / pdf_bsdf => * 1.f + + // rr + // $suspend("rr"); + Float l = dot(make_float3(0.212671f, 0.715160f, 0.072169f), beta); + $if (l == 0.0f) { $break; }; + Float q = max(l, 0.05f); + Float r = lcg(state); + $if (r >= q) { $break; }; + beta *= 1.0f / q; + }; + + $suspend("write_film"); + seed_image.write(coord, make_uint4(state)); + $if (any(dsl::isnan(radiance))) { radiance = make_float3(0.0f); }; + auto color = make_float4(clamp(radiance, 0.0f, 30.0f), 1.0f); + image.write(coord, color); + }; + + coroutine::WavefrontCoroSchedulerConfig config{ + .thread_count = 4_M, + .global_memory_soa = true, + .gather_by_sorting = false, + }; + coroutine::WavefrontCoroScheduler scheduler{device, coro, config}; + // coroutine::PersistentThreadsCoroScheduler scheduler{device, coro}; + + Kernel2D accumulate_kernel = [&](ImageFloat accum_image, ImageFloat curr_image) noexcept { + UInt2 p = dispatch_id().xy(); + Float4 accum = accum_image.read(p); + Float3 curr = curr_image.read(p).xyz(); + accum_image.write(p, accum + make_float4(curr, 1.f)); + }; + + Callable aces_tonemapping = [](Float3 x) noexcept { + static constexpr float a = 2.51f; + static constexpr float b = 0.03f; + static constexpr float c = 2.43f; + static constexpr float d = 0.59f; + static constexpr float e = 0.14f; + return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0f, 1.0f); + }; + + Kernel2D clear_kernel = [](ImageFloat image) noexcept { + image.write(dispatch_id().xy(), make_float4(0.0f)); + }; + + Kernel2D hdr2ldr_kernel = [&](ImageFloat hdr_image, ImageFloat ldr_image, Float scale, Bool is_hdr) noexcept { + UInt2 coord = dispatch_id().xy(); + Float4 hdr = hdr_image.read(coord); + Float3 ldr = hdr.xyz() / hdr.w * scale; + $if (!is_hdr) { + ldr = linear_to_srgb(ldr); + }; + ldr_image.write(coord, make_float4(ldr, 1.0f)); + }; + + ShaderOption o{.enable_debug_info = false}; + auto clear_shader = device.compile(clear_kernel, o); + auto hdr2ldr_shader = device.compile(hdr2ldr_kernel, o); + auto accumulate_shader = device.compile(accumulate_kernel, o); + auto make_sampler_shader = device.compile(make_sampler_kernel, o); + + static constexpr uint2 resolution = make_uint2(1024u); + Image framebuffer = device.create_image(PixelStorage::HALF4, resolution); + Image accum_image = device.create_image(PixelStorage::FLOAT4, resolution); + luisa::vector> host_image(resolution.x * resolution.y); + + Image seed_image = device.create_image(PixelStorage::INT1, resolution); + stream << clear_shader(accum_image).dispatch(resolution) + << make_sampler_shader(seed_image).dispatch(resolution); + + Window window{"path tracing", resolution}; + Swapchain swap_chain = device.create_swapchain( + stream, + SwapchainOption{ + .display = window.native_display(), + .window = window.native_handle(), + .size = make_uint2(resolution), + .wants_hdr = false, + .wants_vsync = false, + .back_buffer_count = 3, + }); + Image ldr_image = device.create_image(swap_chain.backend_storage(), resolution); + double last_time = 0.0; + uint frame_count = 0u; + Clock clock; + + while (!window.should_close()) { + stream << scheduler(framebuffer, seed_image, accel, resolution) + .dispatch(resolution.x, resolution.y, spp_per_dispatch) + << accumulate_shader(accum_image, framebuffer) + .dispatch(resolution) + << hdr2ldr_shader(accum_image, ldr_image, 1.0f, swap_chain.backend_storage() != PixelStorage::BYTE4).dispatch(resolution) + << swap_chain.present(ldr_image) + << synchronize(); + window.poll_events(); + double dt = clock.toc() - last_time; + last_time = clock.toc(); + frame_count += spp_per_dispatch; + LUISA_INFO("spp: {}, time: {} ms, spp/s: {}", + frame_count, dt, spp_per_dispatch / dt * 1000); + } + stream + << ldr_image.copy_to(host_image.data()) + << synchronize(); + + LUISA_INFO("FPS: {}", frame_count / clock.toc() * 1000); + stbi_write_png("test_path_tracing.png", resolution.x, resolution.y, 4, host_image.data(), 0); +} diff --git a/src/tests/coro/playground.cpp b/src/tests/coro/playground.cpp deleted file mode 100644 index 1b5dc1fc6..000000000 --- a/src/tests/coro/playground.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -using namespace luisa; -using namespace luisa::compute; - -struct alignas(4) CoroFrame {}; -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - -int main(int argc, char *argv[]) { - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - Device device = context.create_device(argv[1]); - Stream stream = device.create_stream(); - - Coroutine coro = [&](Var &, Int &a) noexcept { - $for (i, 10) { - device_log("before {}: {}", i, a); - auto token = $suspend(std::make_pair(a, "x")); - // device_log("after {}: {}", i, a); - a += 1; - $suspend(); - }; - }; - - Kernel1D mega_kernel = [&] { - device_log("scheduler: init"); - auto frame = make_coroframe(dispatch_id()); - device_log("coro_id = {}", frame->coro_id()); - device_log("scheduler: enter entry"); - auto a = def(0); - coro(frame, a); - device_log("scheduler: exit entry"); - $loop { - auto token = frame->target_token(); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - device_log("scheduler: enter coro {}, a = {}", i, a); - coro[i](frame, a); - device_log("promise x: {}", frame->promise("x")); - device_log("scheduler: exit coro {}, a = {}", i, a); - }; - } - $default { - device_log("scheduler: terminate"); - $return(); - }; - }; - }; - }; - - auto shader = device.compile(mega_kernel); - stream << shader().dispatch(1u) - << synchronize(); -} diff --git a/src/tests/coro/sdf_renderer.cpp b/src/tests/coro/sdf_renderer.cpp index edfb8f79e..cc8aba1f2 100644 --- a/src/tests/coro/sdf_renderer.cpp +++ b/src/tests/coro/sdf_renderer.cpp @@ -1,15 +1,10 @@ #include #include #include -#include using namespace luisa; using namespace luisa::compute; -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame) {}; - int main(int argc, char *argv[]) { Context context{argv[0]}; @@ -124,9 +119,9 @@ int main(int argc, char *argv[]) { }; }; - Coroutine coro = [&](Var &, ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { + coroutine::Coroutine coro = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { Float2 resolution = make_float2(width, height); - UInt2 coord = make_uint2(coro_id().x / height, coro_id().x % height); + UInt2 coord = dispatch_id().xy(); $if (frame_index == 0u) { seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); @@ -184,26 +179,14 @@ int main(int argc, char *argv[]) { Image seed_image = device.create_image(PixelStorage::INT1, width, height); Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); - Kernel2D mega_kernel = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame, seed_image, accum_image, frame_index); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - coro[i](frame, seed_image, accum_image, frame_index); - }; - } - $default { - $return(); - }; - }; - }; - }; + coroutine::PersistentThreadsCoroSchedulerConfig config{ + .thread_count = 64_k, + .block_size = 64u, + .fetch_size = 3u, + .shared_memory_soa = false, + .global_memory_ext = false}; + coroutine::PersistentThreadsCoroScheduler scheduler{device, coro, config}; - auto shader = device.compile(mega_kernel); auto clear_shader = device.compile<2>([&] { auto coord = dispatch_id().xy(); accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); @@ -211,37 +194,13 @@ int main(int argc, char *argv[]) { }); luisa::vector host_image(accum_image.view().size_bytes()); - //coro::SimpleCoroDispatcher Wdispatcher{&coro, device, resolution.x * resolution.y}; - // coro::WavefrontCoroDispatcherConfig wf_config{ - // .max_instance_count = resolution.x * resolution.y, - // .debug = false, - // }; - // coro::WavefrontCoroDispatcher Wdispatcher{&coro, device, stream, wf_config}; - coro::PersistentCoroDispatcherConfig config{ - .max_thread_count = 256u * 256u, - .block_size = 128u, - .fetch_size = 3u, - .debug = false, - }; - coro::PersistentCoroDispatcher Wdispatcher{&coro, device, stream, config}; - stream << clear_shader().dispatch(resolution) - << synchronize(); - /*for (auto i = 0u; i < 100u; i++) { - stream << shader(seed_image, accum_image, i).dispatch(resolution); - }*/ + stream << clear_shader().dispatch(resolution) << synchronize(); Clock clk; - auto samples = 1000; - //Wdispatcher(seed_image, accum_image, 0u, resolution.x * resolution.y); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //stream << Wdispatcher.await_step(); - //return 0; + auto samples = 1024; for (auto i = 0u; i < samples; ++i) { LUISA_INFO("spp {}", i); - Wdispatcher(seed_image, accum_image, i, resolution.x * resolution.y); - stream << Wdispatcher.await_all(); - // << synchronize(); + stream << scheduler(seed_image, accum_image, i).dispatch(width, height) + << synchronize(); } stream << synchronize(); auto dt = clk.toc(); diff --git a/src/tests/coro/sdf_renderer_wo_dispatcher.cpp b/src/tests/coro/sdf_renderer_wo_dispatcher.cpp deleted file mode 100644 index 2462810a0..000000000 --- a/src/tests/coro/sdf_renderer_wo_dispatcher.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include -#include - -using namespace luisa; -using namespace luisa::compute; - -struct alignas(4) CoroFrame {}; - -LUISA_COROFRAME_STRUCT(CoroFrame){}; - -int main(int argc, char *argv[]) { - - Context context{argv[0]}; - if (argc <= 1) { exit(1); } - - static constexpr uint width = 1280u; - static constexpr uint height = 720u; - - static constexpr int max_ray_depth = 6; - static constexpr float eps = 1e-4f; - static constexpr float inf = 1e10f; - static constexpr float fov = 0.23f; - static constexpr float dist_limit = 100.0f; - static constexpr float3 camera_pos = make_float3(0.0f, 0.32f, 3.7f); - static constexpr float3 light_pos = make_float3(-1.5f, 0.6f, 0.3f); - static constexpr float3 light_normal = make_float3(1.0f, 0.0f, 0.0f); - static constexpr float light_radius = 2.0f; - - Clock clock; - - Callable intersect_light = [](Float3 pos, Float3 d) noexcept { - Float cos_w = dot(-d, light_normal); - Float dist = dot(d, light_pos - pos); - Float D = dist / cos_w; - Float dist_to_center = distance_squared(light_pos, pos + D * d); - Bool valid = cos_w > 0.0f & dist > 0.0f & dist_to_center < light_radius * light_radius; - return ite(valid, D, inf); - }; - - Callable tea = [](UInt v0, UInt v1) noexcept { - Var s0 = 0u; - for (uint n = 0u; n < 4u; n++) { - s0 += 0x9e3779b9u; - v0 += ((v1 << 4) + 0xa341316cu) ^ (v1 + s0) ^ ((v1 >> 5u) + 0xc8013ea4u); - v1 += ((v0 << 4) + 0xad90777du) ^ (v0 + s0) ^ ((v0 >> 5u) + 0x7e95761eu); - } - return v0; - }; - - Callable rand = [](UInt &state) noexcept { - constexpr uint lcg_a = 1664525u; - constexpr uint lcg_c = 1013904223u; - state = lcg_a * state + lcg_c; - return cast(state) / cast(std::numeric_limits::max()); - }; - - Callable out_dir = [&rand](Float3 n, UInt &seed) noexcept { - Float3 u = ite( - abs(n.y) < 1.0f - eps, - normalize(cross(n, make_float3(0.0f, 1.0f, 0.0f))), - make_float3(1.f, 0.f, 0.f)); - Float3 v = cross(n, u); - Float phi = 2.0f * pi * rand(seed); - Float ay = sqrt(rand(seed)); - Float ax = sqrt(1.0f - ay * ay); - return ax * (cos(phi) * u + sin(phi) * v) + ay * n; - }; - - Callable make_nested = [](Float f) noexcept { - static constexpr float freq = 40.0f; - f *= freq; - f = ite(f < 0.f, ite(cast(f) % 2 == 0, 1.f - fract(f), fract(f)), f); - return (f - 0.2f) * (1.0f / freq); - }; - - Callable sdf = [&make_nested](Float3 o) noexcept { - Float wall = min(o.y + 0.1f, o.z + 0.4f); - Float sphere = distance(o, make_float3(0.0f, 0.35f, 0.0f)) - 0.36f; - Float3 q = abs(o - make_float3(0.8f, 0.3f, 0.0f)) - 0.3f; - Float box = length(max(q, 0.0f)) + min(max(max(q.x, q.y), q.z), 0.0f); - Float3 O = o - make_float3(-0.8f, 0.3f, 0.0f); - Float2 d = make_float2(length(make_float2(O.x, O.z)) - 0.3f, abs(O.y) - 0.3f); - Float cylinder = min(max(d.x, d.y), 0.0f) + length(max(d, 0.0f)); - Float geometry = make_nested(min(min(sphere, box), cylinder)); - Float g = max(geometry, -(0.32f - (o.y * 0.6f + o.z * 0.8f))); - return min(wall, g); - }; - - Callable ray_march = [&sdf](Float3 p, Float3 d) noexcept { - Float dist = def(0.0f); - $for (j, 100) { - Float s = sdf(p + dist * d); - $if (s <= 1e-6f | dist >= inf) { $break; }; - dist += s; - }; - return min(dist, inf); - }; - - Callable sdf_normal = [&sdf](Float3 p) noexcept { - static constexpr float d = 1e-3f; - Float3 n = def(make_float3()); - Float sdf_center = sdf(p); - for (uint i = 0; i < 3; i++) { - Float3 inc = p; - inc[i] += d; - n[i] = (1.0f / d) * (sdf(inc) - sdf_center); - } - return normalize(n); - }; - - Callable next_hit = [&ray_march, &sdf_normal](Float &closest, Float3 &normal, Float3 &c, Float3 pos, Float3 d) noexcept { - closest = inf; - normal = make_float3(); - c = make_float3(); - Float ray_march_dist = ray_march(pos, d); - $if (ray_march_dist < min(dist_limit, closest)) { - closest = ray_march_dist; - Float3 hit_pos = pos + d * closest; - normal = sdf_normal(hit_pos); - Int t = cast((hit_pos.x + 10.0f) * 1.1f + 0.5f) % 3; - c = make_float3(0.4f) + make_float3(0.3f, 0.2f, 0.3f) * ite(t == make_int3(0, 1, 2), 1.0f, 0.0f); - }; - }; - - Coroutine coro = [&](Var &, ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) noexcept { - Float2 resolution = make_float2(width, height); - UInt2 coord = make_uint2(coro_id().x / height, coro_id().x % height); - - $if (frame_index == 0u) { - seed_image.write(coord, make_uint4(tea(coord.x, coord.y))); - accum_image.write(coord, make_float4(make_float3(0.0f), 1.0f)); - }; - - Float aspect_ratio = resolution.x / resolution.y; - Float3 pos = def(camera_pos); - UInt seed = seed_image.read(coord).x; - Float ux = rand(seed); - Float uy = rand(seed); - Float2 uv = make_float2(coord.x + ux, height - 1u - coord.y + uy); - Float3 d = make_float3( - 2.0f * fov * uv / resolution.y - fov * make_float2(aspect_ratio, 1.0f) - 1e-5f, -1.0f); - d = normalize(d); - Float3 throughput = def(make_float3(1.0f, 1.0f, 1.0f)); - Float hit_light = def(0.0f); - $for (depth, max_ray_depth) { - - $suspend("2"); - Float closest = def(0.0f); - Float3 normal = def(make_float3()); - Float3 c = def(make_float3()); - next_hit(closest, normal, c, pos, d); - Float dist_to_light = intersect_light(pos, d); - $if (dist_to_light < closest) { - // $if (depth == 0) { - // device_log("xxxxxcoord {} {}", coord.x, coord.y); - // }; - hit_light = 1.0f; - $break; - }; - $if (length_squared(normal) == 0.0f) { - // $if (depth == 0) { - // device_log("coord {} {}", coord.x, coord.y); - // }; - $break; - }; - Float3 hit_pos = pos + closest * d; - d = out_dir(normal, seed); - pos = hit_pos + 1e-4f * d; - throughput *= c; - }; - $suspend("3"); - - Float3 accum_color = lerp(accum_image.read(coord).xyz(), throughput.xyz() * hit_light, 1.0f / (frame_index + 1.0f)); - accum_image.write(coord, make_float4(accum_color, 1.0f)); - //$suspend("4"); - seed_image.write(coord, make_uint4(seed)); - }; - - auto device = context.create_device(argv[1]); - auto stream = device.create_stream(); - constexpr auto resolution = make_uint2(width, height); - constexpr auto spp = 1000u; - - Image seed_image = device.create_image(PixelStorage::INT1, width, height); - Image accum_image = device.create_image(PixelStorage::FLOAT4, width, height); - - Kernel1D mega_kernel = [&](ImageUInt seed_image, ImageFloat accum_image, UInt frame_index) { - Var frame; - initialize_coroframe(frame, dispatch_id()); - coro(frame, seed_image, accum_image, frame_index); - $loop { - auto token = read_promise(frame, "coro_token"); - $switch (token) { - for (auto i = 1u; i <= coro.suspend_count(); i++) { - $case (i) { - coro[i](frame, seed_image, accum_image, frame_index); - }; - } - $default { - $return(); - }; - }; - }; - }; - - auto shader = device.compile(mega_kernel); - auto clear_shader = device.compile<2>([&] { - auto coord = dispatch_id().xy(); - accum_image->write(coord, make_float4(make_float3(0.0f), 1.0f)); - seed_image->write(coord, make_uint4(coord.y * resolution.y + coord.x)); - }); - - luisa::vector host_image(accum_image.view().size_bytes()); - stream << clear_shader().dispatch(resolution) - << synchronize(); - Clock clk; - for (auto i = 0u; i < spp; i++) { - LUISA_INFO("spp {}", i); - stream << shader(seed_image, accum_image, i).dispatch(resolution.x * resolution.y); - } - stream << synchronize(); - auto dt = clk.toc(); - LUISA_INFO("Time: {} ms ({} spp/s)", dt, 1e6 / dt); - - stream << accum_image.copy_to(host_image.data()) - << synchronize(); - stbi_write_hdr("test_sdf_wo_dispatcher.hdr", resolution.x, resolution.y, 4, reinterpret_cast(host_image.data())); -} diff --git a/src/tests/test_skyline.cpp b/src/tests/test_skyline.cpp index 72ca9dbd7..1eab730d9 100644 --- a/src/tests/test_skyline.cpp +++ b/src/tests/test_skyline.cpp @@ -8,13 +8,12 @@ #include #include #include -#include + using namespace luisa; using namespace luisa::compute; -struct alignas(4) Skyline { -}; + const bool SHOW = true; -LUISA_COROFRAME_STRUCT(Skyline) {}; + int main(int argc, char *argv[]) { Context context{argv[0]}; @@ -824,7 +823,7 @@ int main(int argc, char *argv[]) { image.write(xy, make_float4(sqrt(RayTrace(make_float2(xy), time)), 1.0f)); }; is_coroutine = true; - Coroutine coro = [&](Var &frame, Float time) noexcept { + Kernel2D coro = [&](Float time) noexcept { auto xy = make_uint2(coro_id().x % width, coro_id().x / width); device_image->write(xy, make_float4(sqrt(RayTrace(make_float2(xy), time)), 1.0f)); }; diff --git a/src/tests/test_sort.cpp b/src/tests/test_sort.cpp index 9a8a5382c..e43dadb7f 100644 --- a/src/tests/test_sort.cpp +++ b/src/tests/test_sort.cpp @@ -3,7 +3,6 @@ /// reference: [Onesweep: A Faster Least Significant Digit Radix Sort for GPUs] /// https://arxiv.org/abs/2206.01784 #include -#include #include using namespace luisa; using namespace luisa::compute;