diff --git a/include/dpp/discordclient.h b/include/dpp/discordclient.h index d57c2d775b..bb7c4a0974 100644 --- a/include/dpp/discordclient.h +++ b/include/dpp/discordclient.h @@ -179,6 +179,13 @@ class DPP_EXPORT discord_client : public websocket_client */ void disconnect_voice_internal(snowflake guild_id, bool send_json = true); + /** + * @brief Start connecting the websocket + * + * Called from the constructor, or during reconnection + */ + void start_connecting(); + private: /** @@ -191,17 +198,6 @@ class DPP_EXPORT discord_client : public websocket_client */ std::deque message_queue; - /** - * @brief Thread this shard is executing on - */ - std::thread* runner; - - /** - * @brief Run shard loop under a thread. - * Calls discord_client::run() from within a std::thread. - */ - void thread_run(); - /** * @brief If true, stream compression is enabled */ @@ -499,6 +495,11 @@ class DPP_EXPORT discord_client : public websocket_client */ void run(); + /** + * @brief Called when the HTTP socket is closed + */ + virtual void on_disconnect(); + /** * @brief Connect to a voice channel * diff --git a/include/dpp/thread_pool.h b/include/dpp/thread_pool.h index ef15ab33ea..1f40c125ec 100644 --- a/include/dpp/thread_pool.h +++ b/include/dpp/thread_pool.h @@ -41,7 +41,7 @@ struct DPP_EXPORT thread_pool_task { }; struct DPP_EXPORT thread_pool_task_comparator { - bool operator()(const thread_pool_task &a, const thread_pool_task &b) { + bool operator()(const thread_pool_task &a, const thread_pool_task &b) const { return a.priority < b.priority; }; }; @@ -57,7 +57,12 @@ struct DPP_EXPORT thread_pool { std::condition_variable cv; bool stop{false}; - explicit thread_pool(size_t num_threads = std::thread::hardware_concurrency()); + /** + * @brief Create a new priority thread pool + * @param creator creating cluster (for logging) + * @param num_threads number of threads in the pool + */ + explicit thread_pool(class cluster* creator, size_t num_threads = std::thread::hardware_concurrency()); ~thread_pool(); void enqueue(thread_pool_task task); }; diff --git a/include/dpp/wsclient.h b/include/dpp/wsclient.h index 73f8eb9321..c7fc99d4ee 100644 --- a/include/dpp/wsclient.h +++ b/include/dpp/wsclient.h @@ -235,6 +235,11 @@ class DPP_EXPORT websocket_client : public ssl_client { * This indicates graceful close. */ void send_close_packet(); + + /** + * @brief Called on HTTP socket closure + */ + virtual void on_disconnect(); }; } diff --git a/src/dpp/cluster.cpp b/src/dpp/cluster.cpp index c09616b9b1..604deff496 100644 --- a/src/dpp/cluster.cpp +++ b/src/dpp/cluster.cpp @@ -89,7 +89,7 @@ cluster::cluster(const std::string &_token, uint32_t _intents, uint32_t _shards, numshards(_shards), cluster_id(_cluster_id), maxclusters(_maxclusters), rest_ping(0.0), cache_policy(policy), ws_mode(ws_json) { socketengine = create_socket_engine(this); - pool = std::make_unique(request_threads); + pool = std::make_unique(this, request_threads); /* Instantiate REST request queues */ try { rest = new request_queue(this, request_threads); diff --git a/src/dpp/discordclient.cpp b/src/dpp/discordclient.cpp index 15b8fd24f2..50ba878ced 100644 --- a/src/dpp/discordclient.cpp +++ b/src/dpp/discordclient.cpp @@ -65,8 +65,7 @@ thread_local static std::string last_ping_message; discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint32_t _max_shards, const std::string &_token, uint32_t _intents, bool comp, websocket_protocol_t ws_proto) : websocket_client(_cluster, _cluster->default_gateway, "443", comp ? (ws_proto == ws_json ? PATH_COMPRESSED_JSON : PATH_COMPRESSED_ETF) : (ws_proto == ws_json ? PATH_UNCOMPRESSED_JSON : PATH_UNCOMPRESSED_ETF)), - terminating(false), - runner(nullptr), + terminating(false), compressed(comp), decomp_buffer(nullptr), zlib(nullptr), @@ -90,6 +89,10 @@ discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint3 protocol(ws_proto), resume_gateway_url(_cluster->default_gateway) { + start_connecting(); +} + +void discord_client::start_connecting() { try { zlib = new zlibcontext(); etf = new etf_parser(); @@ -112,10 +115,6 @@ discord_client::discord_client(dpp::cluster* _cluster, uint32_t _shard_id, uint3 void discord_client::cleanup() { terminating = true; - if (runner) { - runner->join(); - delete runner; - } delete etf; delete zlib; } @@ -125,6 +124,19 @@ discord_client::~discord_client() cleanup(); } +void discord_client::on_disconnect() +{ + set_resume_hostname(); + log(dpp::ll_debug, "Lost connection to websocket on shard " + std::to_string(shard_id) + ", reconnecting in 5 seconds..."); + owner->start_timer([this](auto handle) { + owner->stop_timer(handle); + cleanup(); + terminating = false; + start_connecting(); + run(); + }, 5); +} + uint64_t discord_client::get_decompressed_bytes_in() { return decompressed_total; @@ -159,20 +171,12 @@ void discord_client::set_resume_hostname() hostname = resume_gateway_url; } -void discord_client::thread_run() -{ -} - void discord_client::run() { - // TODO: This only runs once. Replace the reconnect mechanics. - // To make this work, we will need to intercept errors. setup_zlib(); ready = false; message_queue.clear(); ssl_client::read_loop(); - //ssl_client::close(); - //end_zlib(); } bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode) diff --git a/src/dpp/thread_pool.cpp b/src/dpp/thread_pool.cpp index 2ee5624117..c7a64bc176 100644 --- a/src/dpp/thread_pool.cpp +++ b/src/dpp/thread_pool.cpp @@ -23,12 +23,13 @@ #include #include #include +#include namespace dpp { -thread_pool::thread_pool(size_t num_threads) { +thread_pool::thread_pool(cluster* creator, size_t num_threads) { for (size_t i = 0; i < num_threads; ++i) { - threads.emplace_back([this, i]() { + threads.emplace_back([this, i, creator]() { dpp::utility::set_thread_name("pool/exec/" + std::to_string(i)); while (true) { thread_pool_task task; @@ -47,7 +48,15 @@ thread_pool::thread_pool(size_t num_threads) { tasks.pop(); } - task.function(); + try { + task.function(); + } + catch (const std::exception &e) { + creator->log(ll_warning, "Uncaught exception in thread pool: " + std::string(e.what())); + } + catch (...) { + creator->log(ll_warning, "Uncaught exception in thread pool, but not derived from std::exception!"); + } } }); } diff --git a/src/dpp/wsclient.cpp b/src/dpp/wsclient.cpp index e5ebe762de..328b5d8d8e 100644 --- a/src/dpp/wsclient.cpp +++ b/src/dpp/wsclient.cpp @@ -326,8 +326,13 @@ void websocket_client::error(uint32_t errorcode) { } +void websocket_client::on_disconnect() +{ +} + void websocket_client::close() { + this->on_disconnect(); this->state = HTTP_HEADERS; ssl_client::close(); } diff --git a/src/soaktest/soak.cpp b/src/soaktest/soak.cpp index de82ece80b..cbc462c194 100644 --- a/src/soaktest/soak.cpp +++ b/src/soaktest/soak.cpp @@ -35,12 +35,6 @@ int main() { }); soak_test.start(dpp::st_return); - dpp::https_client c2(&soak_test, "github.com", 80, "/", "GET", "", {}, true, 2, "1.1", [](dpp::https_client* c2) { - std::string hdr2 = c2->get_header("location"); - std::string content2 = c2->get_content(); - std::cout << "hdr2 == " << hdr2 << " ? https://github.com/ status = " << c2->get_status() << "\n"; - }); - while (true) { std::this_thread::sleep_for(60s); dpp::discord_client* dc = soak_test.get_shard(0);