Skip to content

Commit

Permalink
feat: socket engine thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Nov 11, 2024
1 parent b2f9ecd commit 22bcdda
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 79 deletions.
29 changes: 20 additions & 9 deletions include/dpp/socketengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <cstdint>
#include <unordered_map>
#include <memory>
#include <functional>
#include <dpp/thread_pool.h>

namespace dpp {

Expand All @@ -33,37 +35,46 @@ enum socket_event_flags : uint8_t {
WANT_ERROR = 4,
};

using socket_read_event = auto (*)(dpp::socket fd, const struct socket_events&) -> void;
using socket_write_event = auto (*)(dpp::socket fd, const struct socket_events&) -> void;
using socket_error_event = auto (*)(dpp::socket fd, const struct socket_events&, int error_code) -> void;
using socket_read_event = std::function<void(dpp::socket fd, const struct socket_events&)>;
using socket_write_event = std::function<void(dpp::socket fd, const struct socket_events&)>;
using socket_error_event = std::function<void(dpp::socket fd, const struct socket_events&, int error_code)>;

struct socket_events {
dpp::socket fd{INVALID_SOCKET};
uint8_t flags{0};
socket_read_event on_read{};
socket_write_event on_write{};
socket_error_event on_error{};
socket_events(dpp::socket socket_fd, uint8_t _flags, const socket_read_event& read_event, const socket_write_event& write_event = {}, const socket_error_event& error_event = {})
: fd(socket_fd), flags(_flags), on_read(read_event), on_write(write_event), on_error(error_event) { }

};

using socket_container = std::unordered_map<dpp::socket, std::unique_ptr<socket_events>>;

struct socket_engine_base {
socket_container fds;
std::unique_ptr<thread_pool> pool;

socket_engine_base();
socket_engine_base(const socket_engine_base&) = default;
socket_engine_base(socket_engine_base&&) = default;
socket_engine_base& operator=(const socket_engine_base&) = default;
socket_engine_base& operator=(socket_engine_base&&) = default;
socket_engine_base(const socket_engine_base&) = delete;
socket_engine_base(socket_engine_base&&) = delete;
socket_engine_base& operator=(const socket_engine_base&) = delete;
socket_engine_base& operator=(socket_engine_base&&) = delete;

virtual ~socket_engine_base() = default;

virtual void process_events() = 0;
virtual bool register_socket(dpp::socket fd, const socket_events& e);
virtual bool update_socket(dpp::socket fd, const socket_events& e);
virtual bool register_socket(const socket_events& e);
virtual bool update_socket(const socket_events& e);
virtual bool remove_socket(dpp::socket fd);
};

/* This is implemented by whatever derived form socket_engine takes */
std::unique_ptr<socket_engine_base> create_socket_engine();

#ifndef _WIN32
void set_signal_handler(int signal);
#endif

};
42 changes: 22 additions & 20 deletions src/dpp/socketengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,21 @@
#include <dpp/exception.h>
#include <csignal>
#include <memory>
#include <sslclient.h>

namespace dpp {

bool socket_engine_base::register_socket(dpp::socket fd, const socket_events &e) {
if (fd > INVALID_SOCKET && fds.find(fd) == fds.end()) {
fds.emplace(fd, std::make_unique<socket_events>(e));
bool socket_engine_base::register_socket(const socket_events &e) {
if (e.fd > INVALID_SOCKET && fds.find(e.fd) == fds.end()) {
fds.emplace(e.fd, std::make_unique<socket_events>(e));
return true;
}
return false;
}

bool socket_engine_base::update_socket(dpp::socket fd, const socket_events &e) {
if (fd > INVALID_SOCKET && fds.find(fd) != fds.end()) {
auto iter = fds.find(fd);
bool socket_engine_base::update_socket(const socket_events &e) {
if (e.fd > INVALID_SOCKET && fds.find(e.fd) != fds.end()) {
auto iter = fds.find(e.fd);
*(iter->second) = e;
return true;
}
Expand All @@ -53,19 +54,20 @@ bool socket_engine_base::remove_socket(dpp::socket fd) {
}

socket_engine_base::socket_engine_base() {
#ifndef WIN32
signal(SIGALRM, SIG_IGN);
signal(SIGHUP, SIG_IGN);
signal(SIGPIPE, SIG_IGN);
signal(SIGCHLD, SIG_IGN);
signal(SIGXFSZ, SIG_IGN);
#else
// Set up winsock.
WSADATA wsadata;
if (WSAStartup(MAKEWORD(2, 2), &wsadata)) {
throw dpp::connection_exception(err_connect_failure, "WSAStartup failure");
}
#endif
#ifndef WIN32
set_signal_handler(SIGALRM);
set_signal_handler(SIGHUP);
set_signal_handler(SIGPIPE);
set_signal_handler(SIGCHLD);
set_signal_handler(SIGXFSZ);
#else
// Set up winsock.
WSADATA wsadata;
if (WSAStartup(MAKEWORD(2, 2), &wsadata)) {
throw dpp::connection_exception(err_connect_failure, "WSAStartup failure");
}
#endif
pool = std::make_unique<thread_pool>();
}

};
}
92 changes: 59 additions & 33 deletions src/dpp/socketengines/epoll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,39 @@
#include <sys/socket.h>
#include <unistd.h>
#include <vector>
#include <iostream>

namespace dpp {

int modify_event(int epoll_handle, socket_events* eh, int new_events) {
if (new_events != eh->flags) {
struct epoll_event new_ev{};
new_ev.events = EPOLLET;
if ((new_events & WANT_READ) != 0) {
new_ev.events |= EPOLLIN;
}
if ((new_events & WANT_WRITE) != 0) {
new_ev.events |= EPOLLOUT;
}
if ((new_events & WANT_ERROR) != 0) {
new_ev.events |= EPOLLERR;
}
new_ev.data.ptr = static_cast<void *>(eh);
epoll_ctl(epoll_handle, EPOLL_CTL_MOD, eh->fd, &new_ev);
}
return new_events;
}

struct socket_engine_epoll : public socket_engine_base {

int epoll_handle{INVALID_SOCKET};
static const int epoll_hint = 128;
std::vector<struct epoll_event> events;

socket_engine_epoll(const socket_engine_epoll&) = default;
socket_engine_epoll(socket_engine_epoll&&) = default;
socket_engine_epoll& operator=(const socket_engine_epoll&) = default;
socket_engine_epoll& operator=(socket_engine_epoll&&) = default;
socket_engine_epoll(const socket_engine_epoll&) = delete;
socket_engine_epoll(socket_engine_epoll&&) = delete;
socket_engine_epoll& operator=(const socket_engine_epoll&) = delete;
socket_engine_epoll& operator=(socket_engine_epoll&&) = delete;

socket_engine_epoll() : epoll_handle(epoll_create(socket_engine_epoll::epoll_hint)) {
events.resize(socket_engine_epoll::epoll_hint);
Expand All @@ -61,50 +81,55 @@ struct socket_engine_epoll : public socket_engine_base {
epoll_event ev = events[j];

auto* const eh = static_cast<socket_events*>(ev.data.ptr);
const int fd = ev.data.fd;
const int fd = eh->fd;
if (fd == INVALID_SOCKET) {
continue;
}

if ((ev.events & EPOLLHUP) != 0U) {
eh->on_error(fd, *eh, 0);
eh->flags = modify_event(epoll_handle, eh, eh->flags & ~WANT_ERROR);
pool->enqueue([this, eh, fd]() {
eh->on_error(fd, *eh, 0);
eh->flags = modify_event(epoll_handle, eh, eh->flags | WANT_ERROR);
});
continue;
}

if ((ev.events & EPOLLERR) != 0U) {
/* Get error number */
socklen_t codesize = sizeof(int);
int errcode{};
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &errcode, &codesize) < 0) {
errcode = errno;
}
eh->on_error(fd, *eh, errcode);
eh->flags = modify_event(epoll_handle, eh, eh->flags & ~WANT_ERROR);
pool->enqueue([this, eh, fd]() {
socklen_t codesize = sizeof(int);
int errcode{};
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &errcode, &codesize) < 0) {
errcode = errno;
}
eh->on_error(fd, *eh, errcode);
eh->flags = modify_event(epoll_handle, eh, eh->flags | WANT_ERROR);
});
continue;
}

if ((ev.events & EPOLLOUT) != 0U) {
int new_events = eh->flags & ~WANT_WRITE;
if (new_events != eh->flags) {
ev.events = new_events;
ev.data.ptr = static_cast<void *>(eh);
epoll_ctl(epoll_handle, EPOLL_CTL_MOD, fd, &ev);
}
eh->flags = new_events;
}
if (ev.events & EPOLLIN) {
eh->on_read(fd, *eh);
eh->flags = modify_event(epoll_handle, eh, eh->flags & ~WANT_WRITE);
pool->enqueue([eh, fd]() {eh->on_write(fd, *eh); });
}
if (ev.events & EPOLLOUT) {
eh->on_write(fd, *eh);

if ((ev.events & EPOLLIN) != 0U) {
eh->flags = modify_event(epoll_handle, eh, eh->flags & ~WANT_READ);
pool->enqueue([this, eh, fd]() {
eh->on_read(fd, *eh);
eh->flags = modify_event(epoll_handle, eh, eh->flags | WANT_READ);
});
}
}

}

bool register_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::register_socket(fd, e);
bool register_socket(const socket_events& e) final {
bool r = socket_engine_base::register_socket(e);
if (r) {
struct epoll_event ev{};
ev.events = EPOLLET;
if ((e.flags & WANT_READ) != 0) {
ev.events |= EPOLLIN;
}
Expand All @@ -114,8 +139,8 @@ struct socket_engine_epoll : public socket_engine_base {
if ((e.flags & WANT_ERROR) != 0) {
ev.events |= EPOLLERR;
}
ev.data.ptr = fds.find(fd)->second.get();
int i = epoll_ctl(epoll_handle, EPOLL_CTL_ADD, fd, &ev);
ev.data.ptr = fds.find(e.fd)->second.get();
int i = epoll_ctl(epoll_handle, EPOLL_CTL_ADD, e.fd, &ev);
if (i < 0) {
throw dpp::connection_exception("Failed to register socket to epoll_ctl()");
}
Expand All @@ -126,10 +151,11 @@ struct socket_engine_epoll : public socket_engine_base {
return r;
}

bool update_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::update_socket(fd, e);
bool update_socket(const socket_events& e) final {
bool r = socket_engine_base::update_socket(e);
if (r) {
struct epoll_event ev{};
ev.events = EPOLLET;
if ((e.flags & WANT_READ) != 0) {
ev.events |= EPOLLIN;
}
Expand All @@ -139,8 +165,8 @@ struct socket_engine_epoll : public socket_engine_base {
if ((e.flags & WANT_ERROR) != 0) {
ev.events |= EPOLLERR;
}
ev.data.ptr = fds.find(fd)->second.get();
int i = epoll_ctl(epoll_handle, EPOLL_CTL_MOD, fd, &ev);
ev.data.ptr = fds.find(e.fd)->second.get();
int i = epoll_ctl(epoll_handle, EPOLL_CTL_MOD, e.fd, &ev);
if (i < 0) {
throw dpp::connection_exception("Failed to modify socket with epoll_ctl()");
}
Expand Down
20 changes: 10 additions & 10 deletions src/dpp/socketengines/kqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,31 @@ struct socket_engine_kqueue : public socket_engine_base {
}
}

bool register_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::register_socket(fd, e);
bool register_socket(const socket_events& e) final {
bool r = socket_engine_base::register_socket(e);
if (r) {
struct kevent* ke = get_change_kevent();
socket_events* se = fds.find(fd)->second.get();
socket_events* se = fds.find(e.fd)->second.get();
if ((se->flags & WANT_READ) != 0) {
EV_SET(ke, fd, EVFILT_READ, EV_ADD, 0, 0, static_cast<CAST_TYPE>(se));
EV_SET(ke, e.fd, EVFILT_READ, EV_ADD, 0, 0, static_cast<CAST_TYPE>(se));
}
set_event_write_flags(fd, se, 0, e.flags);
set_event_write_flags(e.fd, se, 0, e.flags);
if (fds.size() * 2 > ke_list.size()) {
ke_list.resize(fds.size() * 2);
}
}
return r;
}

bool update_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::update_socket(fd, e);
bool update_socket(const socket_events& e) final {
bool r = socket_engine_base::update_socket(e);
if (r) {
struct kevent* ke = get_change_kevent();
socket_events* se = fds.find(fd)->second.get();
socket_events* se = fds.find(e.fd)->second.get();
if ((se->flags & WANT_READ) != 0) {
EV_SET(ke, fd, EVFILT_READ, EV_ADD, 0, 0, static_cast<CAST_TYPE>(se));
EV_SET(ke, e.fd, EVFILT_READ, EV_ADD, 0, 0, static_cast<CAST_TYPE>(se));
}
set_event_write_flags(fd, se, 0, e.flags);
set_event_write_flags(e.fd, se, 0, e.flags);
if (fds.size() * 2 > ke_list.size()) {
ke_list.resize(fds.size() * 2);
}
Expand Down
12 changes: 6 additions & 6 deletions src/dpp/socketengines/poll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ struct socket_engine_poll : public socket_engine_base {
}
}

bool register_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::register_socket(fd, e);
bool register_socket(const socket_events& e) final {
bool r = socket_engine_base::register_socket(e);
if (r) {
pollfd fd_info{};
fd_info.fd = fd;
fd_info.fd = e.fd;
fd_info.events = 0;
if ((e.flags & WANT_READ) != 0) {
fd_info.events |= POLLIN;
Expand All @@ -121,12 +121,12 @@ struct socket_engine_poll : public socket_engine_base {
return r;
}

bool update_socket(dpp::socket fd, const socket_events& e) final {
bool r = socket_engine_base::update_socket(fd, e);
bool update_socket(const socket_events& e) final {
bool r = socket_engine_base::update_socket(e);
if (r) {
/* We know this will succeed */
for (pollfd& fd_info : poll_set) {
if (fd_info.fd != fd) {
if (fd_info.fd != e.fd) {
continue;
}
fd_info.events = 0;
Expand Down
Loading

0 comments on commit 22bcdda

Please sign in to comment.