Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
support Shared<CoroFrame>::operator[] for non-SoA
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed May 15, 2024
1 parent f57ab51 commit 1161b00
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 32 deletions.
4 changes: 2 additions & 2 deletions include/luisa/coro/coro_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class LC_CORO_API CoroFrame {

private:
luisa::shared_ptr<const CoroFrameDesc> _desc;
const RefExpr *_expression;
const Expression *_expression;

public:
Var<uint3> &coro_id;
Var<uint> &target_token;

public:
CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const RefExpr *expr) noexcept;
CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const Expression *expr) noexcept;
CoroFrame(CoroFrame &&another) noexcept;
CoroFrame(const CoroFrame &another) noexcept;
CoroFrame &operator=(const CoroFrame &rhs) noexcept;
Expand Down
1 change: 1 addition & 0 deletions include/luisa/coro/coro_frame_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class LC_CORO_API CoroFrameDesc : public luisa::enable_shared_from_this<CoroFram
public:
[[nodiscard]] static luisa::shared_ptr<CoroFrameDesc> create(const Type *type, DesignatedFieldDict m) noexcept;
[[nodiscard]] auto type() const noexcept { return _type; }
[[nodiscard]] const Type *field(uint index) const noexcept;
[[nodiscard]] auto &designated_fields() const noexcept { return _designated_fields; }
[[nodiscard]] uint designated_field(luisa::string_view name) const noexcept;
[[nodiscard]] luisa::string dump() const noexcept;
Expand Down
57 changes: 51 additions & 6 deletions include/luisa/coro/coro_frame_smem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

#pragma once

#include <luisa/core/stl/optional.h>
#include <luisa/dsl/shared.h>
#include <luisa/coro/coro_frame.h>

namespace luisa::compute {

namespace detail {
[[noreturn]] LC_CORO_API void error_coro_frame_smem_subscript_on_soa() noexcept;
}// namespace detail

template<>
class Shared<coroutine::CoroFrame> {

private:
luisa::shared_ptr<const coroutine::CoroFrameDesc> _desc;
luisa::vector<const RefExpr *> _expressions;
// we need to return a `CoroFrame &` for operator[] with proper lifetime
luisa::vector<luisa::unique_ptr<coroutine::CoroFrame>> _temp_frames;
size_t _size;

private:
Expand All @@ -39,8 +46,9 @@ class Shared<coroutine::CoroFrame> {
}

public:
Shared(luisa::shared_ptr<const coroutine::CoroFrameDesc> desc, size_t n,
bool soa = true, luisa::span<const uint> soa_excluded_fields = {}) noexcept
Shared(luisa::shared_ptr<const coroutine::CoroFrameDesc> desc,
size_t n, bool soa = false,
luisa::span<const uint> soa_excluded_fields = {}) noexcept
: _desc{std::move(desc)}, _size{n} { _create(soa, soa_excluded_fields); }

Shared(Shared &&) noexcept = default;
Expand All @@ -52,10 +60,11 @@ class Shared<coroutine::CoroFrame> {
[[nodiscard]] auto is_soa() const noexcept { return _expressions.size() > 1; }
[[nodiscard]] auto size() const noexcept { return _size; }

public:
private:
/// Read index with active fields
template<typename I>
[[nodiscard]] auto read(I &&index, luisa::optional<luisa::span<const uint>> active_fields = luisa::nullopt) const noexcept {
requires is_integral_expr_v<I>
[[nodiscard]] auto _read(I &&index, luisa::optional<luisa::span<const uint>> active_fields) const noexcept {
auto i = def(std::forward<I>(index));
auto fb = detail::FunctionBuilder::current();
auto frame = fb->local(_desc->type());
Expand All @@ -71,7 +80,7 @@ class Shared<coroutine::CoroFrame> {
if (m == 0u) {
auto t = Type::of<uint>();
auto s = fb->access(Type::array(t, 3u), _expressions[m], i.expression());
std::array<const Expression *, 3u> elems;
std::array<const Expression *, 3u> elems{};
elems[0] = fb->access(t, s, fb->literal(t, 0u));
elems[1] = fb->access(t, s, fb->literal(t, 1u));
elems[2] = fb->access(t, s, fb->literal(t, 2u));
Expand All @@ -88,7 +97,8 @@ class Shared<coroutine::CoroFrame> {

/// Write index with active fields
template<typename I>
void write(I &&index, const coroutine::CoroFrame &frame, luisa::optional<luisa::span<const uint>> active_fields = luisa::nullopt) const noexcept {
requires is_integral_expr_v<I>
void _write(I &&index, const coroutine::CoroFrame &frame, luisa::optional<luisa::span<const uint>> active_fields) const noexcept {
auto i = def(std::forward<I>(index));
auto fb = detail::FunctionBuilder::current();
if (!is_soa()) {
Expand All @@ -115,6 +125,41 @@ class Shared<coroutine::CoroFrame> {
}
}
}

public:
/// Reference to the i-th element
template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame &operator[](I &&index) noexcept {
if (is_soa()) { detail::error_coro_frame_smem_subscript_on_soa(); }
auto fb = detail::FunctionBuilder::current();
auto ref = fb->access(
_desc->type(), _expressions[0],
detail::extract_expression(std::forward<I>(index)));
auto temp_frame = luisa::make_unique<coroutine::CoroFrame>(_desc, ref);
return *_temp_frames.emplace_back(std::move(temp_frame));
}

template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame read(I &&index) const noexcept {
return _read(std::forward<I>(index), luisa::nullopt);
}
template<typename I>
requires is_integral_expr_v<I>
[[nodiscard]] coroutine::CoroFrame read(I &&index, luisa::span<const uint> active_fields) const noexcept {
return _read(std::forward<I>(index), luisa::make_optional(active_fields));
}
template<typename I>
requires is_integral_expr_v<I>
void write(I &&index, const coroutine::CoroFrame &frame) const noexcept {
_write(std::forward<I>(index), frame, luisa::nullopt);
}
template<typename I>
requires is_integral_expr_v<I>
void write(I &&index, const coroutine::CoroFrame &frame, luisa::span<const uint> active_fields) const noexcept {
_write(std::forward<I>(index), frame, luisa::make_optional(active_fields));
}
};

}// namespace luisa::compute
13 changes: 4 additions & 9 deletions include/luisa/coro/coro_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ class LC_CORO_API CoroAwaiter : public concepts::Noncopyable {
using Await = luisa::move_only_function<void()>;
Await _await;

private:
template<typename T>
friend class Coroutine;
public:
explicit CoroAwaiter(Await await) noexcept
: _await{std::move(await)} {}

public:
void await() && noexcept { _await(); }
};

Expand Down Expand Up @@ -165,11 +161,10 @@ class GeneratorIter : public concepts::Noncopyable {
using Resume = luisa::move_only_function<void(CoroFrame &, CoroToken)>;
Resume _resume;

private:
template<typename U>
friend class Generator;
public:
GeneratorIter(uint n, CoroFrame frame, Resume resume) noexcept
: _n{n}, _frame{std::move(frame)}, _resume{std::move(resume)} {}
: _n{n}, _frame{std::move(frame)},
_resume{std::move(resume)} {}

public:
auto &update() noexcept {
Expand Down
2 changes: 1 addition & 1 deletion include/luisa/coro/schedulers/persistent_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct PersistentThreadsCoroSchedulerConfig {
uint thread_count = 64_k;
uint block_size = 128;
uint fetch_size = 16;
bool shared_memory_soa = true;
bool shared_memory_soa = false;
bool global_memory_ext = false;
};

Expand Down
1 change: 1 addition & 0 deletions src/coro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(LUISA_COMPUTE_CORO_SOURCES
coro_frame.cpp
coro_frame_buffer.cpp
coro_frame_desc.cpp
coro_frame_smem.cpp
coro_func.cpp
coro_graph.cpp
schedulers/persistent_threads.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/coro/coro_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace luisa::compute::coroutine {

CoroFrame::CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const RefExpr *expr) noexcept
CoroFrame::CoroFrame(luisa::shared_ptr<const CoroFrameDesc> desc, const Expression *expr) noexcept
: _desc{std::move(desc)},
_expression{expr},
coro_id{this->get<uint3>(0u)},
Expand Down
6 changes: 6 additions & 0 deletions src/coro/coro_frame_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ luisa::shared_ptr<CoroFrameDesc> CoroFrameDesc::create(const Type *type, Designa
return luisa::make_shared<CoroFrameDesc>(CoroFrameDesc{type, std::move(m)});
}

const Type *CoroFrameDesc::field(uint index) const noexcept {
auto fields = _type->members();
LUISA_ASSERT(index < fields.size(), "CoroFrame field index out of range.");
return fields[index];
}

uint CoroFrameDesc::designated_field(luisa::string_view name) const noexcept {
if (name == "coro_id") { return 0u; }
if (name == "target_token") { return 1u; }
Expand Down
7 changes: 7 additions & 0 deletions src/coro/coro_frame_smem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <luisa/core/logging.h>

namespace luisa::compute::detail {
void error_coro_frame_smem_subscript_on_soa() noexcept {
LUISA_ERROR_WITH_LOCATION("Subscripting on SOA shared frame is not allowed.");
}
}// namespace luisa::compute::detail
2 changes: 1 addition & 1 deletion src/coro/coro_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void coroutine_chained_await_impl(CoroFrame &frame, uint node_count,
};
}

inline void coroutine_generator_next_impl(
void coroutine_generator_next_impl(
CoroFrame &frame, uint node_count,
const luisa::move_only_function<void(CoroFrame &, CoroToken)> &resume) noexcept {
$switch (frame.target_token) {
Expand Down
48 changes: 36 additions & 12 deletions src/coro/schedulers/persistent_threads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
for (auto index : dsl::dynamic_range(q_fac)) {
auto s = index * config.block_size + thread_x();
all_token[s] = 0u;
// frames.write(s, coro.instantiate(), std::array{0u, 1u});
if (config.shared_memory_soa) {
frames.write(s, CoroFrame::create(graph->shared_frame()), std::array{1u});
} else {
frames[s].target_token = 0u;
}
}
for (auto index : dsl::dynamic_range(g_fac)) {
auto s = index * config.block_size + thread_x();
Expand Down Expand Up @@ -120,23 +124,37 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
auto frame_token = all_token[dst];
$if (coro_token != 0u) {
$if (frame_token != 0u) {
auto g_state = global_frames->read(global_id);
global_frames->write(global_id, frames.read(dst));
frames.write(dst, g_state);
if (config.shared_memory_soa) {
auto g_state = global_frames->read(global_id);
global_frames->write(global_id, frames.read(dst));
frames.write(dst, g_state);
} else {
auto g_state = global_frames->read(global_id);
global_frames->write(global_id, frames[dst]);
frames[dst] = g_state;
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
}
$else {
auto g_state = global_frames->read(global_id);
frames.write(dst, g_state);
if (config.shared_memory_soa) {
auto g_state = global_frames->read(global_id);
frames.write(dst, g_state);
} else {
frames[dst] = global_frames->read(global_id);
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
};
}
$else {
$if (frame_token != 0u) {
auto frame = frames.read(dst);
global_frames->write(global_id, frame);
if (config.shared_memory_soa) {
auto frame = frames.read(dst);
global_frames->write(global_id, frame);
} else {
global_frames->write(global_id, frames[dst]);
}
all_token[shared_queue_size + g_queue_id] = frame_token;
all_token[dst] = coro_token;
};
Expand Down Expand Up @@ -183,10 +201,16 @@ void persistent_threads_coro_scheduler_main_kernel_impl(
for (auto i = 1u; i < subroutine_count; i++) {
$case (i) {
work_counter.atomic(i).fetch_sub(1u);
auto frame = frames.read(pid, graph->node(i).input_fields());
call_subroutine(frame, i);
auto next = frame.target_token & coro_token_valid_mask;
frames.write(pid, frame, graph->node(i).output_fields());
auto next = def(0u);
if (config.shared_memory_soa) {
auto frame = frames.read(pid, graph->node(i).input_fields());
call_subroutine(frame, i);
next = frame.target_token & coro_token_valid_mask;
frames.write(pid, frame, graph->node(i).output_fields());
} else {
call_subroutine(frames[pid], i);
next = frames[pid].target_token & coro_token_valid_mask;
}
all_token[pid] = next;
work_counter.atomic(next).fetch_add(1u);
};
Expand Down

0 comments on commit 1161b00

Please sign in to comment.