From 8542b08f2d861701906f0d29f1d7b425358ba4a5 Mon Sep 17 00:00:00 2001 From: Craig Edwards Date: Mon, 18 Nov 2024 15:38:40 +0000 Subject: [PATCH] fix: add auto retry to failed connect, fixes failed unit tests --- include/dpp/httpsclient.h | 12 +-- include/dpp/sslclient.h | 31 +++++++ src/dpp/events/message_create.cpp | 12 +-- src/dpp/httpsclient.cpp | 34 +++++--- src/dpp/message.cpp | 9 +- src/dpp/sslclient.cpp | 136 ++++++++++++++++++------------ src/unittest/test.cpp | 46 +++++----- 7 files changed, 177 insertions(+), 103 deletions(-) diff --git a/include/dpp/httpsclient.h b/include/dpp/httpsclient.h index fdb2336aea..46e3b73901 100644 --- a/include/dpp/httpsclient.h +++ b/include/dpp/httpsclient.h @@ -137,10 +137,6 @@ using https_client_completion_event = std::function; * @note plaintext HTTP without SSL is also supported via a "downgrade" setting */ class DPP_EXPORT https_client : public ssl_client { - /** - * @brief Current connection state - */ - http_state state; /** * @brief The type of the request, e.g. GET, POST @@ -241,7 +237,12 @@ class DPP_EXPORT https_client : public ssl_client { * @brief Function to call when HTTP request is completed */ https_client_completion_event completed; - + + /** + * @brief Current connection state + */ + http_state state; + /** * @brief Connect to a specific HTTP(S) server and complete a request. * @@ -361,7 +362,6 @@ class DPP_EXPORT https_client : public ssl_client { * @return Split URL */ static http_connect_info get_host_info(std::string url); - }; } diff --git a/include/dpp/sslclient.h b/include/dpp/sslclient.h index 7d97e26945..5266e70ab0 100644 --- a/include/dpp/sslclient.h +++ b/include/dpp/sslclient.h @@ -147,6 +147,16 @@ class DPP_EXPORT ssl_client */ time_t last_tick; + /** + * @brief Start time of connection + */ + time_t start; + + /** + * @brief How many times we retried connect() + */ + uint8_t connect_retries{0}; + /** * @brief Hostname connected to */ @@ -177,11 +187,24 @@ class DPP_EXPORT ssl_client */ bool connected{false}; + /** + * @brief True if tcp connect() succeeded + */ + bool tcp_connect_done{false}; + /** * @brief Timer handle for one second timer */ timer timer_handle; + /** + * @brief Unique ID of socket used as a nonce + * You can use this to identify requests vs reply + * if you want. D++ itself only sets this, and does + * not use it in any logic. It starts at 1 and increments + * for each request made. + */ + uint64_t unique_id; /** * @brief Called every second @@ -207,6 +230,14 @@ class DPP_EXPORT ssl_client */ uint64_t get_bytes_in(); + /** + * @brief Every request made has a unique ID. This increments + * for every request, starting at 1. You can use this for statistics, + * or to associate requests and replies in external event loops. + * @return Unique ID + */ + uint64_t get_unique_id() const; + /** * @brief Get SSL cipher name * @return std::string ssl cipher name diff --git a/src/dpp/events/message_create.cpp b/src/dpp/events/message_create.cpp index e017438ab8..9c79d11d57 100644 --- a/src/dpp/events/message_create.cpp +++ b/src/dpp/events/message_create.cpp @@ -22,7 +22,6 @@ #include #include #include -#include #include @@ -39,11 +38,12 @@ namespace dpp::events { void message_create::handle(discord_client* client, json &j, const std::string &raw) { if (!client->creator->on_message_create.empty()) { - json d = j["d"]; - dpp::message_create_t msg(client, raw); - msg.msg.fill_from_json(&d, client->creator->cache_policy); - msg.msg.owner = client->creator; - client->creator->queue_work(1, [client, msg]() { + json js = j; + client->creator->queue_work(1, [client, js, raw]() { + json d = js["d"]; + dpp::message_create_t msg(client, raw); + msg.msg = message(client->owner).fill_from_json(&d, client->creator->cache_policy); + msg.msg.owner = client->creator; client->creator->on_message_create.call(msg); }); } diff --git a/src/dpp/httpsclient.cpp b/src/dpp/httpsclient.cpp index 2428e030de..30835f3980 100644 --- a/src/dpp/httpsclient.cpp +++ b/src/dpp/httpsclient.cpp @@ -32,7 +32,6 @@ namespace dpp { https_client::https_client(cluster* creator, const std::string &hostname, uint16_t port, const std::string &urlpath, const std::string &verb, const std::string &req_body, const http_headers& extra_headers, bool plaintext_connection, uint16_t request_timeout, const std::string &protocol, https_client_completion_event done) : ssl_client(creator, hostname, std::to_string(port), plaintext_connection, false), - state(HTTPS_HEADERS), request_type(verb), path(urlpath), request_body(req_body), @@ -42,7 +41,8 @@ https_client::https_client(cluster* creator, const std::string &hostname, uint16 http_protocol(protocol), timeout(time(nullptr) + request_timeout), timed_out(false), - completed(done) + completed(done), + state(HTTPS_HEADERS) { nonblocking = false; https_client::connect(); @@ -157,6 +157,10 @@ bool https_client::handle_buffer(std::string &buffer) switch (state) { case HTTPS_HEADERS: if (buffer.find("\r\n\r\n") != std::string::npos) { + + /* Add 10 seconds to retrieve body */ + timeout += 10; + /* Got all headers, proceed to new state */ std::string unparsed = buffer; @@ -211,6 +215,10 @@ bool https_client::handle_buffer(std::string &buffer) state_changed = true; continue; } + if (!buffer.empty()) { + /* Got a bit of body content in the same read as the headers */ + continue; + } return true; } else { /* Non-HTTP-like response with invalid headers. Go no further. */ @@ -242,11 +250,11 @@ bool https_client::handle_buffer(std::string &buffer) case HTTPS_CHUNK_TRAILER: if (buffer.length() >= 2 && buffer.substr(0, 2) == "\r\n") { if (state == HTTPS_CHUNK_LAST) { - state = HTTPS_DONE; if (completed) { completed(this); completed = {}; } + state = HTTPS_DONE; this->close(); return false; } else { @@ -281,11 +289,11 @@ bool https_client::handle_buffer(std::string &buffer) body += buffer; buffer.clear(); if (content_length == ULLONG_MAX || body.length() >= content_length) { - state = HTTPS_DONE; if (completed) { completed(this); completed = {}; } + state = HTTPS_DONE; this->close(); return false; } @@ -317,24 +325,26 @@ http_state https_client::get_state() { } void https_client::one_second_timer() { - if ((this->sfd == SOCKET_ERROR || time(nullptr) >= timeout) && this->state != HTTPS_DONE) { - /* if and only if response is timed out */ - if (this->sfd != SOCKET_ERROR) { - timed_out = true; - } - keepalive = false; + if (!tcp_connect_done && time(nullptr) >= timeout) { + timed_out = true; + this->close(); + } else if (tcp_connect_done && !connected && time(nullptr) >= timeout && this->state != HTTPS_DONE) { this->close(); + timed_out = true; + } else if (time(nullptr) >= timeout && this->state != HTTPS_DONE) { + this->close(); + timed_out = true; } } void https_client::close() { if (state != HTTPS_DONE) { - state = HTTPS_DONE; - ssl_client::close(); if (completed) { completed(this); completed = {}; } + state = HTTPS_DONE; + ssl_client::close(); } } diff --git a/src/dpp/message.cpp b/src/dpp/message.cpp index 76c14495c4..9afa80009a 100644 --- a/src/dpp/message.cpp +++ b/src/dpp/message.cpp @@ -663,8 +663,9 @@ embed::embed() : timestamp(0) { } message::message() : managed(0), channel_id(0), guild_id(0), sent(0), edited(0), webhook_id(0), interaction_metadata{}, - owner(nullptr), type(mt_default), flags(0), pinned(false), tts(false), mention_everyone(false) + type(mt_default), flags(0), pinned(false), tts(false), mention_everyone(false) { + owner = nullptr; message_reference.channel_id = 0; message_reference.guild_id = 0; message_reference.message_id = 0; @@ -1059,8 +1060,10 @@ attachment::attachment(struct message* o, json *j) : attachment(o) { void attachment::download(http_completion_event callback) const { /* Download attachment if there is one attached to this object */ - if (owner == nullptr || owner->owner == nullptr) { - throw dpp::logic_exception(err_no_owning_message, "attachment has no owning message/cluster"); + if (owner == nullptr) { + throw dpp::logic_exception(err_no_owning_message, "attachment has no owning message"); + } else if (owner->owner == nullptr) { + throw dpp::logic_exception(err_no_owning_message, "attachment has no owning cluster"); } if (callback && this->id && !this->url.empty()) { owner->owner->request(this->url, dpp::m_get, callback); diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index 3d21ade6f2..934e1e0103 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -54,7 +54,6 @@ #include #include #include -#include #include #include #include @@ -64,6 +63,8 @@ constexpr uint16_t SOCKET_OP_TIMEOUT{5000}; namespace dpp { +uint64_t last_unique_id{1}; + /** * @brief This is an opaque class containing openssl library specific structures. * We define it this way so that the public facing D++ library doesn't require @@ -182,17 +183,23 @@ void set_signal_handler(int signal) } #endif +uint64_t ssl_client::get_unique_id() const { + return unique_id; +} + ssl_client::ssl_client(cluster* creator, const std::string &_hostname, const std::string &_port, bool plaintext_downgrade, bool reuse) : nonblocking(false), sfd(INVALID_SOCKET), ssl(nullptr), last_tick(time(nullptr)), + start(time(nullptr)), hostname(_hostname), port(_port), bytes_out(0), bytes_in(0), plaintext(plaintext_downgrade), timer_handle(0), + unique_id(last_unique_id++), keepalive(reuse), owner(creator) { @@ -254,6 +261,9 @@ void ssl_client::complete_handshake(const socket_events* ev) switch (code) { case SSL_ERROR_NONE: { connected = true; + socket_events se{*ev}; + se.flags = dpp::WANT_READ | dpp::WANT_WRITE | dpp::WANT_ERROR; + owner->socketengine->update_socket(se); break; } case SSL_ERROR_WANT_WRITE: { @@ -297,7 +307,7 @@ void ssl_client::on_read(socket fd, const struct socket_events& ev) { } bytes_in += r; } else if (!plaintext && connected) { - int r = SSL_read(ssl->ssl,server_to_client_buffer,DPP_BUFSIZE); + int r = SSL_read(ssl->ssl, server_to_client_buffer, DPP_BUFSIZE); int e = SSL_get_error(ssl->ssl,r); switch (e) { @@ -307,6 +317,7 @@ void ssl_client::on_read(socket fd, const struct socket_events& ev) { buffer.append(server_to_client_buffer, r); if (!this->handle_buffer(buffer)) { + this->close(); return; } else { socket_events se{ev}; @@ -320,16 +331,15 @@ void ssl_client::on_read(socket fd, const struct socket_events& ev) { /* End of data */ SSL_shutdown(ssl->ssl); return; - break; case SSL_ERROR_WANT_READ: { socket_events se{ev}; se.flags = WANT_READ | WANT_ERROR; owner->socketengine->update_socket(se); break; } - /* We get a WANT_WRITE if we're trying to rehandshake, and we block on a write during that rehandshake. - * We need to wait on the socket to be writeable but initiate the read when it is - */ + /* We get a WANT_WRITE if we're trying to rehandshake, and we block on a write during that rehandshake. + * We need to wait on the socket to be writeable but initiate the read when it is + */ case SSL_ERROR_WANT_WRITE: { socket_events se{ev}; se.flags = WANT_READ | WANT_WRITE | WANT_ERROR; @@ -362,9 +372,49 @@ void ssl_client::on_read(socket fd, const struct socket_events& ev) { void ssl_client::on_write(socket fd, const struct socket_events& e) { + if (!tcp_connect_done) { + tcp_connect_done = true; + } if (!connected && plaintext) { /* Plaintext sockets connect immediately on first write event */ connected = true; + } else if (!connected) { + /* SSL handshake and session setup */ + + /* Each thread needs a context, but we don't need to make a new one for each connection */ + if (!openssl_context) { + /* We're good to go - hand the fd over to openssl */ + const SSL_METHOD *method = TLS_client_method(); /* Create new client-method instance */ + + /* Create SSL context */ + openssl_context.reset(SSL_CTX_new(method)); + if (!openssl_context) { + throw dpp::connection_exception(err_ssl_context, "Failed to create SSL client context!"); + } + + /* Do not allow SSL 3.0, TLS 1.0 or 1.1 + * https://www.packetlabs.net/posts/tls-1-1-no-longer-secure/ + */ + if (!SSL_CTX_set_min_proto_version(openssl_context.get(), TLS1_2_VERSION)) { + throw dpp::connection_exception(err_ssl_version, "Failed to set minimum SSL version!"); + } + } + if (!ssl->ssl) { + /* Create SSL session */ + ssl->ssl = SSL_new(openssl_context.get()); + if (ssl->ssl == nullptr) { + throw dpp::connection_exception(err_ssl_new, "SSL_new failed!"); + } + + SSL_set_fd(ssl->ssl, (int) sfd); + SSL_set_connect_state(ssl->ssl); + + /* Server name identification (SNI) */ + SSL_set_tlsext_host_name(ssl->ssl, hostname.c_str()); + } + + /* If this completes, we fall straight through into if (connected) */ + complete_handshake(&e); } if (connected) { @@ -426,44 +476,6 @@ void ssl_client::on_write(socket fd, const struct socket_events& e) { } } } - } else { - if (!plaintext) { - /* Each thread needs a context, but we don't need to make a new one for each connection */ - if (!openssl_context) { - /* We're good to go - hand the fd over to openssl */ - const SSL_METHOD *method = TLS_client_method(); /* Create new client-method instance */ - - /* Create SSL context */ - openssl_context.reset(SSL_CTX_new(method)); - if (!openssl_context) { - throw dpp::connection_exception(err_ssl_context, "Failed to create SSL client context!"); - } - - /* Do not allow SSL 3.0, TLS 1.0 or 1.1 - * https://www.packetlabs.net/posts/tls-1-1-no-longer-secure/ - */ - if (!SSL_CTX_set_min_proto_version(openssl_context.get(), TLS1_2_VERSION)) { - throw dpp::connection_exception(err_ssl_version, "Failed to set minimum SSL version!"); - } - } - if (!ssl->ssl) { - /* Create SSL session */ - ssl->ssl = SSL_new(openssl_context.get()); - if (ssl->ssl == nullptr) { - throw dpp::connection_exception(err_ssl_new, "SSL_new failed!"); - } - - SSL_set_fd(ssl->ssl, (int) sfd); - SSL_set_connect_state(ssl->ssl); - - /* Server name identification (SNI) */ - SSL_set_tlsext_host_name(ssl->ssl, hostname.c_str()); - } - } - } - - if (!connected && !plaintext) { - complete_handshake(&e); } } @@ -473,16 +485,34 @@ void ssl_client::on_error(socket fd, const struct socket_events&, int error_code void ssl_client::read_loop() { - dpp::socket_events events( - sfd, - WANT_READ | WANT_WRITE | WANT_ERROR, - [this](socket fd, const struct socket_events& e) { on_read(fd, e); }, - [this](socket fd, const struct socket_events& e) { on_write(fd, e); }, - [this](socket fd, const struct socket_events& e, int error_code) { on_error(fd, e, error_code); } - ); - owner->socketengine->register_socket(events); - timer_handle = owner->start_timer([this](auto handle) { + auto setup_events = [this]() { + dpp::socket_events events( + sfd, + WANT_READ | WANT_WRITE | WANT_ERROR, + [this](socket fd, const struct socket_events &e) { on_read(fd, e); }, + [this](socket fd, const struct socket_events &e) { on_write(fd, e); }, + [this](socket fd, const struct socket_events &e, int error_code) { on_error(fd, e, error_code); } + ); + owner->socketengine->register_socket(events); + }; + setup_events(); + timer_handle = owner->start_timer([this, setup_events](auto handle) { one_second_timer(); + if (!tcp_connect_done && time(nullptr) > start + 2 && connect_retries < 3) { + /* Retry failed connect(). This can happen even in the best situation with bullet-proof hosting. + * Previously with blocking connect() there was some leniency in this, but now we have to do this + * ourselves. + * + * Retry up to 3 times, 2 seconds between retries. After this, give up and let timeout code + * take the wheel (will likely end with an exception). + */ + close_socket(sfd); + owner->socketengine->delete_socket(sfd); + ssl_client::connect(); + setup_events(); + start = time(nullptr) + 2; + connect_retries++; + } }, 1); } diff --git a/src/unittest/test.cpp b/src/unittest/test.cpp index f9cbe25e9f..b2cafa8cc2 100644 --- a/src/unittest/test.cpp +++ b/src/unittest/test.cpp @@ -1144,8 +1144,8 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b if (g) { set_test(CACHE, true); set_test(VOICECONN, false); - dpp::discord_client* s = bot.get_shard(0); - s->connect_voice(g->id, TEST_VC_ID, false, false); + //dpp::discord_client* s = bot.get_shard(0); + //s->connect_voice(g->id, TEST_VC_ID, false, false); } else { set_test(CACHE, false); @@ -1602,7 +1602,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b }); bool message_tested = false; - bot.on_message_create([&](const dpp::message_create_t & event) { + bot.on_message_create([&message_tested,&bot,&message_helper,&thread_helper](const dpp::message_create_t & event) { if (event.msg.author.id == bot.me.id) { if (event.msg.content == "test message" && !message_tested) { message_tested = true; @@ -1652,7 +1652,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_message_reaction_add([&](const dpp::message_reaction_add_t & event) { + bot.on_message_reaction_add([&bot,&thread_helper](const dpp::message_reaction_add_t & event) { if (event.reacting_user.id == bot.me.id) { if (event.reacting_emoji.name == "😄") { set_test(REACTEVENT, true); @@ -1664,7 +1664,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_message_reaction_remove([&](const dpp::message_reaction_remove_t & event) { + bot.on_message_reaction_remove([&bot,&thread_helper](const dpp::message_reaction_remove_t & event) { if (event.reacting_user_id == bot.me.id) { if (event.channel_id == thread_helper.thread_id && event.reacting_emoji.name == dpp::unicode_emoji::thread) { set_test(THREAD_MESSAGE_REACT_REMOVE_EVENT, true); @@ -1673,7 +1673,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b } }); - bot.on_message_delete([&](const dpp::message_delete_t & event) { + bot.on_message_delete([&bot,&thread_helper](const dpp::message_delete_t & event) { if (event.channel_id == thread_helper.thread_id) { set_test(THREAD_MESSAGE_DELETE_EVENT, true); thread_helper.notify_event_tested(thread_test_helper::MESSAGE_DELETE); @@ -1681,7 +1681,7 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b }); bool message_edit_tested = false; - bot.on_message_update([&](const dpp::message_update_t &event) { + bot.on_message_update([&bot,&thread_helper,&message_edit_tested](const dpp::message_update_t &event) { if (event.msg.author == bot.me.id) { if (event.msg.content == "test edit" && !message_edit_tested) { message_edit_tested = true; @@ -1974,22 +1974,22 @@ Markdown lol \\|\\|spoiler\\|\\| \\~\\~strikethrough\\~\\~ \\`small \\*code\\* b // testing all user flags from https://discord.com/developers/docs/resources/user#user-object-user-flags // they're manually set here because the dpp::user_flags don't match to the discord API, so we can't use them to compare with the raw flags! if ( - u.is_discord_employee() == ((raw_flags & (1 << 0)) != 0) && - u.is_partnered_owner() == ((raw_flags & (1 << 1)) != 0) && - u.has_hypesquad_events() == ((raw_flags & (1 << 2)) != 0) && - u.is_bughunter_1() == ((raw_flags & (1 << 3)) != 0) && - u.is_house_bravery() == ((raw_flags & (1 << 6)) != 0) && - u.is_house_brilliance() == ((raw_flags & (1 << 7)) != 0) && - u.is_house_balance() == ((raw_flags & (1 << 8)) != 0) && - u.is_early_supporter() == ((raw_flags & (1 << 9)) != 0) && - u.is_team_user() == ((raw_flags & (1 << 10)) != 0) && - u.is_bughunter_2() == ((raw_flags & (1 << 14)) != 0) && - u.is_verified_bot() == ((raw_flags & (1 << 16)) != 0) && - u.is_verified_bot_dev() == ((raw_flags & (1 << 17)) != 0) && - u.is_certified_moderator() == ((raw_flags & (1 << 18)) != 0) && - u.is_bot_http_interactions() == ((raw_flags & (1 << 19)) != 0) && - u.is_active_developer() == ((raw_flags & (1 << 22)) != 0) - ) { + u.is_discord_employee() == ((raw_flags & (1 << 0)) != 0) && + u.is_partnered_owner() == ((raw_flags & (1 << 1)) != 0) && + u.has_hypesquad_events() == ((raw_flags & (1 << 2)) != 0) && + u.is_bughunter_1() == ((raw_flags & (1 << 3)) != 0) && + u.is_house_bravery() == ((raw_flags & (1 << 6)) != 0) && + u.is_house_brilliance() == ((raw_flags & (1 << 7)) != 0) && + u.is_house_balance() == ((raw_flags & (1 << 8)) != 0) && + u.is_early_supporter() == ((raw_flags & (1 << 9)) != 0) && + u.is_team_user() == ((raw_flags & (1 << 10)) != 0) && + u.is_bughunter_2() == ((raw_flags & (1 << 14)) != 0) && + u.is_verified_bot() == ((raw_flags & (1 << 16)) != 0) && + u.is_verified_bot_dev() == ((raw_flags & (1 << 17)) != 0) && + u.is_certified_moderator() == ((raw_flags & (1 << 18)) != 0) && + u.is_bot_http_interactions() == ((raw_flags & (1 << 19)) != 0) && + u.is_active_developer() == ((raw_flags & (1 << 22)) != 0) + ) { set_test(USER_GET_FLAGS, true); } else { set_test(USER_GET_FLAGS, false);