From 4965b5c5917d1db022c0bdb927ee945a2c37e1a6 Mon Sep 17 00:00:00 2001 From: "Craig Edwards (Brain)" Date: Wed, 16 Oct 2024 08:30:47 +0100 Subject: [PATCH] refactor: low level socket tidyups and removal of punning (#1282) --- .cspell.json | 3 +- include/dpp/discordvoiceclient.h | 7 ++ include/dpp/dns.h | 37 ++++-- include/dpp/socket.h | 135 ++++++++++++++++++++- src/dpp/dns.cpp | 155 +++++++++++++++---------- src/dpp/socket.cpp | 63 ++++++++++ src/dpp/sslclient.cpp | 6 +- src/dpp/voice/enabled/discover_ip.cpp | 28 +---- src/dpp/voice/enabled/handle_frame.cpp | 19 +-- src/dpp/voice/enabled/read_write.cpp | 17 ++- 10 files changed, 342 insertions(+), 128 deletions(-) create mode 100644 src/dpp/socket.cpp diff --git a/.cspell.json b/.cspell.json index 3cfb280732..0d8e2cd969 100644 --- a/.cspell.json +++ b/.cspell.json @@ -148,7 +148,8 @@ "nullopt", "chrono", "ciphersuite", - "rmap" + "rmap", + "WSAPOLLFD" ], "flagWords": [ "hte" diff --git a/include/dpp/discordvoiceclient.h b/include/dpp/discordvoiceclient.h index 6ce21bf309..658ba6422a 100644 --- a/include/dpp/discordvoiceclient.h +++ b/include/dpp/discordvoiceclient.h @@ -584,6 +584,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ dave_version_t dave_version; + /** + * @brief Destination address for where packets go + * on the UDP socket + */ + address_t destination{}; + + /** * @brief Send data to UDP socket immediately. * diff --git a/include/dpp/dns.h b/include/dpp/dns.h index 5a3a566d45..48cebcd563 100644 --- a/include/dpp/dns.h +++ b/include/dpp/dns.h @@ -31,6 +31,8 @@ #include #include #include +#include +#include namespace dpp { @@ -40,23 +42,42 @@ namespace dpp { */ struct dns_cache_entry { /** - * @brief Resolved address information + * @brief Resolved address metadata */ addrinfo addr; /** - * @brief Socket address. - * Discord only supports ipv4, but sockaddr_in6 is larger - * than sockaddr_in, sockaddr_storage will hold either. This - * means that if discord ever do support ipv6 we just flip - * one value in dns.cpp and that should be all that is needed. + * @brief Resolved address as string. + * The metadata is needed to know what type of address it is. + * Do not do silly stuff like just looking to see if '.' is in it! */ - sockaddr_storage ai_addr; + std::string resolved_addr; /** * @brief Time at which this cache entry is invalidated */ time_t expire_timestamp; + + /** + * @brief Get address length + * @return address length + */ + [[nodiscard]] int size() const; + + /** + * @brief Get the address_t that corresponds to this cache entry + * for use when connecting with ::connect() + * @param port Port number to connect to + * @return address_t prefilled with the IP and port number + */ + [[nodiscard]] const address_t get_connecting_address(uint16_t port) const; + + /** + * @brief Allocate a socket file descriptor for the given dns address + * @return File descriptor ready for calling connect(), or INVALID_SOCKET + * on failure. + */ + [[nodiscard]] socket make_connecting_socket() const; }; /** @@ -73,4 +94,4 @@ namespace dpp { * @throw dpp::connection_exception On failure to resolve hostname */ const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port); -} // namespace dpp +} diff --git a/include/dpp/socket.h b/include/dpp/socket.h index d94914b35a..3888b2e022 100644 --- a/include/dpp/socket.h +++ b/include/dpp/socket.h @@ -1,17 +1,53 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ #pragma once +#include +#ifdef _WIN32 + #include + #include + #include + #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) + #define pollfd WSAPOLLFD +#else + #include + #include + #include +#endif +#include +#include + + namespace dpp { - /** - * @brief Represents a socket file descriptor. - * This is used to ensure parity between windows and unix-like systems. - */ +/** + * @brief Represents a socket file descriptor. + * This is used to ensure parity between windows and unix-like systems. + */ #ifndef _WIN32 using socket = int; #else using socket = SOCKET; #endif -} // namespace dpp #ifndef SOCKET_ERROR /** @@ -26,3 +62,92 @@ namespace dpp */ #define INVALID_SOCKET ~0 #endif + +/** + * @brief Represents an IPv4 address for use with socket functions such as + * bind(). + * + * Avoids type punning with C style casts from sockaddr_in to sockaddr pointers. + */ +class DPP_EXPORT address_t { + /** + * @brief Internal sockaddr struct + */ + sockaddr socket_addr{}; + +public: + + /** + * @brief Create a new address_t + * @param ip IPv4 address + * @param port Port number + * @note Leave both as defaults to create a default bind-to-any setting + */ + address_t(const std::string_view ip = "0.0.0.0", uint16_t port = 0); + + /** + * @brief Get sockaddr + * @return sockaddr pointer + */ + [[nodiscard]] sockaddr *get_socket_address(); + + /** + * @brief Returns size of sockaddr_in + * @return sockaddr_in size + * @note It is important the size this returns is sizeof(sockaddr_in) not + * sizeof(sockaddr), this is NOT a bug but requirement of C socket functions. + */ + [[nodiscard]] size_t size(); + + /** + * @brief Get the port bound to a file descriptor + * @param fd File descriptor + * @return Port number, or 0 if no port bound + */ + [[nodiscard]] uint16_t get_port(socket fd); +}; + +/** + * @brief Allocates a dpp::socket, closing it on destruction + */ +struct DPP_EXPORT raii_socket { + /** + * @brief File descriptor + */ + socket fd; + + /** + * @brief Construct a socket. + * Calls socket() and returns a new file descriptor + */ + raii_socket(); + + /** + * @brief Non-copyable + */ + raii_socket(raii_socket&) = delete; + + /** + * @brief Non-movable + */ + raii_socket(raii_socket&&) = delete; + + /** + * @brief Non-copyable + */ + raii_socket operator=(raii_socket&) = delete; + + /** + * @brief Non-movable + */ + raii_socket operator=(raii_socket&&) = delete; + + /** + * @brief Destructor + * Frees the socket by closing it + */ + ~raii_socket(); +}; + + +} diff --git a/src/dpp/dns.cpp b/src/dpp/dns.cpp index e87fbd45e8..257b9b2315 100644 --- a/src/dpp/dns.cpp +++ b/src/dpp/dns.cpp @@ -39,75 +39,102 @@ namespace dpp /* Cache container */ dns_cache_t dns_cache; - const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +/** +* @brief Get address length +* @return address length +*/ +int dns_cache_entry::size() const { + return static_cast(addr.ai_addrlen); +} + +const address_t dns_cache_entry::get_connecting_address(uint16_t port) const { + return address_t(resolved_addr, port); +} + +socket dns_cache_entry::make_connecting_socket() const { + return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); +} + +const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +{ + addrinfo hints, *addrs; + dns_cache_t::const_iterator iter; + time_t now = time(nullptr); + int error; + bool exists = false; + + /* Thread safety scope */ { - addrinfo hints, *addrs; - dns_cache_t::const_iterator iter; - time_t now = time(nullptr); - int error; - bool exists = false; - - /* Thread safety scope */ - { - /* Check cache for existing DNS record. This can use a shared lock. */ - std::shared_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { - exists = true; - if (now < iter->second->expire_timestamp) { - /* there is a cached entry that is still valid, return it */ - return iter->second; - } + /* Check cache for existing DNS record. This can use a shared lock. */ + std::shared_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { + exists = true; + if (now < iter->second->expire_timestamp) { + /* there is a cached entry that is still valid, return it */ + return iter->second; } } - if (exists) { - /* there is a cached entry, but it has expired, - * delete and free it, and fall through to a new lookup. - * We must use a unique lock here as we modify the cache. - */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { /* re-validate iter */ - delete iter->second; - dns_cache.erase(iter); - } - } - - /* The hints indicate what sort of DNS results we are interested in. - * To change this to support IPv6, one change we need to make here is - * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + } + if (exists) { + /* there is a cached entry, but it has expired, + * delete and free it, and fall through to a new lookup. + * We must use a unique lock here as we modify the cache. */ - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { - /** - * The -20 makes sure the error codes dont conflict with codes given in the rest of the list - * Because C libraries love to use -1 and below directly as conflicting error codes. - */ - throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + std::unique_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { /* re-validate iter */ + delete iter->second; + dns_cache.erase(iter); } + } + + /* The hints indicate what sort of DNS results we are interested in. + * To change this to support IPv6, one change we need to make here is + * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + */ + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; - /* Thread safety scope */ - { - /* Update cache, requires unique lock */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - dns_cache_entry* cache_entry = new dns_cache_entry(); - - /* The sockaddr struct contains a bunch of raw pointers that we - * must copy to the cache, before freeing it with freeaddrinfo(). - * Icky icky C APIs. - */ - memcpy(&cache_entry->addr, addrs, sizeof(addrinfo)); - memcpy(&cache_entry->ai_addr, addrs->ai_addr, addrs->ai_addrlen); - cache_entry->expire_timestamp = now + one_hour; - dns_cache[hostname] = cache_entry; - - /* Now we're done with this horrible struct, free it and return */ - freeaddrinfo(addrs); - return cache_entry; + if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { + /** + * The -20 makes sure the error codes dont conflict with codes given in the rest of the list + * Because C libraries love to use -1 and below directly as conflicting error codes. + */ + throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + } + + /* Thread safety scope */ + { + /* Update cache, requires unique lock */ + std::unique_lock dns_cache_lock(dns_cache_mutex); + dns_cache_entry* cache_entry = new dns_cache_entry(); + + for (struct addrinfo* rp = addrs; rp != nullptr; rp = rp->ai_next) { + /* Discord only support ipv4, so iterate over any ipv6 results */ + if (rp->ai_family != AF_INET) { + continue; + } + /* Save address family and other metadata for later */ + memcpy(&cache_entry->addr, rp, sizeof(addrinfo)); + char buffer[128]; + sockaddr_in in{}; + std::memcpy(&in, rp->ai_addr, sizeof(sockaddr_in)); + if (inet_ntop(rp->ai_family, &in.sin_addr, buffer, sizeof(buffer))) { + cache_entry->resolved_addr = buffer; + } + break; } + + cache_entry->expire_timestamp = now + one_hour; + dns_cache[hostname] = cache_entry; + + /* Now we're done with this horrible struct, free it and return */ + freeaddrinfo(addrs); + return cache_entry; } -} // namespace dpp +} + +} diff --git a/src/dpp/socket.cpp b/src/dpp/socket.cpp new file mode 100644 index 0000000000..9313405b75 --- /dev/null +++ b/src/dpp/socket.cpp @@ -0,0 +1,63 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ + +#include +#include +#include + +namespace dpp { + +address_t::address_t(const std::string_view ip, uint16_t port) { + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_port = htons(port); + address.sin_addr.s_addr = inet_addr(ip.data()); + std::memcpy(&socket_addr, &address, sizeof(address)); +} + +sockaddr* address_t::get_socket_address() { + return &socket_addr; +} + +size_t address_t::size() { + return sizeof(sockaddr_in); +} + +uint16_t address_t::get_port(socket fd) { + socklen_t len = size(); + if (getsockname(fd, &socket_addr, &len) > -1) { + sockaddr_in sin{}; + std::memcpy(&sin, &socket_addr, sizeof(sockaddr_in)); + return ntohs(sin.sin_port); + } + return 0; +} + +raii_socket::raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) { +} + +raii_socket::~raii_socket() { + close_socket(fd); +} + + +} \ No newline at end of file diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index d1f15ac20a..54698eb0e4 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -70,6 +70,7 @@ #include #include #include +#include #include /* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */ @@ -318,10 +319,11 @@ void ssl_client::connect() /* Resolve hostname to IP */ int err = 0; const dns_cache_entry* addr = resolve_hostname(hostname, port); - sfd = ::socket(addr->addr.ai_family, addr->addr.ai_socktype, addr->addr.ai_protocol); + sfd = addr->make_connecting_socket(); + address_t destination = addr->get_connecting_address(from_string(this->port, std::dec)); if (sfd == ERROR_STATUS) { err = errno; - } else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, (int)addr->addr.ai_addrlen, SOCKET_OP_TIMEOUT) != 0) { + } else if (connect_with_timeout(sfd, destination.get_socket_address(), destination.size(), SOCKET_OP_TIMEOUT) != 0) { close_socket(sfd); sfd = ERROR_STATUS; } diff --git a/src/dpp/voice/enabled/discover_ip.cpp b/src/dpp/voice/enabled/discover_ip.cpp index ea4f310308..341be4a80e 100644 --- a/src/dpp/voice/enabled/discover_ip.cpp +++ b/src/dpp/voice/enabled/discover_ip.cpp @@ -132,20 +132,6 @@ struct ip_discovery_packet { } }; -/** - * @brief Allocates a dpp::socket, closing it on destruction - */ -struct raii_socket { - dpp::socket fd; - - raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) { }; - raii_socket(raii_socket&) = delete; - raii_socket(raii_socket&&) = delete; - raii_socket operator=(raii_socket&) = delete; - raii_socket operator=(raii_socket&&) = delete; - ~raii_socket() { close_socket(fd); }; -}; - constexpr int discovery_timeout = 1000; std::string discord_voice_client::discover_ip() { @@ -158,19 +144,13 @@ std::string discord_voice_client::discover_ip() { ip_discovery_packet discovery(this->ssrc); if (socket.fd >= 0) { - sockaddr_in servaddr{}; - servaddr.sin_family = AF_INET; - servaddr.sin_addr.s_addr = htonl(INADDR_ANY); - servaddr.sin_port = htons(0); - if (bind(socket.fd, reinterpret_cast(&servaddr), sizeof(servaddr)) < 0) { + address_t bind_any; + if (bind(socket.fd, bind_any.get_socket_address(), bind_any.size()) < 0) { log(ll_warning, "Could not bind socket for IP discovery"); return ""; } - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons(this->port); - servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str()); - if (::connect(socket.fd, reinterpret_cast(&servaddr), sizeof(sockaddr_in)) < 0) { + address_t bind_port(this->ip, this->port); + if (::connect(socket.fd, bind_port.get_socket_address(), bind_port.size()) < 0) { log(ll_warning, "Could not connect socket for IP discovery"); return ""; } diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index 7b85cb7820..07423d3377 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -437,6 +437,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->ip = d["ip"].get(); this->port = d["port"].get(); this->ssrc = d["ssrc"].get(); + destination = address_t(this->ip, this->port); + // Modes for (auto & m : d["modes"]) { this->modes.push_back(m.get()); @@ -446,13 +448,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod dpp::socket newfd = 0; if ((newfd = ::socket(AF_INET, SOCK_DGRAM, 0)) >= 0) { - sockaddr_in servaddr{}; - memset(&servaddr, 0, sizeof(sockaddr_in)); - servaddr.sin_family = AF_INET; - servaddr.sin_addr.s_addr = htonl(INADDR_ANY); - servaddr.sin_port = htons(0); - - if (bind(newfd, reinterpret_cast(&servaddr), sizeof(servaddr)) < 0) { + address_t bind_any; + if (bind(newfd, bind_any.get_socket_address(), bind_any.size()) < 0) { throw dpp::connection_exception(err_bind_failure, "Can't bind() client UDP socket"); } @@ -467,13 +464,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->custom_writeable_ready = [this] { write_ready(); }; this->custom_readable_ready = [this] { read_ready(); }; - int bound_port = 0; - sockaddr_in sin{}; - socklen_t len = sizeof(sin); - if (getsockname(this->fd, reinterpret_cast(&sin), &len) > -1) { - bound_port = ntohs(sin.sin_port); - } - + int bound_port = address_t().get_port(this->fd); this->write(json({ { "op", voice_opcode_connection_select_protocol }, { "d", { diff --git a/src/dpp/voice/enabled/read_write.cpp b/src/dpp/voice/enabled/read_write.cpp index 1702e6cf13..c24200b49b 100644 --- a/src/dpp/voice/enabled/read_write.cpp +++ b/src/dpp/voice/enabled/read_write.cpp @@ -43,26 +43,23 @@ dpp::socket discord_voice_client::want_read() { void discord_voice_client::send(const char* packet, size_t len, uint64_t duration) { - std::lock_guard lock(this->stream_mutex); voice_out_packet frame; - frame.packet = std::string(packet, len); + frame.packet.assign(packet, packet + len); frame.duration = duration; - outbuf.emplace_back(frame); + { + std::lock_guard lock(this->stream_mutex); + outbuf.emplace_back(frame); + } } int discord_voice_client::udp_send(const char* data, size_t length) { - sockaddr_in servaddr{}; - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons(this->port); - servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str()); return static_cast(sendto( this->fd, data, static_cast(length), 0, - reinterpret_cast(&servaddr), - static_cast(sizeof(sockaddr_in)) + destination.get_socket_address(), + destination.size() )); }