Skip to content

Commit

Permalink
refactor: low level socket tidyups and removal of punning
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 16, 2024
1 parent b720cc6 commit 4f85429
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 55 deletions.
7 changes: 7 additions & 0 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
dave_version_t dave_version;

/**
* @brief Destination address for where packets go
* on the UDP socket
*/
address_t destination{};


/**
* @brief Send data to UDP socket immediately.
*
Expand Down
19 changes: 19 additions & 0 deletions include/dpp/dns.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <sys/types.h>
#include <string>
#include <unordered_map>
#include <cstring>
#include <dpp/socket.h>

namespace dpp {

Expand All @@ -57,6 +59,23 @@ namespace dpp {
* @brief Time at which this cache entry is invalidated
*/
time_t expire_timestamp;

/**
* @brief Get address length
* @return address length
*/
inline int size() const {
return static_cast<int>(addr.ai_addrlen);
}

/**
* @brief Allocate a socket file descriptor for the given dns address
* @return File descriptor ready for calling connect(), or INVALID_SOCKET
* on failure.
*/
inline socket make_connecting_socket() const {
return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
}
};

/**
Expand Down
134 changes: 129 additions & 5 deletions include/dpp/socket.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
/************************************************************************************
*
* D++, A Lightweight C++ library for Discord
*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2021 Craig Edwards and D++ contributors
* (https://github.com/brainboxdotcc/DPP/graphs/contributors)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
************************************************************************************/
#pragma once

#ifdef _WIN32
#include <WinSock2.h>
#include <WS2tcpip.h>
#include <io.h>
#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout)
#define pollfd WSAPOLLFD
#else
#include <netinet/in.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#endif
#include <string_view>
#include <cstdint>


namespace dpp
{
/**
* @brief Represents a socket file descriptor.
* This is used to ensure parity between windows and unix-like systems.
*/
/**
* @brief Represents a socket file descriptor.
* This is used to ensure parity between windows and unix-like systems.
*/
#ifndef _WIN32
using socket = int;
#else
using socket = SOCKET;
#endif
} // namespace dpp

#ifndef SOCKET_ERROR
/**
Expand All @@ -26,3 +61,92 @@ namespace dpp
*/
#define INVALID_SOCKET ~0
#endif

/**
* @brief Represents an IPv4 address for use with socket functions such as
* bind().
*
* Avoids type punning with C style casts from sockaddr_in to sockaddr pointers.
*/
class address_t {
/**
* @brief Internal sockaddr struct
*/
sockaddr socket_addr{};

public:

/**
* @brief Create a new address_t
* @param ip IPv4 address
* @param port Port number
* @note Leave both as defaults to create a default bind-to-any setting
*/
address_t(const std::string_view ip = "0.0.0.0", uint16_t port = 0);

/**
* @brief Get sockaddr
* @return sockaddr pointer
*/
sockaddr *get_socket_address();

/**
* @brief Returns size of sockaddr_in
* @return sockaddr_in size
* @note It is important the size this returns is sizeof(sockaddr_in) not
* sizeof(sockaddr), this is NOT a bug but requirement of C socket functions.
*/
size_t size();

/**
* @brief Get the port bound to a file descriptor
* @param fd File descriptor
* @return Port number, or 0 if no port bound
*/
uint16_t get_port(socket fd);
};

/**
* @brief Allocates a dpp::socket, closing it on destruction
*/
struct raii_socket {
/**
* @brief File descriptor
*/
socket fd;

/**
* @brief Construct a socket.
* Calls socket() and returns a new file descriptor
*/
raii_socket();

/**
* @brief Non-copyable
*/
raii_socket(raii_socket&) = delete;

/**
* @brief Non-movable
*/
raii_socket(raii_socket&&) = delete;

/**
* @brief Non-copyable
*/
raii_socket operator=(raii_socket&) = delete;

/**
* @brief Non-movable
*/
raii_socket operator=(raii_socket&&) = delete;

/**
* @brief Destructor
* Frees the socket by closing it
*/
~raii_socket();
};


}
63 changes: 63 additions & 0 deletions src/dpp/socket.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/************************************************************************************
*
* D++, A Lightweight C++ library for Discord
*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2021 Craig Edwards and D++ contributors
* (https://github.com/brainboxdotcc/DPP/graphs/contributors)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
************************************************************************************/

#include <dpp/socket.h>
#include <dpp/sslclient.h>
#include <cstring>

namespace dpp {

address_t::address_t(const std::string_view ip, uint16_t port) {
sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_port = htons(port);
address.sin_addr.s_addr = inet_addr(ip.data());
std::memcpy(&socket_addr, &address, sizeof(address));
}

sockaddr* address_t::get_socket_address() {
return &socket_addr;
}

size_t address_t::size() {
return sizeof(sockaddr_in);
}

uint16_t address_t::get_port(socket fd) {
socklen_t len = size();
if (getsockname(fd, &socket_addr, &len) > -1) {
sockaddr_in sin{};
std::memcpy(&sin, &socket_addr, sizeof(sockaddr_in));
return ntohs(sin.sin_port);
}
return 0;
}

raii_socket::raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) {
}

raii_socket::~raii_socket() {
close_socket(fd);
}


}
4 changes: 2 additions & 2 deletions src/dpp/sslclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ void ssl_client::connect()
/* Resolve hostname to IP */
int err = 0;
const dns_cache_entry* addr = resolve_hostname(hostname, port);
sfd = ::socket(addr->addr.ai_family, addr->addr.ai_socktype, addr->addr.ai_protocol);
sfd = addr->make_connecting_socket();
if (sfd == ERROR_STATUS) {
err = errno;
} else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, (int)addr->addr.ai_addrlen, SOCKET_OP_TIMEOUT) != 0) {
} else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, addr->size(), SOCKET_OP_TIMEOUT) != 0) {
close_socket(sfd);
sfd = ERROR_STATUS;
}
Expand Down
28 changes: 4 additions & 24 deletions src/dpp/voice/enabled/discover_ip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,6 @@ struct ip_discovery_packet {
}
};

/**
* @brief Allocates a dpp::socket, closing it on destruction
*/
struct raii_socket {
dpp::socket fd;

raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) { };
raii_socket(raii_socket&) = delete;
raii_socket(raii_socket&&) = delete;
raii_socket operator=(raii_socket&) = delete;
raii_socket operator=(raii_socket&&) = delete;
~raii_socket() { close_socket(fd); };
};

constexpr int discovery_timeout = 1000;

std::string discord_voice_client::discover_ip() {
Expand All @@ -158,19 +144,13 @@ std::string discord_voice_client::discover_ip() {
ip_discovery_packet discovery(this->ssrc);

if (socket.fd >= 0) {
sockaddr_in servaddr{};
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
servaddr.sin_port = htons(0);
if (bind(socket.fd, reinterpret_cast<const sockaddr*>(&servaddr), sizeof(servaddr)) < 0) {
address_t bind_any;
if (bind(socket.fd, bind_any.get_socket_address(), bind_any.size()) < 0) {
log(ll_warning, "Could not bind socket for IP discovery");
return "";
}
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_port = htons(this->port);
servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str());
if (::connect(socket.fd, reinterpret_cast<const sockaddr*>(&servaddr), sizeof(sockaddr_in)) < 0) {
address_t bind_port(this->ip, this->port);
if (::connect(socket.fd, bind_port.get_socket_address(), bind_port.size()) < 0) {
log(ll_warning, "Could not connect socket for IP discovery");
return "";
}
Expand Down
19 changes: 5 additions & 14 deletions src/dpp/voice/enabled/handle_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
this->ip = d["ip"].get<std::string>();
this->port = d["port"].get<uint16_t>();
this->ssrc = d["ssrc"].get<uint64_t>();
destination = address_t(this->ip, this->port);

// Modes
for (auto & m : d["modes"]) {
this->modes.push_back(m.get<std::string>());
Expand All @@ -446,13 +448,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
dpp::socket newfd = 0;
if ((newfd = ::socket(AF_INET, SOCK_DGRAM, 0)) >= 0) {

sockaddr_in servaddr{};
memset(&servaddr, 0, sizeof(sockaddr_in));
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
servaddr.sin_port = htons(0);

if (bind(newfd, reinterpret_cast<sockaddr*>(&servaddr), sizeof(servaddr)) < 0) {
address_t bind_any;
if (bind(newfd, bind_any.get_socket_address(), bind_any.size()) < 0) {
throw dpp::connection_exception(err_bind_failure, "Can't bind() client UDP socket");
}

Expand All @@ -467,13 +464,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
this->custom_writeable_ready = [this] { write_ready(); };
this->custom_readable_ready = [this] { read_ready(); };

int bound_port = 0;
sockaddr_in sin{};
socklen_t len = sizeof(sin);
if (getsockname(this->fd, reinterpret_cast<sockaddr *>(&sin), &len) > -1) {
bound_port = ntohs(sin.sin_port);
}

int bound_port = address_t().get_port(this->fd);
this->write(json({
{ "op", voice_opcode_connection_select_protocol },
{ "d", {
Expand Down
17 changes: 7 additions & 10 deletions src/dpp/voice/enabled/read_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,23 @@ dpp::socket discord_voice_client::want_read() {


void discord_voice_client::send(const char* packet, size_t len, uint64_t duration) {
std::lock_guard<std::mutex> lock(this->stream_mutex);
voice_out_packet frame;
frame.packet = std::string(packet, len);
frame.packet.assign(packet, packet + len);
frame.duration = duration;
outbuf.emplace_back(frame);
{
std::lock_guard<std::mutex> lock(this->stream_mutex);
outbuf.emplace_back(frame);
}
}

int discord_voice_client::udp_send(const char* data, size_t length) {
sockaddr_in servaddr{};
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_port = htons(this->port);
servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str());
return static_cast<int>(sendto(
this->fd,
data,
static_cast<int>(length),
0,
reinterpret_cast<const sockaddr*>(&servaddr),
static_cast<int>(sizeof(sockaddr_in))
destination.get_socket_address(),
destination.size()
));
}

Expand Down

0 comments on commit 4f85429

Please sign in to comment.