Skip to content

Commit

Permalink
Merge branch 'change-thread-shedule-with-mutil-list-per-worker' into …
Browse files Browse the repository at this point in the history
…change-thread-shedule
  • Loading branch information
QlQlqiqi committed Jun 26, 2024
2 parents 7020519 + c21fd6e commit cedbbca
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 45 deletions.
31 changes: 22 additions & 9 deletions src/net/include/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#include <pthread.h>
#include <atomic>
#include <string>
#include <vector>

#include "net/include/net_define.h"
#include "net/include/random.h"
#include "pstd/include/pstd_mutex.h"

namespace net {

using TaskFunc = void (*)(void*);
Expand All @@ -30,7 +31,13 @@ class ThreadPool : public pstd::noncopyable {
public:
class Worker {
public:
explicit Worker(ThreadPool* tp) : start_(false), thread_pool_(tp){};
struct Arg {
Arg(void* p, int i) : arg(p), idx(i) {}
void* arg;
int idx;
};

explicit Worker(ThreadPool* tp, int idx = 0) : start_(false), thread_pool_(tp), idx_(idx), arg_(tp, idx){};
static void* WorkerMain(void* arg);

int start();
Expand All @@ -41,6 +48,8 @@ class ThreadPool : public pstd::noncopyable {
std::atomic<bool> start_;
ThreadPool* const thread_pool_;
std::string worker_name_;
int idx_;
Arg arg_;
};

explicit ThreadPool(size_t worker_num, size_t max_queue_size, std::string thread_pool_name = "ThreadPool");
Expand All @@ -60,7 +69,7 @@ class ThreadPool : public pstd::noncopyable {
std::string thread_pool_name();

private:
void runInThread();
void runInThread(const int idx = 0);

public:
struct AdaptationContext {
Expand Down Expand Up @@ -96,12 +105,16 @@ class ThreadPool : public pstd::noncopyable {
// it's okay for other platforms to be no-ops
}

Node* CreateMissingNewerLinks(Node* head);
Node* CreateMissingNewerLinks(Node* head, int* cnt);
bool LinkOne(Node* node, std::atomic<Node*>* newest_node);

std::atomic<Node*> newest_node_;
uint16_t task_idx_;

const uint8_t nworkers_per_link_ = 2; // numer of workers per link
const uint8_t nlinks_; // number of links (upper around)
std::vector<std::atomic<Node*>> newest_node_;
std::atomic<int> node_cnt_; // for task
std::atomic<Node*> time_newest_node_;
std::vector<std::atomic<Node*>> time_newest_node_;
std::atomic<int> time_node_cnt_; // for time task

const int queue_slow_size_; // default value: min(worker_num_ * 10, max_queue_size_)
Expand All @@ -112,14 +125,14 @@ class ThreadPool : public pstd::noncopyable {

AdaptationContext adp_ctx;

size_t worker_num_;
const size_t worker_num_;
std::string thread_pool_name_;
std::vector<Worker*> workers_;
std::atomic<bool> running_;
std::atomic<bool> should_stop_;

pstd::Mutex mu_;
pstd::CondVar rsignal_;
std::vector<pstd::Mutex> mu_;
std::vector<pstd::CondVar> rsignal_;
};

} // namespace net
Expand Down
96 changes: 60 additions & 36 deletions src/net/src/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

namespace net {

void* ThreadPool::Worker::WorkerMain(void* arg) {
auto tp = static_cast<ThreadPool*>(arg);
tp->runInThread();
void* ThreadPool::Worker::WorkerMain(void* p) {
auto arg = static_cast<Arg*>(p);
auto tp = static_cast<ThreadPool*>(arg->arg);
tp->runInThread(arg->idx);
return nullptr;
}

int ThreadPool::Worker::start() {
if (!start_.load()) {
if (pthread_create(&thread_id_, nullptr, &WorkerMain, thread_pool_) != 0) {
if (pthread_create(&thread_id_, nullptr, &WorkerMain, &arg_) != 0) {
return -1;
} else {
start_.store(true);
Expand All @@ -44,9 +45,11 @@ int ThreadPool::Worker::stop() {
}

ThreadPool::ThreadPool(size_t worker_num, size_t max_queue_size, std::string thread_pool_name)
: newest_node_(nullptr),
: nlinks_((worker_num + nworkers_per_link_ - 1) / nworkers_per_link_),
// : nlinks_(worker_num),
newest_node_(nlinks_),
node_cnt_(0),
time_newest_node_(nullptr),
time_newest_node_(nlinks_),
time_node_cnt_(0),
queue_slow_size_(std::min(worker_num * 10, max_queue_size)),
max_queue_size_(max_queue_size),
Expand All @@ -56,15 +59,22 @@ ThreadPool::ThreadPool(size_t worker_num, size_t max_queue_size, std::string thr
worker_num_(worker_num),
thread_pool_name_(std::move(thread_pool_name)),
running_(false),
should_stop_(false) {}
should_stop_(false),
mu_(nlinks_),
rsignal_(nlinks_) {
for (size_t i = 0; i < nlinks_; ++i) {
newest_node_[i] = nullptr;
time_newest_node_[i] = nullptr;
}
}

ThreadPool::~ThreadPool() { stop_thread_pool(); }

int ThreadPool::start_thread_pool() {
if (!running_.load()) {
should_stop_.store(false);
for (size_t i = 0; i < worker_num_; ++i) {
workers_.push_back(new Worker(this));
for (size_t i = 0; i < nlinks_; ++i) {
workers_.push_back(new Worker(this, i));
int res = workers_[i]->start();
if (res != 0) {
return kCreateThreadError;
Expand All @@ -79,7 +89,9 @@ int ThreadPool::stop_thread_pool() {
int res = 0;
if (running_.load()) {
should_stop_.store(true);
rsignal_.notify_all();
for (auto& r : rsignal_) {
r.notify_all();
}
for (const auto worker : workers_) {
res = worker->stop();
if (res != 0) {
Expand Down Expand Up @@ -107,12 +119,13 @@ void ThreadPool::Schedule(TaskFunc func, void* arg) {
if (node_cnt_.load(std::memory_order_relaxed) >= queue_slow_size_) {
std::this_thread::yield();
}
// std::unique_lock lock(mu_);

if (LIKELY(!should_stop())) {
auto node = new Node(func, arg);
LinkOne(node, &newest_node_);
auto idx = ++task_idx_;
LinkOne(node, &newest_node_[idx % nlinks_]);
node_cnt_++;
rsignal_.notify_one();
rsignal_[idx % nlinks_].notify_one();
}
}

Expand All @@ -124,12 +137,12 @@ void ThreadPool::DelaySchedule(uint64_t timeout, TaskFunc func, void* arg) {
uint64_t unow = std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count();
uint64_t exec_time = unow + timeout * 1000;

// std::unique_lock lock(mu_);
if (LIKELY(!should_stop())) {
auto idx = ++task_idx_;
auto node = new Node(exec_time, func, arg);
LinkOne(node, &time_newest_node_);
LinkOne(node, &newest_node_[idx % nlinks_]);
time_node_cnt_++;
rsignal_.notify_all();
rsignal_[idx % nlinks_].notify_all();
}
}

Expand All @@ -143,15 +156,21 @@ void ThreadPool::cur_time_queue_size(size_t* qsize) { *qsize = time_node_cnt_.lo

std::string ThreadPool::thread_pool_name() { return thread_pool_name_; }

void ThreadPool::runInThread() {
void ThreadPool::runInThread(const int idx) {
Node* tmp = nullptr;
Node* last = nullptr;
Node* time_last = nullptr;

auto& newest_node = newest_node_[idx % nlinks_];
auto& time_newest_node = time_newest_node_[idx % nlinks_];
auto& mu = mu_[idx % nlinks_];
auto& rsignal = rsignal_[idx % nlinks_];

while (LIKELY(!should_stop())) {
std::unique_lock lock(mu_);
rsignal_.wait(lock, [this]() {
return newest_node_.load(std::memory_order_relaxed) != nullptr ||
UNLIKELY(time_newest_node_.load(std::memory_order_relaxed) != nullptr) || UNLIKELY(should_stop());
std::unique_lock lock(mu);
rsignal.wait(lock, [this, &newest_node, &time_newest_node]() {
return newest_node.load(std::memory_order_relaxed) != nullptr ||
UNLIKELY(time_newest_node.load(std::memory_order_relaxed) != nullptr) || UNLIKELY(should_stop());
});
lock.unlock();

Expand All @@ -160,26 +179,26 @@ void ThreadPool::runInThread() {
break;
}

last = newest_node_.exchange(nullptr);
time_last = time_newest_node_.exchange(nullptr);
last = newest_node.exchange(nullptr);
time_last = time_newest_node.exchange(nullptr);
if (last == nullptr && LIKELY(time_last == nullptr)) {
// 1. loop for short time
for (uint32_t tries = 0; tries < 200; ++tries) {
if (newest_node_.load(std::memory_order_acquire) != nullptr) {
last = newest_node_.exchange(nullptr);
if (newest_node.load(std::memory_order_acquire) != nullptr) {
last = newest_node.exchange(nullptr);
if (last != nullptr) {
goto exec;
}
}
if (UNLIKELY(time_newest_node_.load(std::memory_order_acquire) != nullptr)) {
time_last = time_newest_node_.exchange(nullptr);
if (UNLIKELY(time_newest_node.load(std::memory_order_acquire) != nullptr)) {
time_last = time_newest_node.exchange(nullptr);
if (time_last != nullptr) {
goto exec;
}
}
AsmVolatilePause();
}

// 2. loop for a little short time again
const size_t kMaxSlowYieldsWhileSpinning = 3;
auto& yield_credit = adp_ctx.value;
Expand All @@ -198,16 +217,16 @@ void ThreadPool::runInThread() {
while ((iter_begin - spin_begin) <= std::chrono::microseconds(max_yield_usec_)) {
std::this_thread::yield();

if (newest_node_.load(std::memory_order_acquire) != nullptr) {
last = newest_node_.exchange(nullptr);
if (newest_node.load(std::memory_order_acquire) != nullptr) {
last = newest_node.exchange(nullptr);
if (last != nullptr) {
would_spin_again = true;
// success
break;
}
}
if (UNLIKELY(time_newest_node_.load(std::memory_order_acquire) != nullptr)) {
time_last = time_newest_node_.exchange(nullptr);
if (UNLIKELY(time_newest_node.load(std::memory_order_acquire) != nullptr)) {
time_last = time_newest_node.exchange(nullptr);
if (time_last != nullptr) {
would_spin_again = true;
// success
Expand Down Expand Up @@ -243,7 +262,9 @@ void ThreadPool::runInThread() {
exec:
// do all normal tasks older than this task pointed last
if (LIKELY(last != nullptr)) {
auto first = CreateMissingNewerLinks(last);
int cnt = 1;
auto first = CreateMissingNewerLinks(last, &cnt);
// node_cnt_ -= cnt;
assert(!first->is_time_task);
do {
first->Exec();
Expand All @@ -256,7 +277,8 @@ void ThreadPool::runInThread() {

// do all time tasks older than this task pointed time_last
if (UNLIKELY(time_last != nullptr)) {
auto time_first = CreateMissingNewerLinks(time_last);
int cnt = 1;
auto time_first = CreateMissingNewerLinks(time_last, &cnt);
do {
// time task may block normal task
auto now = std::chrono::system_clock::now();
Expand All @@ -268,7 +290,7 @@ void ThreadPool::runInThread() {
time_first->Exec();
} else {
lock.lock();
rsignal_.wait_for(lock, std::chrono::microseconds(exec_time - unow));
rsignal.wait_for(lock, std::chrono::microseconds(exec_time - unow));
lock.unlock();
time_first->Exec();
}
Expand All @@ -282,14 +304,16 @@ void ThreadPool::runInThread() {
}
}

ThreadPool::Node* ThreadPool::CreateMissingNewerLinks(Node* head) {
ThreadPool::Node* ThreadPool::CreateMissingNewerLinks(Node* head, int* cnt) {
assert(head != nullptr);
assert(cnt != nullptr && *cnt == 1);
Node* next = nullptr;
while (true) {
next = head->link_older;
if (next == nullptr) {
return head;
}
++(*cnt);
next->link_newer = head;
head = next;
}
Expand Down

0 comments on commit cedbbca

Please sign in to comment.