Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
trying to optimize dx
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed May 16, 2024
1 parent 1161b00 commit b0d90b8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 68 deletions.
21 changes: 14 additions & 7 deletions include/luisa/coro/schedulers/persistent_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace luisa::compute::coroutine {
struct PersistentThreadsCoroSchedulerConfig {
uint thread_count = 64_k;
uint block_size = 128;
uint fetch_size = 16;
uint fetch_size = 4;
bool shared_memory_soa = false;
bool global_memory_ext = false;
};
Expand All @@ -23,7 +23,7 @@ namespace detail {
LC_CORO_API void persistent_threads_coro_scheduler_main_kernel_impl(
const PersistentThreadsCoroSchedulerConfig &config,
uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size,
const CoroGraph *graph, Shared<CoroFrame> &frames, Expr<uint3> dispatch_shape,
const CoroGraph *graph, Shared<CoroFrame> &frames, Expr<uint3> dispatch_size_prefix_product,
Expr<Buffer<uint>> global, const Buffer<CoroFrame> &global_frames,
luisa::move_only_function<void(CoroFrame &, CoroToken)> call_subroutine) noexcept;
}// namespace detail
Expand Down Expand Up @@ -52,15 +52,15 @@ class PersistentThreadsCoroScheduler : public CoroScheduler<Args...> {
auto global_ext_size = _config.thread_count * g_fac;
_global_frames = device.create_buffer<CoroFrame>(coro.shared_frame(), global_ext_size);
}
Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_shape, Var<Args>... args) noexcept {
Kernel1D main_kernel = [this, q_fac, g_fac, &coro, graph = coro.graph()](BufferUInt global, UInt3 dispatch_size_prefix_product, Var<Args>... args) noexcept {
set_block_size(_config.block_size, 1u, 1u);
auto global_queue_size = _config.block_size * g_fac;
auto shared_queue_size = _config.block_size * q_fac;
auto call_subroutine = [&](CoroFrame &frame, CoroToken token) noexcept { coro[token](frame, args...); };
Shared<CoroFrame> frames{graph->shared_frame(), shared_queue_size, _config.shared_memory_soa};
detail::persistent_threads_coro_scheduler_main_kernel_impl(
_config, q_fac, g_fac, shared_queue_size, global_queue_size, graph,
frames, dispatch_shape, global, _global_frames, call_subroutine);
frames, dispatch_size_prefix_product, global, _global_frames, call_subroutine);
};
_pt_shader = device.compile(main_kernel);
_clear_shader = device.compile<1>([](BufferUInt global) {
Expand All @@ -78,12 +78,19 @@ class PersistentThreadsCoroScheduler : public CoroScheduler<Args...> {

void _dispatch(Stream &stream, uint3 dispatch_size,
compute::detail::prototype_to_shader_invocation_t<Args>... args) noexcept override {
stream << _clear_shader(_global).dispatch(1u);
auto dispatch_size_prefix_product = make_uint3(
dispatch_size.x,
dispatch_size.x * dispatch_size.y,
dispatch_size.x * dispatch_size.y * dispatch_size.z);
if (_config.global_memory_ext) {
auto n = static_cast<uint>(_global_frames.size());
stream << _initialize_shader(n).dispatch(n);
stream << _clear_shader(_global).dispatch(1u)
<< _initialize_shader(n).dispatch(n)
<< _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count);
} else {
stream << _clear_shader(_global).dispatch(1u)
<< _pt_shader(_global, dispatch_size_prefix_product, args...).dispatch(_config.thread_count);
}
stream << _pt_shader(_global, dispatch_size, args...).dispatch(_config.thread_count);
}

public:
Expand Down
109 changes: 48 additions & 61 deletions src/coro/schedulers/persistent_threads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace luisa::compute::coroutine::detail {
void persistent_threads_coro_scheduler_main_kernel_impl(
const PersistentThreadsCoroSchedulerConfig &config,
uint q_fac, uint g_fac, uint shared_queue_size, uint global_queue_size,
const CoroGraph *graph, Shared<CoroFrame> &frames, Expr<uint3> dispatch_shape,
const CoroGraph *graph, Shared<CoroFrame> &frames, Expr<uint3> dispatch_size_prefix_product,
Expr<Buffer<uint>> global, const Buffer<CoroFrame> &global_frames,
luisa::move_only_function<void(CoroFrame &, CoroToken)> call_subroutine) noexcept {

Expand All @@ -26,16 +26,17 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
shared_queue_size};
Shared<uint> workload{2};
Shared<uint> work_stat{2};// work_state[0] for max_count, [1] for max_id
for (auto index : dsl::dynamic_range(q_fac)) {
for (auto index = 0u; index < q_fac; index++) {
auto s = index * config.block_size + thread_x();
all_token[s] = 0u;
if (config.shared_memory_soa) {
frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u});
frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{0u, 1u});
} else {
frames[s].coro_id = make_uint3();
frames[s].target_token = 0u;
}
}
for (auto index : dsl::dynamic_range(g_fac)) {
for (auto index = 0u; index < g_fac; index++) {
auto s = index * config.block_size + thread_x();
all_token[shared_queue_size + s] = 0u;
}
Expand All @@ -58,8 +59,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
rem_local[0] = 0u;
sync_block();
auto count = def(0u);
auto count_limit = def<uint>(-1);
auto dispatch_size = dispatch_shape.x * dispatch_shape.y * dispatch_shape.z;
auto count_limit = def(std::numeric_limits<uint>::max());
$while ((rem_global[0] != 0u | rem_local[0] != 0u) & (count != count_limit)) {
sync_block();//very important, synchronize for condition
rem_local[0] = 0u;
Expand All @@ -70,8 +70,8 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
$if (thread_x() == config.block_size - 1) {
$if (workload[0] >= workload[1] & rem_global[0] == 1u) {//fetch new workload
workload[0] = global.atomic(0u).fetch_add(config.block_size * config.fetch_size);
workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size);
$if (workload[0] >= dispatch_size) {
workload[1] = min(workload[0] + config.block_size * config.fetch_size, dispatch_size_prefix_product.z);
$if (workload[0] >= dispatch_size_prefix_product.z) {
rem_global[0] = 0u;
};
};
Expand All @@ -96,15 +96,15 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
work_offset[1] = 0;
sync_block();
if (!config.global_memory_ext) {
for (auto index : dsl::dynamic_range(q_fac)) {//collect indices
for (auto index = 0u; index < q_fac; index++) {//collect indices
auto frame_token = all_token[index * config.block_size + thread_x()];
$if (frame_token == work_stat[1]) {
auto id = work_offset.atomic(0).fetch_add(1u);
path_id[id] = index * config.block_size + thread_x();
};
}
} else {
for (auto index : dsl::dynamic_range(q_fac)) {//collect switch out indices
for (auto index = 0u; index < q_fac; index++) {//collect switch out indices
auto frame_token = all_token[index * config.block_size + thread_x()];
$if (frame_token != work_stat[1]) {
auto id = work_offset.atomic(0).fetch_add(1u);
Expand All @@ -113,7 +113,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
}
sync_block();
$if (shared_queue_size - work_offset[0] < config.block_size) {//no enough work
for (auto index : dsl::dynamic_range(g_fac)) { //swap frames
$for (index, 0u, g_fac) { //swap frames
auto global_id = block_x() * global_queue_size + index * config.block_size + thread_x();
auto g_queue_id = index * config.block_size + thread_x();
auto coro_token = all_token[shared_queue_size + g_queue_id];
Expand All @@ -123,76 +123,62 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
auto dst = path_id[id];
auto frame_token = all_token[dst];
$if (coro_token != 0u) {
auto g_state = global_frames->read(global_id);
$if (frame_token != 0u) {
if (config.shared_memory_soa) {
auto g_state = global_frames->read(global_id);
global_frames->write(global_id, frames.read(dst));
frames.write(dst, g_state);
} else {
auto g_state = global_frames->read(global_id);
global_frames->write(global_id, frames[dst]);
frames[dst] = g_state;
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
}
$else {
if (config.shared_memory_soa) {
auto g_state = global_frames->read(global_id);
frames.write(dst, g_state);
} else {
frames[dst] = global_frames->read(global_id);
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
};
if (config.shared_memory_soa) {
frames.write(dst, g_state);
} else {
frames[dst] = g_state;
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
}
$else {
$if (frame_token != 0u) {
if (config.shared_memory_soa) {
auto frame = frames.read(dst);
global_frames->write(global_id, frame);
} else {
global_frames->write(global_id, frames[dst]);
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
};
$elif (frame_token != 0u) {
if (config.shared_memory_soa) {
auto frame = frames.read(dst);
global_frames->write(global_id, frame);
} else {
global_frames->write(global_id, frames[dst]);
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
};
};
};
}
};
};
}
auto gen_st = workload[0];
sync_block();
auto pid = def(0u);
if (config.global_memory_ext) {
pid = thread_x();
} else {
pid = path_id[thread_x()];
}
auto launch_condition = def(true);
if (!config.global_memory_ext) {
launch_condition = (thread_x() < work_offset[0]);
} else {
launch_condition = (all_token[pid] == work_stat[1]);
}
$if (launch_condition) {
auto pid = config.global_memory_ext ? thread_x() : path_id[thread_x()];
$if (config.global_memory_ext ? (all_token[pid] == work_stat[1]) : (thread_x() < work_offset[0])) {
$switch (all_token[pid]) {
$case (0u) {
$if (gen_st + thread_x() < workload[1]) {
work_counter.atomic(0u).fetch_sub(1u);
auto global_index = gen_st + thread_x();
auto image_size = dispatch_shape.x * dispatch_shape.y;
auto index_z = global_index / image_size;
auto index_xy = global_index % image_size;
auto index_x = index_xy % dispatch_shape.x;
auto index_y = index_xy / dispatch_shape.x;
auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z));
call_subroutine(frame, coro_token_entry);
auto next = frame.target_token & coro_token_valid_mask;
frames.write(pid, frame, graph->entry().output_fields());
auto index_z = global_index / dispatch_size_prefix_product.y;
auto index_xy = global_index % dispatch_size_prefix_product.y;
auto index_x = index_xy % dispatch_size_prefix_product.x;
auto index_y = index_xy / dispatch_size_prefix_product.x;
auto next = def(0u);
if (config.shared_memory_soa) {
auto frame = CoroFrame::create(graph->shared_frame(), make_uint3(index_x, index_y, index_z));
call_subroutine(frame, coro_token_entry);
next = frame.target_token & coro_token_valid_mask;
frames.write(pid, frame, graph->entry().output_fields());
} else {
frames[pid].coro_id = make_uint3(index_x, index_y, index_z);
frames[pid].target_token = coro_token_entry;
call_subroutine(frames[pid], coro_token_entry);
next = frames[pid].target_token & coro_token_valid_mask;
}
all_token[pid] = next;
work_counter.atomic(next).fetch_add(1u);
workload.atomic(0).fetch_add(1u);
Expand All @@ -215,6 +201,7 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
work_counter.atomic(next).fetch_add(1u);
};
}
$default { dsl::unreachable(); };
};
};
sync_block();
Expand Down

0 comments on commit b0d90b8

Please sign in to comment.