diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f14ea3c2b25309..acd8e9aa10c5d5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -118,6 +118,7 @@ add_library(bitcoin_common STATIC EXCLUDE_FROM_ALL common/run_command.cpp common/settings.cpp common/signmessage.cpp + common/sv2_connman.cpp common/sv2_noise.cpp common/sv2_transport.cpp common/system.cpp diff --git a/src/common/sv2_connman.cpp b/src/common/sv2_connman.cpp new file mode 100644 index 00000000000000..978fc1e516a676 --- /dev/null +++ b/src/common/sv2_connman.cpp @@ -0,0 +1,387 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include + +using node::Sv2MsgType; + +Sv2Connman::~Sv2Connman() +{ + AssertLockNotHeld(m_clients_mutex); + + { + LOCK(m_clients_mutex); + for (const auto& client : m_sv2_clients) { + LogTrace(BCLog::SV2, "Disconnecting client id=%zu\n", + client->m_id); + client->m_disconnect_flag = true; + } + DisconnectFlagged(); + } + + Interrupt(); + StopThreads(); +} + +bool Sv2Connman::Start(Sv2EventsInterface* msgproc, std::string host, uint16_t port) +{ + m_msgproc = msgproc; + + try { + auto sock = BindListenPort(host, port); + m_listening_socket = std::move(sock); + } catch (const std::runtime_error& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Template Provider failed to bind to port %d: %s\n", port, e.what()); + return false; + } + + m_thread_sv2_handler = std::thread(&util::TraceThread, "sv2connman", [this] { ThreadSv2Handler(); }); + return true; +} + +std::shared_ptr Sv2Connman::BindListenPort(std::string host, uint16_t port) const +{ + const CService addr_bind = LookupNumeric(host, port); + + auto sock = CreateSock(addr_bind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); + if (!sock) { + throw std::runtime_error("Sv2 Template Provider cannot create socket"); + } + + struct sockaddr_storage sockaddr; + socklen_t len = sizeof(sockaddr); + + if (!addr_bind.GetSockAddr(reinterpret_cast(&sockaddr), &len)) { + throw std::runtime_error("Sv2 Template Provider failed to get socket address"); + } + + if (sock->Bind(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { + const int nErr = WSAGetLastError(); + if (nErr == WSAEADDRINUSE) { + throw std::runtime_error(strprintf("Unable to bind to %d on this computer. Another Stratum v2 process is probably already running.\n", port)); + } + + throw std::runtime_error(strprintf("Unable to bind to %d on this computer (bind returned error %s )\n", port, NetworkErrorString(nErr))); + } + + constexpr int max_pending_conns{4096}; + if (sock->Listen(max_pending_conns) == SOCKET_ERROR) { + throw std::runtime_error("Sv2 listening socket has an error listening"); + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Info, "%s listening on %s:%d\n", SV2_PROTOCOL_NAMES.at(m_subprotocol), host, port); + + return sock; +} + + +void Sv2Connman::DisconnectFlagged() +{ + AssertLockHeld(m_clients_mutex); + + // Remove clients that are flagged for disconnection. + m_sv2_clients.erase( + std::remove_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto &client) { + return client->m_disconnect_flag; + }), m_sv2_clients.end()); +} + +void Sv2Connman::ThreadSv2Handler() EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) +{ + AssertLockNotHeld(m_clients_mutex); + + while (!m_flag_interrupt_sv2) { + { + LOCK(m_clients_mutex); + DisconnectFlagged(); + } + + // Poll/Select the sockets that need handling. + Sock::EventsPerSock events_per_sock = WITH_LOCK(m_clients_mutex, return GenerateWaitSockets(m_listening_socket, m_sv2_clients)); + + constexpr auto timeout = std::chrono::milliseconds(50); + if (!events_per_sock.begin()->first->WaitMany(timeout, events_per_sock)) { + continue; + } + + // Accept any new connections for sv2 clients. + const auto listening_sock = events_per_sock.find(m_listening_socket); + if (listening_sock != events_per_sock.end() && listening_sock->second.occurred & Sock::RECV) { + struct sockaddr_storage sockaddr; + socklen_t sockaddr_len = sizeof(sockaddr); + + auto sock = m_listening_socket->Accept(reinterpret_cast(&sockaddr), &sockaddr_len); + if (sock) { + Assume(m_certificate); + LOCK(m_clients_mutex); + std::unique_ptr transport = std::make_unique(m_static_key, m_certificate.value()); + size_t id{m_sv2_clients.size() + 1}; + auto client = std::make_unique(id, std::move(sock), std::move(transport)); + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "New client id=%zu connected\n", client->m_id); + m_sv2_clients.emplace_back(std::move(client)); + } + } + + LOCK(m_clients_mutex); + // Process messages from and for connected sv2_clients. + for (auto& client : m_sv2_clients) { + bool has_received_data = false; + bool has_error_occurred = false; + + const auto socket_it = events_per_sock.find(client->m_sock); + if (socket_it != events_per_sock.end()) { + has_received_data = socket_it->second.occurred & Sock::RECV; + has_error_occurred = socket_it->second.occurred & Sock::ERR; + } + + if (has_error_occurred) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Socket receive error, disconnecting client id=%zu\n", + client->m_id); + client->m_disconnect_flag = true; + continue; + } + + // Process message queue and any outbound bytes still held by the transport + auto it = client->m_send_messages.begin(); + std::optional expected_more; + while(true) { + if (it != client->m_send_messages.end()) { + // If possible, move one message from the send queue to the transport. + // This fails when there is an existing message still being sent, + // or when the handshake has not yet completed. + // + // Wrap Sv2NetMsg inside CSerializedNetMsg for transport + CSerializedNetMsg net_msg{*it}; + if (client->m_transport->SetMessageToSend(net_msg)) { + ++it; + } + } + + const auto& [data, more, _m_message_type] = client->m_transport->GetBytesToSend(/*have_next_message=*/it != client->m_send_messages.end()); + size_t total_sent = 0; + + // We rely on the 'more' value returned by GetBytesToSend to correctly predict whether more + // bytes are still to be sent, to correctly set the MSG_MORE flag. As a sanity check, + // verify that the previously returned 'more' was correct. + if (expected_more.has_value()) Assume(!data.empty() == *expected_more); + expected_more = more; + ssize_t sent = 0; + + if (!data.empty()) { + int flags = MSG_NOSIGNAL | MSG_DONTWAIT; +#ifdef MSG_MORE + if (more) { + flags |= MSG_MORE; + } +#endif + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send %d bytes to client id=%zu\n", + data.size() - total_sent, client->m_id); + sent = client->m_sock->Send(data.data() + total_sent, data.size() - total_sent, flags); + } + if (sent > 0) { + // Notify transport that bytes have been processed. + client->m_transport->MarkBytesSent(sent); + if ((size_t)sent != data.size()) { + // could not send full message; stop sending more + break; + } + } else { + if (sent < 0) { + // error + int nErr = WSAGetLastError(); + if (nErr != WSAEWOULDBLOCK && nErr != WSAEMSGSIZE && nErr != WSAEINTR && nErr != WSAEINPROGRESS) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Socket send error for client id=%zu: %s\n", + client->m_id, NetworkErrorString(nErr)); + client->m_disconnect_flag = true; + } + } + break; + } + } + // Clear messages that have been handed to transport from the queue + client->m_send_messages.erase(client->m_send_messages.begin(), it); + + // Stop processing this client if something went wrong during sending + if (client->m_disconnect_flag) break; + + if (has_received_data) { + uint8_t bytes_received_buf[0x10000]; + + const auto num_bytes_received = client->m_sock->Recv(bytes_received_buf, sizeof(bytes_received_buf), MSG_DONTWAIT); + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Num bytes received from client id=%zu: %d\n", + client->m_id, num_bytes_received); + + if (num_bytes_received <= 0) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Disconnecting client id=%zu\n", + client->m_id); + client->m_disconnect_flag = true; + break; + } + + try + { + auto msg_ = Span(bytes_received_buf, num_bytes_received); + Span msg(reinterpret_cast(msg_.data()), msg_.size()); + while (msg.size() > 0) { + // absorb network data + if (!client->m_transport->ReceivedBytes(msg)) { + // Serious transport problem + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Transport problem, disconnecting client id=%zu\n", + client->m_id); + client->m_disconnect_flag = true; + break; + } + + if (client->m_transport->ReceivedMessageComplete()) { + bool dummy_reject_message = false; + Sv2NetMsg msg = client->m_transport->GetReceivedMessage(std::chrono::milliseconds(0), dummy_reject_message); + ProcessSv2Message(msg, *client.get()); + } + } + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received error when processing client id=%zu message: %s\n", client->m_id, e.what()); + client->m_disconnect_flag = true; + } + } + } + } +} + +Sock::EventsPerSock Sv2Connman::GenerateWaitSockets(const std::shared_ptr& listen_socket, const Clients& sv2_clients) const +{ + Sock::EventsPerSock events_per_sock; + events_per_sock.emplace(listen_socket, Sock::Events(Sock::RECV)); + + for (const auto& client : sv2_clients) { + if (!client->m_disconnect_flag && client->m_sock) { + events_per_sock.emplace(client->m_sock, Sock::Events{Sock::RECV | Sock::ERR}); + } + } + + return events_per_sock; +} + +void Sv2Connman::Interrupt() +{ + m_flag_interrupt_sv2 = true; +} + +void Sv2Connman::StopThreads() +{ + if (m_thread_sv2_handler.joinable()) { + m_thread_sv2_handler.join(); + } +} + +void Sv2Connman::ProcessSv2Message(const Sv2NetMsg& sv2_net_msg, Sv2Client& client) +{ + uint8_t msg_type[1] = {uint8_t(sv2_net_msg.m_msg_type)}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Received 0x%s %s from client id=%zu\n", + // After clang-17: + // std::format("{:x}", uint8_t(sv2_net_msg.m_msg_type)), + HexStr(msg_type), + node::SV2_MSG_NAMES.at(sv2_net_msg.m_msg_type), client.m_id); + + DataStream ss (sv2_net_msg.m_msg); + + switch (sv2_net_msg.m_msg_type) + { + case Sv2MsgType::SETUP_CONNECTION: + { + if (client.m_setup_connection_confirmed) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Client client id=%zu connection has already been confirmed\n", + client.m_id); + return; + } + + node::Sv2SetupConnectionMsg setup_conn; + try { + ss >> setup_conn; + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid SetupConnection message from client id=%zu: %s\n", + client.m_id, e.what()); + client.m_disconnect_flag = true; + return; + } + + // Disconnect a client that connects on the wrong subprotocol. + if (setup_conn.m_protocol != m_subprotocol) { + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"unsupported-protocol"}}; + + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnectionError to client id=%zu\n", + client.m_id); + client.m_send_messages.emplace_back(setup_conn_err); + + client.m_disconnect_flag = true; + return; + } + + // Disconnect a client if they are not running a compatible protocol version. + if ((m_protocol_version < setup_conn.m_min_version) || (m_protocol_version > setup_conn.m_max_version)) { + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"protocol-version-mismatch"}}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnection.Error to client id=%zu\n", + client.m_id); + client.m_send_messages.emplace_back(setup_conn_err); + + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received a connection from client id=%zu with incompatible protocol_versions: min_version: %d, max_version: %d\n", + client.m_id, setup_conn.m_min_version, setup_conn.m_max_version); + client.m_disconnect_flag = true; + return; + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x01 SetupConnection.Success to client id=%zu\n", + client.m_id); + node::Sv2SetupConnectionSuccessMsg setup_success{m_protocol_version, m_optional_features}; + client.m_send_messages.emplace_back(setup_success); + + client.m_setup_connection_confirmed = true; + + break; + } + case Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE: + { + if (!client.m_setup_connection_confirmed) { + client.m_disconnect_flag = true; + return; + } + + node::Sv2CoinbaseOutputDataSizeMsg coinbase_output_data_size; + try { + ss >> coinbase_output_data_size; + client.m_coinbase_output_data_size_recv = true; + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid CoinbaseOutputDataSize message from client id=%zu: %s\n", + client.m_id, e.what()); + client.m_disconnect_flag = true; + return; + } + + uint32_t max_additional_size = coinbase_output_data_size.m_coinbase_output_max_additional_size; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "coinbase_output_max_additional_size=%d bytes\n", max_additional_size); + + if (max_additional_size > MAX_BLOCK_WEIGHT) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received impossible CoinbaseOutputDataSize from client id=%zu: %d\n", + client.m_id, max_additional_size); + client.m_disconnect_flag = true; + return; + } + + client.m_coinbase_tx_outputs_size = coinbase_output_data_size.m_coinbase_output_max_additional_size; + + break; + } + default: { + uint8_t msg_type[1]{uint8_t(sv2_net_msg.m_msg_type)}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Warning, "Received unknown message type 0x%s from client id=%zu\n", + HexStr(msg_type), client.m_id); + break; + } + } + + m_msgproc->ReceivedMessage(client, sv2_net_msg.m_msg_type); +} diff --git a/src/common/sv2_connman.h b/src/common/sv2_connman.h new file mode 100644 index 00000000000000..01dd06ccbbd43c --- /dev/null +++ b/src/common/sv2_connman.h @@ -0,0 +1,233 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_COMMON_SV2_CONNMAN_H +#define BITCOIN_COMMON_SV2_CONNMAN_H + +#include +#include +#include + +namespace { + /* + * Supported Stratum v2 subprotocols + */ + static constexpr uint8_t TP_SUBPROTOCOL{0x02}; + + static const std::map SV2_PROTOCOL_NAMES{ + {0x02, "Template Provider"}, + }; +} + +struct Sv2Client +{ + /* Ephemeral identifier for debugging purposes */ + size_t m_id; + + /** + * Receiving and sending socket for the connected client + */ + std::shared_ptr m_sock; + + /** + * Transport + */ + std::unique_ptr m_transport; + + /** + * Whether the client has confirmed the connection with a successful SetupConnection. + */ + bool m_setup_connection_confirmed = false; + + /** + * Whether the client is a candidate for disconnection. + */ + bool m_disconnect_flag = false; + + /** Queue of messages to be sent */ + std::deque m_send_messages; + + /** + * Whether the client has received CoinbaseOutputDataSize message. + */ + bool m_coinbase_output_data_size_recv = false; + + /** + * Specific additional coinbase tx output size required for the client. + */ + unsigned int m_coinbase_tx_outputs_size; + + explicit Sv2Client(size_t id, std::shared_ptr sock, std::unique_ptr transport) : + m_id{id}, m_sock{std::move(sock)}, m_transport{std::move(transport)} {}; + + bool IsFullyConnected() + { + return !m_disconnect_flag && m_setup_connection_confirmed; + } + + Sv2Client(Sv2Client&) = delete; + Sv2Client& operator=(const Sv2Client&) = delete; +}; + +/** + * Interface for sv2 message handling + */ +class Sv2EventsInterface +{ +public: + /** + * Generic notification that a message was received. Does not include the + * message itself. + * + * @param[in] client The client which we have received messages from. + * @param[in] msg_type the message type + */ + virtual void ReceivedMessage(Sv2Client& client, node::Sv2MsgType msg_type) = 0; + + virtual ~Sv2EventsInterface() = default; +}; + +/* + * Handle Stratum v2 connections, similar to CConnman. + * Currently only supports inbound connections. + */ +class Sv2Connman +{ +private: + /** Interface to pass events up */ + Sv2EventsInterface* m_msgproc; + + /** + * The current protocol version of stratum v2 supported by the server. Not to be confused + * with byte value of identitying the stratum v2 subprotocol. + */ + const uint16_t m_protocol_version = 2; + + /** + * The currently supported optional features. + */ + const uint16_t m_optional_features = 0; + + /** + * The subprotocol used in setup connection messages. + * An Sv2Connman only recognizes its own subprotocol. + */ + const uint8_t m_subprotocol; + + /** + * The main listening socket for new stratum v2 connections. + */ + std::shared_ptr m_listening_socket; + + CKey m_static_key; + + XOnlyPubKey m_authority_pubkey; + + std::optional m_certificate; + + /** + * A list of all connected stratum v2 clients. + */ + using Clients = std::vector>; + Clients m_sv2_clients GUARDED_BY(m_clients_mutex); + + /** + * The main thread for connection handling. + */ + std::thread m_thread_sv2_handler; + + /** + * Signal for handling interrupts and stopping the template provider event loop. + */ + std::atomic m_flag_interrupt_sv2{false}; + CThreadInterrupt m_interrupt_sv2; + + /** + * Creates a socket and binds the port for new stratum v2 connections. + * @throws std::runtime_error if port is unable to bind. + */ + [[nodiscard]] std::shared_ptr BindListenPort(std::string host, uint16_t port) const; + + void DisconnectFlagged() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex); + + /** + * The main thread for the template provider, contains an event loop handling + * all tasks for the template provider. + */ + void ThreadSv2Handler(); + + /** + * Generates the socket events for each Sv2Client socket and the main listening socket. + */ + [[nodiscard]] Sock::EventsPerSock GenerateWaitSockets(const std::shared_ptr& listen_socket, const Clients& sv2_clients) const; + + /** + * Encrypt the header and message payload and send it. + * @throws std::runtime_error if encrypting the message fails. + */ + bool EncryptAndSendMessage(Sv2Client& client, node::Sv2NetMsg& net_msg); + + /** + * A helper method to read and decrypt multiple Sv2NetMsgs. + */ + std::vector ReadAndDecryptSv2NetMsgs(Sv2Client& client, Span buffer); + +public: + Sv2Connman(uint8_t subprotocol, CKey static_key, XOnlyPubKey authority_pubkey, Sv2SignatureNoiseMessage certificate) : + m_subprotocol(subprotocol), m_static_key(static_key), m_authority_pubkey(authority_pubkey), m_certificate(certificate) {}; + + ~Sv2Connman(); + + Mutex m_clients_mutex; + + /** + * Starts the Stratum v2 server and thread. + * returns false if port is unable to bind. + */ + [[nodiscard]] bool Start(Sv2EventsInterface* msgproc, std::string host, uint16_t port); + + /** + * Triggered on interrupt signals to stop the main event loop in ThreadSv2Handler(). + */ + void Interrupt(); + + /** + * Tear down of the connman thread and any other necessary tear down. + */ + void StopThreads(); + + /** + * Main handler for all received stratum v2 messages. + */ + void ProcessSv2Message(const node::Sv2NetMsg& sv2_header, Sv2Client& client); + + using Sv2ClientFn = std::function; + /** Perform a function on each fully connected client. */ + void ForEachClient(const Sv2ClientFn& func) EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) + { + LOCK(m_clients_mutex); + for (const auto& client : m_sv2_clients) { + if (client->IsFullyConnected()) func(*client); + } + }; + + /** Number of clients that are not marked for disconnection, used for tests. */ + size_t ConnectedClients() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex) + { + return std::count_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto& c) { + return !c->m_disconnect_flag; + }); + } + + /** Number of clients with m_setup_connection_confirmed, used for tests. */ + size_t FullyConnectedClients() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex) + { + return std::count_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto& c) { + return c->IsFullyConnected(); + }); + } + +}; + +#endif // BITCOIN_COMMON_SV2_CONNMAN_H diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 06c5f47ef3a9ea..8f760b45eba812 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -119,6 +119,7 @@ add_executable(test_bitcoin sock_tests.cpp span_tests.cpp streams_tests.cpp + sv2_connman_tests.cpp sv2_messages_tests.cpp sv2_noise_tests.cpp sv2_transport_tests.cpp diff --git a/src/test/sv2_connman_tests.cpp b/src/test/sv2_connman_tests.cpp new file mode 100644 index 00000000000000..790f895ee5efae --- /dev/null +++ b/src/test/sv2_connman_tests.cpp @@ -0,0 +1,202 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +BOOST_FIXTURE_TEST_SUITE(sv2_connman_tests, TestChain100Setup) + +/** + * A class for testing the Sv2Connman. Each ConnTester encapsulates a + * Sv2Connman (the one being tested) as well as a Sv2Cipher + * to act as the other side. + */ +class ConnTester : Sv2EventsInterface { +private: + std::unique_ptr m_peer_transport; //!< Transport for peer + // Sockets that will be returned by the Sv2Connman's listening socket Accept() method. + std::shared_ptr m_sv2connman_accepted_sockets{std::make_shared()}; + + std::shared_ptr m_current_client_pipes; + + XOnlyPubKey m_connman_authority_pubkey; + +public: + std::unique_ptr m_connman; //!< Sv2Connman being tested + + ConnTester() + { + CreateSock = [this](int, int, int) -> std::unique_ptr { + // This will be the bind/listen socket from m_connman. It will + // create other sockets via its Accept() method. + return std::make_unique(std::make_shared(), m_sv2connman_accepted_sockets); + }; + + CKey static_key; + static_key.MakeNewKey(true); + auto authority_key{GenerateRandomKey()}; + m_connman_authority_pubkey = XOnlyPubKey(authority_key.GetPubKey()); + + // Generate and sign certificate + auto now{GetTime()}; + uint16_t version = 0; + // Start validity a little bit in the past to account for clock difference + uint32_t valid_from = static_cast(std::chrono::duration_cast(now).count()) - 3600; + uint32_t valid_to = std::numeric_limits::max(); // 2106 + Sv2SignatureNoiseMessage certificate{version, valid_from, valid_to, XOnlyPubKey(static_key.GetPubKey()), authority_key}; + + m_connman = std::make_unique(TP_SUBPROTOCOL, static_key, m_connman_authority_pubkey, certificate); + + BOOST_REQUIRE(m_connman->Start(this, "127.0.0.1", 18447)); + } + + ~ConnTester() + { + CreateSock = CreateSockOS; + } + + void SendPeerBytes() + { + const auto& [data, more, _m_message_type] = m_peer_transport->GetBytesToSend(/*have_next_message=*/false); + BOOST_REQUIRE(data.size() > 0); + // Schedule data to be returned by the next Recv() call from + // Sv2Connman on the socket it has accepted. + m_current_client_pipes->recv.PushBytes(data.data(), data.size()); + m_peer_transport->MarkBytesSent(data.size()); + } + + // Have the peer receive and process bytes: + size_t PeerReceiveBytes() + { + uint8_t buf[0x10000]; + // Get the data that has been written to the accepted socket with Send() by Sv2Connman. + // Wait until the bytes appear in the "send" pipe. + ssize_t n; + for (;;) { + n = m_current_client_pipes->send.GetBytes(buf, sizeof(buf), 0); + if (n != -1 || errno != EAGAIN) { + break; + } + UninterruptibleSleep(50ms); + } + + // Inform client's transport that some bytes have been received (sent by Sv2Connman). + if (n > 0) { + Span s(buf, n); + BOOST_REQUIRE(m_peer_transport->ReceivedBytes(s)); + } + + return n; + } + + /* Create a new client and perform handshake */ + void handshake() + { + m_peer_transport.reset(); + + auto peer_static_key{GenerateRandomKey()}; + m_peer_transport = std::make_unique(std::move(peer_static_key), m_connman_authority_pubkey); + + // Have Sv2Connman's listen socket's Accept() simulate a newly arrived connection. + m_current_client_pipes = std::make_shared(); + m_sv2connman_accepted_sockets->Push( + std::make_unique(m_current_client_pipes, std::make_shared())); + + // Flush transport for handshake part 1 + SendPeerBytes(); + + // Read handshake part 2 from transport + BOOST_REQUIRE_EQUAL(PeerReceiveBytes(), Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + + BOOST_REQUIRE(IsConnected()); + } + + void receiveMessage(Sv2NetMsg& msg) + { + // Client encrypts message and puts it on the transport: + CSerializedNetMsg net_msg{std::move(msg)}; + BOOST_REQUIRE(m_peer_transport->SetMessageToSend(net_msg)); + SendPeerBytes(); + } + + bool IsConnected() + { + LOCK(m_connman->m_clients_mutex); + return m_connman->ConnectedClients() > 0; + } + + bool IsFullyConnected() + { + LOCK(m_connman->m_clients_mutex); + return m_connman->FullyConnectedClients() > 0; + } + + Sv2NetMsg SetupConnectionMsg() + { + std::vector bytes{ + 0x02, // protocol + 0x02, 0x00, // min_version + 0x02, 0x00, // max_version + 0x01, 0x00, 0x00, 0x00, // flags + 0x07, 0x30, 0x2e, 0x30, 0x2e, 0x30, 0x2e, 0x30, // endpoint_host + 0x61, 0x21, // endpoint_port + 0x07, 0x42, 0x69, 0x74, 0x6d, 0x61, 0x69, 0x6e, // vendor + 0x08, 0x53, 0x39, 0x69, 0x20, 0x31, 0x33, 0x2e, 0x35, // hardware_version + 0x1c, 0x62, 0x72, 0x61, 0x69, 0x69, 0x6e, 0x73, 0x2d, 0x6f, 0x73, 0x2d, 0x32, 0x30, + 0x31, 0x38, 0x2d, 0x30, 0x39, 0x2d, 0x32, 0x32, 0x2d, 0x31, 0x2d, 0x68, 0x61, 0x73, + 0x68, // firmware + 0x10, 0x73, 0x6f, 0x6d, 0x65, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2d, 0x75, + 0x75, 0x69, 0x64, // device_id + }; + + return node::Sv2NetMsg{node::Sv2MsgType::SETUP_CONNECTION, std::move(bytes)}; + } + + void ReceivedMessage(Sv2Client& client, node::Sv2MsgType msg_type) override { + } + +}; + +BOOST_AUTO_TEST_CASE(client_tests) +{ + ConnTester tester{}; + + BOOST_REQUIRE(!tester.IsConnected()); + tester.handshake(); + BOOST_REQUIRE(!tester.IsFullyConnected()); + + // After the handshake the client must send a SetupConnection message to the + // Template Provider. + + // An empty SetupConnection message should cause disconnection + node::Sv2NetMsg sv2_msg{node::Sv2MsgType::SETUP_CONNECTION, {}}; + tester.receiveMessage(sv2_msg); + BOOST_REQUIRE_EQUAL(tester.PeerReceiveBytes(), 0); + + BOOST_REQUIRE(!tester.IsConnected()); + + BOOST_TEST_MESSAGE("Reconnect after empty message"); + + // Reconnect + tester.handshake(); + BOOST_TEST_MESSAGE("Handshake done, send SetupConnectionMsg"); + + node::Sv2NetMsg setup{tester.SetupConnectionMsg()}; + tester.receiveMessage(setup); + // SetupConnection.Success is 6 bytes + BOOST_REQUIRE_EQUAL(tester.PeerReceiveBytes(), SV2_HEADER_ENCRYPTED_SIZE + 6 + Poly1305::TAGLEN); + BOOST_REQUIRE(tester.IsFullyConnected()); + + std::vector coinbase_output_max_additional_size_bytes{ + 0x01, 0x00, 0x00, 0x00 + }; + node::Sv2NetMsg msg{node::Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE, std::move(coinbase_output_max_additional_size_bytes)}; + // No reply expected, not yet implemented + tester.receiveMessage(msg); +} + +BOOST_AUTO_TEST_SUITE_END()