Skip to content

Commit

Permalink
clean up resolver cache
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 16, 2024
1 parent 4f85429 commit 1102998
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 79 deletions.
30 changes: 16 additions & 14 deletions include/dpp/dns.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,16 @@ namespace dpp {
*/
struct dns_cache_entry {
/**
* @brief Resolved address information
* @brief Resolved address metadata
*/
addrinfo addr;

/**
* @brief Socket address.
* Discord only supports ipv4, but sockaddr_in6 is larger
* than sockaddr_in, sockaddr_storage will hold either. This
* means that if discord ever do support ipv6 we just flip
* one value in dns.cpp and that should be all that is needed.
* @brief Resolved address as string.
* The metadata is needed to know what type of address it is.
* Do not do silly stuff like just looking to see if '.' is in it!
*/
sockaddr_storage ai_addr;
std::string resolved_addr;

/**
* @brief Time at which this cache entry is invalidated
Expand All @@ -64,18 +62,22 @@ namespace dpp {
* @brief Get address length
* @return address length
*/
inline int size() const {
return static_cast<int>(addr.ai_addrlen);
}
[[nodiscard]] int size() const;

/**
* @brief Get the address_t that corresponds to this cache entry
* for use when connecting with ::connect()
* @param port Port number to connect to
* @return address_t prefilled with the IP and port number
*/
[[nodiscard]] const address_t get_connecting_address(uint16_t port) const;

/**
* @brief Allocate a socket file descriptor for the given dns address
* @return File descriptor ready for calling connect(), or INVALID_SOCKET
* on failure.
*/
inline socket make_connecting_socket() const {
return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
}
[[nodiscard]] socket make_connecting_socket() const;
};

/**
Expand All @@ -92,4 +94,4 @@ namespace dpp {
* @throw dpp::connection_exception On failure to resolve hostname
*/
const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port);
} // namespace dpp
}
155 changes: 91 additions & 64 deletions src/dpp/dns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,75 +39,102 @@ namespace dpp
/* Cache container */
dns_cache_t dns_cache;

const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port)
/**
* @brief Get address length
* @return address length
*/
int dns_cache_entry::size() const {
return static_cast<int>(addr.ai_addrlen);
}

const address_t dns_cache_entry::get_connecting_address(uint16_t port) const {
return address_t(resolved_addr, port);
}

socket dns_cache_entry::make_connecting_socket() const {
return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
}

const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port)
{
addrinfo hints, *addrs;
dns_cache_t::const_iterator iter;
time_t now = time(nullptr);
int error;
bool exists = false;

/* Thread safety scope */
{
addrinfo hints, *addrs;
dns_cache_t::const_iterator iter;
time_t now = time(nullptr);
int error;
bool exists = false;

/* Thread safety scope */
{
/* Check cache for existing DNS record. This can use a shared lock. */
std::shared_lock dns_cache_lock(dns_cache_mutex);
iter = dns_cache.find(hostname);
if (iter != dns_cache.end()) {
exists = true;
if (now < iter->second->expire_timestamp) {
/* there is a cached entry that is still valid, return it */
return iter->second;
}
/* Check cache for existing DNS record. This can use a shared lock. */
std::shared_lock dns_cache_lock(dns_cache_mutex);
iter = dns_cache.find(hostname);
if (iter != dns_cache.end()) {
exists = true;
if (now < iter->second->expire_timestamp) {
/* there is a cached entry that is still valid, return it */
return iter->second;
}
}
if (exists) {
/* there is a cached entry, but it has expired,
* delete and free it, and fall through to a new lookup.
* We must use a unique lock here as we modify the cache.
*/
std::unique_lock dns_cache_lock(dns_cache_mutex);
iter = dns_cache.find(hostname);
if (iter != dns_cache.end()) { /* re-validate iter */
delete iter->second;
dns_cache.erase(iter);
}
}

/* The hints indicate what sort of DNS results we are interested in.
* To change this to support IPv6, one change we need to make here is
* to change AF_INET to AF_UNSPEC. Everything else should just work fine.
}
if (exists) {
/* there is a cached entry, but it has expired,
* delete and free it, and fall through to a new lookup.
* We must use a unique lock here as we modify the cache.
*/
memset(&hints, 0, sizeof(addrinfo));
hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;

if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) {
/**
* The -20 makes sure the error codes dont conflict with codes given in the rest of the list
* Because C libraries love to use -1 and below directly as conflicting error codes.
*/
throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error));
std::unique_lock dns_cache_lock(dns_cache_mutex);
iter = dns_cache.find(hostname);
if (iter != dns_cache.end()) { /* re-validate iter */
delete iter->second;
dns_cache.erase(iter);
}
}

/* The hints indicate what sort of DNS results we are interested in.
* To change this to support IPv6, one change we need to make here is
* to change AF_INET to AF_UNSPEC. Everything else should just work fine.
*/
memset(&hints, 0, sizeof(addrinfo));
hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;

/* Thread safety scope */
{
/* Update cache, requires unique lock */
std::unique_lock dns_cache_lock(dns_cache_mutex);
dns_cache_entry* cache_entry = new dns_cache_entry();

/* The sockaddr struct contains a bunch of raw pointers that we
* must copy to the cache, before freeing it with freeaddrinfo().
* Icky icky C APIs.
*/
memcpy(&cache_entry->addr, addrs, sizeof(addrinfo));
memcpy(&cache_entry->ai_addr, addrs->ai_addr, addrs->ai_addrlen);
cache_entry->expire_timestamp = now + one_hour;
dns_cache[hostname] = cache_entry;

/* Now we're done with this horrible struct, free it and return */
freeaddrinfo(addrs);
return cache_entry;
if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) {
/**
* The -20 makes sure the error codes dont conflict with codes given in the rest of the list
* Because C libraries love to use -1 and below directly as conflicting error codes.
*/
throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error));
}

/* Thread safety scope */
{
/* Update cache, requires unique lock */
std::unique_lock dns_cache_lock(dns_cache_mutex);
dns_cache_entry* cache_entry = new dns_cache_entry();

for (struct addrinfo* rp = addrs; rp != nullptr; rp = rp->ai_next) {
/* Discord only support ipv4, so iterate over any ipv6 results */
if (rp->ai_family != AF_INET) {
continue;
}
/* Save address family and other metadata for later */
memcpy(&cache_entry->addr, rp, sizeof(addrinfo));
char buffer[128];
sockaddr_in in{};
std::memcpy(&in, rp->ai_addr, sizeof(sockaddr_in));
if (inet_ntop(rp->ai_family, &in.sin_addr, buffer, sizeof(buffer))) {
cache_entry->resolved_addr = buffer;
}
break;
}

cache_entry->expire_timestamp = now + one_hour;
dns_cache[hostname] = cache_entry;

/* Now we're done with this horrible struct, free it and return */
freeaddrinfo(addrs);
return cache_entry;
}
} // namespace dpp
}

}
4 changes: 3 additions & 1 deletion src/dpp/sslclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include <dpp/sslclient.h>
#include <dpp/exception.h>
#include <dpp/utility.h>
#include <dpp/stringops.h>
#include <dpp/dns.h>

/* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */
Expand Down Expand Up @@ -319,9 +320,10 @@ void ssl_client::connect()
int err = 0;
const dns_cache_entry* addr = resolve_hostname(hostname, port);
sfd = addr->make_connecting_socket();
address_t destination = addr->get_connecting_address(from_string<uint16_t>(this->port, std::dec));
if (sfd == ERROR_STATUS) {
err = errno;
} else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, addr->size(), SOCKET_OP_TIMEOUT) != 0) {
} else if (connect_with_timeout(sfd, destination.get_socket_address(), destination.size(), SOCKET_OP_TIMEOUT) != 0) {
close_socket(sfd);
sfd = ERROR_STATUS;
}
Expand Down

0 comments on commit 1102998

Please sign in to comment.