Skip to content

Commit

Permalink
fixes for poll on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Nov 20, 2024
1 parent 2125d1b commit 93fc814
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 77 deletions.
6 changes: 3 additions & 3 deletions include/dpp/dns.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace dpp {
* @brief Represents a cached DNS result.
* Used by the ssl_client class to store cached copies of dns lookups.
*/
struct dns_cache_entry {
struct DPP_EXPORT dns_cache_entry {
/**
* @brief Resolved address metadata
*/
Expand Down Expand Up @@ -93,5 +93,5 @@ namespace dpp {
* @return dns_cache_entry* First IP address associated with the hostname DNS record
* @throw dpp::connection_exception On failure to resolve hostname
*/
const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port);
}
DPP_EXPORT const dns_cache_entry *resolve_hostname(const std::string &hostname, const std::string &port);
}
6 changes: 3 additions & 3 deletions include/dpp/socketengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ using socket_error_event = std::function<void(dpp::socket fd, const struct socke
* storm which will consume 100% CPU (e.g. if you request to receive write events all
* the time).
*/
struct socket_events {
struct DPP_EXPORT socket_events {
/**
* @brief File descriptor
*
Expand Down Expand Up @@ -143,7 +143,7 @@ using socket_container = std::unordered_map<dpp::socket, std::unique_ptr<socket_
* out implementation-specific behaviours (e.g. difference between edge and level triggered
* event mechanisms etc).
*/
struct socket_engine_base {
struct DPP_EXPORT socket_engine_base {
/**
* @brief File descriptors, and their states
*/
Expand Down Expand Up @@ -242,7 +242,7 @@ struct socket_engine_base {
* @brief This is implemented by whatever derived form socket_engine takes
* @param creator Creating cluster
*/
std::unique_ptr<socket_engine_base> create_socket_engine(class cluster* creator);
DPP_EXPORT std::unique_ptr<socket_engine_base> create_socket_engine(class cluster *creator);

#ifndef _WIN32
void set_signal_handler(int signal);
Expand Down
4 changes: 2 additions & 2 deletions include/dpp/sslclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ typedef std::function<void()> socket_notification_t;
* @param sfd Socket to close
* @return false on error, true on success
*/
bool close_socket(dpp::socket sfd);
DPP_EXPORT bool close_socket(dpp::socket sfd);

/**
* @brief Set a socket to blocking or non-blocking IO
Expand All @@ -63,7 +63,7 @@ bool close_socket(dpp::socket sfd);
* @param non_blocking should socket be non-blocking?
* @return false on error, true on success
*/
bool set_nonblocking(dpp::socket sockfd, bool non_blocking);
DPP_EXPORT bool set_nonblocking(dpp::socket sockfd, bool non_blocking);

/* You'd think that we would get better performance with a bigger buffer, but SSL frames are 16k each.
* SSL_read in non-blocking mode will only read 16k at a time. There's no point in a bigger buffer as
Expand Down
16 changes: 10 additions & 6 deletions include/dpp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,29 @@
#include <mutex>
#include <functional>

namespace dpp {

using work_unit = std::function<void()>;

/**
* A task within a thread pool. A simple lambda that accepts no parameters and returns void.
*/
struct thread_pool_task {
struct DPP_EXPORT thread_pool_task {
int priority;
work_unit function;
};

struct thread_pool_task_comparator {
bool operator()(const thread_pool_task &a, const thread_pool_task &b) {
return a.priority < b.priority;
};
struct DPP_EXPORT thread_pool_task_comparator {
bool operator()(const thread_pool_task &a, const thread_pool_task &b) {
return a.priority < b.priority;
};
};

/**
* @brief A thread pool contains 1 or more worker threads which accept thread_pool_task lambadas
* into a queue, which is processed in-order by whichever thread is free.
*/
struct thread_pool {
struct DPP_EXPORT thread_pool {
std::vector<std::thread> threads;
std::priority_queue<thread_pool_task, std::vector<thread_pool_task>, thread_pool_task_comparator> tasks;
std::mutex queue_mutex;
Expand All @@ -59,3 +61,5 @@ struct thread_pool {
~thread_pool();
void enqueue(thread_pool_task task);
};

}
2 changes: 1 addition & 1 deletion mlspp/include/namespace.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pragma once

// Configurable top-level MLS namespace
#define MLS_NAMESPACE ../include/dpp/mlspp/mls
#define MLS_NAMESPACE mls
3 changes: 1 addition & 2 deletions src/dpp/dns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ socket dns_cache_entry::make_connecting_socket() const {
return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
}

const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port)
{
const dns_cache_entry *resolve_hostname(const std::string &hostname, const std::string &port) {
addrinfo hints, *addrs;
dns_cache_t::const_iterator iter;
time_t now = time(nullptr);
Expand Down
8 changes: 4 additions & 4 deletions src/dpp/socketengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
namespace dpp {

bool socket_engine_base::register_socket(const socket_events &e) {
if (e.fd > INVALID_SOCKET && fds.find(e.fd) == fds.end()) {
if (e.fd != INVALID_SOCKET && fds.find(e.fd) == fds.end()) {
fds.emplace(e.fd, std::make_unique<socket_events>(e));
return true;
}
return false;
}

bool socket_engine_base::update_socket(const socket_events &e) {
if (e.fd > INVALID_SOCKET && fds.find(e.fd) != fds.end()) {
if (e.fd != INVALID_SOCKET && fds.find(e.fd) != fds.end()) {
auto iter = fds.find(e.fd);
*(iter->second) = e;
return true;
Expand All @@ -48,7 +48,7 @@ bool socket_engine_base::update_socket(const socket_events &e) {
}

socket_engine_base::socket_engine_base(cluster* creator) : owner(creator) {
#ifndef WIN32
#ifndef _WIN32
set_signal_handler(SIGALRM);
set_signal_handler(SIGXFSZ);
set_signal_handler(SIGCHLD);
Expand Down Expand Up @@ -108,7 +108,7 @@ bool socket_engine_base::delete_socket(dpp::socket fd) {
}

bool socket_engine_base::remove_socket(dpp::socket fd) {
return false;
return true;
}

}
4 changes: 2 additions & 2 deletions src/dpp/socketengines/epoll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int modify_event(int epoll_handle, socket_events* eh, int new_events) {
return new_events;
}

struct socket_engine_epoll : public socket_engine_base {
struct DPP_EXPORT socket_engine_epoll : public socket_engine_base {

int epoll_handle{INVALID_SOCKET};
static const int epoll_hint = 128;
Expand Down Expand Up @@ -196,7 +196,7 @@ struct socket_engine_epoll : public socket_engine_base {
}
};

std::unique_ptr<socket_engine_base> create_socket_engine(cluster* creator) {
DPP_EXPORT std::unique_ptr<socket_engine_base> create_socket_engine(cluster *creator) {
return std::make_unique<socket_engine_epoll>(creator);
}

Expand Down
4 changes: 2 additions & 2 deletions src/dpp/socketengines/kqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

namespace dpp {

struct socket_engine_kqueue : public socket_engine_base {
struct DPP_EXPORT socket_engine_kqueue : public socket_engine_base {

int kqueue_handle{INVALID_SOCKET};
std::array<struct kevent, 65536> ke_list;
Expand Down Expand Up @@ -147,7 +147,7 @@ struct socket_engine_kqueue : public socket_engine_base {
}
};

std::unique_ptr<socket_engine_base> create_socket_engine(cluster* creator) {
DPP_EXPORT std::unique_ptr<socket_engine_base> create_socket_engine(cluster *creator) {
return std::make_unique<socket_engine_kqueue>(creator);
}

Expand Down
92 changes: 49 additions & 43 deletions src/dpp/socketengines/poll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

namespace dpp {

struct socket_engine_poll : public socket_engine_base {
struct DPP_EXPORT socket_engine_poll : public socket_engine_base {

/* We store the pollfds as a vector. This means that insertion, deletion and updating
* are comparatively slow O(n), but these operations don't happen too often. Obtaining the
Expand All @@ -53,56 +53,68 @@ struct socket_engine_poll : public socket_engine_base {
* anyway.
*/
std::vector<pollfd> poll_set;
pollfd out_set[FD_SETSIZE]{0};

void process_events() final {
const int poll_delay = 1000;
int i = poll(poll_set.data(), static_cast<unsigned int>(poll_set.size()), poll_delay);
int processed = 0;

for (size_t index = 0; index < poll_set.size() && processed < i; index++) {
const int fd = poll_set[index].fd;
const short revents = poll_set[index].revents;

if (revents > 0) {
processed++;
if (poll_set.empty()) {
/* On many platforms, it is not possible to wait on an empty set */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} else {
if (poll_set.size() > FD_SETSIZE) {
throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!");
}

auto iter = fds.find(fd);
if (iter == fds.end()) {
continue;
}
socket_events* eh = iter->second.get();
std::copy(poll_set.begin(), poll_set.end(), out_set);

try {
int i = poll(out_set, static_cast<unsigned int>(poll_set.size()), poll_delay);
int processed = 0;

if ((revents & POLLHUP) != 0) {
eh->on_error(fd, *eh, 0);
continue;
for (size_t index = 0; index < poll_set.size() && processed < i; index++) {
const int fd = out_set[index].fd;
const short revents = out_set[index].revents;

if (revents > 0) {
processed++;
}

if ((revents & POLLERR) != 0) {
socklen_t codesize = sizeof(int);
int errcode{};
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char*)&errcode, &codesize) < 0) {
errcode = errno;
}
eh->on_error(fd, *eh, errcode);
auto iter = fds.find(fd);
if (iter == fds.end()) {
continue;
}
socket_events *eh = iter->second.get();

if ((revents & POLLIN) != 0) {
eh->on_read(fd, *eh);
}
try {

if ((revents & POLLOUT) != 0) {
int mask = eh->flags;
mask &= ~WANT_WRITE;
eh->flags = mask;
eh->on_write(fd, *eh);
}
if ((revents & POLLHUP) != 0) {
eh->on_error(fd, *eh, 0);
continue;
}

} catch (const std::exception& e) {
eh->on_error(fd, *eh, 0);
if ((revents & POLLERR) != 0) {
socklen_t codesize = sizeof(int);
int errcode{};
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *) &errcode, &codesize) < 0) {
errcode = errno;
}
eh->on_error(fd, *eh, errcode);
continue;
}

if ((revents & POLLIN) != 0) {
eh->on_read(fd, *eh);
}

if ((revents & POLLOUT) != 0) {
eh->flags &= ~WANT_WRITE;
update_socket(*eh);
eh->on_write(fd, *eh);
}

} catch (const std::exception &e) {
eh->on_error(fd, *eh, 0);
}
}
}
prune();
Expand All @@ -126,9 +138,6 @@ struct socket_engine_poll : public socket_engine_base {
if ((e.flags & WANT_WRITE) != 0) {
fd_info.events |= POLLOUT;
}
if ((e.flags & WANT_ERROR) != 0) {
fd_info.events |= POLLERR;
}
poll_set.push_back(fd_info);
}
return r;
Expand All @@ -149,9 +158,6 @@ struct socket_engine_poll : public socket_engine_base {
if ((e.flags & WANT_WRITE) != 0) {
fd_info.events |= POLLOUT;
}
if ((e.flags & WANT_ERROR) != 0) {
fd_info.events |= POLLERR;
}
break;
}
}
Expand All @@ -176,7 +182,7 @@ struct socket_engine_poll : public socket_engine_base {
}
};

std::unique_ptr<socket_engine_base> create_socket_engine(cluster* creator) {
DPP_EXPORT std::unique_ptr<socket_engine_base> create_socket_engine(cluster* creator) {
return std::make_unique<socket_engine_poll>(creator);
}

Expand Down
12 changes: 7 additions & 5 deletions src/dpp/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <dpp/thread_pool.h>
#include <shared_mutex>

namespace dpp {

thread_pool::thread_pool(size_t num_threads) {
for (size_t i = 0; i < num_threads; ++i) {
threads.emplace_back([this, i]() {
Expand Down Expand Up @@ -51,24 +53,24 @@ thread_pool::thread_pool(size_t num_threads) {
}
}

thread_pool::~thread_pool()
{
thread_pool::~thread_pool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}

cv.notify_all();
for (auto& thread : threads) {
for (auto &thread: threads) {
thread.join();
}
}

void thread_pool::enqueue(thread_pool_task task)
{
void thread_pool::enqueue(thread_pool_task task) {
{
std::unique_lock<std::mutex> lock(queue_mutex);
tasks.emplace(std::move(task));
}
cv.notify_one();
}

}
4 changes: 4 additions & 0 deletions src/dpp/voice/enabled/discover_ip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ std::string discord_voice_client::discover_ip() {
return "";
}
address_t bind_port(this->ip, this->port);
#ifndef _WIN32
if (::connect(socket.fd, bind_port.get_socket_address(), bind_port.size()) < 0) {
#else
if (WSAConnect(socket.fd, bind_port.get_socket_address(), bind_port.size(), nullptr, nullptr, nullptr, nullptr) < 0) {
#endif
log(ll_warning, "Could not connect socket for IP discovery");
return "";
}
Expand Down
Loading

0 comments on commit 93fc814

Please sign in to comment.