From 814693d9c81aa82734324f031bd1907b348ba798 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sat, 20 Apr 2024 01:59:00 +0800 Subject: [PATCH] update compute --- src/compute | 2 +- src/films/display.cpp | 8 +- src/util/thread_pool.cpp | 172 ++++++++++++++++++++++++++++++++++++++- src/util/thread_pool.h | 151 +++++++++++++++++++++++++++++++++- 4 files changed, 326 insertions(+), 7 deletions(-) diff --git a/src/compute b/src/compute index 79f47953..1b31b9a2 160000 --- a/src/compute +++ b/src/compute @@ -1 +1 @@ -Subproject commit 79f479532914beaea4f05fd5de55d43b2f11c766 +Subproject commit 1b31b9a24b8ed7f3618b03faa07efa523e19a2d0 diff --git a/src/films/display.cpp b/src/films/display.cpp index 8f437bef..6a4ece6a 100644 --- a/src/films/display.cpp +++ b/src/films/display.cpp @@ -148,9 +148,11 @@ class DisplayInstance final : public Film::Instance { if (!_window) { _window = luisa::make_unique("Display", size); auto d = node(); - _swapchain = device.create_swapchain( - _window->native_handle(), *command_buffer.stream(), - size, d->hdr(), d->vsync(), d->back_buffers()); + auto option = SwapchainOption{ + .display = _window->native_display(), + .window = _window->native_handle(), + size, d->hdr(), d->vsync(), d->back_buffers()}; + _swapchain = device.create_swapchain(*command_buffer.stream(), option); _framebuffer = device.create_image( _swapchain.backend_storage(), size); _blit = device.compile<2>([&] { diff --git a/src/util/thread_pool.cpp b/src/util/thread_pool.cpp index a4006c22..1b90e7d0 100644 --- a/src/util/thread_pool.cpp +++ b/src/util/thread_pool.cpp @@ -2,13 +2,183 @@ // Created by Mike Smith on 2023/5/18. // +#include +#include +#include +#include +#include + +#if (!defined(__clang_major__) || __clang_major__ >= 14) && defined(__cpp_lib_barrier) +#define LUISA_COMPUTE_USE_STD_BARRIER +#endif + +#ifdef LUISA_COMPUTE_USE_STD_BARRIER +#include +#endif + +#include +#include +#include #include namespace luisa::render { +namespace detail { + +[[nodiscard]] static auto &is_worker_thread() noexcept { + static thread_local auto is_worker = false; + return is_worker; +} + +[[nodiscard]] static auto &worker_thread_index() noexcept { + static thread_local auto id = 0u; + return id; +} + +static inline void check_not_in_worker_thread(std::string_view f) noexcept { + if (is_worker_thread()) [[unlikely]] { + std::ostringstream oss; + oss << std::this_thread::get_id(); + LUISA_ERROR_WITH_LOCATION( + "Invoking ThreadPool::{}() " + "from worker thread {}.", + f, oss.str()); + } +} + +}// namespace detail + +#ifdef LUISA_COMPUTE_USE_STD_BARRIER +struct Barrier : std::barrier<> { + using std::barrier<>::barrier; +}; +#else +// reference: https://github.com/yohhoy/yamc/blob/master/include/yamc_barrier.hpp +class Barrier { +private: + uint _n; + uint _counter; + uint _phase; + std::condition_variable _cv; + std::mutex _callback_mutex; + +public: + explicit Barrier(uint n) noexcept + : _n{n}, _counter{n}, _phase{0u} {} + void arrive_and_wait() noexcept { + std::unique_lock lock{_callback_mutex}; + auto arrive_phase = _phase; + if (--_counter == 0u) { + _counter = _n; + _phase++; + _cv.notify_all(); + } + while (_phase <= arrive_phase) { + _cv.wait(lock); + } + } +}; +#endif + +struct ThreadPool::Impl { + luisa::vector threads; + luisa::queue> tasks; + std::mutex mutex; + Barrier synchronize_barrier; + Barrier dispatch_barrier; + std::condition_variable cv; + bool should_stop{false}; + explicit Impl(size_t num_threads) noexcept + : synchronize_barrier(num_threads + 1u), + dispatch_barrier(num_threads) { + threads.reserve(num_threads); + for (auto i = 0u; i < num_threads; i++) { + threads.emplace_back([this, i] { + detail::is_worker_thread() = true; + detail::worker_thread_index() = i; + for (;;) { + std::unique_lock lock{mutex}; + cv.wait(lock, [this] { return !tasks.empty() || should_stop; }); + if (should_stop && tasks.empty()) [[unlikely]] { break; } + auto task = std::move(tasks.front()); + tasks.pop(); + lock.unlock(); + task(); + } + }); + } + LUISA_INFO("Created thread pool with {} thread{}.", + num_threads, num_threads == 1u ? "" : "s"); + } +}; + +ThreadPool::ThreadPool(size_t num_threads) noexcept + : _impl{luisa::make_unique([&]() { + if (num_threads == 0u) { + num_threads = std::max( + std::thread::hardware_concurrency(), 1u); + } + return num_threads; + }())} { +} + +void ThreadPool::barrier() noexcept { + detail::check_not_in_worker_thread("barrier"); + _dispatch_all([this] { _impl->dispatch_barrier.arrive_and_wait(); }); +} + +void ThreadPool::synchronize() noexcept { + detail::check_not_in_worker_thread("synchronize"); + while (task_count() != 0u) { + _dispatch_all([this] { _impl->synchronize_barrier.arrive_and_wait(); }); + _impl->synchronize_barrier.arrive_and_wait(); + } +} + +void ThreadPool::_dispatch(luisa::SharedFunction &&task) noexcept { + { + std::lock_guard lock{_impl->mutex}; + _impl->tasks.emplace(std::move(task)); + } + _impl->cv.notify_one(); +} + +void ThreadPool::_dispatch_all(luisa::SharedFunction &&task, size_t max_threads) noexcept { + { + std::lock_guard lock{_impl->mutex}; + for (auto i = 0u; i < std::min(_impl->threads.size(), max_threads) - 1u; i++) { + _impl->tasks.emplace(task); + } + _impl->tasks.emplace(std::move(task)); + } + _impl->cv.notify_all(); +} + +ThreadPool::~ThreadPool() noexcept { + { + std::lock_guard lock{_impl->mutex}; + _impl->should_stop = true; + } + _impl->cv.notify_all(); + for (auto &&t : _impl->threads) { t.join(); } +} + +uint ThreadPool::size() const noexcept { + return static_cast(_impl->threads.size()); +} +bool ThreadPool::is_worker_thread() noexcept { + return detail::is_worker_thread(); +} +uint ThreadPool::worker_thread_index() noexcept { + LUISA_ASSERT(detail::is_worker_thread(), + "ThreadPool::worker_thread_index() " + "called in non-worker thread."); + return detail::worker_thread_index(); +} + ThreadPool &global_thread_pool() noexcept { static ThreadPool pool{std::thread::hardware_concurrency()}; return pool; } -}// namespace luisa::render +}// namespace luisa diff --git a/src/util/thread_pool.h b/src/util/thread_pool.h index 448530a7..93ed9faa 100644 --- a/src/util/thread_pool.h +++ b/src/util/thread_pool.h @@ -2,10 +2,157 @@ // Created by Mike Smith on 2023/5/18. // -#include +#pragma once + +#include + +#include +#include +#include namespace luisa::render { +/// Thread pool class +class ThreadPool { + +public: + struct Impl; + +private: + luisa::unique_ptr _impl; + std::atomic_uint _task_count; + +private: + void _dispatch(luisa::SharedFunction &&task) noexcept; + void _dispatch_all(luisa::SharedFunction &&task, size_t max_threads = std::numeric_limits::max()) noexcept; + +public: + /// Create a thread pool with num_threads threads + explicit ThreadPool(size_t num_threads = 0u) noexcept; + ~ThreadPool() noexcept; + ThreadPool(ThreadPool &&) noexcept = delete; + ThreadPool(const ThreadPool &) noexcept = delete; + ThreadPool &operator=(ThreadPool &&) noexcept = delete; + ThreadPool &operator=(const ThreadPool &) noexcept = delete; + /// Return global static ThreadPool instance + [[nodiscard]] static bool is_worker_thread() noexcept; + [[nodiscard]] static uint worker_thread_index() noexcept; + +public: + /// Barrier all threads + void barrier() noexcept; + /// Synchronize all threads + void synchronize() noexcept; + /// Return size of threads + [[nodiscard]] uint size() const noexcept; + /// Return count of tasks + [[nodiscard]] uint task_count() const noexcept { return _task_count.load(); } + + /// Run a function async and return future of return value + template + requires std::is_invocable_v + auto async(F &&f) noexcept { + using R = std::invoke_result_t; + auto promise = luisa::make_unique>( + std::allocator_arg, luisa::allocator{}); + auto future = promise->get_future().share(); + _task_count.fetch_add(1u); + _dispatch([promise = std::move(promise), future, f = std::forward(f), this]() mutable noexcept { + if constexpr (std::same_as) { + f(); + promise->set_value(); + } else { + promise->set_value(f()); + } + _task_count.fetch_sub(1u); + }); + return future; + } + + /// Run a function parallel + template + requires std::is_invocable_v + void parallel(uint n, F &&f) noexcept { + if (n == 0u) return; + _task_count.fetch_add(1u); + auto counter = luisa::make_unique(0u); + _dispatch_all( + [counter = std::move(counter), n, f = std::forward(f), this]() mutable noexcept { + auto i = 0u; + while ((i = counter->fetch_add(1u)) < n) { f(i); } + if (i == n) { _task_count.fetch_sub(1u); } + }, + n); + } + + /// Run a function 2D parallel + template + requires std::is_invocable_v + void parallel(uint nx, uint ny, F &&f) noexcept { + parallel(nx * ny, [nx, f = std::forward(f)](auto i) mutable noexcept { + f(i % nx, i / nx); + }); + } + + /// Run a function 3D parallel + template + requires std::is_invocable_v + void parallel(uint nx, uint ny, uint nz, F &&f) noexcept { + parallel(nx * ny * nz, [nx, ny, f = std::forward(f)](auto i) mutable noexcept { + f(i % nx, i / nx % ny, i / nx / ny); + }); + } + + template + requires std::is_invocable_v + auto async_parallel(uint n, F &&f) noexcept { + auto promise = luisa::make_unique>( + std::allocator_arg, luisa::allocator{}); + auto future = promise->get_future().share(); + if (n == 0u) { + promise->set_value(); + return future; + } + _task_count.fetch_add(1u); + auto counter = luisa::make_unique>(0u, 0u); + _dispatch_all( + [counter = std::move(counter), promise = std::move(promise), n, f = std::forward(f), this]() mutable noexcept { + auto i = 0u; + auto dispatched_count = 0u; + while ((i = counter->first.fetch_add(1u)) < n) { + f(i); + ++dispatched_count; + } + if (i == n) { + _task_count.fetch_sub(1u); + } + if (counter->second.fetch_add(dispatched_count) + dispatched_count == n) { + promise->set_value(); + } + }, + n); + return future; + } + + /// Run a function 2D parallel + template + requires std::is_invocable_v + auto async_parallel(uint nx, uint ny, F &&f) noexcept { + return async_parallel(nx * ny, [nx, f = std::forward(f)](auto i) mutable noexcept { + f(i % nx, i / nx); + }); + } + + /// Run a function 3D parallel + template + requires std::is_invocable_v + auto async_parallel(uint nx, uint ny, uint nz, F &&f) noexcept { + return async_parallel(nx * ny * nz, [nx, ny, f = std::forward(f)](auto i) mutable noexcept { + f(i % nx, i / nx % ny, i / nx / ny); + }); + } +}; + [[nodiscard]] ThreadPool &global_thread_pool() noexcept; -}// namespace luisa::render \ No newline at end of file +}// namespace luisa