Skip to content

Commit

Permalink
Add classes for Task and SharedTask in threadpool (#5391)
Browse files Browse the repository at this point in the history
Currently `ThreadPool::Task` is a typedef to `std:: future<Status>`.
This PR:
1) Replaces this with a full class that imposes things like `wait` for
our async processes to go through our ThreadPool's wait method. In that
way, we never use a `std::future`'s `wait` method directly and we avoid
possible deadlocks because only the ThreadPool's `wait` will yield.
2) Adjusts the structure and relationship of tasks and the threadpool so
that the caller can rely on the task's internal functions and doesn't
need to keep track of the relationship between the ThreadPool and each
task manually throughout the codebase.
3) Adds a new first class citizen of our ThreadPool: `SharedTask`, that
is encapsulating `std::shared_future` which allows multiple threads to
wait on an async operation result to become available. This is needed
for the upcoming work on parallelizing IO and compute operations in the
codebase even further.

This is heavily influenced from similar work by @Shelnutt2.

---
TYPE: IMPROVEMENT
DESC: Add classes for Task and SharedTask in threadpool

---------

Co-authored-by: Seth Shelnutt <[email protected]>
  • Loading branch information
ypatia and Shelnutt2 authored Dec 9, 2024
1 parent 0ea3351 commit 2637682
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 32 deletions.
6 changes: 3 additions & 3 deletions tiledb/common/thread_pool/test/unit_thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void wait_all(
ThreadPool& pool, bool use_wait, std::vector<ThreadPool::Task>& results) {
if (use_wait) {
for (auto& r : results) {
REQUIRE(pool.wait(r).ok());
REQUIRE(r.wait().ok());
}
} else {
REQUIRE(pool.wait_all(results).ok());
Expand All @@ -117,7 +117,7 @@ Status wait_all_status(
if (use_wait) {
Status ret;
for (auto& r : results) {
auto st = pool.wait(r);
auto st = r.wait();
if (ret.ok() && !st.ok()) {
ret = st;
}
Expand All @@ -139,7 +139,7 @@ uint64_t wait_all_num_status(
int num_ok = 0;
if (use_wait) {
for (auto& r : results) {
num_ok += pool.wait(r).ok() ? 1 : 0;
num_ok += r.wait().ok() ? 1 : 0;
}
} else {
std::vector<Status> statuses = pool.wait_all_status(results);
Expand Down
50 changes: 44 additions & 6 deletions tiledb/common/thread_pool/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void ThreadPool::shutdown() {
threads_.clear();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
Status ThreadPool::wait_all(std::vector<ThreadPoolTask*>& tasks) {
auto statuses = wait_all_status(tasks);
for (auto& st : statuses) {
if (!st.ok()) {
Expand All @@ -131,14 +131,33 @@ Status ThreadPool::wait_all(std::vector<Task>& tasks) {
return Status::Ok();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all(task_ptrs);
}

Status ThreadPool::wait_all(std::vector<SharedTask>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all(task_ptrs);
}

// Return a vector of Status. If any task returns an error value or throws an
// exception, we save an error code in the corresponding location in the Status
// vector. All tasks are waited on before return. Multiple error statuses may
// be saved. We may call logger here because thread pool will not be used until
// context is fully constructed (which will include logger).
// Unfortunately, C++ does not have the notion of an aggregate exception, so we
// don't throw in the case of errors/exceptions.
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<Status> ThreadPool::wait_all_status(
std::vector<ThreadPoolTask*>& tasks) {
std::vector<Status> statuses(tasks.size());

std::queue<size_t> pending_tasks;
Expand All @@ -154,17 +173,17 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
pending_tasks.pop();
auto& task = tasks[task_id];

if (!task.valid()) {
if (task && !task->valid()) {
statuses[task_id] = Status_ThreadPoolError("Invalid task future");
LOG_STATUS_NO_RETURN_VALUE(statuses[task_id]);
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
task->wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
return task->get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
Expand Down Expand Up @@ -205,7 +224,26 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
return statuses;
}

Status ThreadPool::wait(Task& task) {
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all_status(task_ptrs);
}

std::vector<Status> ThreadPool::wait_all_status(
std::vector<SharedTask>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all_status(task_ptrs);
}

Status ThreadPool::wait(ThreadPoolTask& task) {
while (true) {
if (!task.valid()) {
return Status_ThreadPoolError("Invalid task future");
Expand Down
196 changes: 186 additions & 10 deletions tiledb/common/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,175 @@ namespace tiledb::common {

class ThreadPool {
public:
using Task = std::future<Status>;
/**
* @brief Abstract base class for tasks that can run in this threadpool.
*/
class ThreadPoolTask {
public:
ThreadPoolTask() = default;
ThreadPoolTask(ThreadPool* tp)
: tp_(tp){};

virtual ~ThreadPoolTask(){};

protected:
friend class ThreadPool;

/* C.67 A polymorphic class should suppress public copy/move to prevent
* slicing */
ThreadPoolTask(const ThreadPoolTask&) = default;
ThreadPoolTask& operator=(const ThreadPoolTask&) = default;
ThreadPoolTask(ThreadPoolTask&&) = default;
ThreadPoolTask& operator=(ThreadPoolTask&&) = default;

/**
* Pure virtual functions that tasks need to implement so that they can be
* run in the threadpool wait loop
*/
virtual std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const = 0;
virtual bool valid() const noexcept = 0;
virtual Status get() = 0;

ThreadPool* tp_{nullptr};
};

/**
* @brief Task class encapsulating std::future. Like std::future it's shared
* state can only be get once and thus only one thread. It can only be moved
* and not copied.
*/
class Task : public ThreadPoolTask {
public:
/**
* Default constructor
* @brief Default constructed SharedTask is possible but not valid().
*/
Task() = default;

/**
* Constructor from std::future
*/
Task(std::future<Status>&& f, ThreadPool* tp)
: ThreadPoolTask(tp)
, f_(std::move(f)){};

/**
* Wait in the threadpool for this task to be ready.
*/
Status wait() {
if (tp_ == nullptr) {
throw std::runtime_error("Cannot wait, threadpool is not initialized.");
} else if (!f_.valid()) {
throw std::runtime_error("Cannot wait, task is invalid.");
} else {
return tp_->wait(*this);
}
}

/**
* Is this task valid. Wait can only be called on vaid tasks.
*/
bool valid() const noexcept {
return f_.valid();
}

private:
friend class ThreadPool;

/**
* Wait for input milliseconds for this task to be ready.
*/
std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const {
return f_.wait_for(timeout_duration);
}

/**
* Get the result of that task. Can only be used once. Only accessible from
* within the threadpool `wait` loop.
*/
Status get() {
return f_.get();
}

/**
* The encapsulated std::shared_future
*/
std::future<Status> f_;
};

/**
* @brief SharedTask class encapsulating std::shared_future. Like
* std::shared_future multiple threads can wait/get on the shared state
* multiple times. It can be both moved and copied.
*/
class SharedTask : public ThreadPoolTask {
public:
/**
* Default constructor
* @brief Default constructed SharedTask is possible but not valid().
*/
SharedTask() = default;

/**
* Constructor from std::future or std::shared_future
*/
SharedTask(auto&& f, ThreadPool* tp)
: ThreadPoolTask(tp)
, f_(std::forward<decltype(f)>(f)){};

/**
* Move constructor from a Task
*/
SharedTask(Task&& t) noexcept
: ThreadPoolTask(t.tp_)
, f_(std::move(t.f_)){};

/**
* Wait in the threadpool for this task to be ready.
*/
Status wait() {
if (tp_ == nullptr) {
throw std::runtime_error("Cannot wait, threadpool is not initialized.");
} else if (!f_.valid()) {
throw std::runtime_error("Cannot wait, shared task is invalid.");
} else {
return tp_->wait(*this);
}
}

/**
* Is this task valid. Wait can only be called on vaid tasks.
*/
bool valid() const noexcept {
return f_.valid();
}

private:
friend class ThreadPool;

/**
* Wait for input milliseconds for this task to be ready.
*/
std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const {
return f_.wait_for(timeout_duration);
}

/**
* Get the result of that task. Can be called multiple times from multiple
* threads. Only accessible from within the threadpool `wait` loop.
*/
Status get() {
return f_.get();
}

/**
* The encapsulated std::shared_future
*/
std::shared_future<Status> f_;
};

/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
Expand Down Expand Up @@ -108,7 +276,7 @@ class ThreadPool {
return std::apply(std::move(f), std::move(args));
});

std::future<R> future = task->get_future();
Task future(task->get_future(), this);

task_queue_.push(task);

Expand All @@ -127,6 +295,19 @@ class ThreadPool {
return async(std::forward<Fn>(f), std::forward<Args>(args)...);
}

/* Helper functions for lists that consists purely of Tasks */
Status wait_all(std::vector<Task>& tasks);
std::vector<Status> wait_all_status(std::vector<Task>& tasks);

/* Helper functions for lists that consists purely of SharedTasks */
Status wait_all(std::vector<SharedTask>& tasks);
std::vector<Status> wait_all_status(std::vector<SharedTask>& tasks);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */

private:
/**
* Wait on all the given tasks to complete. This function is safe to call
* recursively and may execute pending tasks on the calling thread while
Expand All @@ -136,7 +317,7 @@ class ThreadPool {
* @return Status::Ok if all tasks returned Status::Ok, otherwise the first
* error status is returned
*/
Status wait_all(std::vector<Task>& tasks);
Status wait_all(std::vector<ThreadPoolTask*>& tasks);

/**
* Wait on all the given tasks to complete, returning a vector of their return
Expand All @@ -151,7 +332,7 @@ class ThreadPool {
* @param tasks Task list to wait on
* @return Vector of each task's Status.
*/
std::vector<Status> wait_all_status(std::vector<Task>& tasks);
std::vector<Status> wait_all_status(std::vector<ThreadPoolTask*>& tasks);

/**
* Wait on a single tasks to complete. This function is safe to call
Expand All @@ -162,13 +343,8 @@ class ThreadPool {
* @return Status::Ok if the task returned Status::Ok, otherwise the error
* status is returned
*/
Status wait(Task& task);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */
Status wait(ThreadPoolTask& task);

private:
/** The worker thread routine */
void worker();

Expand Down
Loading

0 comments on commit 2637682

Please sign in to comment.