diff --git a/include/dpp/dns.h b/include/dpp/dns.h index 65e5fb8e7b..48cebcd563 100644 --- a/include/dpp/dns.h +++ b/include/dpp/dns.h @@ -42,18 +42,16 @@ namespace dpp { */ struct dns_cache_entry { /** - * @brief Resolved address information + * @brief Resolved address metadata */ addrinfo addr; /** - * @brief Socket address. - * Discord only supports ipv4, but sockaddr_in6 is larger - * than sockaddr_in, sockaddr_storage will hold either. This - * means that if discord ever do support ipv6 we just flip - * one value in dns.cpp and that should be all that is needed. + * @brief Resolved address as string. + * The metadata is needed to know what type of address it is. + * Do not do silly stuff like just looking to see if '.' is in it! */ - sockaddr_storage ai_addr; + std::string resolved_addr; /** * @brief Time at which this cache entry is invalidated @@ -64,18 +62,22 @@ namespace dpp { * @brief Get address length * @return address length */ - inline int size() const { - return static_cast(addr.ai_addrlen); - } + [[nodiscard]] int size() const; + + /** + * @brief Get the address_t that corresponds to this cache entry + * for use when connecting with ::connect() + * @param port Port number to connect to + * @return address_t prefilled with the IP and port number + */ + [[nodiscard]] const address_t get_connecting_address(uint16_t port) const; /** * @brief Allocate a socket file descriptor for the given dns address * @return File descriptor ready for calling connect(), or INVALID_SOCKET * on failure. */ - inline socket make_connecting_socket() const { - return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); - } + [[nodiscard]] socket make_connecting_socket() const; }; /** @@ -92,4 +94,4 @@ namespace dpp { * @throw dpp::connection_exception On failure to resolve hostname */ const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port); -} // namespace dpp +} diff --git a/src/dpp/dns.cpp b/src/dpp/dns.cpp index e87fbd45e8..257b9b2315 100644 --- a/src/dpp/dns.cpp +++ b/src/dpp/dns.cpp @@ -39,75 +39,102 @@ namespace dpp /* Cache container */ dns_cache_t dns_cache; - const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +/** +* @brief Get address length +* @return address length +*/ +int dns_cache_entry::size() const { + return static_cast(addr.ai_addrlen); +} + +const address_t dns_cache_entry::get_connecting_address(uint16_t port) const { + return address_t(resolved_addr, port); +} + +socket dns_cache_entry::make_connecting_socket() const { + return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); +} + +const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +{ + addrinfo hints, *addrs; + dns_cache_t::const_iterator iter; + time_t now = time(nullptr); + int error; + bool exists = false; + + /* Thread safety scope */ { - addrinfo hints, *addrs; - dns_cache_t::const_iterator iter; - time_t now = time(nullptr); - int error; - bool exists = false; - - /* Thread safety scope */ - { - /* Check cache for existing DNS record. This can use a shared lock. */ - std::shared_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { - exists = true; - if (now < iter->second->expire_timestamp) { - /* there is a cached entry that is still valid, return it */ - return iter->second; - } + /* Check cache for existing DNS record. This can use a shared lock. */ + std::shared_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { + exists = true; + if (now < iter->second->expire_timestamp) { + /* there is a cached entry that is still valid, return it */ + return iter->second; } } - if (exists) { - /* there is a cached entry, but it has expired, - * delete and free it, and fall through to a new lookup. - * We must use a unique lock here as we modify the cache. - */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { /* re-validate iter */ - delete iter->second; - dns_cache.erase(iter); - } - } - - /* The hints indicate what sort of DNS results we are interested in. - * To change this to support IPv6, one change we need to make here is - * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + } + if (exists) { + /* there is a cached entry, but it has expired, + * delete and free it, and fall through to a new lookup. + * We must use a unique lock here as we modify the cache. */ - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { - /** - * The -20 makes sure the error codes dont conflict with codes given in the rest of the list - * Because C libraries love to use -1 and below directly as conflicting error codes. - */ - throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + std::unique_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { /* re-validate iter */ + delete iter->second; + dns_cache.erase(iter); } + } + + /* The hints indicate what sort of DNS results we are interested in. + * To change this to support IPv6, one change we need to make here is + * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + */ + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; - /* Thread safety scope */ - { - /* Update cache, requires unique lock */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - dns_cache_entry* cache_entry = new dns_cache_entry(); - - /* The sockaddr struct contains a bunch of raw pointers that we - * must copy to the cache, before freeing it with freeaddrinfo(). - * Icky icky C APIs. - */ - memcpy(&cache_entry->addr, addrs, sizeof(addrinfo)); - memcpy(&cache_entry->ai_addr, addrs->ai_addr, addrs->ai_addrlen); - cache_entry->expire_timestamp = now + one_hour; - dns_cache[hostname] = cache_entry; - - /* Now we're done with this horrible struct, free it and return */ - freeaddrinfo(addrs); - return cache_entry; + if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { + /** + * The -20 makes sure the error codes dont conflict with codes given in the rest of the list + * Because C libraries love to use -1 and below directly as conflicting error codes. + */ + throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + } + + /* Thread safety scope */ + { + /* Update cache, requires unique lock */ + std::unique_lock dns_cache_lock(dns_cache_mutex); + dns_cache_entry* cache_entry = new dns_cache_entry(); + + for (struct addrinfo* rp = addrs; rp != nullptr; rp = rp->ai_next) { + /* Discord only support ipv4, so iterate over any ipv6 results */ + if (rp->ai_family != AF_INET) { + continue; + } + /* Save address family and other metadata for later */ + memcpy(&cache_entry->addr, rp, sizeof(addrinfo)); + char buffer[128]; + sockaddr_in in{}; + std::memcpy(&in, rp->ai_addr, sizeof(sockaddr_in)); + if (inet_ntop(rp->ai_family, &in.sin_addr, buffer, sizeof(buffer))) { + cache_entry->resolved_addr = buffer; + } + break; } + + cache_entry->expire_timestamp = now + one_hour; + dns_cache[hostname] = cache_entry; + + /* Now we're done with this horrible struct, free it and return */ + freeaddrinfo(addrs); + return cache_entry; } -} // namespace dpp +} + +} diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index 498da077b0..54698eb0e4 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -70,6 +70,7 @@ #include #include #include +#include #include /* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */ @@ -319,9 +320,10 @@ void ssl_client::connect() int err = 0; const dns_cache_entry* addr = resolve_hostname(hostname, port); sfd = addr->make_connecting_socket(); + address_t destination = addr->get_connecting_address(from_string(this->port, std::dec)); if (sfd == ERROR_STATUS) { err = errno; - } else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, addr->size(), SOCKET_OP_TIMEOUT) != 0) { + } else if (connect_with_timeout(sfd, destination.get_socket_address(), destination.size(), SOCKET_OP_TIMEOUT) != 0) { close_socket(sfd); sfd = ERROR_STATUS; }