Skip to content

Commit

Permalink
update compute
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Apr 19, 2024
1 parent c051479 commit 814693d
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/compute
Submodule compute updated 385 files
8 changes: 5 additions & 3 deletions src/films/display.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ class DisplayInstance final : public Film::Instance {
if (!_window) {
_window = luisa::make_unique<Window>("Display", size);
auto d = node<Display>();
_swapchain = device.create_swapchain(
_window->native_handle(), *command_buffer.stream(),
size, d->hdr(), d->vsync(), d->back_buffers());
auto option = SwapchainOption{
.display = _window->native_display(),
.window = _window->native_handle(),
size, d->hdr(), d->vsync(), d->back_buffers()};
_swapchain = device.create_swapchain(*command_buffer.stream(), option);
_framebuffer = device.create_image<float>(
_swapchain.backend_storage(), size);
_blit = device.compile<2>([&] {
Expand Down
172 changes: 171 additions & 1 deletion src/util/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,183 @@
// Created by Mike Smith on 2023/5/18.
//

#include <version>
#include <sstream>
#include <thread>
#include <memory>
#include <condition_variable>

#if (!defined(__clang_major__) || __clang_major__ >= 14) && defined(__cpp_lib_barrier)
#define LUISA_COMPUTE_USE_STD_BARRIER
#endif

#ifdef LUISA_COMPUTE_USE_STD_BARRIER
#include <barrier>
#endif

#include <luisa/core/stl/vector.h>
#include <luisa/core/stl/queue.h>
#include <luisa/core/logging.h>
#include <util/thread_pool.h>

namespace luisa::render {

namespace detail {

[[nodiscard]] static auto &is_worker_thread() noexcept {
static thread_local auto is_worker = false;
return is_worker;
}

[[nodiscard]] static auto &worker_thread_index() noexcept {
static thread_local auto id = 0u;
return id;
}

static inline void check_not_in_worker_thread(std::string_view f) noexcept {
if (is_worker_thread()) [[unlikely]] {
std::ostringstream oss;
oss << std::this_thread::get_id();
LUISA_ERROR_WITH_LOCATION(
"Invoking ThreadPool::{}() "
"from worker thread {}.",
f, oss.str());
}
}

}// namespace detail

#ifdef LUISA_COMPUTE_USE_STD_BARRIER
struct Barrier : std::barrier<> {
using std::barrier<>::barrier;
};
#else
// reference: https://github.com/yohhoy/yamc/blob/master/include/yamc_barrier.hpp
class Barrier {
private:
uint _n;
uint _counter;
uint _phase;
std::condition_variable _cv;
std::mutex _callback_mutex;

public:
explicit Barrier(uint n) noexcept
: _n{n}, _counter{n}, _phase{0u} {}
void arrive_and_wait() noexcept {
std::unique_lock lock{_callback_mutex};
auto arrive_phase = _phase;
if (--_counter == 0u) {
_counter = _n;
_phase++;
_cv.notify_all();
}
while (_phase <= arrive_phase) {
_cv.wait(lock);
}
}
};
#endif

struct ThreadPool::Impl {
luisa::vector<std::thread> threads;
luisa::queue<luisa::SharedFunction<void()>> tasks;
std::mutex mutex;
Barrier synchronize_barrier;
Barrier dispatch_barrier;
std::condition_variable cv;
bool should_stop{false};
explicit Impl(size_t num_threads) noexcept
: synchronize_barrier(num_threads + 1u),
dispatch_barrier(num_threads) {
threads.reserve(num_threads);
for (auto i = 0u; i < num_threads; i++) {
threads.emplace_back([this, i] {
detail::is_worker_thread() = true;
detail::worker_thread_index() = i;
for (;;) {
std::unique_lock lock{mutex};
cv.wait(lock, [this] { return !tasks.empty() || should_stop; });
if (should_stop && tasks.empty()) [[unlikely]] { break; }
auto task = std::move(tasks.front());
tasks.pop();
lock.unlock();
task();
}
});
}
LUISA_INFO("Created thread pool with {} thread{}.",
num_threads, num_threads == 1u ? "" : "s");
}
};

ThreadPool::ThreadPool(size_t num_threads) noexcept
: _impl{luisa::make_unique<Impl>([&]() {
if (num_threads == 0u) {
num_threads = std::max(
std::thread::hardware_concurrency(), 1u);
}
return num_threads;
}())} {
}

void ThreadPool::barrier() noexcept {
detail::check_not_in_worker_thread("barrier");
_dispatch_all([this] { _impl->dispatch_barrier.arrive_and_wait(); });
}

void ThreadPool::synchronize() noexcept {
detail::check_not_in_worker_thread("synchronize");
while (task_count() != 0u) {
_dispatch_all([this] { _impl->synchronize_barrier.arrive_and_wait(); });
_impl->synchronize_barrier.arrive_and_wait();
}
}

void ThreadPool::_dispatch(luisa::SharedFunction<void()> &&task) noexcept {
{
std::lock_guard lock{_impl->mutex};
_impl->tasks.emplace(std::move(task));
}
_impl->cv.notify_one();
}

void ThreadPool::_dispatch_all(luisa::SharedFunction<void()> &&task, size_t max_threads) noexcept {
{
std::lock_guard lock{_impl->mutex};
for (auto i = 0u; i < std::min(_impl->threads.size(), max_threads) - 1u; i++) {
_impl->tasks.emplace(task);
}
_impl->tasks.emplace(std::move(task));
}
_impl->cv.notify_all();
}

ThreadPool::~ThreadPool() noexcept {
{
std::lock_guard lock{_impl->mutex};
_impl->should_stop = true;
}
_impl->cv.notify_all();
for (auto &&t : _impl->threads) { t.join(); }
}

uint ThreadPool::size() const noexcept {
return static_cast<uint>(_impl->threads.size());
}
bool ThreadPool::is_worker_thread() noexcept {
return detail::is_worker_thread();
}
uint ThreadPool::worker_thread_index() noexcept {
LUISA_ASSERT(detail::is_worker_thread(),
"ThreadPool::worker_thread_index() "
"called in non-worker thread.");
return detail::worker_thread_index();
}

ThreadPool &global_thread_pool() noexcept {
static ThreadPool pool{std::thread::hardware_concurrency()};
return pool;
}

}// namespace luisa::render
}// namespace luisa
151 changes: 149 additions & 2 deletions src/util/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,157 @@
// Created by Mike Smith on 2023/5/18.
//

#include <core/thread_pool.h>
#pragma once

#include <future>

#include <luisa/core/stl/memory.h>
#include <luisa/core/basic_types.h>
#include <luisa/core/shared_function.h>

namespace luisa::render {

/// Thread pool class
class ThreadPool {

public:
struct Impl;

private:
luisa::unique_ptr<Impl> _impl;
std::atomic_uint _task_count;

private:
void _dispatch(luisa::SharedFunction<void()> &&task) noexcept;
void _dispatch_all(luisa::SharedFunction<void()> &&task, size_t max_threads = std::numeric_limits<size_t>::max()) noexcept;

public:
/// Create a thread pool with num_threads threads
explicit ThreadPool(size_t num_threads = 0u) noexcept;
~ThreadPool() noexcept;
ThreadPool(ThreadPool &&) noexcept = delete;
ThreadPool(const ThreadPool &) noexcept = delete;
ThreadPool &operator=(ThreadPool &&) noexcept = delete;
ThreadPool &operator=(const ThreadPool &) noexcept = delete;
/// Return global static ThreadPool instance
[[nodiscard]] static bool is_worker_thread() noexcept;
[[nodiscard]] static uint worker_thread_index() noexcept;

public:
/// Barrier all threads
void barrier() noexcept;
/// Synchronize all threads
void synchronize() noexcept;
/// Return size of threads
[[nodiscard]] uint size() const noexcept;
/// Return count of tasks
[[nodiscard]] uint task_count() const noexcept { return _task_count.load(); }

/// Run a function async and return future of return value
template<typename F>
requires std::is_invocable_v<F>
auto async(F &&f) noexcept {
using R = std::invoke_result_t<F>;
auto promise = luisa::make_unique<std::promise<R>>(
std::allocator_arg, luisa::allocator{});
auto future = promise->get_future().share();
_task_count.fetch_add(1u);
_dispatch([promise = std::move(promise), future, f = std::forward<F>(f), this]() mutable noexcept {
if constexpr (std::same_as<R, void>) {
f();
promise->set_value();
} else {
promise->set_value(f());
}
_task_count.fetch_sub(1u);
});
return future;
}

/// Run a function parallel
template<typename F>
requires std::is_invocable_v<F, uint>
void parallel(uint n, F &&f) noexcept {
if (n == 0u) return;
_task_count.fetch_add(1u);
auto counter = luisa::make_unique<std::atomic_uint>(0u);
_dispatch_all(
[counter = std::move(counter), n, f = std::forward<F>(f), this]() mutable noexcept {
auto i = 0u;
while ((i = counter->fetch_add(1u)) < n) { f(i); }
if (i == n) { _task_count.fetch_sub(1u); }
},
n);
}

/// Run a function 2D parallel
template<typename F>
requires std::is_invocable_v<F, uint, uint>
void parallel(uint nx, uint ny, F &&f) noexcept {
parallel(nx * ny, [nx, f = std::forward<F>(f)](auto i) mutable noexcept {
f(i % nx, i / nx);
});
}

/// Run a function 3D parallel
template<typename F>
requires std::is_invocable_v<F, uint, uint, uint>
void parallel(uint nx, uint ny, uint nz, F &&f) noexcept {
parallel(nx * ny * nz, [nx, ny, f = std::forward<F>(f)](auto i) mutable noexcept {
f(i % nx, i / nx % ny, i / nx / ny);
});
}

template<typename F>
requires std::is_invocable_v<F, uint>
auto async_parallel(uint n, F &&f) noexcept {
auto promise = luisa::make_unique<std::promise<void>>(
std::allocator_arg, luisa::allocator{});
auto future = promise->get_future().share();
if (n == 0u) {
promise->set_value();
return future;
}
_task_count.fetch_add(1u);
auto counter = luisa::make_unique<std::pair<std::atomic_uint, std::atomic_uint>>(0u, 0u);
_dispatch_all(
[counter = std::move(counter), promise = std::move(promise), n, f = std::forward<F>(f), this]() mutable noexcept {
auto i = 0u;
auto dispatched_count = 0u;
while ((i = counter->first.fetch_add(1u)) < n) {
f(i);
++dispatched_count;
}
if (i == n) {
_task_count.fetch_sub(1u);
}
if (counter->second.fetch_add(dispatched_count) + dispatched_count == n) {
promise->set_value();
}
},
n);
return future;
}

/// Run a function 2D parallel
template<typename F>
requires std::is_invocable_v<F, uint, uint>
auto async_parallel(uint nx, uint ny, F &&f) noexcept {
return async_parallel(nx * ny, [nx, f = std::forward<F>(f)](auto i) mutable noexcept {
f(i % nx, i / nx);
});
}

/// Run a function 3D parallel
template<typename F>
requires std::is_invocable_v<F, uint, uint, uint>
auto async_parallel(uint nx, uint ny, uint nz, F &&f) noexcept {
return async_parallel(nx * ny * nz, [nx, ny, f = std::forward<F>(f)](auto i) mutable noexcept {
f(i % nx, i / nx % ny, i / nx / ny);
});
}
};

[[nodiscard]] ThreadPool &global_thread_pool() noexcept;

}// namespace luisa::render
}// namespace luisa

0 comments on commit 814693d

Please sign in to comment.