diff --git a/.cspell.json b/.cspell.json index 3cfb280732..0d8e2cd969 100644 --- a/.cspell.json +++ b/.cspell.json @@ -148,7 +148,8 @@ "nullopt", "chrono", "ciphersuite", - "rmap" + "rmap", + "WSAPOLLFD" ], "flagWords": [ "hte" diff --git a/include/dpp/discordvoiceclient.h b/include/dpp/discordvoiceclient.h index 6ce21bf309..658ba6422a 100644 --- a/include/dpp/discordvoiceclient.h +++ b/include/dpp/discordvoiceclient.h @@ -584,6 +584,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ dave_version_t dave_version; + /** + * @brief Destination address for where packets go + * on the UDP socket + */ + address_t destination{}; + + /** * @brief Send data to UDP socket immediately. * diff --git a/include/dpp/dns.h b/include/dpp/dns.h index 5a3a566d45..48cebcd563 100644 --- a/include/dpp/dns.h +++ b/include/dpp/dns.h @@ -31,6 +31,8 @@ #include #include #include +#include +#include namespace dpp { @@ -40,23 +42,42 @@ namespace dpp { */ struct dns_cache_entry { /** - * @brief Resolved address information + * @brief Resolved address metadata */ addrinfo addr; /** - * @brief Socket address. - * Discord only supports ipv4, but sockaddr_in6 is larger - * than sockaddr_in, sockaddr_storage will hold either. This - * means that if discord ever do support ipv6 we just flip - * one value in dns.cpp and that should be all that is needed. + * @brief Resolved address as string. + * The metadata is needed to know what type of address it is. + * Do not do silly stuff like just looking to see if '.' is in it! */ - sockaddr_storage ai_addr; + std::string resolved_addr; /** * @brief Time at which this cache entry is invalidated */ time_t expire_timestamp; + + /** + * @brief Get address length + * @return address length + */ + [[nodiscard]] int size() const; + + /** + * @brief Get the address_t that corresponds to this cache entry + * for use when connecting with ::connect() + * @param port Port number to connect to + * @return address_t prefilled with the IP and port number + */ + [[nodiscard]] const address_t get_connecting_address(uint16_t port) const; + + /** + * @brief Allocate a socket file descriptor for the given dns address + * @return File descriptor ready for calling connect(), or INVALID_SOCKET + * on failure. + */ + [[nodiscard]] socket make_connecting_socket() const; }; /** @@ -73,4 +94,4 @@ namespace dpp { * @throw dpp::connection_exception On failure to resolve hostname */ const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port); -} // namespace dpp +} diff --git a/include/dpp/socket.h b/include/dpp/socket.h index d94914b35a..3888b2e022 100644 --- a/include/dpp/socket.h +++ b/include/dpp/socket.h @@ -1,17 +1,53 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ #pragma once +#include +#ifdef _WIN32 + #include + #include + #include + #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) + #define pollfd WSAPOLLFD +#else + #include + #include + #include +#endif +#include +#include + + namespace dpp { - /** - * @brief Represents a socket file descriptor. - * This is used to ensure parity between windows and unix-like systems. - */ +/** + * @brief Represents a socket file descriptor. + * This is used to ensure parity between windows and unix-like systems. + */ #ifndef _WIN32 using socket = int; #else using socket = SOCKET; #endif -} // namespace dpp #ifndef SOCKET_ERROR /** @@ -26,3 +62,92 @@ namespace dpp */ #define INVALID_SOCKET ~0 #endif + +/** + * @brief Represents an IPv4 address for use with socket functions such as + * bind(). + * + * Avoids type punning with C style casts from sockaddr_in to sockaddr pointers. + */ +class DPP_EXPORT address_t { + /** + * @brief Internal sockaddr struct + */ + sockaddr socket_addr{}; + +public: + + /** + * @brief Create a new address_t + * @param ip IPv4 address + * @param port Port number + * @note Leave both as defaults to create a default bind-to-any setting + */ + address_t(const std::string_view ip = "0.0.0.0", uint16_t port = 0); + + /** + * @brief Get sockaddr + * @return sockaddr pointer + */ + [[nodiscard]] sockaddr *get_socket_address(); + + /** + * @brief Returns size of sockaddr_in + * @return sockaddr_in size + * @note It is important the size this returns is sizeof(sockaddr_in) not + * sizeof(sockaddr), this is NOT a bug but requirement of C socket functions. + */ + [[nodiscard]] size_t size(); + + /** + * @brief Get the port bound to a file descriptor + * @param fd File descriptor + * @return Port number, or 0 if no port bound + */ + [[nodiscard]] uint16_t get_port(socket fd); +}; + +/** + * @brief Allocates a dpp::socket, closing it on destruction + */ +struct DPP_EXPORT raii_socket { + /** + * @brief File descriptor + */ + socket fd; + + /** + * @brief Construct a socket. + * Calls socket() and returns a new file descriptor + */ + raii_socket(); + + /** + * @brief Non-copyable + */ + raii_socket(raii_socket&) = delete; + + /** + * @brief Non-movable + */ + raii_socket(raii_socket&&) = delete; + + /** + * @brief Non-copyable + */ + raii_socket operator=(raii_socket&) = delete; + + /** + * @brief Non-movable + */ + raii_socket operator=(raii_socket&&) = delete; + + /** + * @brief Destructor + * Frees the socket by closing it + */ + ~raii_socket(); +}; + + +} diff --git a/src/dpp/dave/array_view.h b/src/dpp/dave/array_view.h index 98b6025b34..e1e9c18c80 100755 --- a/src/dpp/dave/array_view.h +++ b/src/dpp/dave/array_view.h @@ -45,45 +45,50 @@ template class array_view { * @param data data pointer to array * @param size size of array */ - array_view(T* data, size_t size) - : data_(data) - , size_(size) - { + array_view(T* data, size_t size) : array(data), array_size(size) { } /** * @brief Get size of view * @return size */ - size_t size() const { return size_; } + size_t size() const { + return array_size; + } /** * @brief Get data of view from first element * @return data */ - T* data() const { return data_; } + T* data() const { + return array; + } /** * @brief Get start of view, first element * @return first element */ - T* begin() const { return data_; } + T* begin() const { + return array; + } /** * @brief Get ending iterator of view, 1+last element * @return end of view */ - T* end() const { return data_ + size_; } + T* end() const { + return array + array_size; + } private: /** * @brief array data */ - T* data_ = nullptr; + T* array = nullptr; /** * @brief Array size */ - size_t size_ = 0; + size_t array_size = 0; }; /** @@ -111,5 +116,4 @@ inline array_view make_array_view(std::vector& data) return array_view(data.data(), data.size()); } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/cipher_interface.cpp b/src/dpp/dave/cipher_interface.cpp index 3dae2807cc..58784c51bc 100755 --- a/src/dpp/dave/cipher_interface.cpp +++ b/src/dpp/dave/cipher_interface.cpp @@ -28,11 +28,10 @@ namespace dpp::dave { -std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey) +std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& key) { - auto cipher = std::make_unique(cl, encryptionKey); + auto cipher = std::make_unique(cl, key); return cipher->is_valid() ? std::move(cipher) : nullptr; } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/cipher_interface.h b/src/dpp/dave/cipher_interface.h index 35b849c279..44004883e8 100755 --- a/src/dpp/dave/cipher_interface.h +++ b/src/dpp/dave/cipher_interface.h @@ -64,25 +64,25 @@ class cipher_interface { // NOLINT /** * @brief Encrypt audio or video - * @param ciphertextBufferOut Output buffer of ciphertext - * @param plaintextBuffer Input buffer for plaintext - * @param nonceBuffer Input nonce/IV - * @param additionalData Additional data for GCM AEAD encryption - * @param tagBufferOut AEAD Tag for verification + * @param ciphertext_buffer_out Output buffer of ciphertext + * @param plaintext_buffer Input buffer for plaintext + * @param nonce_buffer Input nonce/IV + * @param additional_data Additional data for GCM AEAD encryption + * @param tag_buffer_out AEAD Tag for verification * @return true if encryption succeeded, false if it failed */ - virtual bool encrypt(byte_view ciphertextBufferOut, const_byte_view plaintextBuffer, const_byte_view nonceBuffer, const_byte_view additionalData, byte_view tagBufferOut) = 0; + virtual bool encrypt(byte_view ciphertext_buffer_out, const_byte_view plaintext_buffer, const_byte_view nonce_buffer, const_byte_view additional_data, byte_view tag_buffer_out) = 0; /** * @brief Decrypt audio or video - * @param plaintextBufferOut Output buffer for plaintext - * @param ciphertextBuffer Input buffer for ciphetext - * @param tagBuffer AEAD Tag for verification - * @param nonceBuffer Nonce/IV - * @param additionalData Additional data for GCM AEAD encryption + * @param plaintext_buffer_out Output buffer for plaintext + * @param ciphertext_buffer Input buffer for ciphetext + * @param tag_buffer AEAD Tag for verification + * @param nonce_buffer Nonce/IV + * @param additional_data Additional data for GCM AEAD encryption * @return true if decryption succeeded, false if it failed */ - virtual bool decrypt(byte_view plaintextBufferOut, const_byte_view ciphertextBuffer, const_byte_view tagBuffer, const_byte_view nonceBuffer, const_byte_view additionalData) = 0; + virtual bool decrypt(byte_view plaintext_buffer_out, const_byte_view ciphertext_buffer, const_byte_view tag_buffer, const_byte_view nonce_buffer, const_byte_view additional_data) = 0; protected: @@ -94,10 +94,9 @@ class cipher_interface { // NOLINT /** * @brief Factory function to create new cipher interface of the best supported type for DAVE - * @param encryptionKey encryption key + * @param key encryption key * @return an instance of a class derived from cipher_interface */ -std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey); - -} // namespace dpp::dave +std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& key); +} diff --git a/src/dpp/dave/clock.h b/src/dpp/dave/clock.h index c8a2ab6872..db93afd914 100755 --- a/src/dpp/dave/clock.h +++ b/src/dpp/dave/clock.h @@ -74,5 +74,4 @@ class clock : public clock_interface { } }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/codec_utils.cpp b/src/dpp/dave/codec_utils.cpp index ea4269c050..3cae886066 100755 --- a/src/dpp/dave/codec_utils.cpp +++ b/src/dpp/dave/codec_utils.cpp @@ -31,98 +31,97 @@ namespace dpp::dave::codec_utils { -unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload, const uint64_t sizeRemaining) -{ +unencrypted_frame_header_size bytes_covering_h264_pps(const uint8_t* payload, const uint64_t size_remaining) { // the payload starts with three exponential golomb encoded values // (first_mb_in_slice, sps_id, pps_id) // the depacketizer needs the pps_id unencrypted // and the payload has RBSP encoding that we need to work around - constexpr uint8_t kEmulationPreventionByte = 0x03; + constexpr uint8_t emulation_prevention_byte = 0x03; - uint64_t payloadBitIndex = 0; - auto zeroBitCount = 0; - auto parsedExpGolombValues = 0; + uint64_t payload_bit_index = 0; + auto zero_bit_count = 0; + auto parsed_exp_golomb_values = 0; - while (payloadBitIndex < sizeRemaining * 8 && parsedExpGolombValues < 3) { - auto bitIndex = payloadBitIndex % 8; - auto byteIndex = payloadBitIndex / 8; - auto payloadByte = payload[byteIndex]; + while (payload_bit_index < size_remaining * 8 && parsed_exp_golomb_values < 3) { + auto bit_index = payload_bit_index % 8; + auto byte_index = payload_bit_index / 8; + auto payload_byte = payload[byte_index]; // if we're starting a new byte // check if this is an emulation prevention byte // which we skip over - if (bitIndex == 0) { - if (byteIndex >= 2 && payloadByte == kEmulationPreventionByte && - payload[byteIndex - 1] == 0 && payload[byteIndex - 2] == 0) { - payloadBitIndex += 8; + if (bit_index == 0) { + if (byte_index >= 2 && payload_byte == emulation_prevention_byte && payload[byte_index - 1] == 0 && payload[byte_index - 2] == 0) { + payload_bit_index += 8; continue; } } - if ((payloadByte & (1 << (7 - bitIndex))) == 0) { + if ((payload_byte & (1 << (7 - bit_index))) == 0) { // still in the run of leading zero bits - ++zeroBitCount; - ++payloadBitIndex; + ++zero_bit_count; + ++payload_bit_index; - if (zeroBitCount >= 32) { + if (zero_bit_count >= 32) { throw dpp::length_exception("Unexpectedly large exponential golomb encoded value"); } - } - else { + } else { // we hit a one // skip forward the number of bits dictated by the leading number of zeroes - parsedExpGolombValues += 1; - payloadBitIndex += 1 + zeroBitCount; - zeroBitCount = 0; + parsed_exp_golomb_values += 1; + payload_bit_index += 1 + zero_bit_count; + zero_bit_count = 0; } } // return the number of bytes that covers the last exp golomb encoded value - return (payloadBitIndex / 8) + 1; + auto result = (payload_bit_index / 8) + 1; + if (result > std::numeric_limits::max()) { + // bytes covering H264 PPS result cannot fit into unencrypted frame header size + return 0; + } + return static_cast(result); } -const uint8_t kH26XNaluLongStartCode[] = {0, 0, 0, 1}; -constexpr uint8_t kH26XNaluShortStartSequenceSize = 3; +const uint8_t nalu_long_start_code[] = {0, 0, 0, 1}; +constexpr uint8_t nalu_short_start_sequence_size = 3; -using IndexStartCodeSizePair = std::pair; +using index_start_code_size_pair = std::pair; -std::optional FindNextH26XNaluIndex(const uint8_t* buffer, const size_t bufferSize, const size_t searchStartIndex = 0) +std::optional next_h26x_nalu_index(const uint8_t* buffer, const size_t buffer_size, const size_t search_start_index = 0) { - constexpr uint8_t kH26XStartCodeHighestPossibleValue = 1; - constexpr uint8_t kH26XStartCodeEndByteValue = 1; - constexpr uint8_t kH26XStartCodeLeadingBytesValue = 0; + constexpr uint8_t start_code_highest_possible_value = 1; + constexpr uint8_t start_code_end_byte_value = 1; + constexpr uint8_t start_code_leading_bytes_value = 0; - if (bufferSize < kH26XNaluShortStartSequenceSize) { + if (buffer_size < nalu_short_start_sequence_size) { return std::nullopt; } // look for NAL unit 3 or 4 byte start code - for (size_t i = searchStartIndex; i < bufferSize - kH26XNaluShortStartSequenceSize;) { - if (buffer[i + 2] > kH26XStartCodeHighestPossibleValue) { + for (size_t i = search_start_index; i < buffer_size - nalu_short_start_sequence_size;) { + if (buffer[i + 2] > start_code_highest_possible_value) { // third byte is not 0 or 1, can't be a start code - i += kH26XNaluShortStartSequenceSize; - } - else if (buffer[i + 2] == kH26XStartCodeEndByteValue) { + i += nalu_short_start_sequence_size; + } else if (buffer[i + 2] == start_code_end_byte_value) { // third byte matches the start code end byte, might be a start code sequence - if (buffer[i + 1] == kH26XStartCodeLeadingBytesValue && - buffer[i] == kH26XStartCodeLeadingBytesValue) { + if (buffer[i + 1] == start_code_leading_bytes_value && buffer[i] == start_code_leading_bytes_value) { // confirmed start sequence {0, 0, 1} - auto nalUnitStartIndex = i + kH26XNaluShortStartSequenceSize; + auto nal_unit_start_index = i + nalu_short_start_sequence_size; - if (i >= 1 && buffer[i - 1] == kH26XStartCodeLeadingBytesValue) { + if (i >= 1 && buffer[i - 1] == start_code_leading_bytes_value) { // 4 byte start code - return std::optional({nalUnitStartIndex, 4}); + return std::optional({nal_unit_start_index, 4}); } else { // 3 byte start code - return std::optional({nalUnitStartIndex, 3}); + return std::optional({nal_unit_start_index, 3}); } } - i += kH26XNaluShortStartSequenceSize; - } - else { + i += nalu_short_start_sequence_size; + } else { // third byte is 0, might be a four byte start code ++i; } @@ -139,8 +138,8 @@ bool process_frame_opus(outbound_frame_processor& processor, array_view frame) { - constexpr uint8_t kVP8KeyFrameUnencryptedBytes = 10; - constexpr uint8_t kVP8DeltaFrameUnencryptedBytes = 1; + constexpr uint8_t key_frame_unencrypted_bytes = 10; + constexpr uint8_t delta_frame_unencrypted_bytes = 1; // parse the VP8 payload header to determine if it's a key frame // https://datatracker.ietf.org/doc/html/rfc7741#section-4.3 @@ -155,17 +154,15 @@ bool process_frame_vp8(outbound_frame_processor& processor, array_viewfirst < frame.size() - 1) { - auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + auto nalu_index_pair = next_h26x_nalu_index(frame.data(), frame.size()); + while (nalu_index_pair && nalu_index_pair->first < frame.size() - 1) { + auto [nal_unit_start_index, start_code_size] = *nalu_index_pair; - auto nalType = frame.data()[nalUnitStartIndex] & kH264NalHeaderTypeMask; + auto nal_type = frame.data()[nal_unit_start_index] & nal_header_type_mask; // copy the start code and then the NAL unit // Because WebRTC will convert them all start codes to 4-byte on the receiver side // always write a long start code and then the NAL unit - processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + processor.add_unencrypted_bytes(nalu_long_start_code, sizeof(nalu_long_start_code)); - auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); - auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); + auto next_nalu_index_pair = next_h26x_nalu_index(frame.data(), frame.size(), nal_unit_start_index); + auto next_nalu_start = next_nalu_index_pair.has_value() ? next_nalu_index_pair->first - next_nalu_index_pair->second : frame.size(); - if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { + if (nal_type == nal_type_slice || nal_type == nal_type_idr) { // once we've hit a slice or an IDR // we just need to cover getting to the PPS ID - auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; - auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, frame.size() - nalUnitPayloadStart); + auto nal_unit_payload_start = nal_unit_start_index + nal_unit_header_size; + auto nal_unit_pps_bytes = bytes_covering_h264_pps(frame.data() + nal_unit_payload_start, frame.size() - nal_unit_payload_start); - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH264NalUnitHeaderSize + nalUnitPPSBytes); - processor.add_encrypted_bytes( - frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes, - nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes); - } - else { + processor.add_unencrypted_bytes(frame.data() + nal_unit_start_index, nal_unit_header_size + nal_unit_pps_bytes); + processor.add_encrypted_bytes(frame.data() + nal_unit_start_index + nal_unit_header_size + nal_unit_pps_bytes, next_nalu_start - nal_unit_start_index - nal_unit_header_size - nal_unit_pps_bytes); + } else { // copy the whole NAL unit - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); + processor.add_unencrypted_bytes(frame.data() + nal_unit_start_index, next_nalu_start - nal_unit_start_index); } - - naluIndexPair = nextNaluIndexPair; + nalu_index_pair = next_nalu_index_pair; } return true; @@ -241,43 +233,43 @@ bool process_frame_h265(outbound_frame_processor& processor, array_viewfirst < frame.size() - 1) { - auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + auto nalu_index = next_h26x_nalu_index(frame.data(), frame.size()); + while (nalu_index && nalu_index->first < frame.size() - 1) { + auto [nal_unit_start_index, start_code_size] = *nalu_index; - uint8_t nalType = (frame.data()[nalUnitStartIndex] & kH265NalHeaderTypeMask) >> 1; + uint8_t nal_type = (frame.data()[nal_unit_start_index] & nal_header_type_mask) >> 1; // copy the start code and then the NAL unit // Because WebRTC will convert them all start codes to 4-byte on the receiver side // always write a long start code and then the NAL unit - processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + processor.add_unencrypted_bytes(nalu_long_start_code, sizeof(nalu_long_start_code)); - auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); - auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); + auto next_nalu_index_pair = next_h26x_nalu_index(frame.data(), frame.size(), nal_unit_start_index); + auto next_nalu_start = next_nalu_index_pair.has_value() ? next_nalu_index_pair->first - next_nalu_index_pair->second : frame.size(); - if (nalType < kH265NalTypeVclCutoff) { + if (nal_type < nal_type_vcl_cutoff) { // found a VCL NAL, encrypt the payload only - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); - processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); + processor.add_unencrypted_bytes(frame.data() + nal_unit_start_index, nal_unit_header_size); + processor.add_encrypted_bytes(frame.data() + nal_unit_start_index + nal_unit_header_size, next_nalu_start - nal_unit_start_index - nal_unit_header_size); } else { // copy the whole NAL unit - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); + processor.add_unencrypted_bytes(frame.data() + nal_unit_start_index, next_nalu_start - nal_unit_start_index); } - naluIndexPair = nextNaluIndexPair; + nalu_index = next_nalu_index_pair; } return true; @@ -285,28 +277,28 @@ bool process_frame_h265(outbound_frame_processor& processor, array_view frame) { - constexpr uint8_t kAv1ObuHeaderHasExtensionMask = 0b0'0000'100; - constexpr uint8_t kAv1ObuHeaderHasSizeMask = 0b0'0000'010; - constexpr uint8_t kAv1ObuHeaderTypeMask = 0b0'1111'000; - constexpr uint8_t kObuTypeTemporalDelimiter = 2; - constexpr uint8_t kObuTypeTileList = 8; - constexpr uint8_t kObuTypePadding = 15; - constexpr uint8_t kObuExtensionSizeBytes = 1; + constexpr uint8_t obu_header_has_extension_mask = 0b0'0000'100; + constexpr uint8_t obu_header_has_size_mask = 0b0'0000'010; + constexpr uint8_t obu_header_type_mask = 0b0'1111'000; + constexpr uint8_t obu_type_temporal_delimiter = 2; + constexpr uint8_t obu_type_tile_list = 8; + constexpr uint8_t obu_type_padding = 15; + constexpr uint8_t obu_extension_size_bytes = 1; size_t i = 0; while (i < frame.size()) { // Read the OBU header. - size_t obuHeaderIndex = i; - uint8_t obuHeader = frame.data()[obuHeaderIndex]; - i += sizeof(obuHeader); + size_t obu_header_index = i; + uint8_t obu_header = frame.data()[obu_header_index]; + i += sizeof(obu_header); - bool obuHasExtension = obuHeader & kAv1ObuHeaderHasExtensionMask; - bool obuHasSize = obuHeader & kAv1ObuHeaderHasSizeMask; - int obuType = (obuHeader & kAv1ObuHeaderTypeMask) >> 3; + bool obu_has_extension = obu_header & obu_header_has_extension_mask; + bool obu_has_size = obu_header & obu_header_has_size_mask; + int obu_type = (obu_header & obu_header_type_mask) >> 3; - if (obuHasExtension) { + if (obu_has_extension) { // Skip extension byte - i += kObuExtensionSizeBytes; + i += obu_extension_size_bytes; } if (i >= frame.size()) { @@ -314,12 +306,12 @@ bool process_frame_av1(outbound_frame_processor& processor, array_view frame.size()) { + if (i + obu_payload_size > frame.size()) { // Malformed frame throw dpp::logic_exception("Malformed AV1 frame: payload overflows frame"); } - i += obuPayloadSize; + i += obu_payload_size; // We only copy the OBUs that will not get dropped by the packetizer - if (obuType != kObuTypeTemporalDelimiter && obuType != kObuTypeTileList && - obuType != kObuTypePadding) { + if (obu_type != obu_type_temporal_delimiter && obu_type != obu_type_tile_list && obu_type != obu_type_padding) { // if this is the last OBU, we may need to flip the "has size" bit // which allows us to append necessary protocol data to the frame - bool rewrittenWithoutSize = false; + bool rewritten_without_size = false; - if (i == frame.size() && obuHasSize) { + if (i == frame.size() && obu_has_size) { // Flip the "has size" bit - obuHeader &= ~kAv1ObuHeaderHasSizeMask; - rewrittenWithoutSize = true; + obu_header &= ~obu_header_has_size_mask; + rewritten_without_size = true; } // write the OBU header unencrypted - processor.add_unencrypted_bytes(&obuHeader, sizeof(obuHeader)); - if (obuHasExtension) { + processor.add_unencrypted_bytes(&obu_header, sizeof(obu_header)); + if (obu_has_extension) { // write the extension byte unencrypted - processor.add_unencrypted_bytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), kObuExtensionSizeBytes); + processor.add_unencrypted_bytes(frame.data() + obu_header_index + sizeof(obu_header), obu_extension_size_bytes); } // write the OBU payload size unencrypted if it was present and we didn't rewrite // without it - if (obuHasSize && !rewrittenWithoutSize) { + if (obu_has_size && !rewritten_without_size) { // The AMD AV1 encoder may pad LEB128 encoded sizes with a zero byte which the // webrtc packetizer removes. To prevent the packetizer from changing the frame, // we sanitize the size by re-writing it ourselves uint8_t leb128Buffer[LEB128_MAX_SIZE]; - size_t additionalBytesToWrite = write_leb128(obuPayloadSize, leb128Buffer); + size_t additionalBytesToWrite = write_leb128(obu_payload_size, leb128Buffer); processor.add_unencrypted_bytes(leb128Buffer, additionalBytesToWrite); } // add the OBU payload, encrypted - processor.add_encrypted_bytes(frame.data() + obuPayloadIndex, obuPayloadSize); + processor.add_encrypted_bytes(frame.data() + obu_payload_index, obu_payload_size); } } @@ -386,42 +377,42 @@ bool validate_encrypted_frame(outbound_frame_processor& processor, array_view frame); -} // namespace dpp::dave::codec_utils - - +} diff --git a/src/dpp/dave/cryptor_manager.cpp b/src/dpp/dave/cryptor_manager.cpp index e5b6af321b..050d3be3cc 100755 --- a/src/dpp/dave/cryptor_manager.cpp +++ b/src/dpp/dave/cryptor_manager.cpp @@ -44,58 +44,52 @@ key_generation compute_wrapped_generation(key_generation oldest, key_generation big_nonce compute_wrapped_big_nonce(key_generation generation, truncated_sync_nonce nonce) { // Remove the generation bits from the nonce - auto maskedNonce = nonce & ((1 << RATCHET_GENERATION_SHIFT_BITS) - 1); + auto masked_nonce = nonce & ((1 << RATCHET_GENERATION_SHIFT_BITS) - 1); // Add the wrapped generation bits back in - return static_cast(generation) << RATCHET_GENERATION_SHIFT_BITS | maskedNonce; + return static_cast(generation) << RATCHET_GENERATION_SHIFT_BITS | masked_nonce; } -aead_cipher_manager::aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr keyRatchet) - : clock_(clock) - , keyRatchet_(std::move(keyRatchet)) - , ratchetCreation_(clock.now()) - , ratchetExpiry_(time_point::max()) - , creator(cl) -{ +aead_cipher_manager::aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr key_ratchet) + : current_clock(clock), current_key_ratchet(std::move(key_ratchet)), ratchet_creation(clock.now()), ratchet_expiry(time_point::max()), creator(cl) { } bool aead_cipher_manager::can_process_nonce(key_generation generation, truncated_sync_nonce nonce) const { - if (!newestProcessedNonce_) { + if (!newest_processed_nonce) { return true; } - auto bigNonce = compute_wrapped_big_nonce(generation, nonce); - return bigNonce > *newestProcessedNonce_ || - std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); + auto wrapped_big_nonce = compute_wrapped_big_nonce(generation, nonce); + return wrapped_big_nonce > *newest_processed_nonce || std::find(missing_nonces.rbegin(), missing_nonces.rend(), wrapped_big_nonce) != missing_nonces.rend(); } cipher_interface* aead_cipher_manager::get_cipher(key_generation generation) { cleanup_expired_ciphers(); - if (generation < oldestGeneration_) { - creator.log(dpp::ll_trace, "Received frame with old generation: " + std::to_string(generation) + ", oldest generation: " + std::to_string(oldestGeneration_)); + if (generation < oldest_generation) { + creator.log(dpp::ll_trace, "Received frame with old generation: " + std::to_string(generation) + ", oldest generation: " + std::to_string(oldest_generation)); return nullptr; } - if (generation > newestGeneration_ + MAX_GENERATION_GAP) { - creator.log(dpp::ll_trace, "Received frame with future generation: " + std::to_string(generation) + ", newest generation: " + std::to_string(newestGeneration_)); + if (generation > newest_generation + MAX_GENERATION_GAP) { + creator.log(dpp::ll_trace, "Received frame with future generation: " + std::to_string(generation) + ", newest generation: " + std::to_string(newest_generation)); return nullptr; } - auto ratchetLifetimeSec = - std::chrono::duration_cast(clock_.now() - ratchetCreation_).count(); - auto maxLifetimeFrames = MAX_FRAMES_PER_SECOND * ratchetLifetimeSec; - auto maxLifetimeGenerations = maxLifetimeFrames >> RATCHET_GENERATION_SHIFT_BITS; - if (generation > maxLifetimeGenerations) { - creator.log(dpp::ll_debug, "Received frame with generation " + std::to_string(generation) + " beyond ratchet max lifetime generations: " + std::to_string(maxLifetimeGenerations) + ", ratchet lifetime: " + std::to_string(ratchetLifetimeSec) + "s"); + auto ratchet_lifetime_sec = + std::chrono::duration_cast(current_clock.now() - ratchet_creation).count(); + auto max_lifetime_frames = MAX_FRAMES_PER_SECOND * ratchet_lifetime_sec; + auto max_lifetime_generations = max_lifetime_frames >> RATCHET_GENERATION_SHIFT_BITS; + if (generation > max_lifetime_generations) { + creator.log(dpp::ll_debug, "Received frame with generation " + std::to_string(generation) + " beyond ratchet max lifetime generations: " + std::to_string(max_lifetime_generations) + ", ratchet lifetime: " + std::to_string(ratchet_lifetime_sec) + "s"); return nullptr; } - auto it = cryptors_.find(generation); - if (it == cryptors_.end()) { + auto it = cryptor_generations.find(generation); + if (it == cryptor_generations.end()) { // We don't have a cryptor for this generation, create one - std::tie(it, std::ignore) = cryptors_.emplace(generation, make_expiring_cipher(generation)); + std::tie(it, std::ignore) = cryptor_generations.emplace(generation, make_expiring_cipher(generation)); } // Return a non-owning pointer to the cryptor @@ -105,95 +99,94 @@ cipher_interface* aead_cipher_manager::get_cipher(key_generation generation) void aead_cipher_manager::report_cipher_success(key_generation generation, truncated_sync_nonce nonce) { - auto bigNonce = compute_wrapped_big_nonce(generation, nonce); + auto wrapped_big_nonce = compute_wrapped_big_nonce(generation, nonce); // Add any missing nonces to the queue - if (!newestProcessedNonce_) { - newestProcessedNonce_ = bigNonce; + if (!newest_processed_nonce) { + newest_processed_nonce = wrapped_big_nonce; } - else if (bigNonce > *newestProcessedNonce_) { - auto oldestMissingNonce = bigNonce > MAX_MISSING_NONCES ? bigNonce - MAX_MISSING_NONCES : 0; + else if (wrapped_big_nonce > *newest_processed_nonce) { + auto oldest_missing_nonce = wrapped_big_nonce > MAX_MISSING_NONCES ? wrapped_big_nonce - MAX_MISSING_NONCES : 0; - while (!missingNonces_.empty() && missingNonces_.front() < oldestMissingNonce) { - missingNonces_.pop_front(); + while (!missing_nonces.empty() && missing_nonces.front() < oldest_missing_nonce) { + missing_nonces.pop_front(); } // If we're missing a lot, we don't want to add everything since newestProcessedNonce_ - auto missingRangeStart = std::max(oldestMissingNonce, *newestProcessedNonce_ + 1); - for (auto i = missingRangeStart; i < bigNonce; ++i) { - missingNonces_.push_back(i); + auto missing_range_start = std::max(oldest_missing_nonce, *newest_processed_nonce + 1); + for (auto i = missing_range_start; i < wrapped_big_nonce; ++i) { + missing_nonces.push_back(i); } // Update the newest processed nonce - newestProcessedNonce_ = bigNonce; + newest_processed_nonce = wrapped_big_nonce; } else { - auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); - if (it != missingNonces_.end()) { - missingNonces_.erase(it); + auto it = std::find(missing_nonces.begin(), missing_nonces.end(), wrapped_big_nonce); + if (it != missing_nonces.end()) { + missing_nonces.erase(it); } } - if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { + if (generation <= newest_generation || cryptor_generations.find(generation) == cryptor_generations.end()) { return; } creator.log(dpp::ll_trace, "Reporting cryptor success, generation: " + std::to_string(generation)); - newestGeneration_ = generation; + newest_generation = generation; // Update the expiry time for all old cryptors - const auto expiryTime = clock_.now() + CIPHER_EXPIRY; - for (auto& [gen, cryptor] : cryptors_) { - if (gen < newestGeneration_) { + const auto expiry_time = current_clock.now() + CIPHER_EXPIRY; + for (auto& [gen, cryptor] : cryptor_generations) { + if (gen < newest_generation) { creator.log(dpp::ll_trace, "Updating expiry for cryptor, generation: " + std::to_string(gen)); - cryptor.expiry = std::min(cryptor.expiry, expiryTime); + cryptor.expiry = std::min(cryptor.expiry, expiry_time); } } } key_generation aead_cipher_manager::compute_wrapped_generation(key_generation generation) { - return ::dpp::dave::compute_wrapped_generation(oldestGeneration_, generation); + return ::dpp::dave::compute_wrapped_generation(oldest_generation, generation); } aead_cipher_manager::expiring_cipher aead_cipher_manager::make_expiring_cipher(key_generation generation) { // Get the new key from the ratchet - auto encryptionKey = keyRatchet_->get_key(generation); - auto expiryTime = time_point::max(); + auto key = current_key_ratchet->get_key(generation); + auto expiry_time = time_point::max(); // If we got frames out of order, we might have to create a cryptor for an old generation // In that case, create it with a non-infinite expiry time as we have already transitioned // to a newer generation - if (generation < newestGeneration_) { + if (generation < newest_generation) { creator.log(dpp::ll_debug, "Creating cryptor for old generation: " + std::to_string(generation)); - expiryTime = clock_.now() + CIPHER_EXPIRY; + expiry_time = current_clock.now() + CIPHER_EXPIRY; } else { creator.log(dpp::ll_debug, "Creating cryptor for new generation: " + std::to_string(generation)); } - return {create_cipher(creator, encryptionKey), expiryTime}; + return {create_cipher(creator, key), expiry_time}; } void aead_cipher_manager::cleanup_expired_ciphers() { - for (auto it = cryptors_.begin(); it != cryptors_.end();) { + for (auto it = cryptor_generations.begin(); it != cryptor_generations.end();) { auto& [generation, cryptor] = *it; - bool expired = cryptor.expiry < clock_.now(); + bool expired = cryptor.expiry < current_clock.now(); if (expired) { creator.log(dpp::ll_trace, "Removing expired cryptor, generation: " + std::to_string(generation)); } - it = expired ? cryptors_.erase(it) : ++it; + it = expired ? cryptor_generations.erase(it) : ++it; } - while (oldestGeneration_ < newestGeneration_ && cryptors_.find(oldestGeneration_) == cryptors_.end()) { - creator.log(dpp::ll_trace, "Deleting key for old generation: " + std::to_string(oldestGeneration_)); - keyRatchet_->delete_key(oldestGeneration_); - ++oldestGeneration_; + while (oldest_generation < newest_generation && cryptor_generations.find(oldest_generation) == cryptor_generations.end()) { + creator.log(dpp::ll_trace, "Deleting key for old generation: " + std::to_string(oldest_generation)); + current_key_ratchet->delete_key(oldest_generation); + ++oldest_generation; } } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/cryptor_manager.h b/src/dpp/dave/cryptor_manager.h index 774e87f23e..731c145234 100755 --- a/src/dpp/dave/cryptor_manager.h +++ b/src/dpp/dave/cryptor_manager.h @@ -77,16 +77,16 @@ class aead_cipher_manager { * @brief Constructor * @param cl Creating cluster * @param clock chrono clock - * @param keyRatchet key ratchet for cipher + * @param key_ratchet key ratchet for cipher */ - aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr keyRatchet); + aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr key_ratchet); /** * @brief Update cipher expiry * @param expiry expiry time */ void update_expiry(time_point expiry) { - ratchetExpiry_ = expiry; + ratchet_expiry = expiry; } /** @@ -94,7 +94,7 @@ class aead_cipher_manager { * @return true if expired */ bool is_expired() const { - return clock_.now() > ratchetExpiry_; + return current_clock.now() > ratchet_expiry; } /** @@ -149,17 +149,50 @@ class aead_cipher_manager { */ void cleanup_expired_ciphers(); - const clock_interface& clock_; - std::unique_ptr keyRatchet_; - std::unordered_map cryptors_; + /** + * @brief chrono clock + */ + const clock_interface& current_clock; + + /** + * @brief key ratchet for cryptor + */ + std::unique_ptr current_key_ratchet; + + /** + * @brief Cryptor for each generation with expiry + */ + std::unordered_map cryptor_generations; + + /** + * @brief Time ratchet was created + */ + time_point ratchet_creation; + + /** + * @brief Time ratchet expired + */ + time_point ratchet_expiry; + + /** + * @brief Oldest generation for ratchet + */ + key_generation oldest_generation{0}; + + /** + * @brief Newest generation for ratchet + */ + key_generation newest_generation{0}; - time_point ratchetCreation_; - time_point ratchetExpiry_; - key_generation oldestGeneration_{0}; - key_generation newestGeneration_{0}; + /** + * @brief Newest nonce + */ + std::optional newest_processed_nonce; - std::optional newestProcessedNonce_; - std::deque missingNonces_; + /** + * @brief List of missing nonces from sequence + */ + std::deque missing_nonces; /** * @brief DPP Cluster, used for logging diff --git a/src/dpp/dave/decryptor.cpp b/src/dpp/dave/decryptor.cpp index 54bbd721d7..847b647eca 100755 --- a/src/dpp/dave/decryptor.cpp +++ b/src/dpp/dave/decryptor.cpp @@ -34,56 +34,51 @@ using namespace std::chrono_literals; namespace dpp::dave { -constexpr auto kStatsInterval = 10s; - -void decryptor::transition_to_key_ratchet(std::unique_ptr keyRatchet, duration transitionExpiry) +void decryptor::transition_to_key_ratchet(std::unique_ptr key_ratchet, duration transition_expiry) { - if (keyRatchet) { - creator.log(dpp::ll_trace, "Transitioning to new key ratchet, expiry: " + std::to_string(transitionExpiry.count())); + if (key_ratchet) { + creator.log(dpp::ll_trace, "Transitioning to new key ratchet, expiry: " + std::to_string(transition_expiry.count())); } // Update the expiry time for all existing cryptor managers - update_cryptor_manager_expiry(transitionExpiry); + update_cryptor_manager_expiry(transition_expiry); - if (keyRatchet) { - cryptorManagers_.emplace_back(creator, clock_, std::move(keyRatchet)); + if (key_ratchet) { + cryptor_managers.emplace_back(creator, current_clock, std::move(key_ratchet)); } } -void decryptor::transition_to_passthrough_mode(bool passthroughMode, duration transitionExpiry) +void decryptor::transition_to_passthrough_mode(bool passthrough_mode, duration transition_expiry) { - if (passthroughMode) { - allowPassThroughUntil_ = time_point::max(); + if (passthrough_mode) { + allow_pass_through_until = time_point::max(); } else { // Update the pass through mode expiry - auto maxExpiry = clock_.now() + transitionExpiry; - allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); + auto max_expiry = current_clock.now() + transition_expiry; + allow_pass_through_until = std::min(allow_pass_through_until, max_expiry); } } -size_t decryptor::decrypt(media_type mediaType, - array_view encryptedFrame, - array_view frame) +size_t decryptor::decrypt(media_type this_media_type, array_view encrypted_frame, array_view frame) { - if (mediaType != media_audio && mediaType != media_video) { - creator.log(dpp::ll_trace, "decrypt failed, invalid media type: " + std::to_string(static_cast(mediaType))); + if (this_media_type != media_audio && this_media_type != media_video) { + creator.log(dpp::ll_trace, "decrypt failed, invalid media type: " + std::to_string(static_cast(this_media_type))); return 0; } - auto start = clock_.now(); + auto start = current_clock.now(); - auto localFrame = get_or_create_frame_processor(); - scope_exit cleanup([&] { return_frame_processor(std::move(localFrame)); }); + auto local_frame = get_or_create_frame_processor(); + scope_exit cleanup([&] { return_frame_processor(std::move(local_frame)); }); // Skip decrypting for silence frames - if (mediaType == media_audio && encryptedFrame.size() == OPUS_SILENCE_PACKET.size() && - std::memcmp(encryptedFrame.data(), OPUS_SILENCE_PACKET.data(), OPUS_SILENCE_PACKET.size()) == 0) { - creator.log(dpp::ll_trace, "decrypt skipping silence of size: " + std::to_string(encryptedFrame.size())); - if (encryptedFrame.data() != frame.data()) { - std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); + if (this_media_type == media_audio && encrypted_frame.size() == OPUS_SILENCE_PACKET.size() && std::memcmp(encrypted_frame.data(), OPUS_SILENCE_PACKET.data(), OPUS_SILENCE_PACKET.size()) == 0) { + creator.log(dpp::ll_trace, "decrypt skipping silence of size: " + std::to_string(encrypted_frame.size())); + if (encrypted_frame.data() != frame.data()) { + std::memcpy(frame.data(), encrypted_frame.data(), encrypted_frame.size()); } - return encryptedFrame.size(); + return encrypted_frame.size(); } // Remove any expired cryptor manager @@ -92,81 +87,75 @@ size_t decryptor::decrypt(media_type mediaType, // Process the incoming frame // This will check whether it looks like a valid encrypted frame // and if so it will parse it into its different components - localFrame->parse_frame(encryptedFrame); + local_frame->parse_frame(encrypted_frame); // If the frame is not encrypted and we can pass it through, do it - bool canUsePassThrough = allowPassThroughUntil_ > start; - if (!localFrame->is_encrypted() && canUsePassThrough) { - if (encryptedFrame.data() != frame.data()) { - std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); + bool can_use_pass_through = allow_pass_through_until > start; + if (!local_frame->is_encrypted() && can_use_pass_through) { + if (encrypted_frame.data() != frame.data()) { + std::memcpy(frame.data(), encrypted_frame.data(), encrypted_frame.size()); } - stats_[mediaType].passthroughs++; - return encryptedFrame.size(); + stats[this_media_type].passthroughs++; + return encrypted_frame.size(); } - // If the frame is not encrypted and we can't pass it through, fail - if (!localFrame->is_encrypted()) { + // If the frame is not encrypted, and we can't pass it through, fail + if (!local_frame->is_encrypted()) { creator.log(dpp::ll_warning, "decrypt failed, frame is not encrypted and pass through is disabled"); - stats_[mediaType].decrypt_failure++; + stats[this_media_type].decrypt_failure++; return 0; } // Try and decrypt with each valid cryptor // reverse iterate to try the newest cryptors first bool success = false; - for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { + for (auto it = cryptor_managers.rbegin(); it != cryptor_managers.rend(); ++it) { auto& cryptorManager = *it; - success = decrypt_impl(cryptorManager, mediaType, *localFrame, frame); + success = decrypt_impl(cryptorManager, this_media_type, *local_frame, frame); if (success) { break; } } - size_t bytesWritten = 0; + size_t bytes_written = 0; if (success) { - stats_[mediaType].decrypt_success++; - bytesWritten = localFrame->reconstruct_frame(frame); + stats[this_media_type].decrypt_success++; + bytes_written = local_frame->reconstruct_frame(frame); } else { - stats_[mediaType].decrypt_failure++; - creator.log(dpp::ll_warning, "decrypt failed, no valid cryptor found, type: " + std::string(mediaType ? "video" : "audio") + - ", encrypted frame size: " + std::to_string(encryptedFrame.size()) + + stats[this_media_type].decrypt_failure++; + creator.log(dpp::ll_warning, + "decrypt failed, no valid cryptor found, type: " + std::string(this_media_type ? "video" : "audio") + + ", encrypted frame size: " + std::to_string(encrypted_frame.size()) + ", plaintext frame size: " + std::to_string(frame.size()) + - ", number of cryptor managers: " + std::to_string(cryptorManagers_.size()) + - ", pass through enabled: " + std::string(canUsePassThrough ? "yes" : "no") + ", number of cryptor managers: " + std::to_string(cryptor_managers.size()) + + ", pass through enabled: " + std::string(can_use_pass_through ? "yes" : "no") ); } - auto end = clock_.now(); - stats_[mediaType].decrypt_duration += std::chrono::duration_cast(end - start).count(); + auto end = current_clock.now(); + stats[this_media_type].decrypt_duration += std::chrono::duration_cast(end - start).count(); - return bytesWritten; + return bytes_written; } -bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, - media_type mediaType, - inbound_frame_processor& encryptedFrame, - array_view frame) +bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, media_type this_media_type, inbound_frame_processor& encrypted_frame, array_view frame) { - auto tag = encryptedFrame.get_tag(); - auto truncatedNonce = encryptedFrame.get_truncated_nonce(); - - auto authenticatedData = encryptedFrame.get_authenticated_data(); - auto ciphertext = encryptedFrame.get_ciphertext(); - auto plaintext = encryptedFrame.get_plaintext(); + auto tag = encrypted_frame.get_tag(); + auto truncated_nonce = encrypted_frame.get_truncated_nonce(); + auto authenticated_data = encrypted_frame.get_authenticated_data(); + auto ciphertext_buffer = encrypted_frame.get_ciphertext(); + auto plaintext = encrypted_frame.get_plaintext(); // expand the truncated nonce to the full sized one needed for decryption - auto nonceBuffer = std::array(); - memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, - &truncatedNonce, - AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); + auto nonce_buffer = std::array(); + memcpy(nonce_buffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, &truncated_nonce, AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); - auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); + auto nonce_buffer_view = make_array_view(nonce_buffer.data(), nonce_buffer.size()); - auto generation = - cipher_manager.compute_wrapped_generation(truncatedNonce >> RATCHET_GENERATION_SHIFT_BITS); + auto generation = cipher_manager.compute_wrapped_generation(truncated_nonce >> RATCHET_GENERATION_SHIFT_BITS); - if (!cipher_manager.can_process_nonce(generation, truncatedNonce)) { + if (!cipher_manager.can_process_nonce(generation, truncated_nonce)) { creator.log(dpp::ll_trace, "decrypt failed, cannot process nonce"); return false; } @@ -180,53 +169,52 @@ bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, } // perform the decryption - bool success = cipher->decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); - stats_[mediaType].decrypt_attempts++; + bool success = cipher->decrypt(plaintext, ciphertext_buffer, tag, nonce_buffer_view, authenticated_data); + stats[this_media_type].decrypt_attempts++; if (success) { - cipher_manager.report_cipher_success(generation, truncatedNonce); + cipher_manager.report_cipher_success(generation, truncated_nonce); } return success; } -size_t decryptor::get_max_plaintext_byte_size(media_type mediaType, size_t encryptedFrameSize) +size_t decryptor::get_max_plaintext_byte_size(media_type this_media_type, size_t encrypted_frame_size) { - return encryptedFrameSize; + return encrypted_frame_size; } void decryptor::update_cryptor_manager_expiry(duration expiry) { - auto maxExpiryTime = clock_.now() + expiry; - for (auto& cryptorManager : cryptorManagers_) { - cryptorManager.update_expiry(maxExpiryTime); + auto max_expiry_time = current_clock.now() + expiry; + for (auto& cryptorManager : cryptor_managers) { + cryptorManager.update_expiry(max_expiry_time); } } void decryptor::cleanup_expired_cryptor_managers() { - while (!cryptorManagers_.empty() && cryptorManagers_.front().is_expired()) { + while (!cryptor_managers.empty() && cryptor_managers.front().is_expired()) { creator.log(dpp::ll_trace, "Removing expired cryptor manager"); - cryptorManagers_.pop_front(); + cryptor_managers.pop_front(); } } std::unique_ptr decryptor::get_or_create_frame_processor() { - std::lock_guard lock(frameProcessorsMutex_); - if (frameProcessors_.empty()) { + std::lock_guard lock(frame_processors_mutex); + if (frame_processors.empty()) { return std::make_unique(creator); } - auto frameProcessor = std::move(frameProcessors_.back()); - frameProcessors_.pop_back(); - return frameProcessor; + auto frame_processor = std::move(frame_processors.back()); + frame_processors.pop_back(); + return frame_processor; } -void decryptor::return_frame_processor(std::unique_ptr frameProcessor) +void decryptor::return_frame_processor(std::unique_ptr frame_processor) { - std::lock_guard lock(frameProcessorsMutex_); - frameProcessors_.push_back(std::move(frameProcessor)); + std::lock_guard lock(frame_processors_mutex); + frame_processors.push_back(std::move(frame_processor)); } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/decryptor.h b/src/dpp/dave/decryptor.h index 804c41c826..e093dc46fb 100755 --- a/src/dpp/dave/decryptor.h +++ b/src/dpp/dave/decryptor.h @@ -94,50 +94,48 @@ class decryptor { * of the session. Once you have a key ratchet, you can derive the key, and decrypt that * user's audio/video. * - * @param keyRatchet Key ratchet - * @param transitionExpiry Transition expiry. Old keys last this long before being withdrawn + * @param key_ratchet Key ratchet + * @param transition_expiry Transition expiry. Old keys last this long before being withdrawn * in preference of this new one. */ - void transition_to_key_ratchet(std::unique_ptr keyRatchet, - duration transitionExpiry = DEFAULT_TRANSITION_EXPIRY); + void transition_to_key_ratchet(std::unique_ptr key_ratchet, duration transition_expiry = DEFAULT_TRANSITION_EXPIRY); /** * @brief Transition to passthrough mode * * Passthrough mode occurs when a non-DAVE user connects to the VC. * - * @param passthroughMode True to enable passthrough mode - * @param transitionExpiry Expiry for the transition + * @param passthrough_mode True to enable passthrough mode + * @param transition_expiry Expiry for the transition */ - void transition_to_passthrough_mode(bool passthroughMode, - duration transitionExpiry = DEFAULT_TRANSITION_EXPIRY); + void transition_to_passthrough_mode(bool passthrough_mode, duration transition_expiry = DEFAULT_TRANSITION_EXPIRY); /** * @brief Decrypt a frame * - * @param mediaType type of media, audio or video - * @param encryptedFrame encrypted frame bytes + * @param this_media_type type of media, audio or video + * @param encrypted_frame encrypted frame bytes * @param frame plaintext output * @return size of decrypted frame, or 0 if failure */ - size_t decrypt(media_type mediaType, - array_view encryptedFrame, - array_view frame); + size_t decrypt(media_type this_media_type, array_view encrypted_frame, array_view frame); /** * @brief Get maximum possible decrypted size of frame from an encrypted frame - * @param mediaType type of media - * @param encryptedFrameSize encrypted frame size + * @param this_media_type type of media + * @param encrypted_frame_size encrypted frame size * @return size of plaintext buffer required */ - size_t get_max_plaintext_byte_size(media_type mediaType, size_t encryptedFrameSize); + size_t get_max_plaintext_byte_size(media_type this_media_type, size_t encrypted_frame_size); /** * @brief Get decryption stats - * @param mediaType media type, audio or video + * @param this_media_type media type, audio or video * @return decryption stats */ - decryption_stats get_stats(media_type mediaType) const { return stats_[mediaType]; } + decryption_stats get_stats(media_type this_media_type) const { + return stats[this_media_type]; + } private: /** @@ -149,12 +147,12 @@ class decryptor { * @brief Decryption implementation * * @param cipher_manager cipher manager - * @param mediaType media time, audio or video - * @param encryptedFrame encrypted frame data + * @param this_media_type media type, audio or video + * @param encrypted_frame encrypted frame data * @param frame decrypted frame data * @return True if decryption succeeded */ - bool decrypt_impl(aead_cipher_manager& cipher_manager, media_type mediaType, inbound_frame_processor& encryptedFrame, array_view frame); + bool decrypt_impl(aead_cipher_manager& cipher_manager, media_type this_media_type, inbound_frame_processor& encrypted_frame, array_view frame); /** * @brief Update expiry for an instance of the manager @@ -175,20 +173,44 @@ class decryptor { /** * Return frame processor - * @param frameProcessor frame processor + * @param frame_processor frame processor */ - void return_frame_processor(std::unique_ptr frameProcessor); + void return_frame_processor(std::unique_ptr frame_processor); - clock clock_; - std::deque cryptorManagers_; + /** + * @brief Chrono clock + */ + clock current_clock; - std::mutex frameProcessorsMutex_; - std::vector> frameProcessors_; + /** + * @brief Cryptor manager list + */ + std::deque cryptor_managers; - time_point allowPassThroughUntil_{time_point::min()}; + /** + * @brief Mutex for thread safety of frame processor list + */ + std::mutex frame_processors_mutex; - time_point lastStatsTime_{time_point::min()}; - std::array stats_; + /** + * @brief List of frame processors + */ + std::vector> frame_processors; + + /** + * @brief Passthrough expiry time + */ + time_point allow_pass_through_until{time_point::min()}; + + /** + * @brief Last stats generation time + */ + time_point last_stats_time{time_point::min()}; + + /** + * @brief Stats for audio and video decryption + */ + std::array stats; /** * @brief DPP Cluster, used for logging @@ -196,5 +218,4 @@ class decryptor { dpp::cluster& creator; }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/encryptor.cpp b/src/dpp/dave/encryptor.cpp index 879976942f..782d02984a 100755 --- a/src/dpp/dave/encryptor.cpp +++ b/src/dpp/dave/encryptor.cpp @@ -39,39 +39,39 @@ using namespace std::chrono_literals; namespace dpp::dave { -void encryptor::set_key_ratchet(std::unique_ptr keyRatchet) +void encryptor::set_key_ratchet(std::unique_ptr key_ratchet) { - std::lock_guard lock(keyGenMutex_); - keyRatchet_ = std::move(keyRatchet); - cryptor_ = nullptr; - currentKeyGeneration_ = 0; - truncatedNonce_ = 0; + std::lock_guard lock(key_gen_mutex); + ratchet = std::move(key_ratchet); + cryptor = nullptr; + current_key_generation = 0; + truncated_nonce = 0; } -void encryptor::set_passthrough_mode(bool passthroughMode) +void encryptor::set_passthrough_mode(bool passthrough_mode) { - passthroughMode_ = passthroughMode; - update_current_protocol_version(passthroughMode ? 0 : max_protocol_version()); + passthrough_mode_enable = passthrough_mode; + update_current_protocol_version(passthrough_mode ? 0 : max_protocol_version()); } -encryptor::result_code encryptor::encrypt(media_type mediaType, uint32_t ssrc, array_view frame, array_view encryptedFrame, size_t* bytesWritten) { - if (mediaType != media_audio && mediaType != media_video) { - creator.log(dpp::ll_warning, "encrypt failed, invalid media type: " + std::to_string(static_cast(mediaType))); +encryptor::result_code encryptor::encrypt(media_type this_media_type, uint32_t ssrc, array_view frame, array_view encrypted_frame, size_t* bytes_written) { + if (this_media_type != media_audio && this_media_type != media_video) { + creator.log(dpp::ll_warning, "encrypt failed, invalid media type: " + std::to_string(static_cast(this_media_type))); return result_code::rc_encryption_failure; } - if (passthroughMode_) { + if (passthrough_mode_enable) { // Pass frame through without encrypting - std::memcpy(encryptedFrame.data(), frame.data(), frame.size()); - *bytesWritten = frame.size(); - stats_[mediaType].passthroughs++; + std::memcpy(encrypted_frame.data(), frame.data(), frame.size()); + *bytes_written = frame.size(); + stats[this_media_type].passthroughs++; return result_code::rc_success; } { - std::lock_guard lock(keyGenMutex_); - if (!keyRatchet_) { - stats_[mediaType].encrypt_failure++; + std::lock_guard lock(key_gen_mutex); + if (!ratchet) { + stats[this_media_type].encrypt_failure++; return result_code::rc_encryption_failure; } } @@ -82,27 +82,27 @@ encryptor::result_code encryptor::encrypt(media_type mediaType, uint32_t ssrc, a // write the codec identifier auto codec = codec_for_ssrc(ssrc); - auto frameProcessor = get_or_create_frame_processor(); - scope_exit cleanup([&] { return_frame_processor(std::move(frameProcessor)); }); + auto frame_processor = get_or_create_frame_processor(); + scope_exit cleanup([&] { return_frame_processor(std::move(frame_processor)); }); - frameProcessor->process_frame(frame, codec); + frame_processor->process_frame(frame, codec); - const auto& unencryptedBytes = frameProcessor->get_unencrypted_bytes(); - const auto& encryptedBytes = frameProcessor->get_encrypted_bytes(); - auto& ciphertextBytes = frameProcessor->get_ciphertext_bytes(); + const auto& unencrypted_bytes = frame_processor->get_unencrypted_bytes(); + const auto& encrypted_bytes = frame_processor->get_encrypted_bytes(); + auto& ciphertext_bytes = frame_processor->get_ciphertext_bytes(); - const auto& unencryptedRanges = frameProcessor->get_unencrypted_ranges(); - auto unencryptedRangesSize = unencrypted_ranges_size(unencryptedRanges); + const auto& unencrypted_ranges = frame_processor->get_unencrypted_ranges(); + auto ranges_size = unencrypted_ranges_size(unencrypted_ranges); - auto additionalData = make_array_view(unencryptedBytes.data(), unencryptedBytes.size()); - auto plaintextBuffer = make_array_view(encryptedBytes.data(), encryptedBytes.size()); - auto ciphertextBuffer = make_array_view(ciphertextBytes.data(), ciphertextBytes.size()); + auto additional_data = make_array_view(unencrypted_bytes.data(), unencrypted_bytes.size()); + auto plaintext_buffer = make_array_view(encrypted_bytes.data(), encrypted_bytes.size()); + auto ciphertext_buffer = make_array_view(ciphertext_bytes.data(), ciphertext_bytes.size()); - auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); - auto tagBuffer = make_array_view(encryptedFrame.data() + frameSize, AES_GCM_127_TRUNCATED_TAG_BYTES); + auto frame_size = encrypted_bytes.size() + unencrypted_bytes.size(); + auto tag_buffer = make_array_view(encrypted_frame.data() + frame_size, AES_GCM_127_TRUNCATED_TAG_BYTES); - auto nonceBuffer = std::array(); - auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); + auto nonce_buffer = std::array(); + auto nonce_buffer_view = make_array_view(nonce_buffer.data(), nonce_buffer.size()); constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; @@ -117,72 +117,70 @@ encryptor::result_code encryptor::encrypt(media_type mediaType, uint32_t ssrc, a // which can remove start codes from the last 1 or 2 bytes of the nonce // and the two bytes of the unencrypted header bytes for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { - auto [cryptor, truncatedNonce] = get_next_cryptor_and_nonce(); + auto [curr_cryptor, truncatedNonce] = get_next_cryptor_and_nonce(); - if (!cryptor) { + if (!curr_cryptor) { result = result_code::rc_encryption_failure; break; } // write the truncated nonce to our temporary full nonce array // (since the encryption call expects a full size nonce) - std::memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, - &truncatedNonce, - AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); + std::memcpy(nonce_buffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, &truncatedNonce, AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); // encrypt the plaintext, adding the unencrypted header to the tag - bool success = cryptor->encrypt( - ciphertextBuffer, plaintextBuffer, nonceBufferView, additionalData, tagBuffer); + bool success = curr_cryptor->encrypt(ciphertext_buffer, plaintext_buffer, nonce_buffer_view, additional_data, tag_buffer); - stats_[mediaType].encrypt_attempts++; - stats_[mediaType].encrypt_max_attempts = - std::max(stats_[mediaType].encrypt_max_attempts, (uint64_t)attempt); + stats[this_media_type].encrypt_attempts++; + stats[this_media_type].encrypt_max_attempts = + std::max(stats[this_media_type].encrypt_max_attempts, (uint64_t)attempt); if (!success) { result = result_code::rc_encryption_failure; break; } - auto reconstructedFrameSize = frameProcessor->reconstruct_frame(encryptedFrame); + auto reconstructed_frame_size = frame_processor->reconstruct_frame(encrypted_frame); - auto nonceSize = leb128_size(truncatedNonce); + auto size = leb128_size(truncatedNonce); - auto truncatedNonceBuffer = make_array_view(tagBuffer.end(), nonceSize); - auto unencryptedRangesBuffer = - make_array_view(truncatedNonceBuffer.end(), unencryptedRangesSize); - auto supplementalBytesBuffer = - make_array_view(unencryptedRangesBuffer.end(), sizeof(supplemental_bytes_size)); - auto markerBytesBuffer = make_array_view(supplementalBytesBuffer.end(), sizeof(magic_marker)); + auto truncated_nonce_buffer = make_array_view(tag_buffer.end(), size); + auto unencrypted_ranges_buffer = make_array_view(truncated_nonce_buffer.end(), ranges_size); + auto supplemental_bytes_buffer = make_array_view(unencrypted_ranges_buffer.end(), sizeof(supplemental_bytes_size)); + auto marker_bytes_buffer = make_array_view(supplemental_bytes_buffer.end(), sizeof(magic_marker)); // write the nonce - auto res = write_leb128(truncatedNonce, truncatedNonceBuffer.begin()); - if (res != nonceSize) { + auto res = write_leb128(truncatedNonce, truncated_nonce_buffer.begin()); + if (res != size) { result = result_code::rc_encryption_failure; break; } // write the unencrypted ranges - res = serialize_unencrypted_ranges( - unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); - if (res != unencryptedRangesSize) { + res = serialize_unencrypted_ranges(unencrypted_ranges, unencrypted_ranges_buffer.begin(), unencrypted_ranges_buffer.size()); + if (res != ranges_size) { result = result_code::rc_encryption_failure; break; } // write the supplemental bytes size - supplemental_bytes_size supplementalBytes = - SUPPLEMENTAL_BYTES + nonceSize + unencryptedRangesSize; - std::memcpy(supplementalBytesBuffer.data(), &supplementalBytes, sizeof(supplemental_bytes_size)); + uint64_t supplemental_bytes_large = SUPPLEMENTAL_BYTES + size + ranges_size; + + if (supplemental_bytes_large > std::numeric_limits::max()) { + result = rc_encryption_failure; + break; + } + + supplemental_bytes_size supplemental_bytes = supplemental_bytes_large; + std::memcpy(supplemental_bytes_buffer.data(), &supplemental_bytes, sizeof(supplemental_bytes_size)); // write the marker bytes, ends the frame - std::memcpy(markerBytesBuffer.data(), &MARKER_BYTES, sizeof(magic_marker)); + std::memcpy(marker_bytes_buffer.data(), &MARKER_BYTES, sizeof(magic_marker)); - auto encryptedFrameBytes = reconstructedFrameSize + AES_GCM_127_TRUNCATED_TAG_BYTES + - nonceSize + unencryptedRangesSize + sizeof(supplemental_bytes_size) + sizeof(magic_marker); + auto encrypted_frame_bytes = reconstructed_frame_size + AES_GCM_127_TRUNCATED_TAG_BYTES + size + ranges_size + sizeof(supplemental_bytes_size) + sizeof(magic_marker); - if (codec_utils::validate_encrypted_frame( - *frameProcessor, make_array_view(encryptedFrame.data(), encryptedFrameBytes))) { - *bytesWritten = encryptedFrameBytes; + if (codec_utils::validate_encrypted_frame(*frame_processor, make_array_view(encrypted_frame.data(), encrypted_frame_bytes))) { + *bytes_written = encrypted_frame_bytes; break; } else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { @@ -192,47 +190,48 @@ encryptor::result_code encryptor::encrypt(media_type mediaType, uint32_t ssrc, a } auto now = std::chrono::steady_clock::now(); - stats_[mediaType].encrypt_duration += - std::chrono::duration_cast(now - start).count(); + stats[this_media_type].encrypt_duration += std::chrono::duration_cast(now - start).count(); if (result == result_code::rc_success) { - stats_[mediaType].encrypt_success++; + stats[this_media_type].encrypt_success++; } else { - stats_[mediaType].encrypt_failure++; + stats[this_media_type].encrypt_failure++; } return result; } -size_t encryptor::get_max_ciphertext_byte_size(media_type mediaType, size_t frameSize) +size_t encryptor::get_max_ciphertext_byte_size(media_type this_media_type, size_t frame_size) { - return frameSize + SUPPLEMENTAL_BYTES + TRANSFORM_PADDING_BYTES; + return frame_size + SUPPLEMENTAL_BYTES + TRANSFORM_PADDING_BYTES; } -void encryptor::assign_ssrc_to_codec(uint32_t ssrc, codec codecType) +void encryptor::assign_ssrc_to_codec(uint32_t ssrc, codec codec_type) { - auto existingCodecIt = std::find_if( - ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { - return pair.first == ssrc; - }); + auto existing_codec_it = std::find_if( + ssrc_codec_pairs.begin(), ssrc_codec_pairs.end(), [ssrc](const ssrc_codec_pair& pair) { + return pair.first == ssrc; + } + ); - if (existingCodecIt == ssrcCodecPairs_.end()) { - ssrcCodecPairs_.emplace_back(ssrc, codecType); + if (existing_codec_it == ssrc_codec_pairs.end()) { + ssrc_codec_pairs.emplace_back(ssrc, codec_type); } else { - existingCodecIt->second = codecType; + existing_codec_it->second = codec_type; } } codec encryptor::codec_for_ssrc(uint32_t ssrc) { - auto existingCodecIt = std::find_if( - ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { - return pair.first == ssrc; - }); + auto existing_codec_it = std::find_if( + ssrc_codec_pairs.begin(), ssrc_codec_pairs.end(), [ssrc](const ssrc_codec_pair& pair) { + return pair.first == ssrc; + } + ); - if (existingCodecIt != ssrcCodecPairs_.end()) { - return existingCodecIt->second; + if (existing_codec_it != ssrc_codec_pairs.end()) { + return existing_codec_it->second; } else { return codec::cd_opus; @@ -241,51 +240,50 @@ codec encryptor::codec_for_ssrc(uint32_t ssrc) std::unique_ptr encryptor::get_or_create_frame_processor() { - std::lock_guard lock(frameProcessorsMutex_); - if (frameProcessors_.empty()) { + std::lock_guard lock(frame_processors_mutex); + if (frame_processors.empty()) { return std::make_unique(creator); } - auto frameProcessor = std::move(frameProcessors_.back()); - frameProcessors_.pop_back(); - return frameProcessor; + auto frame_processor = std::move(frame_processors.back()); + frame_processors.pop_back(); + return frame_processor; } void encryptor::return_frame_processor(std::unique_ptr frameProcessor) { - std::lock_guard lock(frameProcessorsMutex_); - frameProcessors_.push_back(std::move(frameProcessor)); + std::lock_guard lock(frame_processors_mutex); + frame_processors.push_back(std::move(frameProcessor)); } encryptor::cryptor_and_nonce encryptor::get_next_cryptor_and_nonce() { - std::lock_guard lock(keyGenMutex_); - if (!keyRatchet_) { + std::lock_guard lock(key_gen_mutex); + if (!ratchet) { return {nullptr, 0}; } - auto generation = compute_wrapped_generation(currentKeyGeneration_, ++truncatedNonce_ >> RATCHET_GENERATION_SHIFT_BITS); + auto generation = compute_wrapped_generation(current_key_generation, ++truncated_nonce >> RATCHET_GENERATION_SHIFT_BITS); - if (generation != currentKeyGeneration_ || !cryptor_) { - currentKeyGeneration_ = generation; + if (generation != current_key_generation || !cryptor) { + current_key_generation = generation; - auto encryptionKey = keyRatchet_->get_key(currentKeyGeneration_); - cryptor_ = create_cipher(creator, encryptionKey); + auto key = ratchet->get_key(current_key_generation); + cryptor = create_cipher(creator, key); } - return {cryptor_, truncatedNonce_}; + return {cryptor, truncated_nonce}; } void encryptor::update_current_protocol_version(protocol_version version) { - if (version == currentProtocolVersion_) { + if (version == current_protocol_version) { return; } - currentProtocolVersion_ = version; - if (protocolVersionChangedCallback_) { - protocolVersionChangedCallback_(); + current_protocol_version = version; + if (changed_callback) { + changed_callback(); } } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/encryptor.h b/src/dpp/dave/encryptor.h index 978238d744..6b327c83ff 100755 --- a/src/dpp/dave/encryptor.h +++ b/src/dpp/dave/encryptor.h @@ -77,9 +77,9 @@ struct encryption_stats { class encryptor { public: /** - * @brief Constructor - * @param cl Creator - */ + * @brief Constructor + * @param cl Creator + */ encryptor(dpp::cluster& cl) : creator(cl) { }; /** @@ -98,22 +98,22 @@ class encryptor { /** * @brief Set key ratchet for encryptor, this should be the bot's ratchet. - * @param keyRatchet Bot's key ratchet + * @param key_ratchet Bot's key ratchet */ - void set_key_ratchet(std::unique_ptr keyRatchet); + void set_key_ratchet(std::unique_ptr key_ratchet); /** * @brief Set encryption to passthrough mode - * @param passthroughMode true to enable passthrough mode, false to disable + * @param passthrough_mode true to enable passthrough mode, false to disable */ - void set_passthrough_mode(bool passthroughMode); + void set_passthrough_mode(bool passthrough_mode); /** * @brief True if key ratchet assigned * @return key ratchet is assigned */ bool has_key_ratchet() const { - return keyRatchet_ != nullptr; + return ratchet != nullptr; } /** @@ -121,16 +121,16 @@ class encryptor { * @return is in passthrough mode */ bool is_passthrough_mode() const { - return passthroughMode_; + return passthrough_mode_enable; } /** * @brief Assign SSRC to codec * @note This is unused - all SSRC are assumed to be OPUS for bots at present. * @param ssrc RTP SSRC - * @param codecType Codec type + * @param codec_type Codec type */ - void assign_ssrc_to_codec(uint32_t ssrc, codec codecType); + void assign_ssrc_to_codec(uint32_t ssrc, codec codec_type); /** * @brief Get codec for RTP SSRC @@ -142,34 +142,30 @@ class encryptor { /** * @brief Encrypt plaintext opus frames - * @param mediaType media type, should always be audio + * @param this_media_type media type, should always be audio * @param ssrc RTP SSRC * @param frame Frame plaintext - * @param encryptedFrame Encrypted frame - * @param bytesWritten Number of bytes written to the encrypted buffer + * @param encrypted_frame Encrypted frame + * @param bytes_written Number of bytes written to the encrypted buffer * @return Status code for encryption */ - encryptor::result_code encrypt(media_type mediaType, - uint32_t ssrc, - array_view frame, - array_view encryptedFrame, - size_t* bytesWritten); + encryptor::result_code encrypt(media_type this_media_type, uint32_t ssrc, array_view frame, array_view encrypted_frame, size_t* bytes_written); /** * @brief Get maximum possible ciphertext size for a plaintext buffer - * @param mediaType media type, should always be audio for bots - * @param frameSize frame size of plaintext buffer + * @param this_media_type media type, should always be audio for bots + * @param frame_size frame size of plaintext buffer * @return size of ciphertext buffer to allocate */ - size_t get_max_ciphertext_byte_size(media_type mediaType, size_t frameSize); + size_t get_max_ciphertext_byte_size(media_type this_media_type, size_t frame_size); /** * @brief Get encryption stats - * @param mediaType media type + * @param this_media_type media type * @return encryption stats */ - encryption_stats get_stats(media_type mediaType) const { - return stats_[mediaType]; + encryption_stats get_stats(media_type this_media_type) const { + return stats[this_media_type]; } /** @@ -182,7 +178,7 @@ class encryptor { * @param callback Callback to set */ void set_protocol_version_changed_callback(protocol_version_changed_callback callback) { - protocolVersionChangedCallback_ = std::move(callback); + changed_callback = std::move(callback); } /** @@ -190,7 +186,7 @@ class encryptor { * @return protocol version */ protocol_version get_protocol_version() const { - return currentProtocolVersion_; + return current_protocol_version; } private: @@ -223,26 +219,80 @@ class encryptor { */ void update_current_protocol_version(protocol_version version); - std::atomic_bool passthroughMode_{false}; + /** + * @brief True if passthrough is enabled + */ + std::atomic_bool passthrough_mode_enable{false}; - std::mutex keyGenMutex_; - std::unique_ptr keyRatchet_; - std::shared_ptr cryptor_; - key_generation currentKeyGeneration_{0}; - truncated_sync_nonce truncatedNonce_{0}; + /** + * @brief Key generation mutex for thread safety + */ + std::mutex key_gen_mutex; + + /** + * @brief Current encryption (send) ratchet + */ + std::unique_ptr ratchet; + + /** + * @brief Current encryption cipher + */ + std::shared_ptr cryptor; - std::mutex frameProcessorsMutex_; - std::vector> frameProcessors_; + /** + * @brief Current key generation number + */ + key_generation current_key_generation{0}; - using SsrcCodecPair = std::pair; - std::vector ssrcCodecPairs_; + /** + * @brief Current truncated sync nonce + */ + truncated_sync_nonce truncated_nonce{0}; - using TimePoint = std::chrono::time_point; - TimePoint lastStatsTime_{TimePoint::min()}; - std::array stats_; + /** + * @brief Frame processor list mutex + */ + std::mutex frame_processors_mutex; - protocol_version_changed_callback protocolVersionChangedCallback_; - protocol_version currentProtocolVersion_{max_protocol_version()}; + /** + * @brief List of outbound frame processors + */ + std::vector> frame_processors; + + /** + * @brief A pair of 32 bit SSRC and codec in use for that SSRC + */ + using ssrc_codec_pair = std::pair; + + /** + * @brief List of codec pairs for SSRCs + */ + std::vector ssrc_codec_pairs; + + /** + * @brief A chrono time point + */ + using time_point = std::chrono::time_point; + + /** + * @brief Last time stats were updated + */ + time_point last_stats_time{time_point::min()}; + + /** + * @brief Stores audio/video encryption stats + */ + std::array stats; + + /** + * @brief Callback for version change, if any + */ + protocol_version_changed_callback changed_callback; + + /** + * Current protocol version supported + */ + protocol_version current_protocol_version{max_protocol_version()}; /** * @brief DPP Cluster, used for logging @@ -250,5 +300,4 @@ class encryptor { dpp::cluster& creator; }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/frame_processors.cpp b/src/dpp/dave/frame_processors.cpp index 15177a877e..a8b7aa185e 100755 --- a/src/dpp/dave/frame_processors.cpp +++ b/src/dpp/dave/frame_processors.cpp @@ -59,7 +59,7 @@ namespace dpp::dave { } #endif -std::pair OverflowAdd(size_t a, size_t b) +std::pair overflow_add(size_t a, size_t b) { size_t res; #if defined(_MSC_VER) && defined(_M_X64) @@ -74,78 +74,74 @@ std::pair OverflowAdd(size_t a, size_t b) return {didOverflow, res}; } -uint8_t unencrypted_ranges_size(const ranges& unencryptedRanges) +uint8_t unencrypted_ranges_size(const ranges& unencrypted_ranges) { size_t size = 0; - for (const auto& range : unencryptedRanges) { + for (const auto& range : unencrypted_ranges) { size += leb128_size(range.offset); size += leb128_size(range.size); } return static_cast(size); } -uint8_t serialize_unencrypted_ranges(const ranges& unencryptedRanges, - uint8_t* buffer, - size_t bufferSize) +uint8_t serialize_unencrypted_ranges(const ranges& unencrypted_ranges, uint8_t* buffer, size_t buffer_size) { - auto writeAt = buffer; - auto end = buffer + bufferSize; - for (const auto& range : unencryptedRanges) { - auto rangeSize = leb128_size(range.offset) + leb128_size(range.size); - if (rangeSize > static_cast(end - writeAt)) { + auto write_at = buffer; + auto end = buffer + buffer_size; + for (const auto& range : unencrypted_ranges) { + auto range_size = leb128_size(range.offset) + leb128_size(range.size); + if (range_size > static_cast(end - write_at)) { break; } - writeAt += write_leb128(range.offset, writeAt); - writeAt += write_leb128(range.size, writeAt); + write_at += write_leb128(range.offset, write_at); + write_at += write_leb128(range.size, write_at); } - return writeAt - buffer; + return static_cast(write_at - buffer); } -uint8_t deserialize_unencrypted_ranges(const uint8_t*& readAt, - const size_t bufferSize, - ranges& unencryptedRanges) +uint8_t deserialize_unencrypted_ranges(const uint8_t*& read_at, const uint8_t buffer_size, ranges& unencrypted_ranges) { - auto start = readAt; - auto end = readAt + bufferSize; - while (readAt < end) { - size_t offset = read_leb128(readAt, end); - if (readAt == nullptr) { + auto start = read_at; + auto end = read_at + buffer_size; + while (read_at < end) { + size_t offset = read_leb128(read_at, end); + if (read_at == nullptr) { break; } - size_t size = read_leb128(readAt, end); - if (readAt == nullptr) { + size_t size = read_leb128(read_at, end); + if (read_at == nullptr) { break; } - unencryptedRanges.push_back({offset, size}); + unencrypted_ranges.push_back({offset, size}); } - if (readAt != end) { - unencryptedRanges.clear(); - readAt = nullptr; + if (read_at != end) { + unencrypted_ranges.clear(); + read_at = nullptr; return 0; } - return readAt - start; + return static_cast(read_at - start); } -bool validate_unencrypted_ranges(const ranges& unencryptedRanges, size_t frameSize) +bool validate_unencrypted_ranges(const ranges& unencrypted_ranges, size_t frame_size) { - if (unencryptedRanges.empty()) { + if (unencrypted_ranges.empty()) { return true; } // validate that the ranges are in order and don't overlap - for (auto i = 0u; i < unencryptedRanges.size(); ++i) { - auto current = unencryptedRanges[i]; + for (auto i = 0u; i < unencrypted_ranges.size(); ++i) { + auto current = unencrypted_ranges[i]; // The current range should not overflow into the next range // or if it is the last range, the end of the frame - auto maxEnd = - i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; + auto max_end = + i + 1 < unencrypted_ranges.size() ? unencrypted_ranges[i + 1].offset : frame_size; - auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); - if (didOverflow || currentEnd > maxEnd) { + auto [did_overflow, current_end] = overflow_add(current.offset, current.size); + if (did_overflow || current_end > max_end) { return false; } } @@ -153,189 +149,185 @@ bool validate_unencrypted_ranges(const ranges& unencryptedRanges, size_t frameSi return true; } -size_t Reconstruct(ranges ranges, - const std::vector& rangeBytes, - const std::vector& otherBytes, - const array_view& output) +size_t do_reconstruct(ranges ranges, const std::vector& range_bytes, const std::vector& other_bytes, const array_view& output) { - size_t frameIndex = 0; - size_t rangeBytesIndex = 0; - size_t otherBytesIndex = 0; - - const auto CopyRangeBytes = [&](size_t size) { - std::memcpy(output.data() + frameIndex, rangeBytes.data() + rangeBytesIndex, size); - rangeBytesIndex += size; - frameIndex += size; + size_t frame_index = 0; + size_t range_bytes_index = 0; + size_t other_bytes_index = 0; + + const auto copy_range_bytes = [&](size_t size) { + std::memcpy(output.data() + frame_index, range_bytes.data() + range_bytes_index, size); + range_bytes_index += size; + frame_index += size; }; - const auto CopyOtherBytes = [&](size_t size) { - std::memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); - otherBytesIndex += size; - frameIndex += size; + const auto copy_other_bytes = [&](size_t size) { + std::memcpy(output.data() + frame_index, other_bytes.data() + other_bytes_index, size); + other_bytes_index += size; + frame_index += size; }; for (const auto& range : ranges) { - if (range.offset > frameIndex) { - CopyOtherBytes(range.offset - frameIndex); + if (range.offset > frame_index) { + copy_other_bytes(range.offset - frame_index); } - CopyRangeBytes(range.size); + copy_range_bytes(range.size); } - if (otherBytesIndex < otherBytes.size()) { - CopyOtherBytes(otherBytes.size() - otherBytesIndex); + if (other_bytes_index < other_bytes.size()) { + copy_other_bytes(other_bytes.size() - other_bytes_index); } - return frameIndex; + return frame_index; } void inbound_frame_processor::clear() { - isEncrypted_ = false; - originalSize_ = 0; - truncatedNonce_ = std::numeric_limits::max(); - unencryptedRanges_.clear(); - authenticated_.clear(); - ciphertext_.clear(); - plaintext_.clear(); + encrypted = false; + original_size = 0; + truncated_nonce = std::numeric_limits::max(); + unencrypted_ranges.clear(); + authenticated.clear(); + ciphertext.clear(); + plaintext.clear(); } void inbound_frame_processor::parse_frame(array_view frame) { clear(); - constexpr auto MinSupplementalBytesSize = - AES_GCM_127_TRUNCATED_TAG_BYTES + sizeof(supplemental_bytes_size) + sizeof(magic_marker); - if (frame.size() < MinSupplementalBytesSize) { + constexpr auto min_supplemental_bytes_size = AES_GCM_127_TRUNCATED_TAG_BYTES + sizeof(supplemental_bytes_size) + sizeof(magic_marker); + if (frame.size() < min_supplemental_bytes_size) { creator.log(dpp::ll_warning, "Encrypted frame is too small to contain min supplemental bytes"); return; } // Check the frame ends with the magic marker - auto magicMarkerBuffer = frame.end() - sizeof(magic_marker); - if (memcmp(magicMarkerBuffer, &MARKER_BYTES, sizeof(magic_marker)) != 0) { + auto magic_marker_buffer = frame.end() - sizeof(magic_marker); + if (memcmp(magic_marker_buffer, &MARKER_BYTES, sizeof(magic_marker)) != 0) { return; } // Read the supplemental bytes size - supplemental_bytes_size supplementalBytesSize; - auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(supplemental_bytes_size); - memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(supplemental_bytes_size)); + supplemental_bytes_size bytes_size; + auto bytes_size_buffer = magic_marker_buffer - sizeof(supplemental_bytes_size); + memcpy(&bytes_size, bytes_size_buffer, sizeof(supplemental_bytes_size)); // Check the frame is large enough to contain the supplemental bytes - if (frame.size() < supplementalBytesSize) { + if (frame.size() < bytes_size) { creator.log(dpp::ll_warning, "Encrypted frame is too small to contain supplemental bytes"); return; } // Check that supplemental bytes size is large enough to contain the supplemental bytes - if (supplementalBytesSize < MinSupplementalBytesSize) { + if (bytes_size < min_supplemental_bytes_size) { creator.log(dpp::ll_warning, "Supplemental bytes size is too small to contain supplemental bytes"); return; } - auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; + auto supplemental_bytes_buffer = frame.end() - bytes_size; // Read the tag - tag_ = make_array_view(supplementalBytesBuffer, AES_GCM_127_TRUNCATED_TAG_BYTES); + tag = make_array_view(supplemental_bytes_buffer, AES_GCM_127_TRUNCATED_TAG_BYTES); // Read the nonce - auto nonceBuffer = supplementalBytesBuffer + AES_GCM_127_TRUNCATED_TAG_BYTES; - auto readAt = nonceBuffer; - auto end = supplementalBytesSizeBuffer; - truncatedNonce_ = read_leb128(readAt, end); - if (readAt == nullptr) { + auto nonce_buffer = supplemental_bytes_buffer + AES_GCM_127_TRUNCATED_TAG_BYTES; + auto read_at = nonce_buffer; + auto end = bytes_size_buffer; + truncated_nonce = static_cast(read_leb128(read_at, end)); + if (read_at == nullptr) { creator.log(dpp::ll_warning, "Failed to read truncated nonce"); return; } // Read the unencrypted ranges - auto unencryptedRangesSize = end - readAt; - deserialize_unencrypted_ranges(readAt, unencryptedRangesSize, unencryptedRanges_); - if (readAt == nullptr) { + auto ranges_size = static_cast(end - read_at); + deserialize_unencrypted_ranges(read_at, ranges_size, unencrypted_ranges); + if (read_at == nullptr) { creator.log(dpp::ll_warning, "Failed to read unencrypted ranges"); return; } - if (!validate_unencrypted_ranges(unencryptedRanges_, frame.size())) { + if (!validate_unencrypted_ranges(unencrypted_ranges, frame.size())) { creator.log(dpp::ll_warning, "Invalid unencrypted ranges"); return; } // This is overly aggressive but will keep reallocations to a minimum - authenticated_.reserve(frame.size()); - ciphertext_.reserve(frame.size()); - plaintext_.reserve(frame.size()); + authenticated.reserve(frame.size()); + ciphertext.reserve(frame.size()); + plaintext.reserve(frame.size()); - originalSize_ = frame.size(); + original_size = frame.size(); // Split the frame into authenticated and ciphertext bytes - size_t frameIndex = 0; - for (const auto& range : unencryptedRanges_) { - auto encryptedBytes = range.offset - frameIndex; - if (encryptedBytes > 0) { - add_ciphertext_bytes(frame.data() + frameIndex, encryptedBytes); + size_t frame_index = 0; + for (const auto& range : unencrypted_ranges) { + auto encrypted_bytes = range.offset - frame_index; + if (encrypted_bytes > 0) { + add_ciphertext_bytes(frame.data() + frame_index, encrypted_bytes); } add_authenticated_bytes(frame.data() + range.offset, range.size); - frameIndex = range.offset + range.size; + frame_index = range.offset + range.size; } - auto actualFrameSize = frame.size() - supplementalBytesSize; - if (frameIndex < actualFrameSize) { - add_ciphertext_bytes(frame.data() + frameIndex, actualFrameSize - frameIndex); + auto actual_frame_size = frame.size() - bytes_size; + if (frame_index < actual_frame_size) { + add_ciphertext_bytes(frame.data() + frame_index, actual_frame_size - frame_index); } // Make sure the plaintext buffer is the same size as the ciphertext buffer - plaintext_.resize(ciphertext_.size()); + plaintext.resize(ciphertext.size()); // We've successfully parsed the frame // Mark the frame as encrypted - isEncrypted_ = true; + encrypted = true; } size_t inbound_frame_processor::reconstruct_frame(array_view frame) const { - if (!isEncrypted_) { + if (!encrypted) { creator.log(dpp::ll_warning, "Cannot reconstruct an invalid encrypted frame"); return 0; } - if (authenticated_.size() + plaintext_.size() > frame.size()) { + if (authenticated.size() + plaintext.size() > frame.size()) { creator.log(dpp::ll_warning, "Frame is too small to contain the decrypted frame"); return 0; } - return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); + return do_reconstruct(unencrypted_ranges, authenticated, plaintext, frame); } void inbound_frame_processor::add_authenticated_bytes(const uint8_t* data, size_t size) { - authenticated_.resize(authenticated_.size() + size); - memcpy(authenticated_.data() + authenticated_.size() - size, data, size); + authenticated.resize(authenticated.size() + size); + memcpy(authenticated.data() + authenticated.size() - size, data, size); } void inbound_frame_processor::add_ciphertext_bytes(const uint8_t* data, size_t size) { - ciphertext_.resize(ciphertext_.size() + size); - memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); + ciphertext.resize(ciphertext.size() + size); + memcpy(ciphertext.data() + ciphertext.size() - size, data, size); } void outbound_frame_processor::reset() { - codec_ = codec::cd_unknown; - frameIndex_ = 0; - unencryptedBytes_.clear(); - encryptedBytes_.clear(); - unencryptedRanges_.clear(); + frame_codec = codec::cd_unknown; + frame_index = 0; + unencrypted_bytes.clear(); + encrypted_bytes.clear(); + unencrypted_ranges.clear(); } void outbound_frame_processor::process_frame(array_view frame, codec codec) { reset(); - codec_ = codec; - unencryptedBytes_.reserve(frame.size()); - encryptedBytes_.reserve(frame.size()); + frame_codec = codec; + unencrypted_bytes.reserve(frame.size()); + encrypted_bytes.reserve(frame.size()); bool success = false; switch (codec) { @@ -362,49 +354,47 @@ void outbound_frame_processor::process_frame(array_view frame, co } if (!success) { - frameIndex_ = 0; - unencryptedBytes_.clear(); - encryptedBytes_.clear(); - unencryptedRanges_.clear(); + frame_index = 0; + unencrypted_bytes.clear(); + encrypted_bytes.clear(); + unencrypted_ranges.clear(); add_encrypted_bytes(frame.data(), frame.size()); } - ciphertextBytes_.resize(encryptedBytes_.size()); + ciphertext_bytes.resize(encrypted_bytes.size()); } size_t outbound_frame_processor::reconstruct_frame(array_view frame) { - if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { + if (unencrypted_bytes.size() + ciphertext_bytes.size() > frame.size()) { creator.log(dpp::ll_warning, "Frame is too small to contain the encrypted frame"); return 0; } - return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); + return do_reconstruct(unencrypted_ranges, unencrypted_bytes, ciphertext_bytes, frame); } void outbound_frame_processor::add_unencrypted_bytes(const uint8_t* bytes, size_t size) { - if (!unencryptedRanges_.empty() && - unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { + if (!unencrypted_ranges.empty() && + unencrypted_ranges.back().offset + unencrypted_ranges.back().size == frame_index) { // extend the last range - unencryptedRanges_.back().size += size; - } - else { + unencrypted_ranges.back().size += size; + } else { // add a new range (offset, size) - unencryptedRanges_.push_back({frameIndex_, size}); + unencrypted_ranges.push_back({frame_index, size}); } - unencryptedBytes_.resize(unencryptedBytes_.size() + size); - memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); - frameIndex_ += size; + unencrypted_bytes.resize(unencrypted_bytes.size() + size); + memcpy(unencrypted_bytes.data() + unencrypted_bytes.size() - size, bytes, size); + frame_index += size; } void outbound_frame_processor::add_encrypted_bytes(const uint8_t* bytes, size_t size) { - encryptedBytes_.resize(encryptedBytes_.size() + size); - memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); - frameIndex_ += size; + encrypted_bytes.resize(encrypted_bytes.size() + size); + memcpy(encrypted_bytes.data() + encrypted_bytes.size() - size, bytes, size); + frame_index += size; } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/frame_processors.h b/src/dpp/dave/frame_processors.h index 26e58d0b42..0ec4b8e255 100755 --- a/src/dpp/dave/frame_processors.h +++ b/src/dpp/dave/frame_processors.h @@ -52,36 +52,36 @@ using ranges = std::vector; /** * @brief Get total size of unencrypted ranges - * @param unencryptedRanges unencrypted ranges + * @param unencrypted_ranges unencrypted ranges * @return size */ -uint8_t unencrypted_ranges_size(const ranges& unencryptedRanges); +uint8_t unencrypted_ranges_size(const ranges& unencrypted_ranges); /** * @brief Serialise unencrypted ranges - * @param unencryptedRanges unencrypted ranges + * @param unencrypted_ranges unencrypted ranges * @param buffer buffer to serialise to - * @param bufferSize size of buffer + * @param buffer_size size of buffer * @return size of ranges written */ -uint8_t serialize_unencrypted_ranges(const ranges& unencryptedRanges, uint8_t* buffer, size_t bufferSize); +uint8_t serialize_unencrypted_ranges(const ranges& unencrypted_ranges, uint8_t* buffer, size_t buffer_size); /** * @brief Deserialise unencrypted ranges - * @param buffer buffer to write to - * @param bufferSize buffer size - * @param unencryptedRanges unencrypted ranges to write to + * @param read_at buffer to write to + * @param buffer_size buffer size + * @param unencrypted_ranges unencrypted ranges to write to * @return size of unencrypted ranges written */ -uint8_t deserialize_unencrypted_ranges(const uint8_t*& buffer, const size_t bufferSize, ranges& unencryptedRanges); +uint8_t deserialize_unencrypted_ranges(const uint8_t*& read_at, const uint8_t buffer_size, ranges& unencrypted_ranges); /** * @brief Validate unencrypted ranges - * @param unencryptedRanges unencrypted ranges - * @param frameSize frame size + * @param unencrypted_ranges unencrypted ranges + * @param frame_size frame size * @return true if validated */ -bool validate_unencrypted_ranges(const ranges& unencryptedRanges, size_t frameSize); +bool validate_unencrypted_ranges(const ranges& unencrypted_ranges, size_t frame_size); /** * @brief Processes inbound frames from the decryptor @@ -112,7 +112,7 @@ class inbound_frame_processor { * @return is encrypted */ [[nodiscard]] bool is_encrypted() const { - return isEncrypted_; + return encrypted; } /** @@ -120,7 +120,7 @@ class inbound_frame_processor { * @return Original frame size */ [[nodiscard]] size_t size() const { - return originalSize_; + return original_size; } /** @@ -133,7 +133,7 @@ class inbound_frame_processor { * @return AEAD tag */ [[nodiscard]] array_view get_tag() const { - return tag_; + return tag; } /** @@ -141,7 +141,7 @@ class inbound_frame_processor { * @return truncated sync nonce */ [[nodiscard]] truncated_sync_nonce get_truncated_nonce() const { - return truncatedNonce_; + return truncated_nonce; } /** @@ -149,7 +149,7 @@ class inbound_frame_processor { * @return AEAD auth data */ [[nodiscard]] array_view get_authenticated_data() const { - return make_array_view(authenticated_.data(), authenticated_.size()); + return make_array_view(authenticated.data(), authenticated.size()); } /** @@ -157,14 +157,16 @@ class inbound_frame_processor { * @return Ciphertext view */ [[nodiscard]] array_view get_ciphertext() const { - return make_array_view(ciphertext_.data(), ciphertext_.size()); + return make_array_view(ciphertext.data(), ciphertext.size()); } /** * @brief Get plain text * @return Plain text view */ - [[nodiscard]] array_view get_plaintext() { return make_array_view(plaintext_); } + [[nodiscard]] array_view get_plaintext() { + return make_array_view(plaintext); + } private: /** @@ -181,14 +183,45 @@ class inbound_frame_processor { */ void add_ciphertext_bytes(const uint8_t* data, size_t size); - bool isEncrypted_{false}; - size_t originalSize_{0}; - array_view tag_; - truncated_sync_nonce truncatedNonce_; - ranges unencryptedRanges_; - std::vector authenticated_; - std::vector ciphertext_; - std::vector plaintext_; + /** + * @brief True if frames are encrypted + */ + bool encrypted{false}; + + /** + * @brief Original size + */ + size_t original_size{0}; + + /** + * @brief AEAD tag + */ + array_view tag; + + /** + * @brief Truncated nonce + */ + truncated_sync_nonce truncated_nonce; + + /** + * @brief Unencrypted parts of the frames + */ + ranges unencrypted_ranges; + + /** + * @brief additional authenticated data + */ + std::vector authenticated; + + /** + * @brief Ciphertext + */ + std::vector ciphertext; + + /** + * @brief Plaintext + */ + std::vector plaintext; /** * @brief DPP Cluster, used for logging @@ -215,7 +248,7 @@ class outbound_frame_processor { void process_frame(array_view frame, codec codec); /** - * @brief Reconstruct frame + * @brief do_reconstruct frame * @param frame frame data * @return size of reconstructed frame */ @@ -226,7 +259,7 @@ class outbound_frame_processor { * @return codec */ [[nodiscard]] codec get_codec() const { - return codec_; + return frame_codec; } /** @@ -234,7 +267,7 @@ class outbound_frame_processor { * @return unencrypted bytes */ [[nodiscard]] const std::vector& get_unencrypted_bytes() const { - return unencryptedBytes_; + return unencrypted_bytes; } /** @@ -242,7 +275,7 @@ class outbound_frame_processor { * @return Encrypted bytes */ [[nodiscard]] const std::vector& get_encrypted_bytes() const { - return encryptedBytes_; + return encrypted_bytes; } /** @@ -250,7 +283,7 @@ class outbound_frame_processor { * @return ciphertext bytes */ [[nodiscard]] std::vector& get_ciphertext_bytes() { - return ciphertextBytes_; + return ciphertext_bytes; } /** @@ -258,7 +291,7 @@ class outbound_frame_processor { * @return unencrypted bytes */ [[nodiscard]] const ranges& get_unencrypted_ranges() const { - return unencryptedRanges_; + return unencrypted_ranges; } /** @@ -281,12 +314,35 @@ class outbound_frame_processor { void add_encrypted_bytes(const uint8_t* bytes, size_t size); private: - codec codec_{codec::cd_unknown}; - size_t frameIndex_{0}; - std::vector unencryptedBytes_; - std::vector encryptedBytes_; - std::vector ciphertextBytes_; - ranges unencryptedRanges_; + /** + * @brief Codec used to decrypt + */ + codec frame_codec{codec::cd_unknown}; + + /** + * @brief Frame index + */ + size_t frame_index{0}; + + /** + * @brief Unencrypted bytes + */ + std::vector unencrypted_bytes; + + /** + * @brief Encrypted bytes + */ + std::vector encrypted_bytes; + + /** + * @brief Ciphertext bytes + */ + std::vector ciphertext_bytes; + + /** + * @brief Unencrypted ranges that need to be kept plaintext to allow for RTP routing + */ + ranges unencrypted_ranges; /** * @brief DPP Cluster, used for logging @@ -294,5 +350,4 @@ class outbound_frame_processor { dpp::cluster& creator; }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/key_ratchet.h b/src/dpp/dave/key_ratchet.h index 3847f5f757..a5ed643364 100755 --- a/src/dpp/dave/key_ratchet.h +++ b/src/dpp/dave/key_ratchet.h @@ -60,5 +60,4 @@ class key_ratchet_interface { // NOLINT virtual void delete_key(key_generation generation) noexcept = 0; }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/leb128.cpp b/src/dpp/dave/leb128.cpp index 8bbe12543f..f43de97772 100755 --- a/src/dpp/dave/leb128.cpp +++ b/src/dpp/dave/leb128.cpp @@ -42,29 +42,29 @@ size_t leb128_size(uint64_t value) return size + 1; } -uint64_t read_leb128(const uint8_t*& readAt, const uint8_t* end) +uint64_t read_leb128(const uint8_t*& read_at, const uint8_t* end) { uint64_t value = 0; - int fillBits = 0; - while (readAt != end && fillBits < 64 - 7) { - uint8_t leb128Byte = *readAt; - value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; - ++readAt; - fillBits += 7; - if ((leb128Byte & 0x80) == 0) { + int fill_bits = 0; + while (read_at != end && fill_bits < 64 - 7) { + uint8_t leb_128_byte = *read_at; + value |= uint64_t{leb_128_byte & 0x7Fu} << fill_bits; + ++read_at; + fill_bits += 7; + if ((leb_128_byte & 0x80) == 0) { return value; } } // Read 9 bytes and didn't find the terminator byte. Check if 10th byte // is that terminator, however to fit result into uint64_t it may carry only // single bit. - if (readAt != end && *readAt <= 1) { - value |= uint64_t{*readAt} << fillBits; - ++readAt; + if (read_at != end && *read_at <= 1) { + value |= uint64_t{*read_at} << fill_bits; + ++read_at; return value; } // Failed to find terminator leb128 byte. - readAt = nullptr; + read_at = nullptr; return 0; } @@ -76,10 +76,9 @@ size_t write_leb128(uint64_t value, uint8_t* buffer) ++size; value >>= 7; } - buffer[size] = value; + buffer[size] = static_cast(value); ++size; return size; } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/leb128.h b/src/dpp/dave/leb128.h index 292df06864..29b1298416 100755 --- a/src/dpp/dave/leb128.h +++ b/src/dpp/dave/leb128.h @@ -50,11 +50,11 @@ size_t leb128_size(uint64_t value); /** * @brief Reads leb128 encoded value and advance read_at by number of bytes consumed. * Sets read_at to nullptr on error. - * @param readAt start position + * @param read_at start position * @param end end position * @return decoded value */ -uint64_t read_leb128(const uint8_t*& readAt, const uint8_t* end); +uint64_t read_leb128(const uint8_t*& read_at, const uint8_t* end); /** * @brief Encodes `value` in leb128 format. Assumes buffer has size of @@ -65,5 +65,4 @@ uint64_t read_leb128(const uint8_t*& readAt, const uint8_t* end); */ size_t write_leb128(uint64_t value, uint8_t* buffer); -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/mls_key_ratchet.cpp b/src/dpp/dave/mls_key_ratchet.cpp index 1f8b1149dc..b693210b5b 100755 --- a/src/dpp/dave/mls_key_ratchet.cpp +++ b/src/dpp/dave/mls_key_ratchet.cpp @@ -27,19 +27,17 @@ namespace dpp::dave { -mls_key_ratchet::mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes baseSecret) noexcept - : hashRatchet_(suite, std::move(baseSecret)), creator(cl) -{ +mls_key_ratchet::mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes base_secret) noexcept : ratchet(suite, std::move(base_secret)), creator(cl) { } mls_key_ratchet::~mls_key_ratchet() noexcept = default; encryption_key mls_key_ratchet::get_key(key_generation generation) noexcept { - creator.log(dpp::ll_debug, "Retrieving key for generation " + std::to_string(generation) + " from HashRatchet"); + creator.log(dpp::ll_debug, "Retrieving key for generation " + std::to_string(generation) + " from hash ratchet"); try { - auto keyAndNonce = hashRatchet_.get(generation); - return std::move(keyAndNonce.key.as_vec()); + auto key_and_nonce = ratchet.get(generation); + return std::move(key_and_nonce.key.as_vec()); } catch (const std::exception& e) { creator.log(dpp::ll_warning, "Failed to retrieve key for generation " + std::to_string(generation) + ": " + std::string(e.what())); @@ -49,8 +47,8 @@ encryption_key mls_key_ratchet::get_key(key_generation generation) noexcept void mls_key_ratchet::delete_key(key_generation generation) noexcept { - hashRatchet_.erase(generation); + ratchet.erase(generation); } -} // namespace dpp::dave +} diff --git a/src/dpp/dave/mls_key_ratchet.h b/src/dpp/dave/mls_key_ratchet.h index 0a4d053819..3f61c18a72 100755 --- a/src/dpp/dave/mls_key_ratchet.h +++ b/src/dpp/dave/mls_key_ratchet.h @@ -41,9 +41,9 @@ class mls_key_ratchet : public key_ratchet_interface { // NOLINT /** * @brief Constructor * @param suite MLS ciphersuite to use - * @param baseSecret base secret + * @param base_secret base secret */ - mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes baseSecret) noexcept; + mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes base_secret) noexcept; /** * @brief Destructor @@ -67,7 +67,7 @@ class mls_key_ratchet : public key_ratchet_interface { // NOLINT /** * @brief MLS hash ratchet */ - ::mlspp::HashRatchet hashRatchet_; + ::mlspp::HashRatchet ratchet; /** * @brief DPP Cluster, used for logging @@ -75,5 +75,5 @@ class mls_key_ratchet : public key_ratchet_interface { // NOLINT dpp::cluster& creator; }; -} // namespace dpp::dave +} diff --git a/src/dpp/dave/openssl_aead_cipher.cpp b/src/dpp/dave/openssl_aead_cipher.cpp index fcb2cad352..175710e598 100755 --- a/src/dpp/dave/openssl_aead_cipher.cpp +++ b/src/dpp/dave/openssl_aead_cipher.cpp @@ -24,26 +24,25 @@ #include #include #include -#include #include "common.h" namespace dpp::dave { -openssl_aead_cipher::openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& encryptionKey) : +openssl_aead_cipher::openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& key) : cipher_interface(_creator), - cipherCtx_(EVP_CIPHER_CTX_new()), - key_(std::vector(encryptionKey.data(), encryptionKey.data() + encryptionKey.size())) { + ssl_context(EVP_CIPHER_CTX_new()), + aes_key(std::vector(key.data(), key.data() + key.size())) { } openssl_aead_cipher::~openssl_aead_cipher() { - EVP_CIPHER_CTX_free(cipherCtx_); + EVP_CIPHER_CTX_free(ssl_context); } -bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view plaintextBuffer, const_byte_view nonceBuffer, const_byte_view additionalData, byte_view tagBufferOut) { +bool openssl_aead_cipher::encrypt(byte_view ciphertext_buffer_out, const_byte_view plaintext_buffer, const_byte_view nonce_buffer, const_byte_view additional_data, byte_view tag_buffer_out) { int len{}; - if (EVP_EncryptInit_ex(cipherCtx_, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { + if (EVP_EncryptInit_ex(ssl_context, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -51,13 +50,13 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view /* * Set IV length */ - if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { + if (EVP_CIPHER_CTX_ctrl(ssl_context, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Initialise key and IV */ - if (EVP_EncryptInit_ex(cipherCtx_, nullptr, nullptr, key_.data(), nonceBuffer.data()) == 0) { + if (EVP_EncryptInit_ex(ssl_context, nullptr, nullptr, aes_key.data(), nonce_buffer.data()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -66,7 +65,7 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * Provide any AAD data. This can be called zero or more times as * required */ - if (EVP_EncryptUpdate(cipherCtx_, nullptr, &len, additionalData.data(), (int)additionalData.size()) == 0) { + if (EVP_EncryptUpdate(ssl_context, nullptr, &len, additional_data.data(), (int)additional_data.size()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -75,7 +74,7 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * Provide the message to be encrypted, and obtain the encrypted output. * EVP_EncryptUpdate can be called multiple times if necessary */ - if (EVP_EncryptUpdate(cipherCtx_, ciphertextBufferOut.data(), &len, plaintextBuffer.data(), (int)plaintextBuffer.size()) == 0) { + if (EVP_EncryptUpdate(ssl_context, ciphertext_buffer_out.data(), &len, plaintext_buffer.data(), (int)plaintext_buffer.size()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -84,13 +83,13 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * Finalise the encryption. Normally ciphertext bytes may be written at * this stage, but this does not occur in GCM mode */ - if (EVP_EncryptFinal_ex(cipherCtx_, ciphertextBufferOut.data() + len, &len) == 0) { + if (EVP_EncryptFinal_ex(ssl_context, ciphertext_buffer_out.data() + len, &len) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Get the tag */ - if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_GET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, tagBufferOut.data()) == 0) { + if (EVP_CIPHER_CTX_ctrl(ssl_context, EVP_CTRL_GCM_GET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, tag_buffer_out.data()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -98,24 +97,24 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view return true; } -bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view ciphertextBuffer, const_byte_view tagBuffer, const_byte_view nonceBuffer, const_byte_view additionalData) { +bool openssl_aead_cipher::decrypt(byte_view plaintext_buffer_out, const_byte_view ciphertext_buffer, const_byte_view tag_buffer, const_byte_view nonce_buffer, const_byte_view additional_data) { int len = 0; /* Initialise the decryption operation. */ - if (EVP_DecryptInit_ex(cipherCtx_, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { + if (EVP_DecryptInit_ex(ssl_context, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Set IV length. Not necessary if this is 12 bytes (96 bits) */ - if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { + if (EVP_CIPHER_CTX_ctrl(ssl_context, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Initialise key and IV */ - if (EVP_DecryptInit_ex(cipherCtx_, nullptr, nullptr, key_.data(), nonceBuffer.data()) == 0) { + if (EVP_DecryptInit_ex(ssl_context, nullptr, nullptr, aes_key.data(), nonce_buffer.data()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -124,7 +123,7 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * Provide any AAD data. This can be called zero or more times as * required */ - if (EVP_DecryptUpdate(cipherCtx_, nullptr, &len, additionalData.data(), (int)additionalData.size()) == 0) { + if (EVP_DecryptUpdate(ssl_context, nullptr, &len, additional_data.data(), (int)additional_data.size()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -133,13 +132,13 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * Provide the message to be decrypted, and obtain the plaintext output. * EVP_DecryptUpdate can be called multiple times if necessary */ - if (EVP_DecryptUpdate(cipherCtx_, plaintextBufferOut.data(), &len, ciphertextBuffer.data(), (int)ciphertextBuffer.size()) == 0) { + if (EVP_DecryptUpdate(ssl_context, plaintext_buffer_out.data(), &len, ciphertext_buffer.data(), (int)ciphertext_buffer.size()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Set expected tag value. Works in OpenSSL 1.0.1d and later */ - if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, (void*)tagBuffer.data()) == 0) { + if (EVP_CIPHER_CTX_ctrl(ssl_context, EVP_CTRL_GCM_SET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, (void*)tag_buffer.data()) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -148,7 +147,7 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * Finalise the decryption. A positive return value indicates success, * anything else is a failure - the plaintext is not trustworthy. */ - if (EVP_DecryptFinal_ex(cipherCtx_, plaintextBufferOut.data() + len, &len) == 0) { + if (EVP_DecryptFinal_ex(ssl_context, plaintext_buffer_out.data() + len, &len) == 0) { creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } diff --git a/src/dpp/dave/openssl_aead_cipher.h b/src/dpp/dave/openssl_aead_cipher.h index b69ae2c322..c535eef108 100755 --- a/src/dpp/dave/openssl_aead_cipher.h +++ b/src/dpp/dave/openssl_aead_cipher.h @@ -44,9 +44,9 @@ class openssl_aead_cipher : public cipher_interface { // NOLINT /** * @brief constructor * @param _creator Creator - * @param encryptionKey encryption key + * @param key encryption key */ - openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& encryptionKey); + openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& key); /** * @brief Destructor @@ -58,41 +58,41 @@ class openssl_aead_cipher : public cipher_interface { // NOLINT * @return True if valid */ [[nodiscard]] bool inline is_valid() const { - return cipherCtx_ != nullptr; + return ssl_context != nullptr; } /** * @brief Encrypt plaintext to ciphertext and authenticate it with tag/AAD - * @param ciphertextBufferOut ciphertext - * @param plaintextBuffer plaintext - * @param nonceBuffer nonce/IV - * @param additionalData additional authenticated data - * @param tagBufferOut tag + * @param ciphertext_buffer_out ciphertext + * @param plaintext_buffer plaintext + * @param nonce_buffer nonce/IV + * @param additional_data additional authenticated data + * @param tag_buffer_out tag * @return True if encryption succeeded */ - bool encrypt(byte_view ciphertextBufferOut, const_byte_view plaintextBuffer, const_byte_view nonceBuffer, const_byte_view additionalData, byte_view tagBufferOut) override; + bool encrypt(byte_view ciphertext_buffer_out, const_byte_view plaintext_buffer, const_byte_view nonce_buffer, const_byte_view additional_data, byte_view tag_buffer_out) override; /** * @brief Decrypt ciphertext to plaintext if it authenticates with tag/AAD - * @param plaintextBufferOut plaintext - * @param ciphertextBuffer ciphertext - * @param tagBuffer tag - * @param nonceBuffer nonce/IV - * @param additionalData additional authenticated data + * @param plaintext_buffer_out plaintext + * @param ciphertext_buffer ciphertext + * @param tag_buffer tag + * @param nonce_buffer nonce/IV + * @param additional_data additional authenticated data * @return True if decryption succeeded */ - bool decrypt(byte_view plaintextBufferOut, const_byte_view ciphertextBuffer, const_byte_view tagBuffer, const_byte_view nonceBuffer, const_byte_view additionalData) override; + bool decrypt(byte_view plaintext_buffer_out, const_byte_view ciphertext_buffer, const_byte_view tag_buffer, const_byte_view nonce_buffer, const_byte_view additional_data) override; private: /** * @brief Using EVP_CIPHER_CTX instead of EVP_AEAD_CTX */ - EVP_CIPHER_CTX* cipherCtx_; + EVP_CIPHER_CTX* ssl_context; /** * @brief Encryption/decryption key */ - std::vector key_; + std::vector aes_key; }; } // namespace dpp::dave diff --git a/src/dpp/dave/parameters.cpp b/src/dpp/dave/parameters.cpp index 965ba80b39..bfaa44f7c2 100755 --- a/src/dpp/dave/parameters.cpp +++ b/src/dpp/dave/parameters.cpp @@ -61,17 +61,13 @@ ::mlspp::ExtensionList leaf_node_extensions_for_protocol_version(protocol_versio return ::mlspp::ExtensionList{}; } -::mlspp::ExtensionList group_extensions_for_protocol_version( - protocol_version version, - const ::mlspp::ExternalSender& externalSender) noexcept +::mlspp::ExtensionList group_extensions_for_protocol_version(protocol_version version, const ::mlspp::ExternalSender& external_sender) noexcept { - auto extensionList = ::mlspp::ExtensionList{}; - - extensionList.add(::mlspp::ExternalSendersExtension{{ - {externalSender.signature_key, externalSender.credential}, + auto extension_list = ::mlspp::ExtensionList{}; + extension_list.add(::mlspp::ExternalSendersExtension{{ + {external_sender.signature_key, external_sender.credential}, }}); - - return extensionList; + return extension_list; } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/parameters.h b/src/dpp/dave/parameters.h index a5a8d4a813..ca8590d42d 100755 --- a/src/dpp/dave/parameters.h +++ b/src/dpp/dave/parameters.h @@ -77,9 +77,9 @@ ::mlspp::ExtensionList leaf_node_extensions_for_protocol_version(protocol_versio /** * @brief Get group extensions for protocol version * @param version protocol bersion - * @param externalSender external sender + * @param external_sender external sender * @return extension list */ -::mlspp::ExtensionList group_extensions_for_protocol_version(protocol_version version, const ::mlspp::ExternalSender& externalSender) noexcept; +::mlspp::ExtensionList group_extensions_for_protocol_version(protocol_version version, const ::mlspp::ExternalSender& external_sender) noexcept; -} // namespace dpp::dave::mls +} diff --git a/src/dpp/dave/persisted_key_pair.cpp b/src/dpp/dave/persisted_key_pair.cpp index 77ea4f415e..2ed4afc87a 100755 --- a/src/dpp/dave/persisted_key_pair.cpp +++ b/src/dpp/dave/persisted_key_pair.cpp @@ -36,11 +36,10 @@ #include #include "parameters.h" -static const std::string SelfSignatureLabel = "DiscordSelfSignature"; +static const std::string self_signature_label = "DiscordSelfSignature"; -static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) -{ - return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + std::to_string(dpp::dave::mls::KeyVersion); +static std::string make_key_id(const std::string& session_id, ::mlspp::CipherSuite suite) { + return session_id + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + std::to_string(dpp::dave::mls::KeyVersion); } static std::mutex mtx; @@ -48,11 +47,10 @@ static std::map> map; namespace dpp::dave::mls { -static std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, ::mlspp::CipherSuite suite) -{ +static std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, ::mlspp::CipherSuite suite) { std::lock_guard lk(mtx); - std::string id = MakeKeyID(sessionID, suite); + std::string id = make_key_id(session_id, suite); if (auto it = map.find(id); it != map.end()) { return it->second; @@ -70,34 +68,32 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp: return ret; } -std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, protocol_version version) +std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, protocol_version version) { - return get_persisted_key_pair(creator, ctx, sessionID, ciphersuite_for_protocol_version(version)); + return get_persisted_key_pair(creator, ctx, session_id, ciphersuite_for_protocol_version(version)); } -KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version) -{ +KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, signature_version version) { auto suite = ciphersuite_for_signature_version(version); - auto pair = get_persisted_key_pair(creator, ctx, sessionID, suite); + auto pair = get_persisted_key_pair(creator, ctx, session_id, suite); if (!pair) { return {}; } - bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; + bytes sign_data = from_ascii(session_id + ":") + pair->public_key.data; return { pair->public_key.data.as_vec(), - std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), + std::move(pair->sign(suite, self_signature_label, sign_data).as_vec()), }; } -bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version) -{ - std::string id = MakeKeyID(sessionID, ciphersuite_for_signature_version(version)); +bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, signature_version version) { + std::string id = make_key_id(session_id, ciphersuite_for_signature_version(version)); std::lock_guard lk(mtx); map.erase(id); return ::dpp::dave::mls::detail::delete_generic_persisted_key_pair(creator, ctx, id); } -} // namespace dpp::dave::mls +} diff --git a/src/dpp/dave/persisted_key_pair.h b/src/dpp/dave/persisted_key_pair.h index 7c13c6578b..ceb5d185b1 100755 --- a/src/dpp/dave/persisted_key_pair.h +++ b/src/dpp/dave/persisted_key_pair.h @@ -55,11 +55,11 @@ using key_pair_context_type = const char *; /** * @brief Get persisted key pair * @param ctx context (pass nullptr to generate transient key) - * @param sessionID session id (pass empty string to generate transient key) + * @param session_id session id (pass empty string to generate transient key) * @param version Protocol version * @return MLS signature private key */ -std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, protocol_version version); +std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, protocol_version version); /** * @brief self signed signature and key @@ -78,20 +78,20 @@ struct KeyAndSelfSignature { /** * @brief Get persisted public key * @param ctx context (set to nullptr to get transient key) - * @param sessionID session id (set to empty string to get transient key) + * @param session_id session id (set to empty string to get transient key) * @param version protocol version * @return Key and self signature */ -KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version); +KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, signature_version version); /** * @brief Delete persisted key pair * @param ctx context - * @param sessionID session ID + * @param session_id session ID * @param version protocol version * @return true if deleted */ -bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version); +bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& session_id, signature_version version); /** * @brief Key version for DAVE @@ -115,9 +115,6 @@ namespace detail { * @return true if deleted */ bool delete_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id); -} // namespace dpp::dave::mls::detail - -} // namespace dpp::dave::mls - - +} +} diff --git a/src/dpp/dave/persisted_key_pair_generic.cpp b/src/dpp/dave/persisted_key_pair_generic.cpp index c0a473acc7..1bc315eda6 100755 --- a/src/dpp/dave/persisted_key_pair_generic.cpp +++ b/src/dpp/dave/persisted_key_pair_generic.cpp @@ -41,10 +41,9 @@ #include "parameters.h" #include "persisted_key_pair.h" -static const std::string_view KeyStorageDir = "Discord Key Storage"; +static const std::string_view key_storage_dir = "Discord Key Storage"; -static std::filesystem::path GetKeyStorageDirectory() -{ +static std::filesystem::path get_key_storage_directory() { std::filesystem::path dir; #if defined(__ANDROID__) @@ -56,26 +55,26 @@ static std::filesystem::path GetKeyStorageDirectory() std::getline(idFile, appId, '\0'); dir /= appId; } -#else // __ANDROID__ -#if defined(_WIN32) - if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { - dir = std::filesystem::path(appdata); - } -#else // _WIN32 - if (const char* xdg = getenv("XDG_CONFIG_HOME")) { - dir = std::filesystem::path(xdg); - } - else if (const char* home = getenv("HOME")) { - dir = std::filesystem::path(home); - dir /= ".config"; - } -#endif // !_WIN32 +#else + #if defined(_WIN32) + if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { + dir = std::filesystem::path(appdata); + } + #else + if (const char* xdg = getenv("XDG_CONFIG_HOME")) { + dir = std::filesystem::path(xdg); + } + else if (const char* home = getenv("HOME")) { + dir = std::filesystem::path(home); + dir /= ".config"; + } + #endif else { return dir; } -#endif // !__ANDROID__ +#endif - return dir / KeyStorageDir; + return dir / key_storage_dir; } namespace dpp::dave::mls::detail { @@ -84,7 +83,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp { ::mlspp::SignaturePrivateKey ret; std::string curstr; - std::filesystem::path dir = GetKeyStorageDirectory(); + std::filesystem::path dir = get_key_storage_directory(); if (dir.empty()) { creator.log(dpp::ll_warning, "Failed to determine key storage directory in get_persisted_key_pair"); @@ -107,9 +106,9 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp return nullptr; } - std::stringstream s; - s << ifs.rdbuf(); - curstr = s.str(); + std::stringstream s; + s << ifs.rdbuf(); + curstr = s.str(); if (!ifs) { creator.log(dpp::ll_warning, "Failed to read key in get_persisted_key_pair"); return nullptr; @@ -134,9 +133,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp #ifdef _WIN32 int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); #else - int fd = open(tmpfile.c_str(), - O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, - S_IRUSR | S_IWUSR); + int fd = open(tmpfile.c_str(), O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); #endif if (fd < 0) { creator.log(dpp::ll_warning, "Failed to open output file in get_persisted_key_pair: " + std::to_string(errno) + " (" + tmpfile.generic_string() + ")"); @@ -144,13 +141,13 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp } #ifdef _WIN32 - int wret = _write(fd, newstr.c_str(), newstr.size()); + int written = _write(fd, newstr.c_str(), static_cast(newstr.size())); _close(fd); #else - ssize_t wret = write(fd, newstr.c_str(), newstr.size()); + ssize_t written = write(fd, newstr.c_str(), newstr.size()); close(fd); #endif - if (wret < 0 || (size_t)wret != newstr.size()) { + if (written < 0 || (size_t)written != newstr.size()) { creator.log(dpp::ll_warning, "Failed to write output file in get_persisted_key_pair: " + std::to_string(errno)); return nullptr; } @@ -172,7 +169,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp bool delete_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id) { std::error_code errc; - std::filesystem::path dir = GetKeyStorageDirectory(); + std::filesystem::path dir = get_key_storage_directory(); if (dir.empty()) { creator.log(dpp::ll_warning, "Failed to determine key storage directory in get_persisted_key_pair"); return false; @@ -182,7 +179,4 @@ bool delete_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_t return std::filesystem::remove(file, errc); } -} // namespace dpp::dave::mls::detail - - - +} diff --git a/src/dpp/dave/scope_exit.h b/src/dpp/dave/scope_exit.h index 219c51dc02..6191e2e499 100755 --- a/src/dpp/dave/scope_exit.h +++ b/src/dpp/dave/scope_exit.h @@ -40,27 +40,23 @@ class [[nodiscard]] scope_exit final { * @tparam Cleanup lambda type * @param cleanup lambda */ - template explicit scope_exit(Cleanup&& cleanup) - : cleanup_{std::forward(cleanup)} - { + template explicit scope_exit(Cleanup&& cleanup) : exit_function{std::forward(cleanup)} { } /** * @brief Move constructor * @param rhs other object */ - scope_exit(scope_exit&& rhs) : cleanup_{std::move(rhs.cleanup_)} - { - rhs.cleanup_ = nullptr; + scope_exit(scope_exit&& rhs) : exit_function{std::move(rhs.exit_function)} { + rhs.exit_function = nullptr; } /** * @brief Calls lambda */ - ~scope_exit() - { - if (cleanup_) { - cleanup_(); + ~scope_exit() { + if (exit_function) { + exit_function(); } } @@ -69,17 +65,18 @@ class [[nodiscard]] scope_exit final { * @param rhs other object * @return self */ - scope_exit& operator=(scope_exit&& rhs) - { - cleanup_ = std::move(rhs.cleanup_); - rhs.cleanup_ = nullptr; + scope_exit& operator=(scope_exit&& rhs) { + exit_function = std::move(rhs.exit_function); + rhs.exit_function = nullptr; return *this; } /** * @brief Clear the lambda so it isn't called */ - void dismiss() { cleanup_ = nullptr; } + void dismiss() { + exit_function = nullptr; + } private: /** @@ -96,8 +93,7 @@ class [[nodiscard]] scope_exit final { /** * @brief Lambda to call */ - std::function cleanup_; + std::function exit_function; }; -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/session.cpp b/src/dpp/dave/session.cpp index 5aa0859856..0bb7f10568 100755 --- a/src/dpp/dave/session.cpp +++ b/src/dpp/dave/session.cpp @@ -23,12 +23,10 @@ * ************************************************************************************/ #include "session.h" - #include #include #include #include - #include #include #include @@ -38,12 +36,11 @@ #include "parameters.h" #include "persisted_key_pair.h" #include "util.h" - #include "openssl/evp.h" -#define TRACK_MLS_ERROR(reason) \ - if (onMLSFailureCallback_) { \ - onMLSFailureCallback_(__FUNCTION__, reason); \ +#define TRACK_MLS_ERROR(reason) \ + if (failure_callback) { \ + failure_callback(__FUNCTION__, reason); \ } namespace dpp::dave::mls { @@ -53,24 +50,24 @@ struct queued_proposal { ::mlspp::bytes_ns::bytes ref; }; -session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& authSessionId, mls_failure_callback callback) noexcept - : signingKeyId_(authSessionId), keyPairContext_(context), onMLSFailureCallback_(std::move(callback)), creator(cluster) +session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept + : signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster) { creator.log(dpp::ll_debug, "Creating a new MLS session"); } session::~session() noexcept = default; -void session::init(protocol_version version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept { +void session::init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept { reset(); - selfUserId_ = selfUserId; + bot_user_id = self_user_id; - creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(groupId)); - protocolVersion_ = version; - groupId_ = std::move(big_endian_bytes_from(groupId).as_vec()); + creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(group_id)); + session_protocol_version = version; + session_group_id = std::move(big_endian_bytes_from(group_id).as_vec()); - init_leaf_node(selfUserId, transientKey); + init_leaf_node(self_user_id, transient_key); create_pending_group(); } @@ -80,43 +77,43 @@ void session::reset() noexcept { clear_pending_state(); - currentState_.reset(); - outboundCachedGroupState_.reset(); + current_state.reset(); + outbound_cached_group_state.reset(); - protocolVersion_ = 0; - groupId_.clear(); + session_protocol_version = 0; + session_group_id.clear(); } void session::set_protocol_version(protocol_version version) noexcept { - if (version != protocolVersion_) { + if (version != session_protocol_version) { // when we need to retain backwards compatibility // there may be some changes to the MLS objects required here // until then we can just update the stored version - protocolVersion_ = version; + session_protocol_version = version; } } std::vector session::get_last_epoch_authenticator() const noexcept { - if (!currentState_) { + if (!current_state) { creator.log(dpp::ll_debug, "Cannot get epoch authenticator without an established MLS group"); return {}; } - return std::move(currentState_->epoch_authenticator().as_vec()); + return std::move(current_state->epoch_authenticator().as_vec()); } -void session::set_external_sender(const std::vector &externalSenderPackage) noexcept +void session::set_external_sender(const std::vector &external_sender_package) noexcept try { - if (currentState_) { + if (current_state) { creator.log(dpp::ll_warning, "Cannot set external sender after joining/creating an MLS group"); return; } creator.log(dpp::ll_debug, "Unmarshalling MLS external sender"); - externalSender_ = std::make_unique<::mlspp::ExternalSender>( - ::mlspp::tls::get<::mlspp::ExternalSender>(externalSenderPackage)); + mls_external_sender = std::make_unique<::mlspp::ExternalSender>( + ::mlspp::tls::get<::mlspp::ExternalSender>(external_sender_package)); - if (!groupId_.empty()) { + if (!session_group_id.empty()) { create_pending_group(); } } @@ -126,41 +123,41 @@ catch (const std::exception& e) { return; } -std::optional> session::process_proposals(std::vector proposals, std::set const& recognizedUserIDs) noexcept +std::optional> session::process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept try { - if (!pendingGroupState_ && !currentState_) { + if (!pending_group_state && !current_state) { creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state"); return std::nullopt; } - if (!stateWithProposals_) { - stateWithProposals_ = std::make_unique<::mlspp::State>( - pendingGroupState_ ? *pendingGroupState_ : *currentState_); + if (!state_with_proposals) { + state_with_proposals = std::make_unique<::mlspp::State>( + pending_group_state ? *pending_group_state : *current_state); } creator.log(dpp::ll_debug, "Processing MLS proposals message of " + std::to_string(proposals.size()) + " bytes"); - ::mlspp::tls::istream inStream(proposals); + ::mlspp::tls::istream in_stream(proposals); - bool isRevoke = false; - inStream >> isRevoke; + bool is_revoke = false; + in_stream >> is_revoke; - if (isRevoke) { + if (is_revoke) { creator.log(dpp::ll_trace, "Revoking from proposals"); } - const auto suite = stateWithProposals_->cipher_suite(); + const auto suite = state_with_proposals->cipher_suite(); - if (isRevoke) { + if (is_revoke) { std::vector<::mlspp::bytes_ns::bytes> refs; - inStream >> refs; + in_stream >> refs; for (const auto& ref : refs) { bool found = false; - for (auto it = proposalQueue_.begin(); it != proposalQueue_.end(); it++) { + for (auto it = proposal_queue.begin(); it != proposal_queue.end(); it++) { if (it->ref == ref) { found = true; - proposalQueue_.erase(it); + proposal_queue.erase(it); break; } } @@ -172,71 +169,68 @@ try { } } - stateWithProposals_ = std::make_unique<::mlspp::State>( - pendingGroupState_ ? *pendingGroupState_ : *currentState_); + state_with_proposals = std::make_unique<::mlspp::State>( + pending_group_state ? *pending_group_state : *current_state); - for (auto& prop : proposalQueue_) { + for (auto& prop : proposal_queue) { // success will queue the proposal, failure will throw - stateWithProposals_->handle(prop.content); + state_with_proposals->handle(prop.content); } - } - else { + } else { std::vector<::mlspp::MLSMessage> messages; - inStream >> messages; + in_stream >> messages; - for (const auto& proposalMessage : messages) { - auto validatedMessage = stateWithProposals_->unwrap(proposalMessage); + for (const auto& proposal_message : messages) { + auto validated_content = state_with_proposals->unwrap(proposal_message); - if (!validate_proposal_message(validatedMessage.authenticated_content(), - *stateWithProposals_, - recognizedUserIDs)) { + if (!validate_proposal_message(validated_content.authenticated_content(), + *state_with_proposals, + recognised_user_ids)) { return std::nullopt; } // success will queue the proposal, failure will throw - stateWithProposals_->handle(validatedMessage); + state_with_proposals->handle(validated_content); - auto ref = suite.ref(validatedMessage.authenticated_content()); + auto ref = suite.ref(validated_content.authenticated_content()); - proposalQueue_.push_back({ - std::move(validatedMessage), + proposal_queue.push_back({ + std::move(validated_content), std::move(ref), }); } } // generate a commit - auto commitSecret = ::mlspp::hpke::random_bytes(suite.secret_size()); + auto commit_secret = ::mlspp::hpke::random_bytes(suite.secret_size()); - auto commitOpts = ::mlspp::CommitOpts{ - {}, // no extra proposals - true, // inline tree in welcome - false, // do not force path - {} // default leaf node options + auto commit_options = ::mlspp::CommitOpts{ + {}, // no extra proposals + true, // inline tree in welcome + false, // do not force path + {} // default leaf node options }; - auto [commitMessage, welcomeMessage, newState] = - stateWithProposals_->commit(commitSecret, commitOpts, {}); + auto [commit_message, welcome_message, new_state] = state_with_proposals->commit(commit_secret, commit_options, {}); creator.log(dpp::ll_debug, "Prepared commit/welcome/next state for MLS group from received proposals"); // combine the commit and welcome messages into a single buffer - auto outStream = ::mlspp::tls::ostream(); - outStream << commitMessage; + auto out_stream = ::mlspp::tls::ostream(); + out_stream << commit_message; // keep a copy of the commit, we can check incoming pending group commit later for a match - pendingGroupCommit_ = std::make_unique<::mlspp::MLSMessage>(std::move(commitMessage)); + pending_group_commit = std::make_unique<::mlspp::MLSMessage>(std::move(commit_message)); // if there were any add proposals in this commit, then we also include the welcome message - if (welcomeMessage.secrets.size() > 0) { - outStream << welcomeMessage; + if (welcome_message.secrets.size() > 0) { + out_stream << welcome_message; } // cache the outbound state in case we're the winning sender - outboundCachedGroupState_ = std::make_unique<::mlspp::State>(std::move(newState)); - + outbound_cached_group_state = std::make_unique<::mlspp::State>(std::move(new_state)); - return outStream.bytes(); + return out_stream.bytes(); } catch (const std::exception& e) { creator.log(dpp::ll_warning, "Failed to parse MLS proposals: " + std::string(e.what())); @@ -244,15 +238,15 @@ catch (const std::exception& e) { return std::nullopt; } -bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const +bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const { - std::string uid = user_credential_to_string(cred, protocolVersion_); + std::string uid = user_credential_to_string(cred, session_protocol_version); if (uid.empty()) { creator.log(dpp::ll_warning, "Attempted to verify credential of unexpected type"); return false; } - if (recognizedUserIDs.find(uid) == recognizedUserIDs.end()) { + if (recognised_user_ids.find(uid) == recognised_user_ids.end()) { creator.log(dpp::ll_warning, "Attempted to verify credential for unrecognized user ID: " + uid); return false; } @@ -260,15 +254,15 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const { +bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set const& recognised_user_ids) const { if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage"); TRACK_MLS_ERROR("Invalid proposal wire format"); return false; } - if (message.content.epoch != targetState.epoch()) { - creator.log(dpp::ll_warning, "MLS proposal message must be for current epoch (" + std::to_string(message.content.epoch) + " != " + std::to_string(targetState.epoch()) + ")"); + if (message.content.epoch != target_state.epoch()) { + creator.log(dpp::ll_warning, "MLS proposal message must be for current epoch (" + std::to_string(message.content.epoch) + " != " + std::to_string(target_state.epoch()) + ")"); TRACK_MLS_ERROR("Proposal epoch mismatch"); return false; } @@ -290,7 +284,7 @@ bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& mes case ::mlspp::ProposalType::add: { const auto& credential = ::mlspp::tls::var::get<::mlspp::Add>(proposal.content).key_package.leaf_node.credential; - if (!is_recognized_user_id(credential, recognizedUserIDs)) { + if (!is_recognized_user_id(credential, recognised_user_ids)) { creator.log(dpp::ll_warning, "MLS proposal must be for recognised user"); TRACK_MLS_ERROR("Unexpected user ID in add proposal"); return false; @@ -311,11 +305,11 @@ bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& mes bool session::can_process_commit(const ::mlspp::MLSMessage& commit) noexcept { - if (!stateWithProposals_) { + if (!state_with_proposals) { return false; } - if (commit.group_id() != groupId_) { + if (commit.group_id() != session_group_id) { creator.log(dpp::ll_warning, "MLS commit message was for unexpected group"); return false; } @@ -327,34 +321,32 @@ roster_variant session::process_commit(std::vector commit) noexcept try { creator.log(dpp::ll_debug, "Processing commit"); - auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); + auto commit_message = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); - if (!can_process_commit(commitMessage)) { + if (!can_process_commit(commit_message)) { creator.log(dpp::ll_warning, "process_commit called with unprocessable MLS commit"); return ignored_t{}; } // in case we're the sender of this commit // we need to pull the cached state from our outbound cache - std::optional<::mlspp::State> optionalCachedState = std::nullopt; - if (outboundCachedGroupState_) { - optionalCachedState = *(outboundCachedGroupState_.get()); + std::optional<::mlspp::State> optional_cached_state = std::nullopt; + if (outbound_cached_group_state) { + optional_cached_state = *(outbound_cached_group_state.get()); } - auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState); - - if (!newState) { + auto new_state = state_with_proposals->handle(commit_message, optional_cached_state); + if (!new_state) { creator.log(dpp::ll_warning, "MLS commit handling did not produce a new state"); return failed_t{}; } - creator.log(dpp::ll_debug, "Successfully processed MLS commit, updating state; our leaf index is " + std::to_string(newState->index().val) + "; current epoch is " + std::to_string(newState->epoch())); + creator.log(dpp::ll_debug, "Successfully processed MLS commit, updating state; our leaf index is " + std::to_string(new_state->index().val) + "; current epoch is " + std::to_string(new_state->epoch())); - roster_map ret = replace_state(std::make_unique<::mlspp::State>(std::move(*newState))); + roster_map ret = replace_state(std::make_unique<::mlspp::State>(std::move(*new_state))); // reset the outbound cached group since we handled the commit for this epoch - outboundCachedGroupState_.reset(); - + outbound_cached_group_state.reset(); clear_pending_state(); return ret; @@ -365,46 +357,46 @@ catch (const std::exception& e) { return failed_t{}; } -std::optional session::process_welcome(std::vector welcome, std::set const& recognizedUserIDs) noexcept +std::optional session::process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept try { if (!has_cryptographic_state_for_welcome()) { creator.log(dpp::ll_warning, "Missing local crypto state necessary to process MLS welcome"); return std::nullopt; } - if (!externalSender_) { + if (!mls_external_sender) { creator.log(dpp::ll_warning, "Cannot process MLS welcome without an external sender"); return std::nullopt; } - if (currentState_) { + if (current_state) { creator.log(dpp::ll_warning, "Cannot process MLS welcome after joining/creating an MLS group"); return std::nullopt; } // unmarshal the incoming welcome - auto unmarshalledWelcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); + auto unmarshalled_welcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); // construct the state from the unmarshalled welcome - auto newState = std::make_unique<::mlspp::State>( - *joinInitPrivateKey_, - *selfHPKEPrivateKey_, - *selfSigPrivateKey_, - *joinKeyPackage_, - unmarshalledWelcome, - std::nullopt, - std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); + auto new_state = std::make_unique<::mlspp::State>( + *join_init_private_key, + *hpke_private_key, + *signature_private_key, + *join_key_package, + unmarshalled_welcome, + std::nullopt, + std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); // perform application-level verification of the new state - if (!verify_welcome_state(*newState, recognizedUserIDs)) { + if (!verify_welcome_state(*new_state, recognised_user_ids)) { creator.log(dpp::ll_warning, "Group received in MLS welcome is not valid"); return std::nullopt; } - creator.log(dpp::ll_debug, "Successfully welcomed to MLS Group, our leaf index is " + std::to_string(newState->index().val) + "; current epoch is " + std::to_string(newState->epoch())); + creator.log(dpp::ll_debug, "Successfully welcomed to MLS Group, our leaf index is " + std::to_string(new_state->index().val) + "; current epoch is " + std::to_string(new_state->epoch())); // make the verified state our new (and only) state - roster_map ret = replace_state(std::move(newState)); + roster_map ret = replace_state(std::move(new_state)); // clear out any pending state for creating/joining a group clear_pending_state(); @@ -419,27 +411,22 @@ catch (const std::exception& e) { roster_map session::replace_state(std::unique_ptr<::mlspp::State>&& state) { - roster_map newRoster; + roster_map new_roster; for (const ::mlspp::LeafNode& node : state->roster()) { if (node.credential.type() != ::mlspp::CredentialType::basic) { continue; } const auto& cred = node.credential.template get<::mlspp::BasicCredential>(); - - newRoster[from_big_endian_bytes(cred.identity)] = node.signature_key.data.as_vec(); + new_roster[from_big_endian_bytes(cred.identity)] = node.signature_key.data.as_vec(); } - roster_map changeMap; + roster_map change_map; - std::set_difference(newRoster.begin(), - newRoster.end(), - roster_.begin(), - roster_.end(), - std::inserter(changeMap, changeMap.end())); + std::set_difference(new_roster.begin(), new_roster.end(), roster.begin(), roster.end(), std::inserter(change_map, change_map.end())); - struct MissingItemWrapper { - roster_map& changeMap_; + struct missing_item_wrapper { + roster_map& map; using iterator = roster_map::iterator; using const_iterator = roster_map::const_iterator; @@ -447,37 +434,36 @@ roster_map session::replace_state(std::unique_ptr<::mlspp::State>&& state) iterator insert(const_iterator it, const value_type& value) { - return changeMap_.try_emplace(it, value.first, std::vector{}); + return map.try_emplace(it, value.first, std::vector{}); } - iterator begin() { return changeMap_.begin(); } + iterator begin() { return map.begin(); } - iterator end() { return changeMap_.end(); } + iterator end() { return map.end(); } }; - MissingItemWrapper wrapper{changeMap}; + missing_item_wrapper wrapper{change_map}; - std::set_difference(roster_.begin(), - roster_.end(), - newRoster.begin(), - newRoster.end(), - std::inserter(wrapper, wrapper.end())); + std::set_difference(roster.begin(), + roster.end(), + new_roster.begin(), + new_roster.end(), + std::inserter(wrapper, wrapper.end())); - roster_ = std::move(newRoster); - currentState_ = std::move(state); + roster = std::move(new_roster); + current_state = std::move(state); - return changeMap; + return change_map; } bool session::has_cryptographic_state_for_welcome() const noexcept { - return joinKeyPackage_ && joinInitPrivateKey_ && selfSigPrivateKey_ && selfHPKEPrivateKey_; + return join_key_package && join_init_private_key && signature_private_key && hpke_private_key; } -bool session::verify_welcome_state(::mlspp::State const& state, - std::set const& recognizedUserIDs) const +bool session::verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const { - if (!externalSender_) { + if (!mls_external_sender) { creator.log(dpp::ll_warning, "Cannot verify MLS welcome without an external sender"); TRACK_MLS_ERROR("Missing external sender when processing Welcome"); return false; @@ -496,7 +482,7 @@ bool session::verify_welcome_state(::mlspp::State const& state, return false; } - if (ext->senders.front() != *externalSender_) { + if (ext->senders.front() != *mls_external_sender) { creator.log(dpp::ll_warning, "MLS welcome lists unexpected external sender"); TRACK_MLS_ERROR("Welcome message lists unexpected external sender"); return false; @@ -508,7 +494,7 @@ bool session::verify_welcome_state(::mlspp::State const& state, // before all in-flight proposals were handled. for (const auto& leaf : state.roster()) { - if (!is_recognized_user_id(leaf.credential, recognizedUserIDs)) { + if (!is_recognized_user_id(leaf.credential, recognised_user_ids)) { creator.log(dpp::ll_warning, "MLS welcome lists unrecognized user ID"); } } @@ -516,33 +502,33 @@ bool session::verify_welcome_state(::mlspp::State const& state, return true; } -void session::init_leaf_node(std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept +void session::init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept try { - auto ciphersuite = ciphersuite_for_protocol_version(protocolVersion_); + auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version); - if (!transientKey) { - if (!signingKeyId_.empty()) { - transientKey = get_persisted_key_pair(creator, keyPairContext_, signingKeyId_, protocolVersion_); - if (!transientKey) { + if (!transient_key) { + if (!signing_key_id.empty()) { + transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id, session_protocol_version); + if (!transient_key) { creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting"); return; } } else { - transientKey = std::make_shared<::mlspp::SignaturePrivateKey>( + transient_key = std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::generate(ciphersuite)); } } - selfSigPrivateKey_ = transientKey; - - auto selfCredential = create_user_credential(selfUserId, protocolVersion_); + signature_private_key = transient_key; - selfHPKEPrivateKey_ = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); - - selfLeafNode_ = std::make_unique<::mlspp::LeafNode>(ciphersuite, selfHPKEPrivateKey_->public_key, selfSigPrivateKey_->public_key, std::move(selfCredential), - leaf_node_capabilities_for_protocol_version(protocolVersion_), ::mlspp::Lifetime::create_default(), - leaf_node_extensions_for_protocol_version(protocolVersion_), *selfSigPrivateKey_); + auto self_credential = create_user_credential(self_user_id, session_protocol_version); + hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + self_leaf_node = std::make_unique<::mlspp::LeafNode>( + ciphersuite, hpke_private_key->public_key, signature_private_key->public_key, std::move(self_credential), + leaf_node_capabilities_for_protocol_version(session_protocol_version), ::mlspp::Lifetime::create_default(), + leaf_node_extensions_for_protocol_version(session_protocol_version), *signature_private_key + ); creator.log(dpp::ll_debug, "Created MLS leaf node"); } @@ -553,22 +539,14 @@ catch (const std::exception& e) { void session::reset_join_key_package() noexcept try { - if (!selfLeafNode_) { + if (!self_leaf_node) { creator.log(dpp::ll_warning, "Cannot initialize join key package without a leaf node"); return; } - auto ciphersuite = ciphersuite_for_protocol_version(protocolVersion_); - - joinInitPrivateKey_ = - std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); - - joinKeyPackage_ = - std::make_unique<::mlspp::KeyPackage>(ciphersuite, - joinInitPrivateKey_->public_key, - *selfLeafNode_, - leaf_node_extensions_for_protocol_version(protocolVersion_), - *selfSigPrivateKey_); + auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version); + join_init_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + join_key_package = std::make_unique<::mlspp::KeyPackage>(ciphersuite, join_init_private_key->public_key, *self_leaf_node, leaf_node_extensions_for_protocol_version(session_protocol_version), *signature_private_key); } catch (const std::exception& e) { creator.log(dpp::ll_warning, "Failed to initialize join key package: " + std::string(e.what())); @@ -577,33 +555,32 @@ catch (const std::exception& e) { void session::create_pending_group() noexcept try { - if (groupId_.empty()) { + if (session_group_id.empty()) { creator.log(dpp::ll_warning, "Cannot create MLS group without a group ID"); return; } - if (!externalSender_) { - creator.log(dpp::ll_warning, "Cannot create MLS group without ExternalSender"); + if (!mls_external_sender) { + creator.log(dpp::ll_debug, "Cannot create MLS group without external sender"); return; } - if (!selfLeafNode_) { + if (!self_leaf_node) { creator.log(dpp::ll_warning, "Cannot create MLS group without self leaf node"); return; } creator.log(dpp::ll_debug, "Creating a pending MLS group"); - auto ciphersuite = ciphersuite_for_protocol_version(protocolVersion_); - - pendingGroupState_ = std::make_unique<::mlspp::State>( - groupId_, + auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version); + pending_group_state = std::make_unique<::mlspp::State>( + session_group_id, ciphersuite, - *selfHPKEPrivateKey_, - *selfSigPrivateKey_, - *selfLeafNode_, - group_extensions_for_protocol_version(protocolVersion_, *externalSender_)); - + *hpke_private_key, + *signature_private_key, + *self_leaf_node, + group_extensions_for_protocol_version(session_protocol_version, *mls_external_sender) + ); creator.log(dpp::ll_debug, "Created a pending MLS group"); } catch (const std::exception& e) { @@ -618,12 +595,12 @@ try { // so every time the client asks for a key package we create a new one reset_join_key_package(); - if (!joinKeyPackage_) { + if (!join_key_package) { creator.log(dpp::ll_warning, "Cannot marshal an uninitialized key package"); return {}; } - return ::mlspp::tls::marshal(*joinKeyPackage_); + return ::mlspp::tls::marshal(*join_key_package); } catch (const std::exception& e) { creator.log(dpp::ll_warning, "Failed to marshal join key package: " + std::string(e.what())); @@ -631,41 +608,39 @@ catch (const std::exception& e) { return {}; } -std::unique_ptr session::get_key_ratchet(std::string const& userId) const noexcept +std::unique_ptr session::get_key_ratchet(std::string const& user_id) const noexcept { - if (!currentState_) { + if (!current_state) { creator.log(dpp::ll_warning, "Cannot get key ratchet without an established MLS group"); return nullptr; } // change the string user ID to a little endian 64 bit user ID - auto u64userId = strtoull(userId.c_str(), nullptr, 10); - auto userIdBytes = ::mlspp::bytes_ns::bytes(sizeof(u64userId)); - memcpy(userIdBytes.data(), &u64userId, sizeof(u64userId)); + // TODO: Make this use dpp::snowflake + auto u64_user_id = strtoull(user_id.c_str(), nullptr, 10); + auto user_id_bytes = ::mlspp::bytes_ns::bytes(sizeof(u64_user_id)); + memcpy(user_id_bytes.data(), &u64_user_id, sizeof(u64_user_id)); // generate the base secret for the hash ratchet - auto baseSecret = - currentState_->do_export(session::USER_MEDIA_KEY_BASE_LABEL, userIdBytes, AES_GCM_128_KEY_BYTES); + auto secret = current_state->do_export(session::USER_MEDIA_KEY_BASE_LABEL, user_id_bytes, AES_GCM_128_KEY_BYTES); - // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key + // this assumes the MLS ciphersuite produces an AES_GCM_128_KEY_BYTES sized key // would need to be updated to a different ciphersuite if there's a future mismatch - return std::make_unique(creator, currentState_->cipher_suite(), std::move(baseSecret)); + return std::make_unique(creator, current_state->cipher_suite(), std::move(secret)); } -void session::get_pairwise_fingerprint(uint16_t version, - std::string const& userId, - pairwise_fingerprint_callback callback) const noexcept +void session::get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept try { - if (!currentState_ || !selfSigPrivateKey_) { + if (!current_state || !signature_private_key) { throw std::invalid_argument("No established MLS group"); } - uint64_t u64RemoteUserId = strtoull(userId.c_str(), nullptr, 10); - uint64_t u64SelfUserId = strtoull(selfUserId_.c_str(), nullptr, 10); + uint64_t remote_user_id = strtoull(user_id.c_str(), nullptr, 10); + uint64_t self_user_id = strtoull(bot_user_id.c_str(), nullptr, 10); - auto it = roster_.find(u64RemoteUserId); - if (it == roster_.end()) { - throw std::invalid_argument("Unknown user ID: " + userId); + auto it = roster.find(remote_user_id); + if (it == roster.end()) { + throw std::invalid_argument("Unknown user ID: " + user_id); } ::mlspp::tls::ostream toHash1; @@ -673,11 +648,11 @@ try { toHash1 << version; toHash1.write_raw(it->second); - toHash1 << u64RemoteUserId; + toHash1 << remote_user_id; toHash2 << version; - toHash2.write_raw(selfSigPrivateKey_->public_key.data); - toHash2 << u64SelfUserId; + toHash2.write_raw(signature_private_key->public_key.data); + toHash2 << self_user_id; std::vector> keyData = { toHash1.bytes(), @@ -725,8 +700,7 @@ try { if (ret == 1) { callback(out); - } - else { + } else { callback({}); } }).detach(); @@ -738,20 +712,14 @@ catch (const std::exception& e) { void session::clear_pending_state() { - pendingGroupState_.reset(); - pendingGroupCommit_.reset(); - - joinInitPrivateKey_.reset(); - joinKeyPackage_.reset(); - - selfHPKEPrivateKey_.reset(); - - selfLeafNode_.reset(); - - stateWithProposals_.reset(); - proposalQueue_.clear(); + pending_group_state.reset(); + pending_group_commit.reset(); + join_init_private_key.reset(); + join_key_package.reset(); + hpke_private_key.reset(); + self_leaf_node.reset(); + state_with_proposals.reset(); + proposal_queue.clear(); } -} // namespace dpp::dave::mls - - +} diff --git a/src/dpp/dave/session.h b/src/dpp/dave/session.h index da488bb895..36e5dabb17 100755 --- a/src/dpp/dave/session.h +++ b/src/dpp/dave/session.h @@ -47,7 +47,7 @@ namespace mlspp { struct MLSMessage; struct SignaturePrivateKey; class State; -} // namespace mlspp +} namespace dpp { class cluster; @@ -70,10 +70,10 @@ class session { // NOLINT /** * @brief Constructor * @param context key pair context (set to nullptr to use a transient key pair) - * @param authSessionId auth session id (set to empty string to use a transient key pair) + * @param auth_session_id auth session id (set to empty string to use a transient key pair) * @param callback callback for failure */ - session(dpp::cluster& cluster, key_pair_context_type context, const std::string& authSessionId, mls_failure_callback callback) noexcept; + session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept; /** * @brief Destructor @@ -86,11 +86,11 @@ class session { // NOLINT * whilst preserving other state set by the constructor. * * @param version protocol version - * @param groupId group id (channel id) - * @param selfUserId bot's user id - * @param transientKey transient private key + * @param group_id group id (channel id) + * @param self_user_id bot's user id + * @param transient_key transient private key */ - void init(protocol_version version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; /** * @brief Reset the session to defaults @@ -108,7 +108,7 @@ class session { // NOLINT * @return protocol version */ [[nodiscard]] protocol_version get_protocol_version() const noexcept { - return protocolVersion_; + return session_protocol_version; } /** @@ -119,17 +119,17 @@ class session { // NOLINT /** * @brief Set external sender from external sender opcode - * @param externalSenderPackage external sender package + * @param external_sender_package external sender package */ - void set_external_sender(std::vector const& externalSenderPackage) noexcept; + void set_external_sender(std::vector const& external_sender_package) noexcept; /** * @brief Process proposals from proposals opcode * @param proposals proposals blob from websocket - * @param recognizedUserIDs list of recognised user IDs + * @param recognised_user_ids list of recognised user IDs * @return optional vector to send in reply as commit welcome */ - std::optional> process_proposals(std::vector proposals, std::set const& recognizedUserIDs) noexcept; + std::optional> process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept; /** * @brief Process commit message from discord websocket @@ -141,10 +141,10 @@ class session { // NOLINT /** * @brief Process welcome blob * @param welcome welcome blob from discord - * @param recognizedUserIDs Recognised user ID list + * @param recognised_user_ids Recognised user ID list * @return roster list of people in the vc */ - std::optional process_welcome(std::vector welcome, std::set const& recognizedUserIDs) noexcept; + std::optional process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept; /** * @brief Get the bot user's key package for sending to websocket @@ -154,10 +154,10 @@ class session { // NOLINT /** * @brief Get key ratchet for a user (including the bot) - * @param userId User id to get ratchet for + * @param user_id User id to get ratchet for * @return The user's key ratchet for use in an encryptor or decryptor */ - [[nodiscard]] std::unique_ptr get_key_ratchet(std::string const& userId) const noexcept; + [[nodiscard]] std::unique_ptr get_key_ratchet(std::string const& user_id) const noexcept; /** * @brief callback for completion of pairwise fingerprint @@ -169,18 +169,18 @@ class session { // NOLINT * @warning This uses SCRYPT and is extremely resource intensive. It will spawn a thread * which will call your callback on completion. * @param version Should always be 0x00 - * @param userId User ID to get fingerprint for + * @param user_id User ID to get fingerprint for * @param callback Callback for completion */ - void get_pairwise_fingerprint(uint16_t version, std::string const& userId, pairwise_fingerprint_callback callback) const noexcept; + void get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept; private: /** * @brief Initialise leaf node - * @param selfUserId Bot user id - * @param transientKey Transient key + * @param self_user_id Bot user id + * @param transient_key Transient key */ - void init_leaf_node(std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; /** * @brief Reset join key @@ -201,27 +201,27 @@ class session { // NOLINT /** * @brief Check if user ID is valid * @param cred MLS credential - * @param recognizedUserIDs list of recognised user IDs + * @param recognised_user_ids list of recognised user IDs * @return */ - [[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const; + [[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const; /** * @brief Validate proposals message * @param message authenticated content message - * @param targetState new state - * @param recognizedUserIDs recognised list of user IDs + * @param target_state new state + * @param recognised_user_ids recognised list of user IDs * @return true if validated */ - [[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& targetState, std::set const& recognizedUserIDs) const; + [[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set const& recognised_user_ids) const; /** * @brief Verify that welcome state is valid * @param state current state - * @param recognizedUserIDs list of recognised user IDs + * @param recognised_user_ids list of recognised user IDs * @return */ - [[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set const& recognizedUserIDs) const; + [[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const; /** * @brief Check if can process a commit now @@ -247,33 +247,100 @@ class session { // NOLINT */ inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; - protocol_version protocolVersion_; - std::vector groupId_; - std::string signingKeyId_; - std::string selfUserId_; - key_pair_context_type keyPairContext_{nullptr}; + /** + * @brief DAVE protocol version for the session + */ + protocol_version session_protocol_version; + + /** + * @brief Session group ID (voice channel id) + */ + std::vector session_group_id; + + /** + * @brief Signing key id + */ + std::string signing_key_id; + + /** + * @brief The bot's user snowflake ID + */ + std::string bot_user_id; + + /** + * @brief The bot's key pair context + */ + key_pair_context_type key_pair_context{nullptr}; + + /** + * @brief Our leaf node in the ratchet tree + */ + std::unique_ptr<::mlspp::LeafNode> self_leaf_node; + + /** + * @brief The bots signature private key + */ + std::shared_ptr<::mlspp::SignaturePrivateKey> signature_private_key; + + /** + * @brief HPKE private key + */ + std::unique_ptr<::mlspp::HPKEPrivateKey> hpke_private_key; + + /** + * @brief Private key for join initialisation + */ + std::unique_ptr<::mlspp::HPKEPrivateKey> join_init_private_key; + + /** + * @brief Join key package + */ + std::unique_ptr<::mlspp::KeyPackage> join_key_package; - std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; - std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; - std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; + /** + * @brief MLS External sender (the discord voice gateway server) + */ + std::unique_ptr<::mlspp::ExternalSender> mls_external_sender; - std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; - std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; + /** + * @brief Pending MLS group state + */ + std::unique_ptr<::mlspp::State> pending_group_state; + + /** + * @brief Pending MLS group commit + */ + std::unique_ptr<::mlspp::MLSMessage> pending_group_commit; - std::unique_ptr<::mlspp::ExternalSender> externalSender_; + /** + * @brief Outbound cached group state + */ + std::unique_ptr<::mlspp::State> outbound_cached_group_state; - std::unique_ptr<::mlspp::State> pendingGroupState_; - std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; + /** + * @brief Current MLS state + */ + std::unique_ptr<::mlspp::State> current_state; - std::unique_ptr<::mlspp::State> outboundCachedGroupState_; + /** + * @brief Participant roster, all users who are in the VC with dave enabled + */ + roster_map roster; - std::unique_ptr<::mlspp::State> currentState_; - roster_map roster_; + /** + * @brief Current state containing proposals + */ + std::unique_ptr<::mlspp::State> state_with_proposals; - std::unique_ptr<::mlspp::State> stateWithProposals_; - std::list proposalQueue_; + /** + * @brief Queue of proposals to process + */ + std::list proposal_queue; - mls_failure_callback onMLSFailureCallback_{}; + /** + * @brief Function to call on failure, if any + */ + mls_failure_callback failure_callback{}; /** * @brief DPP Cluster, used for logging @@ -281,6 +348,4 @@ class session { // NOLINT dpp::cluster& creator; }; -} // namespace dpp::dave::mls - - +} diff --git a/src/dpp/dave/user_credential.cpp b/src/dpp/dave/user_credential.cpp index f5fd9a6493..49cfefcaef 100755 --- a/src/dpp/dave/user_credential.cpp +++ b/src/dpp/dave/user_credential.cpp @@ -28,12 +28,11 @@ namespace dpp::dave::mls { -::mlspp::Credential create_user_credential(const std::string& userId, protocol_version version) { +::mlspp::Credential create_user_credential(const std::string& user_id, protocol_version version) { // convert the string user ID to a big endian uint64_t - auto userID = std::stoull(userId); - auto credentialBytes = big_endian_bytes_from(userID); - - return ::mlspp::Credential::basic(credentialBytes); + auto id = std::stoull(user_id); + auto credential_bytes = big_endian_bytes_from(id); + return ::mlspp::Credential::basic(credential_bytes); } std::string user_credential_to_string(const ::mlspp::Credential& cred, protocol_version version) { @@ -42,12 +41,10 @@ std::string user_credential_to_string(const ::mlspp::Credential& cred, protocol_ } const auto& basic = cred.template get<::mlspp::BasicCredential>(); - - auto uidVal = from_big_endian_bytes(basic.identity); - - return std::to_string(uidVal); + auto uid_val = from_big_endian_bytes(basic.identity); + return std::to_string(uid_val); } -} // namespace dpp::dave::mls +} diff --git a/src/dpp/dave/user_credential.h b/src/dpp/dave/user_credential.h index 1afa2b7c3b..3fecf70bdd 100755 --- a/src/dpp/dave/user_credential.h +++ b/src/dpp/dave/user_credential.h @@ -32,11 +32,11 @@ namespace dpp::dave::mls { /** * @brief Create user credentials - * @param userId user id + * @param user_id user id * @param version protocol version * @return */ -::mlspp::Credential create_user_credential(const std::string& userId, protocol_version version); +::mlspp::Credential create_user_credential(const std::string& user_id, protocol_version version); /** * @brief Convert user credentials to string @@ -46,6 +46,5 @@ ::mlspp::Credential create_user_credential(const std::string& userId, protocol_v */ std::string user_credential_to_string(const ::mlspp::Credential& cred, protocol_version version); -} // namespace dpp::dave::mls - +} diff --git a/src/dpp/dave/util.cpp b/src/dpp/dave/util.cpp index 1abdcc2faf..d8e8929b72 100755 --- a/src/dpp/dave/util.cpp +++ b/src/dpp/dave/util.cpp @@ -49,6 +49,4 @@ uint64_t from_big_endian_bytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept return val; } -} // namespace dpp::dave::mls - - +} diff --git a/src/dpp/dave/util.h b/src/dpp/dave/util.h index f351f45feb..3ecdf4dae7 100755 --- a/src/dpp/dave/util.h +++ b/src/dpp/dave/util.h @@ -43,6 +43,4 @@ ::mlspp::bytes_ns::bytes big_endian_bytes_from(uint64_t value) noexcept; */ uint64_t from_big_endian_bytes(const ::mlspp::bytes_ns::bytes& value) noexcept; -} // namespace dpp::dave::mls - - +} diff --git a/src/dpp/dave/version.cpp b/src/dpp/dave/version.cpp index 34a5ac8269..8fe703a41b 100755 --- a/src/dpp/dave/version.cpp +++ b/src/dpp/dave/version.cpp @@ -28,10 +28,8 @@ namespace dpp::dave { constexpr protocol_version current_dave_protocol_version = 1; -protocol_version max_protocol_version() -{ +protocol_version max_protocol_version() { return current_dave_protocol_version; } -} // namespace dpp::dave - +} diff --git a/src/dpp/dave/version.h b/src/dpp/dave/version.h index 6c554c2ef0..972940a769 100755 --- a/src/dpp/dave/version.h +++ b/src/dpp/dave/version.h @@ -44,5 +44,4 @@ using signature_version = uint8_t; */ protocol_version max_protocol_version(); -} // namespace dpp::dave - +} diff --git a/src/dpp/dns.cpp b/src/dpp/dns.cpp index e87fbd45e8..257b9b2315 100644 --- a/src/dpp/dns.cpp +++ b/src/dpp/dns.cpp @@ -39,75 +39,102 @@ namespace dpp /* Cache container */ dns_cache_t dns_cache; - const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +/** +* @brief Get address length +* @return address length +*/ +int dns_cache_entry::size() const { + return static_cast(addr.ai_addrlen); +} + +const address_t dns_cache_entry::get_connecting_address(uint16_t port) const { + return address_t(resolved_addr, port); +} + +socket dns_cache_entry::make_connecting_socket() const { + return ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); +} + +const dns_cache_entry* resolve_hostname(const std::string& hostname, const std::string& port) +{ + addrinfo hints, *addrs; + dns_cache_t::const_iterator iter; + time_t now = time(nullptr); + int error; + bool exists = false; + + /* Thread safety scope */ { - addrinfo hints, *addrs; - dns_cache_t::const_iterator iter; - time_t now = time(nullptr); - int error; - bool exists = false; - - /* Thread safety scope */ - { - /* Check cache for existing DNS record. This can use a shared lock. */ - std::shared_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { - exists = true; - if (now < iter->second->expire_timestamp) { - /* there is a cached entry that is still valid, return it */ - return iter->second; - } + /* Check cache for existing DNS record. This can use a shared lock. */ + std::shared_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { + exists = true; + if (now < iter->second->expire_timestamp) { + /* there is a cached entry that is still valid, return it */ + return iter->second; } } - if (exists) { - /* there is a cached entry, but it has expired, - * delete and free it, and fall through to a new lookup. - * We must use a unique lock here as we modify the cache. - */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - iter = dns_cache.find(hostname); - if (iter != dns_cache.end()) { /* re-validate iter */ - delete iter->second; - dns_cache.erase(iter); - } - } - - /* The hints indicate what sort of DNS results we are interested in. - * To change this to support IPv6, one change we need to make here is - * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + } + if (exists) { + /* there is a cached entry, but it has expired, + * delete and free it, and fall through to a new lookup. + * We must use a unique lock here as we modify the cache. */ - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { - /** - * The -20 makes sure the error codes dont conflict with codes given in the rest of the list - * Because C libraries love to use -1 and below directly as conflicting error codes. - */ - throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + std::unique_lock dns_cache_lock(dns_cache_mutex); + iter = dns_cache.find(hostname); + if (iter != dns_cache.end()) { /* re-validate iter */ + delete iter->second; + dns_cache.erase(iter); } + } + + /* The hints indicate what sort of DNS results we are interested in. + * To change this to support IPv6, one change we need to make here is + * to change AF_INET to AF_UNSPEC. Everything else should just work fine. + */ + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_INET; // IPv6 explicitly unsupported by Discord + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; - /* Thread safety scope */ - { - /* Update cache, requires unique lock */ - std::unique_lock dns_cache_lock(dns_cache_mutex); - dns_cache_entry* cache_entry = new dns_cache_entry(); - - /* The sockaddr struct contains a bunch of raw pointers that we - * must copy to the cache, before freeing it with freeaddrinfo(). - * Icky icky C APIs. - */ - memcpy(&cache_entry->addr, addrs, sizeof(addrinfo)); - memcpy(&cache_entry->ai_addr, addrs->ai_addr, addrs->ai_addrlen); - cache_entry->expire_timestamp = now + one_hour; - dns_cache[hostname] = cache_entry; - - /* Now we're done with this horrible struct, free it and return */ - freeaddrinfo(addrs); - return cache_entry; + if ((error = getaddrinfo(hostname.c_str(), port.c_str(), &hints, &addrs))) { + /** + * The -20 makes sure the error codes dont conflict with codes given in the rest of the list + * Because C libraries love to use -1 and below directly as conflicting error codes. + */ + throw dpp::connection_exception((exception_error_code)(error - 20), std::string("getaddrinfo error: ") + gai_strerror(error)); + } + + /* Thread safety scope */ + { + /* Update cache, requires unique lock */ + std::unique_lock dns_cache_lock(dns_cache_mutex); + dns_cache_entry* cache_entry = new dns_cache_entry(); + + for (struct addrinfo* rp = addrs; rp != nullptr; rp = rp->ai_next) { + /* Discord only support ipv4, so iterate over any ipv6 results */ + if (rp->ai_family != AF_INET) { + continue; + } + /* Save address family and other metadata for later */ + memcpy(&cache_entry->addr, rp, sizeof(addrinfo)); + char buffer[128]; + sockaddr_in in{}; + std::memcpy(&in, rp->ai_addr, sizeof(sockaddr_in)); + if (inet_ntop(rp->ai_family, &in.sin_addr, buffer, sizeof(buffer))) { + cache_entry->resolved_addr = buffer; + } + break; } + + cache_entry->expire_timestamp = now + one_hour; + dns_cache[hostname] = cache_entry; + + /* Now we're done with this horrible struct, free it and return */ + freeaddrinfo(addrs); + return cache_entry; } -} // namespace dpp +} + +} diff --git a/src/dpp/socket.cpp b/src/dpp/socket.cpp new file mode 100644 index 0000000000..9313405b75 --- /dev/null +++ b/src/dpp/socket.cpp @@ -0,0 +1,63 @@ +/************************************************************************************ + * + * D++, A Lightweight C++ library for Discord + * + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2021 Craig Edwards and D++ contributors + * (https://github.com/brainboxdotcc/DPP/graphs/contributors) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ************************************************************************************/ + +#include +#include +#include + +namespace dpp { + +address_t::address_t(const std::string_view ip, uint16_t port) { + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_port = htons(port); + address.sin_addr.s_addr = inet_addr(ip.data()); + std::memcpy(&socket_addr, &address, sizeof(address)); +} + +sockaddr* address_t::get_socket_address() { + return &socket_addr; +} + +size_t address_t::size() { + return sizeof(sockaddr_in); +} + +uint16_t address_t::get_port(socket fd) { + socklen_t len = size(); + if (getsockname(fd, &socket_addr, &len) > -1) { + sockaddr_in sin{}; + std::memcpy(&sin, &socket_addr, sizeof(sockaddr_in)); + return ntohs(sin.sin_port); + } + return 0; +} + +raii_socket::raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) { +} + +raii_socket::~raii_socket() { + close_socket(fd); +} + + +} \ No newline at end of file diff --git a/src/dpp/sslclient.cpp b/src/dpp/sslclient.cpp index d1f15ac20a..54698eb0e4 100644 --- a/src/dpp/sslclient.cpp +++ b/src/dpp/sslclient.cpp @@ -70,6 +70,7 @@ #include #include #include +#include #include /* Maximum allowed time in milliseconds for socket read/write timeouts and connect() */ @@ -318,10 +319,11 @@ void ssl_client::connect() /* Resolve hostname to IP */ int err = 0; const dns_cache_entry* addr = resolve_hostname(hostname, port); - sfd = ::socket(addr->addr.ai_family, addr->addr.ai_socktype, addr->addr.ai_protocol); + sfd = addr->make_connecting_socket(); + address_t destination = addr->get_connecting_address(from_string(this->port, std::dec)); if (sfd == ERROR_STATUS) { err = errno; - } else if (connect_with_timeout(sfd, (sockaddr*)&addr->ai_addr, (int)addr->addr.ai_addrlen, SOCKET_OP_TIMEOUT) != 0) { + } else if (connect_with_timeout(sfd, destination.get_socket_address(), destination.size(), SOCKET_OP_TIMEOUT) != 0) { close_socket(sfd); sfd = ERROR_STATUS; } diff --git a/src/dpp/voice/enabled/discover_ip.cpp b/src/dpp/voice/enabled/discover_ip.cpp index ea4f310308..7061c20f08 100644 --- a/src/dpp/voice/enabled/discover_ip.cpp +++ b/src/dpp/voice/enabled/discover_ip.cpp @@ -38,13 +38,11 @@ namespace dpp { -/** - * https://discord.com/developers/docs/topics/voice-connections#ip-discovery - */ - /** * @brief Represents an IP discovery packet sent to Discord or received * from Discord. + * + * https://discord.com/developers/docs/topics/voice-connections#ip-discovery */ struct ip_discovery_packet { @@ -132,20 +130,6 @@ struct ip_discovery_packet { } }; -/** - * @brief Allocates a dpp::socket, closing it on destruction - */ -struct raii_socket { - dpp::socket fd; - - raii_socket() : fd(::socket(AF_INET, SOCK_DGRAM, 0)) { }; - raii_socket(raii_socket&) = delete; - raii_socket(raii_socket&&) = delete; - raii_socket operator=(raii_socket&) = delete; - raii_socket operator=(raii_socket&&) = delete; - ~raii_socket() { close_socket(fd); }; -}; - constexpr int discovery_timeout = 1000; std::string discord_voice_client::discover_ip() { @@ -158,19 +142,13 @@ std::string discord_voice_client::discover_ip() { ip_discovery_packet discovery(this->ssrc); if (socket.fd >= 0) { - sockaddr_in servaddr{}; - servaddr.sin_family = AF_INET; - servaddr.sin_addr.s_addr = htonl(INADDR_ANY); - servaddr.sin_port = htons(0); - if (bind(socket.fd, reinterpret_cast(&servaddr), sizeof(servaddr)) < 0) { + address_t bind_any; + if (bind(socket.fd, bind_any.get_socket_address(), bind_any.size()) < 0) { log(ll_warning, "Could not bind socket for IP discovery"); return ""; } - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons(this->port); - servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str()); - if (::connect(socket.fd, reinterpret_cast(&servaddr), sizeof(sockaddr_in)) < 0) { + address_t bind_port(this->ip, this->port); + if (::connect(socket.fd, bind_port.get_socket_address(), bind_port.size()) < 0) { log(ll_warning, "Could not connect socket for IP discovery"); return ""; } diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index 7b85cb7820..07423d3377 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -437,6 +437,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->ip = d["ip"].get(); this->port = d["port"].get(); this->ssrc = d["ssrc"].get(); + destination = address_t(this->ip, this->port); + // Modes for (auto & m : d["modes"]) { this->modes.push_back(m.get()); @@ -446,13 +448,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod dpp::socket newfd = 0; if ((newfd = ::socket(AF_INET, SOCK_DGRAM, 0)) >= 0) { - sockaddr_in servaddr{}; - memset(&servaddr, 0, sizeof(sockaddr_in)); - servaddr.sin_family = AF_INET; - servaddr.sin_addr.s_addr = htonl(INADDR_ANY); - servaddr.sin_port = htons(0); - - if (bind(newfd, reinterpret_cast(&servaddr), sizeof(servaddr)) < 0) { + address_t bind_any; + if (bind(newfd, bind_any.get_socket_address(), bind_any.size()) < 0) { throw dpp::connection_exception(err_bind_failure, "Can't bind() client UDP socket"); } @@ -467,13 +464,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->custom_writeable_ready = [this] { write_ready(); }; this->custom_readable_ready = [this] { read_ready(); }; - int bound_port = 0; - sockaddr_in sin{}; - socklen_t len = sizeof(sin); - if (getsockname(this->fd, reinterpret_cast(&sin), &len) > -1) { - bound_port = ntohs(sin.sin_port); - } - + int bound_port = address_t().get_port(this->fd); this->write(json({ { "op", voice_opcode_connection_select_protocol }, { "d", { diff --git a/src/dpp/voice/enabled/read_write.cpp b/src/dpp/voice/enabled/read_write.cpp index 1702e6cf13..c24200b49b 100644 --- a/src/dpp/voice/enabled/read_write.cpp +++ b/src/dpp/voice/enabled/read_write.cpp @@ -43,26 +43,23 @@ dpp::socket discord_voice_client::want_read() { void discord_voice_client::send(const char* packet, size_t len, uint64_t duration) { - std::lock_guard lock(this->stream_mutex); voice_out_packet frame; - frame.packet = std::string(packet, len); + frame.packet.assign(packet, packet + len); frame.duration = duration; - outbuf.emplace_back(frame); + { + std::lock_guard lock(this->stream_mutex); + outbuf.emplace_back(frame); + } } int discord_voice_client::udp_send(const char* data, size_t length) { - sockaddr_in servaddr{}; - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons(this->port); - servaddr.sin_addr.s_addr = inet_addr(this->ip.c_str()); return static_cast(sendto( this->fd, data, static_cast(length), 0, - reinterpret_cast(&servaddr), - static_cast(sizeof(sockaddr_in)) + destination.get_socket_address(), + destination.size() )); }