diff --git a/include/luisa/coro/schedulers/persistent_threads.h b/include/luisa/coro/schedulers/persistent_threads.h index e0c647980..7d777e2b4 100644 --- a/include/luisa/coro/schedulers/persistent_threads.h +++ b/include/luisa/coro/schedulers/persistent_threads.h @@ -14,7 +14,7 @@ namespace luisa::compute::coroutine { struct PersistentThreadsCoroSchedulerConfig { uint thread_count = 64_k; uint block_size = 128; - uint fetch_size = 16; + uint fetch_size = 4; bool shared_memory_soa = false; bool global_memory_ext = false; }; @@ -23,7 +23,7 @@ namespace detail { LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept; }// namespace detail @@ -52,7 +52,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { auto global_ext_size = _config.thread_count * g_fac; _global_frames = device.create_buffer(coro.shared_frame(), global_ext_size); } - Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_shape, Var... args) noexcept { + Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_size_prefix_product, Var... args) noexcept { set_block_size(_config.block_size, 1u, 1u); auto global_queue_size = _config.block_size * g_fac; auto shared_queue_size = _config.block_size * q_fac; @@ -60,7 +60,7 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { Shared frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa}; detail::persistent_threads_coro_scheduler_main_kernel_impl( _config, q_fac, g_fac, shared_queue_size, global_queue_size, graph, - frames, dispatch_shape, global, _global_frames, call_subroutine); + frames, dispatch_size_prefix_product, global, _global_frames, call_subroutine); }; _pt_shader = device.compile(main_kernel); _clear_shader = device.compile<1>([](BufferUInt global) { @@ -78,12 +78,19 @@ class PersistentThreadsCoroScheduler : public CoroScheduler { void _dispatch(Stream &stream, uint3 dispatch_size, compute::detail::prototype_to_shader_invocation_t... args) noexcept override { - stream << _clear_shader(_global).dispatch(1u); + auto dispatch_size_prefix_product = make_uint3( + dispatch_size.x, + dispatch_size.x * dispatch_size.y, + dispatch_size.x * dispatch_size.y * dispatch_size.z); if (_config.global_memory_ext) { auto n = static_cast(_global_frames.size()); - stream << _initialize_shader(n).dispatch(n); + stream << _clear_shader(_global).dispatch(1u) + << _initialize_shader(n).dispatch(n) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); + } else { + stream << _clear_shader(_global).dispatch(1u) + << _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count); } - stream << _pt_shader(_global, dispatch_size, args...).dispatch(_config.thread_count); } public: diff --git a/src/coro/schedulers/persistent_threads.cpp b/src/coro/schedulers/persistent_threads.cpp index 9b7ac37c0..81302e9b1 100644 --- a/src/coro/schedulers/persistent_threads.cpp +++ b/src/coro/schedulers/persistent_threads.cpp @@ -13,7 +13,7 @@ namespace luisa::compute::coroutine::detail { void persistent_threads_coro_scheduler_main_kernel_impl( const PersistentThreadsCoroSchedulerConfig &config, uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size, - const CoroGraph *graph, Shared &frames, Expr dispatch_shape, + const CoroGraph *graph, Shared &frames, Expr dispatch_size_prefix_product, Expr> global, const Buffer &global_frames, luisa::move_only_function call_subroutine) noexcept { @@ -26,16 +26,17 @@ void persistent_threads_coro_scheduler_main_kernel_impl( shared_queue_size}; Shared workload{2}; Shared work_stat{2};// work_state[0] for max_count, [1] for max_id - for (auto index : dsl::dynamic_range(q_fac)) { + for (auto index = 0u; index < q_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[s] = 0u; if (config.shared_memory_soa) { - frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u}); + frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{0u, 1u}); } else { + frames[s].coro_id = make_uint3(); frames[s].target_token = 0u; } } - for (auto index : dsl::dynamic_range(g_fac)) { + for (auto index = 0u; index < g_fac; index++) { auto s = index * config.block_size + thread_x(); all_token[shared_queue_size + s] = 0u; } @@ -58,8 +59,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( rem_local[0] = 0u; sync_block(); auto count = def(0u); - auto count_limit = def(-1); - auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z; + auto count_limit = def(std::numeric_limits::max()); $while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) { sync_block();//very important, synchronize for condition rem_local[0] = 0u; @@ -70,8 +70,8 @@ void persistent_threads_coro_scheduler_main_kernel_impl( $if (thread_x() == config.block_size - 1) { $if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size); - workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size); - $if (workload[0] >= dispatch_size) { + workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size_prefix_product.z); + $if (workload[0] >= dispatch_size_prefix_product.z) { rem_global[0] = 0u; }; }; @@ -96,7 +96,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_offset[1] = 0; sync_block(); if (!config.global_memory_ext) { - for (auto index : dsl::dynamic_range(q_fac)) {//collect indices + for (auto index = 0u; index < q_fac; index++) {//collect indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token == work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -104,7 +104,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( }; } } else { - for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices + for (auto index = 0u; index < q_fac; index++) {//collect switch out indices auto frame_token = all_token[index * config.block_size + thread_x()]; $if (frame_token != work_stat[1]) { auto id = work_offset.atomic(0).fetch_add(1u); @@ -113,7 +113,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( } sync_block(); $if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work - for (auto index : dsl::dynamic_range(g_fac)) { //swap frames + $for (index, 0u, g_fac) { //swap frames auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x(); auto g_queue_id = index * config.block_size + thread_x(); auto coro_token = all_token[shared_queue_size + g_queue_id]; @@ -123,76 +123,62 @@ void persistent_threads_coro_scheduler_main_kernel_impl( auto dst = path_id[id]; auto frame_token = all_token[dst]; $if (coro_token != 0u) { + auto g_state = global_frames->read(global_id); $if (frame_token != 0u) { if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames.read(dst)); - frames.write(dst, g_state); } else { - auto g_state = global_frames->read(global_id); global_frames->write(global_id, frames[dst]); - frames[dst] = g_state; } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - } - $else { - if (config.shared_memory_soa) { - auto g_state = global_frames->read(global_id); - frames.write(dst, g_state); - } else { - frames[dst] = global_frames->read(global_id); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; }; + if (config.shared_memory_soa) { + frames.write(dst, g_state); + } else { + frames[dst] = g_state; + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; } - $else { - $if (frame_token != 0u) { - if (config.shared_memory_soa) { - auto frame = frames.read(dst); - global_frames->write(global_id, frame); - } else { - global_frames->write(global_id, frames[dst]); - } - all_token[shared_queue_size + g_queue_id] = frame_token; - all_token[dst] = coro_token; - }; + $elif (frame_token != 0u) { + if (config.shared_memory_soa) { + auto frame = frames.read(dst); + global_frames->write(global_id, frame); + } else { + global_frames->write(global_id, frames[dst]); + } + all_token[shared_queue_size + g_queue_id] = frame_token; + all_token[dst] = coro_token; }; }; }; - } + }; }; } auto gen_st = workload[0]; sync_block(); - auto pid = def(0u); - if (config.global_memory_ext) { - pid = thread_x(); - } else { - pid = path_id[thread_x()]; - } - auto launch_condition = def(true); - if (!config.global_memory_ext) { - launch_condition = (thread_x() < work_offset[0]); - } else { - launch_condition = (all_token[pid] == work_stat[1]); - } - $if (launch_condition) { + auto pid = config.global_memory_ext ? thread_x() : path_id[thread_x()]; + $if (config.global_memory_ext ? (all_token[pid] == work_stat[1]) : (thread_x() < work_offset[0])) { $switch (all_token[pid]) { $case (0u) { $if (gen_st + thread_x() < workload[1]) { work_counter.atomic(0u).fetch_sub(1u); auto global_index = gen_st + thread_x(); - auto image_size = dispatch_shape.x * dispatch_shape.y; - auto index_z = global_index / image_size; - auto index_xy = global_index % image_size; - auto index_x = index_xy % dispatch_shape.x; - auto index_y = index_xy / dispatch_shape.x; - auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); - call_subroutine(frame, coro_token_entry); - auto next = frame.target_token & coro_token_valid_mask; - frames.write(pid, frame, graph->entry().output_fields()); + auto index_z = global_index / dispatch_size_prefix_product.y; + auto index_xy = global_index % dispatch_size_prefix_product.y; + auto index_x = index_xy % dispatch_size_prefix_product.x; + auto index_y = index_xy / dispatch_size_prefix_product.x; + auto next = def(0u); + if (config.shared_memory_soa) { + auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z)); + call_subroutine(frame, coro_token_entry); + next = frame.target_token & coro_token_valid_mask; + frames.write(pid, frame, graph->entry().output_fields()); + } else { + frames[pid].coro_id = make_uint3(index_x, index_y, index_z); + frames[pid].target_token = coro_token_entry; + call_subroutine(frames[pid], coro_token_entry); + next = frames[pid].target_token & coro_token_valid_mask; + } all_token[pid] = next; work_counter.atomic(next).fetch_add(1u); workload.atomic(0).fetch_add(1u); @@ -215,6 +201,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl( work_counter.atomic(next).fetch_add(1u); }; } + $default { dsl::unreachable(); }; }; }; sync_block();