diff --git a/src/dpp/dave/array_view.h b/src/dpp/dave/array_view.h index 2c2198cfa0..883e3e27c3 100755 --- a/src/dpp/dave/array_view.h +++ b/src/dpp/dave/array_view.h @@ -32,34 +32,34 @@ namespace dpp::dave { template class array_view { public: - array_view() = default; - array_view(T* data, size_t size) - : data_(data) - , size_(size) - { - } + array_view() = default; + array_view(T* data, size_t size) + : data_(data) + , size_(size) + { + } - size_t size() const { return size_; } - T* data() const { return data_; } + size_t size() const { return size_; } + T* data() const { return data_; } - T* begin() const { return data_; } - T* end() const { return data_ + size_; } + T* begin() const { return data_; } + T* end() const { return data_ + size_; } private: - T* data_ = nullptr; - size_t size_ = 0; + T* data_ = nullptr; + size_t size_ = 0; }; template inline array_view make_array_view(T* data, size_t size) { - return array_view(data, size); + return array_view(data, size); } template inline array_view make_array_view(std::vector& data) { - return array_view(data.data(), data.size()); + return array_view(data.data(), data.size()); } } // namespace dpp::dave diff --git a/src/dpp/dave/cipher_interface.cpp b/src/dpp/dave/cipher_interface.cpp index 2444cf4483..38be00a61d 100755 --- a/src/dpp/dave/cipher_interface.cpp +++ b/src/dpp/dave/cipher_interface.cpp @@ -29,8 +29,8 @@ namespace dpp::dave { std::unique_ptr create_cipher(const encryption_key& encryptionKey) { - auto cipher = std::make_unique(encryptionKey); - return cipher->is_valid() ? std::move(cipher) : nullptr; + auto cipher = std::make_unique(encryptionKey); + return cipher->is_valid() ? std::move(cipher) : nullptr; } } // namespace dpp::dave diff --git a/src/dpp/dave/clock.h b/src/dpp/dave/clock.h index 13207b9bde..743fd79ef4 100755 --- a/src/dpp/dave/clock.h +++ b/src/dpp/dave/clock.h @@ -30,17 +30,17 @@ namespace dpp::dave { class clock_interface { public: - using base_clock = std::chrono::steady_clock; - using time_point = base_clock::time_point; - using clock_duration = base_clock::duration; + using base_clock = std::chrono::steady_clock; + using time_point = base_clock::time_point; + using clock_duration = base_clock::duration; - virtual ~clock_interface() = default; - virtual time_point now() const = 0; + virtual ~clock_interface() = default; + virtual time_point now() const = 0; }; class clock : public clock_interface { public: - time_point now() const override { return base_clock::now(); } + time_point now() const override { return base_clock::now(); } }; } // namespace dpp::dave diff --git a/src/dpp/dave/codec_utils.cpp b/src/dpp/dave/codec_utils.cpp index f886b5f9fb..2fd01233d5 100755 --- a/src/dpp/dave/codec_utils.cpp +++ b/src/dpp/dave/codec_utils.cpp @@ -36,54 +36,54 @@ namespace dpp::dave::codec_utils { unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload, const uint64_t sizeRemaining) { - // the payload starts with three exponential golomb encoded values - // (first_mb_in_slice, sps_id, pps_id) - // the depacketizer needs the pps_id unencrypted - // and the payload has RBSP encoding that we need to work around - - constexpr uint8_t kEmulationPreventionByte = 0x03; - - uint64_t payloadBitIndex = 0; - auto zeroBitCount = 0; - auto parsedExpGolombValues = 0; - - while (payloadBitIndex < sizeRemaining * 8 && parsedExpGolombValues < 3) { - auto bitIndex = payloadBitIndex % 8; - auto byteIndex = payloadBitIndex / 8; - auto payloadByte = payload[byteIndex]; - - // if we're starting a new byte - // check if this is an emulation prevention byte - // which we skip over - if (bitIndex == 0) { - if (byteIndex >= 2 && payloadByte == kEmulationPreventionByte && - payload[byteIndex - 1] == 0 && payload[byteIndex - 2] == 0) { - payloadBitIndex += 8; - continue; - } - } - - if ((payloadByte & (1 << (7 - bitIndex))) == 0) { - // still in the run of leading zero bits - ++zeroBitCount; - ++payloadBitIndex; - - if (zeroBitCount >= 32) { - assert(false && "Unexpectedly large exponential golomb encoded value"); - return 0; - } - } - else { - // we hit a one - // skip forward the number of bits dictated by the leading number of zeroes - parsedExpGolombValues += 1; - payloadBitIndex += 1 + zeroBitCount; - zeroBitCount = 0; - } - } - - // return the number of bytes that covers the last exp golomb encoded value - return (payloadBitIndex / 8) + 1; + // the payload starts with three exponential golomb encoded values + // (first_mb_in_slice, sps_id, pps_id) + // the depacketizer needs the pps_id unencrypted + // and the payload has RBSP encoding that we need to work around + + constexpr uint8_t kEmulationPreventionByte = 0x03; + + uint64_t payloadBitIndex = 0; + auto zeroBitCount = 0; + auto parsedExpGolombValues = 0; + + while (payloadBitIndex < sizeRemaining * 8 && parsedExpGolombValues < 3) { + auto bitIndex = payloadBitIndex % 8; + auto byteIndex = payloadBitIndex / 8; + auto payloadByte = payload[byteIndex]; + + // if we're starting a new byte + // check if this is an emulation prevention byte + // which we skip over + if (bitIndex == 0) { + if (byteIndex >= 2 && payloadByte == kEmulationPreventionByte && + payload[byteIndex - 1] == 0 && payload[byteIndex - 2] == 0) { + payloadBitIndex += 8; + continue; + } + } + + if ((payloadByte & (1 << (7 - bitIndex))) == 0) { + // still in the run of leading zero bits + ++zeroBitCount; + ++payloadBitIndex; + + if (zeroBitCount >= 32) { + assert(false && "Unexpectedly large exponential golomb encoded value"); + return 0; + } + } + else { + // we hit a one + // skip forward the number of bits dictated by the leading number of zeroes + parsedExpGolombValues += 1; + payloadBitIndex += 1 + zeroBitCount; + zeroBitCount = 0; + } + } + + // return the number of bytes that covers the last exp golomb encoded value + return (payloadBitIndex / 8) + 1; } const uint8_t kH26XNaluLongStartCode[] = {0, 0, 0, 1}; @@ -92,364 +92,364 @@ constexpr uint8_t kH26XNaluShortStartSequenceSize = 3; using IndexStartCodeSizePair = std::pair; std::optional FindNextH26XNaluIndex(const uint8_t* buffer, - const size_t bufferSize, - const size_t searchStartIndex = 0) + const size_t bufferSize, + const size_t searchStartIndex = 0) { - constexpr uint8_t kH26XStartCodeHighestPossibleValue = 1; - constexpr uint8_t kH26XStartCodeEndByteValue = 1; - constexpr uint8_t kH26XStartCodeLeadingBytesValue = 0; - - if (bufferSize < kH26XNaluShortStartSequenceSize) { - return std::nullopt; - } - - // look for NAL unit 3 or 4 byte start code - for (size_t i = searchStartIndex; i < bufferSize - kH26XNaluShortStartSequenceSize;) { - if (buffer[i + 2] > kH26XStartCodeHighestPossibleValue) { - // third byte is not 0 or 1, can't be a start code - i += kH26XNaluShortStartSequenceSize; - } - else if (buffer[i + 2] == kH26XStartCodeEndByteValue) { - // third byte matches the start code end byte, might be a start code sequence - if (buffer[i + 1] == kH26XStartCodeLeadingBytesValue && - buffer[i] == kH26XStartCodeLeadingBytesValue) { - // confirmed start sequence {0, 0, 1} - auto nalUnitStartIndex = i + kH26XNaluShortStartSequenceSize; - - if (i >= 1 && buffer[i - 1] == kH26XStartCodeLeadingBytesValue) { - // 4 byte start code - return std::optional({nalUnitStartIndex, 4}); - } - else { - // 3 byte start code - return std::optional({nalUnitStartIndex, 3}); - } - } - - i += kH26XNaluShortStartSequenceSize; - } - else { - // third byte is 0, might be a four byte start code - ++i; - } - } - - return std::nullopt; + constexpr uint8_t kH26XStartCodeHighestPossibleValue = 1; + constexpr uint8_t kH26XStartCodeEndByteValue = 1; + constexpr uint8_t kH26XStartCodeLeadingBytesValue = 0; + + if (bufferSize < kH26XNaluShortStartSequenceSize) { + return std::nullopt; + } + + // look for NAL unit 3 or 4 byte start code + for (size_t i = searchStartIndex; i < bufferSize - kH26XNaluShortStartSequenceSize;) { + if (buffer[i + 2] > kH26XStartCodeHighestPossibleValue) { + // third byte is not 0 or 1, can't be a start code + i += kH26XNaluShortStartSequenceSize; + } + else if (buffer[i + 2] == kH26XStartCodeEndByteValue) { + // third byte matches the start code end byte, might be a start code sequence + if (buffer[i + 1] == kH26XStartCodeLeadingBytesValue && + buffer[i] == kH26XStartCodeLeadingBytesValue) { + // confirmed start sequence {0, 0, 1} + auto nalUnitStartIndex = i + kH26XNaluShortStartSequenceSize; + + if (i >= 1 && buffer[i - 1] == kH26XStartCodeLeadingBytesValue) { + // 4 byte start code + return std::optional({nalUnitStartIndex, 4}); + } + else { + // 3 byte start code + return std::optional({nalUnitStartIndex, 3}); + } + } + + i += kH26XNaluShortStartSequenceSize; + } + else { + // third byte is 0, might be a four byte start code + ++i; + } + } + + return std::nullopt; } bool process_frame_opus(outbound_frame_processor& processor, array_view frame) { processor.add_encrypted_bytes(frame.data(), frame.size()); - return true; + return true; } bool process_frame_vp8(outbound_frame_processor& processor, array_view frame) { - constexpr uint8_t kVP8KeyFrameUnencryptedBytes = 10; - constexpr uint8_t kVP8DeltaFrameUnencryptedBytes = 1; - - // parse the VP8 payload header to determine if it's a key frame - // https://datatracker.ietf.org/doc/html/rfc7741#section-4.3 - - // 0 1 2 3 4 5 6 7 - // +-+-+-+-+-+-+-+-+ - // |Size0|H| VER |P| - // +-+-+-+-+-+-+-+-+ - // P is an inverse key frame flag - - // if this is a key frame the depacketizer will read 10 bytes into the payload header - // if this is a delta frame the depacketizer only needs the first byte of the payload - // header (since that's where the key frame flag is) - - size_t unencryptedHeaderBytes = 0; - if ((frame.data()[0] & 0x01) == 0) { - unencryptedHeaderBytes = kVP8KeyFrameUnencryptedBytes; - } - else { - unencryptedHeaderBytes = kVP8DeltaFrameUnencryptedBytes; - } + constexpr uint8_t kVP8KeyFrameUnencryptedBytes = 10; + constexpr uint8_t kVP8DeltaFrameUnencryptedBytes = 1; + + // parse the VP8 payload header to determine if it's a key frame + // https://datatracker.ietf.org/doc/html/rfc7741#section-4.3 + + // 0 1 2 3 4 5 6 7 + // +-+-+-+-+-+-+-+-+ + // |Size0|H| VER |P| + // +-+-+-+-+-+-+-+-+ + // P is an inverse key frame flag + + // if this is a key frame the depacketizer will read 10 bytes into the payload header + // if this is a delta frame the depacketizer only needs the first byte of the payload + // header (since that's where the key frame flag is) + + size_t unencryptedHeaderBytes = 0; + if ((frame.data()[0] & 0x01) == 0) { + unencryptedHeaderBytes = kVP8KeyFrameUnencryptedBytes; + } + else { + unencryptedHeaderBytes = kVP8DeltaFrameUnencryptedBytes; + } processor.add_unencrypted_bytes(frame.data(), unencryptedHeaderBytes); processor.add_encrypted_bytes(frame.data() + unencryptedHeaderBytes, - frame.size() - unencryptedHeaderBytes); - return true; + frame.size() - unencryptedHeaderBytes); + return true; } bool process_frame_vp9(outbound_frame_processor& processor, array_view frame) { - // payload descriptor is unencrypted in each packet - // and includes all information the depacketizer needs + // payload descriptor is unencrypted in each packet + // and includes all information the depacketizer needs processor.add_encrypted_bytes(frame.data(), frame.size()); - return true; + return true; } bool process_frame_h264(outbound_frame_processor& processor, array_view frame) { - // minimize the amount of unencrypted header data for H264 depending on the NAL unit - // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h264.cc - // src/common_video/h264/h264_common.cc - // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h264.cc - - // constexpr uint8_t kH264SBit = 0x80; - constexpr uint8_t kH264NalHeaderTypeMask = 0x1F; - constexpr uint8_t kH264NalTypeSlice = 1; - constexpr uint8_t kH264NalTypeIdr = 5; - constexpr uint8_t kH264NalUnitHeaderSize = 1; - - // this frame can be packetized as a STAP-A or a FU-A - // so we need to look at the first NAL units to determine how many bytes - // the packetizer/depacketizer will need into the payload - if (frame.size() < kH26XNaluShortStartSequenceSize + kH264NalUnitHeaderSize) { - assert(false && "H264 frame is too small to contain a NAL unit"); - DISCORD_LOG(LS_WARNING) << "H264 frame is too small to contain a NAL unit"; - return false; - } - - auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); - while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { - auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; - - auto nalType = frame.data()[nalUnitStartIndex] & kH264NalHeaderTypeMask; - - // copy the start code and then the NAL unit - - // Because WebRTC will convert them all start codes to 4-byte on the receiver side - // always write a long start code and then the NAL unit - processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); - - auto nextNaluIndexPair = - FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); - auto nextNaluStart = nextNaluIndexPair.has_value() - ? nextNaluIndexPair->first - nextNaluIndexPair->second - : frame.size(); - - if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { - // once we've hit a slice or an IDR - // we just need to cover getting to the PPS ID - auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; - auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, - frame.size() - nalUnitPayloadStart); + // minimize the amount of unencrypted header data for H264 depending on the NAL unit + // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h264.cc + // src/common_video/h264/h264_common.cc + // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h264.cc + + // constexpr uint8_t kH264SBit = 0x80; + constexpr uint8_t kH264NalHeaderTypeMask = 0x1F; + constexpr uint8_t kH264NalTypeSlice = 1; + constexpr uint8_t kH264NalTypeIdr = 5; + constexpr uint8_t kH264NalUnitHeaderSize = 1; + + // this frame can be packetized as a STAP-A or a FU-A + // so we need to look at the first NAL units to determine how many bytes + // the packetizer/depacketizer will need into the payload + if (frame.size() < kH26XNaluShortStartSequenceSize + kH264NalUnitHeaderSize) { + assert(false && "H264 frame is too small to contain a NAL unit"); + DISCORD_LOG(LS_WARNING) << "H264 frame is too small to contain a NAL unit"; + return false; + } + + auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); + while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { + auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + + auto nalType = frame.data()[nalUnitStartIndex] & kH264NalHeaderTypeMask; + + // copy the start code and then the NAL unit + + // Because WebRTC will convert them all start codes to 4-byte on the receiver side + // always write a long start code and then the NAL unit + processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + + auto nextNaluIndexPair = + FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() + ? nextNaluIndexPair->first - nextNaluIndexPair->second + : frame.size(); + + if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { + // once we've hit a slice or an IDR + // we just need to cover getting to the PPS ID + auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; + auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, + frame.size() - nalUnitPayloadStart); processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH264NalUnitHeaderSize + nalUnitPPSBytes); processor.add_encrypted_bytes( frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes, nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes); - } - else { - // copy the whole NAL unit + } + else { + // copy the whole NAL unit processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); - } + } - naluIndexPair = nextNaluIndexPair; - } + naluIndexPair = nextNaluIndexPair; + } - return true; + return true; } bool process_frame_h265(outbound_frame_processor& processor, array_view frame) { - // minimize the amount of unencrypted header data for H265 depending on the NAL unit - // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h265.cc - // src/common_video/h265/h265_common.cc - // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h265.cc - - constexpr uint8_t kH265NalHeaderTypeMask = 0x7E; - constexpr uint8_t kH265NalTypeVclCutoff = 32; - constexpr uint8_t kH265NalUnitHeaderSize = 2; - - // this frame can be packetized as a STAP-A or a FU-A - // so we need to look at the first NAL units to determine how many bytes - // the packetizer/depacketizer will need into the payload - if (frame.size() < kH26XNaluShortStartSequenceSize + kH265NalUnitHeaderSize) { - assert(false && "H265 frame is too small to contain a NAL unit"); - DISCORD_LOG(LS_WARNING) << "H265 frame is too small to contain a NAL unit"; - return false; - } - - // look for NAL unit 3 or 4 byte start code - auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); - while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { - auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; - - uint8_t nalType = (frame.data()[nalUnitStartIndex] & kH265NalHeaderTypeMask) >> 1; - - // copy the start code and then the NAL unit - - // Because WebRTC will convert them all start codes to 4-byte on the receiver side - // always write a long start code and then the NAL unit - processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); - - auto nextNaluIndexPair = - FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); - auto nextNaluStart = nextNaluIndexPair.has_value() - ? nextNaluIndexPair->first - nextNaluIndexPair->second - : frame.size(); - - if (nalType < kH265NalTypeVclCutoff) { - // found a VCL NAL, encrypt the payload only + // minimize the amount of unencrypted header data for H265 depending on the NAL unit + // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h265.cc + // src/common_video/h265/h265_common.cc + // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h265.cc + + constexpr uint8_t kH265NalHeaderTypeMask = 0x7E; + constexpr uint8_t kH265NalTypeVclCutoff = 32; + constexpr uint8_t kH265NalUnitHeaderSize = 2; + + // this frame can be packetized as a STAP-A or a FU-A + // so we need to look at the first NAL units to determine how many bytes + // the packetizer/depacketizer will need into the payload + if (frame.size() < kH26XNaluShortStartSequenceSize + kH265NalUnitHeaderSize) { + assert(false && "H265 frame is too small to contain a NAL unit"); + DISCORD_LOG(LS_WARNING) << "H265 frame is too small to contain a NAL unit"; + return false; + } + + // look for NAL unit 3 or 4 byte start code + auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); + while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { + auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + + uint8_t nalType = (frame.data()[nalUnitStartIndex] & kH265NalHeaderTypeMask) >> 1; + + // copy the start code and then the NAL unit + + // Because WebRTC will convert them all start codes to 4-byte on the receiver side + // always write a long start code and then the NAL unit + processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + + auto nextNaluIndexPair = + FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() + ? nextNaluIndexPair->first - nextNaluIndexPair->second + : frame.size(); + + if (nalType < kH265NalTypeVclCutoff) { + // found a VCL NAL, encrypt the payload only processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, - nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); - } - else { - // copy the whole NAL unit + nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); + } + else { + // copy the whole NAL unit processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); - } + } - naluIndexPair = nextNaluIndexPair; - } + naluIndexPair = nextNaluIndexPair; + } - return true; + return true; } bool process_frame_av1(outbound_frame_processor& processor, array_view frame) { - constexpr uint8_t kAv1ObuHeaderHasExtensionMask = 0b0'0000'100; - constexpr uint8_t kAv1ObuHeaderHasSizeMask = 0b0'0000'010; - constexpr uint8_t kAv1ObuHeaderTypeMask = 0b0'1111'000; - constexpr uint8_t kObuTypeTemporalDelimiter = 2; - constexpr uint8_t kObuTypeTileList = 8; - constexpr uint8_t kObuTypePadding = 15; - constexpr uint8_t kObuExtensionSizeBytes = 1; - - size_t i = 0; - while (i < frame.size()) { - // Read the OBU header. - size_t obuHeaderIndex = i; - uint8_t obuHeader = frame.data()[obuHeaderIndex]; - i += sizeof(obuHeader); - - bool obuHasExtension = obuHeader & kAv1ObuHeaderHasExtensionMask; - bool obuHasSize = obuHeader & kAv1ObuHeaderHasSizeMask; - int obuType = (obuHeader & kAv1ObuHeaderTypeMask) >> 3; - - if (obuHasExtension) { - // Skip extension byte - i += kObuExtensionSizeBytes; - } - - if (i >= frame.size()) { - // Malformed frame - assert(false && "Malformed AV1 frame: header overflows frame"); - DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: header overflows frame"; - return false; - } - - size_t obuPayloadSize = 0; - if (obuHasSize) { - // Read payload size - const uint8_t* start = frame.data() + i; - const uint8_t* ptr = start; - obuPayloadSize = ReadLeb128(ptr, frame.end()); - if (!ptr) { - // Malformed frame - assert(false && "Malformed AV1 frame: invalid LEB128 size"); - DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: invalid LEB128 size"; - return false; - } - i += ptr - start; - } - else { - // If the size is not present, the OBU extends to the end of the frame. - obuPayloadSize = frame.size() - i; - } - - const auto obuPayloadIndex = i; - - if (i + obuPayloadSize > frame.size()) { - // Malformed frame - assert(false && "Malformed AV1 frame: payload overflows frame"); - DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: payload overflows frame"; - return false; - } - - i += obuPayloadSize; - - // We only copy the OBUs that will not get dropped by the packetizer - if (obuType != kObuTypeTemporalDelimiter && obuType != kObuTypeTileList && - obuType != kObuTypePadding) { - // if this is the last OBU, we may need to flip the "has size" bit - // which allows us to append necessary protocol data to the frame - bool rewrittenWithoutSize = false; - - if (i == frame.size() && obuHasSize) { - // Flip the "has size" bit - obuHeader &= ~kAv1ObuHeaderHasSizeMask; - rewrittenWithoutSize = true; - } - - // write the OBU header unencrypted + constexpr uint8_t kAv1ObuHeaderHasExtensionMask = 0b0'0000'100; + constexpr uint8_t kAv1ObuHeaderHasSizeMask = 0b0'0000'010; + constexpr uint8_t kAv1ObuHeaderTypeMask = 0b0'1111'000; + constexpr uint8_t kObuTypeTemporalDelimiter = 2; + constexpr uint8_t kObuTypeTileList = 8; + constexpr uint8_t kObuTypePadding = 15; + constexpr uint8_t kObuExtensionSizeBytes = 1; + + size_t i = 0; + while (i < frame.size()) { + // Read the OBU header. + size_t obuHeaderIndex = i; + uint8_t obuHeader = frame.data()[obuHeaderIndex]; + i += sizeof(obuHeader); + + bool obuHasExtension = obuHeader & kAv1ObuHeaderHasExtensionMask; + bool obuHasSize = obuHeader & kAv1ObuHeaderHasSizeMask; + int obuType = (obuHeader & kAv1ObuHeaderTypeMask) >> 3; + + if (obuHasExtension) { + // Skip extension byte + i += kObuExtensionSizeBytes; + } + + if (i >= frame.size()) { + // Malformed frame + assert(false && "Malformed AV1 frame: header overflows frame"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: header overflows frame"; + return false; + } + + size_t obuPayloadSize = 0; + if (obuHasSize) { + // Read payload size + const uint8_t* start = frame.data() + i; + const uint8_t* ptr = start; + obuPayloadSize = ReadLeb128(ptr, frame.end()); + if (!ptr) { + // Malformed frame + assert(false && "Malformed AV1 frame: invalid LEB128 size"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: invalid LEB128 size"; + return false; + } + i += ptr - start; + } + else { + // If the size is not present, the OBU extends to the end of the frame. + obuPayloadSize = frame.size() - i; + } + + const auto obuPayloadIndex = i; + + if (i + obuPayloadSize > frame.size()) { + // Malformed frame + assert(false && "Malformed AV1 frame: payload overflows frame"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: payload overflows frame"; + return false; + } + + i += obuPayloadSize; + + // We only copy the OBUs that will not get dropped by the packetizer + if (obuType != kObuTypeTemporalDelimiter && obuType != kObuTypeTileList && + obuType != kObuTypePadding) { + // if this is the last OBU, we may need to flip the "has size" bit + // which allows us to append necessary protocol data to the frame + bool rewrittenWithoutSize = false; + + if (i == frame.size() && obuHasSize) { + // Flip the "has size" bit + obuHeader &= ~kAv1ObuHeaderHasSizeMask; + rewrittenWithoutSize = true; + } + + // write the OBU header unencrypted processor.add_unencrypted_bytes(&obuHeader, sizeof(obuHeader)); - if (obuHasExtension) { - // write the extension byte unencrypted - processor.add_unencrypted_bytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), - kObuExtensionSizeBytes); - } - - // write the OBU payload size unencrypted if it was present and we didn't rewrite - // without it - if (obuHasSize && !rewrittenWithoutSize) { - // The AMD AV1 encoder may pad LEB128 encoded sizes with a zero byte which the - // webrtc packetizer removes. To prevent the packetizer from changing the frame, - // we sanitize the size by re-writing it ourselves - uint8_t leb128Buffer[Leb128MaxSize]; - size_t additionalBytesToWrite = WriteLeb128(obuPayloadSize, leb128Buffer); - processor.add_unencrypted_bytes(leb128Buffer, additionalBytesToWrite); - } - - // add the OBU payload, encrypted + if (obuHasExtension) { + // write the extension byte unencrypted + processor.add_unencrypted_bytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), + kObuExtensionSizeBytes); + } + + // write the OBU payload size unencrypted if it was present and we didn't rewrite + // without it + if (obuHasSize && !rewrittenWithoutSize) { + // The AMD AV1 encoder may pad LEB128 encoded sizes with a zero byte which the + // webrtc packetizer removes. To prevent the packetizer from changing the frame, + // we sanitize the size by re-writing it ourselves + uint8_t leb128Buffer[Leb128MaxSize]; + size_t additionalBytesToWrite = WriteLeb128(obuPayloadSize, leb128Buffer); + processor.add_unencrypted_bytes(leb128Buffer, additionalBytesToWrite); + } + + // add the OBU payload, encrypted processor.add_encrypted_bytes(frame.data() + obuPayloadIndex, obuPayloadSize); - } - } + } + } - return true; + return true; } bool validate_encrypted_frame(outbound_frame_processor& processor, array_view frame) { - auto codec = processor.get_codec(); - if (codec != Codec::H264 && codec != Codec::H265) { - return true; - } - - static_assert(kH26XNaluShortStartSequenceSize - 1 >= 0, "Padding will overflow!"); - constexpr size_t Padding = kH26XNaluShortStartSequenceSize - 1; - - const auto& unencryptedRanges = processor.get_unencrypted_ranges(); - - // H264 and H265 ciphertexts cannot contain a 3 or 4 byte start code {0, 0, 1} - // otherwise the packetizer gets confused - // and the frame we get on the decryption side will be shifted and fail to decrypt - size_t encryptedSectionStart = 0; - for (auto& range : unencryptedRanges) { - if (encryptedSectionStart == range.offset) { - encryptedSectionStart += range.size; - continue; - } - - auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); - auto end = std::min(range.offset + Padding, frame.size()); - if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { - return false; - } - - encryptedSectionStart = range.offset + range.size; - } - - if (encryptedSectionStart == frame.size()) { - return true; - } - - auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); - auto end = frame.size(); - if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { - return false; - } - - return true; + auto codec = processor.get_codec(); + if (codec != Codec::H264 && codec != Codec::H265) { + return true; + } + + static_assert(kH26XNaluShortStartSequenceSize - 1 >= 0, "Padding will overflow!"); + constexpr size_t Padding = kH26XNaluShortStartSequenceSize - 1; + + const auto& unencryptedRanges = processor.get_unencrypted_ranges(); + + // H264 and H265 ciphertexts cannot contain a 3 or 4 byte start code {0, 0, 1} + // otherwise the packetizer gets confused + // and the frame we get on the decryption side will be shifted and fail to decrypt + size_t encryptedSectionStart = 0; + for (auto& range : unencryptedRanges) { + if (encryptedSectionStart == range.offset) { + encryptedSectionStart += range.size; + continue; + } + + auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); + auto end = std::min(range.offset + Padding, frame.size()); + if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { + return false; + } + + encryptedSectionStart = range.offset + range.size; + } + + if (encryptedSectionStart == frame.size()) { + return true; + } + + auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); + auto end = frame.size(); + if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { + return false; + } + + return true; } } // namespace dpp::dave::codec_utils diff --git a/src/dpp/dave/common.h b/src/dpp/dave/common.h index 6700a97939..df94ab1a18 100755 --- a/src/dpp/dave/common.h +++ b/src/dpp/dave/common.h @@ -93,17 +93,17 @@ constexpr std::array OPUS_SILENCE_PACKET = {0xF8, 0xFF, 0xFE}; // Utility routine for variant return types template inline std::optional get_optional(V&& variant) { - if (auto map = std::get_if(&variant)) { - if constexpr (std::is_rvalue_reference_v) { - return std::move(*map); - } - else { - return *map; - } - } - else { - return std::nullopt; - } + if (auto map = std::get_if(&variant)) { + if constexpr (std::is_rvalue_reference_v) { + return std::move(*map); + } + else { + return *map; + } + } + else { + return std::nullopt; + } } } // namespace dpp::dave diff --git a/src/dpp/dave/cryptor_manager.cpp b/src/dpp/dave/cryptor_manager.cpp index 3433262f13..fc387e5458 100755 --- a/src/dpp/dave/cryptor_manager.cpp +++ b/src/dpp/dave/cryptor_manager.cpp @@ -37,19 +37,19 @@ namespace dpp::dave { KeyGeneration compute_wrapped_generation(KeyGeneration oldest, KeyGeneration generation) { - // Assume generation is greater than or equal to oldest, this may be wrong in a few cases but - // will be caught by the max generation gap check. - auto remainder = oldest % GENERATION_WRAP; - auto factor = oldest / GENERATION_WRAP + (generation < remainder ? 1 : 0); - return factor * GENERATION_WRAP + generation; + // Assume generation is greater than or equal to oldest, this may be wrong in a few cases but + // will be caught by the max generation gap check. + auto remainder = oldest % GENERATION_WRAP; + auto factor = oldest / GENERATION_WRAP + (generation < remainder ? 1 : 0); + return factor * GENERATION_WRAP + generation; } big_nonce compute_wrapped_big_nonce(KeyGeneration generation, truncated_sync_nonce nonce) { - // Remove the generation bits from the nonce - auto maskedNonce = nonce & ((1 << RATCHET_GENERATION_SHIFT_BITS) - 1); - // Add the wrapped generation bits back in - return static_cast(generation) << RATCHET_GENERATION_SHIFT_BITS | maskedNonce; + // Remove the generation bits from the nonce + auto maskedNonce = nonce & ((1 << RATCHET_GENERATION_SHIFT_BITS) - 1); + // Add the wrapped generation bits back in + return static_cast(generation) << RATCHET_GENERATION_SHIFT_BITS | maskedNonce; } aead_cipher_manager::aead_cipher_manager(const clock_interface& clock, std::unique_ptr keyRatchet) @@ -62,145 +62,145 @@ aead_cipher_manager::aead_cipher_manager(const clock_interface& clock, std::uniq bool aead_cipher_manager::can_process_nonce(KeyGeneration generation, truncated_sync_nonce nonce) const { - if (!newestProcessedNonce_) { - return true; - } + if (!newestProcessedNonce_) { + return true; + } - auto bigNonce = compute_wrapped_big_nonce(generation, nonce); - return bigNonce > *newestProcessedNonce_ || - std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); + auto bigNonce = compute_wrapped_big_nonce(generation, nonce); + return bigNonce > *newestProcessedNonce_ || + std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); } cipher_interface* aead_cipher_manager::get_cipher(KeyGeneration generation) { cleanup_expired_ciphers(); - if (generation < oldestGeneration_) { - DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation - << ", oldest generation: " << oldestGeneration_; - return nullptr; - } - - if (generation > newestGeneration_ + MAX_GENERATION_GAP) { - DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation - << ", newest generation: " << newestGeneration_; - return nullptr; - } - - auto ratchetLifetimeSec = - std::chrono::duration_cast(clock_.now() - ratchetCreation_).count(); - auto maxLifetimeFrames = MAX_FRAMES_PER_SECOND * ratchetLifetimeSec; - auto maxLifetimeGenerations = maxLifetimeFrames >> RATCHET_GENERATION_SHIFT_BITS; - if (generation > maxLifetimeGenerations) { - DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation - << " beyond ratchet max lifetime generations: " - << maxLifetimeGenerations - << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; - return nullptr; - } - - auto it = cryptors_.find(generation); - if (it == cryptors_.end()) { - // We don't have a cryptor for this generation, create one - std::tie(it, std::ignore) = cryptors_.emplace(generation, make_expiring_cipher(generation)); - } - - // Return a non-owning pointer to the cryptor - auto& [cryptor, expiry] = it->second; - return cryptor.get(); + if (generation < oldestGeneration_) { + DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation + << ", oldest generation: " << oldestGeneration_; + return nullptr; + } + + if (generation > newestGeneration_ + MAX_GENERATION_GAP) { + DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation + << ", newest generation: " << newestGeneration_; + return nullptr; + } + + auto ratchetLifetimeSec = + std::chrono::duration_cast(clock_.now() - ratchetCreation_).count(); + auto maxLifetimeFrames = MAX_FRAMES_PER_SECOND * ratchetLifetimeSec; + auto maxLifetimeGenerations = maxLifetimeFrames >> RATCHET_GENERATION_SHIFT_BITS; + if (generation > maxLifetimeGenerations) { + DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation + << " beyond ratchet max lifetime generations: " + << maxLifetimeGenerations + << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; + return nullptr; + } + + auto it = cryptors_.find(generation); + if (it == cryptors_.end()) { + // We don't have a cryptor for this generation, create one + std::tie(it, std::ignore) = cryptors_.emplace(generation, make_expiring_cipher(generation)); + } + + // Return a non-owning pointer to the cryptor + auto& [cryptor, expiry] = it->second; + return cryptor.get(); } void aead_cipher_manager::report_cipher_success(KeyGeneration generation, truncated_sync_nonce nonce) { - auto bigNonce = compute_wrapped_big_nonce(generation, nonce); - - // Add any missing nonces to the queue - if (!newestProcessedNonce_) { - newestProcessedNonce_ = bigNonce; - } - else if (bigNonce > *newestProcessedNonce_) { - auto oldestMissingNonce = bigNonce > MAX_MISSING_NONCES ? bigNonce - MAX_MISSING_NONCES : 0; - - while (!missingNonces_.empty() && missingNonces_.front() < oldestMissingNonce) { - missingNonces_.pop_front(); - } - - // If we're missing a lot, we don't want to add everything since newestProcessedNonce_ - auto missingRangeStart = std::max(oldestMissingNonce, *newestProcessedNonce_ + 1); - for (auto i = missingRangeStart; i < bigNonce; ++i) { - missingNonces_.push_back(i); - } - - // Update the newest processed nonce - newestProcessedNonce_ = bigNonce; - } - else { - auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); - if (it != missingNonces_.end()) { - missingNonces_.erase(it); - } - } - - if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { - return; - } - DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; - newestGeneration_ = generation; - - // Update the expiry time for all old cryptors - const auto expiryTime = clock_.now() + CIPHER_EXPIRY; - for (auto& [gen, cryptor] : cryptors_) { - if (gen < newestGeneration_) { - DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; - cryptor.expiry = std::min(cryptor.expiry, expiryTime); - } - } + auto bigNonce = compute_wrapped_big_nonce(generation, nonce); + + // Add any missing nonces to the queue + if (!newestProcessedNonce_) { + newestProcessedNonce_ = bigNonce; + } + else if (bigNonce > *newestProcessedNonce_) { + auto oldestMissingNonce = bigNonce > MAX_MISSING_NONCES ? bigNonce - MAX_MISSING_NONCES : 0; + + while (!missingNonces_.empty() && missingNonces_.front() < oldestMissingNonce) { + missingNonces_.pop_front(); + } + + // If we're missing a lot, we don't want to add everything since newestProcessedNonce_ + auto missingRangeStart = std::max(oldestMissingNonce, *newestProcessedNonce_ + 1); + for (auto i = missingRangeStart; i < bigNonce; ++i) { + missingNonces_.push_back(i); + } + + // Update the newest processed nonce + newestProcessedNonce_ = bigNonce; + } + else { + auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); + if (it != missingNonces_.end()) { + missingNonces_.erase(it); + } + } + + if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { + return; + } + DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; + newestGeneration_ = generation; + + // Update the expiry time for all old cryptors + const auto expiryTime = clock_.now() + CIPHER_EXPIRY; + for (auto& [gen, cryptor] : cryptors_) { + if (gen < newestGeneration_) { + DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; + cryptor.expiry = std::min(cryptor.expiry, expiryTime); + } + } } KeyGeneration aead_cipher_manager::compute_wrapped_generation(KeyGeneration generation) { - return ::dpp::dave::compute_wrapped_generation(oldestGeneration_, generation); + return ::dpp::dave::compute_wrapped_generation(oldestGeneration_, generation); } aead_cipher_manager::expiring_cipher aead_cipher_manager::make_expiring_cipher(KeyGeneration generation) { - // Get the new key from the ratchet - auto encryptionKey = keyRatchet_->GetKey(generation); - auto expiryTime = time_point::max(); - - // If we got frames out of order, we might have to create a cryptor for an old generation - // In that case, create it with a non-infinite expiry time as we have already transitioned - // to a newer generation - if (generation < newestGeneration_) { - DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; - expiryTime = clock_.now() + CIPHER_EXPIRY; - } - else { - DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; - } - - return {create_cipher(encryptionKey), expiryTime}; + // Get the new key from the ratchet + auto encryptionKey = keyRatchet_->GetKey(generation); + auto expiryTime = time_point::max(); + + // If we got frames out of order, we might have to create a cryptor for an old generation + // In that case, create it with a non-infinite expiry time as we have already transitioned + // to a newer generation + if (generation < newestGeneration_) { + DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; + expiryTime = clock_.now() + CIPHER_EXPIRY; + } + else { + DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; + } + + return {create_cipher(encryptionKey), expiryTime}; } void aead_cipher_manager::cleanup_expired_ciphers() { - for (auto it = cryptors_.begin(); it != cryptors_.end();) { - auto& [generation, cryptor] = *it; - - bool expired = cryptor.expiry < clock_.now(); - if (expired) { - DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; - } - - it = expired ? cryptors_.erase(it) : ++it; - } - - while (oldestGeneration_ < newestGeneration_ && - cryptors_.find(oldestGeneration_) == cryptors_.end()) { - DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; - keyRatchet_->DeleteKey(oldestGeneration_); - ++oldestGeneration_; - } + for (auto it = cryptors_.begin(); it != cryptors_.end();) { + auto& [generation, cryptor] = *it; + + bool expired = cryptor.expiry < clock_.now(); + if (expired) { + DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; + } + + it = expired ? cryptors_.erase(it) : ++it; + } + + while (oldestGeneration_ < newestGeneration_ && + cryptors_.find(oldestGeneration_) == cryptors_.end()) { + DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; + keyRatchet_->DeleteKey(oldestGeneration_); + ++oldestGeneration_; + } } } // namespace dpp::dave diff --git a/src/dpp/dave/cryptor_manager.h b/src/dpp/dave/cryptor_manager.h index 4c950b852d..ebf6a5052d 100755 --- a/src/dpp/dave/cryptor_manager.h +++ b/src/dpp/dave/cryptor_manager.h @@ -43,39 +43,39 @@ big_nonce compute_wrapped_big_nonce(KeyGeneration generation, truncated_sync_non class aead_cipher_manager { public: - using time_point = typename clock_interface::time_point; + using time_point = typename clock_interface::time_point; - aead_cipher_manager(const clock_interface& clock, std::unique_ptr keyRatchet); + aead_cipher_manager(const clock_interface& clock, std::unique_ptr keyRatchet); - void update_expiry(time_point expiry) { ratchetExpiry_ = expiry; } - bool is_expired() const { return clock_.now() > ratchetExpiry_; } + void update_expiry(time_point expiry) { ratchetExpiry_ = expiry; } + bool is_expired() const { return clock_.now() > ratchetExpiry_; } - bool can_process_nonce(KeyGeneration generation, truncated_sync_nonce nonce) const; - KeyGeneration compute_wrapped_generation(KeyGeneration generation); + bool can_process_nonce(KeyGeneration generation, truncated_sync_nonce nonce) const; + KeyGeneration compute_wrapped_generation(KeyGeneration generation); - cipher_interface* get_cipher(KeyGeneration generation); - void report_cipher_success(KeyGeneration generation, truncated_sync_nonce nonce); + cipher_interface* get_cipher(KeyGeneration generation); + void report_cipher_success(KeyGeneration generation, truncated_sync_nonce nonce); private: - struct expiring_cipher { - std::unique_ptr cryptor; - time_point expiry; - }; + struct expiring_cipher { + std::unique_ptr cryptor; + time_point expiry; + }; - expiring_cipher make_expiring_cipher(KeyGeneration generation); - void cleanup_expired_ciphers(); + expiring_cipher make_expiring_cipher(KeyGeneration generation); + void cleanup_expired_ciphers(); - const clock_interface& clock_; - std::unique_ptr keyRatchet_; - std::unordered_map cryptors_; + const clock_interface& clock_; + std::unique_ptr keyRatchet_; + std::unordered_map cryptors_; - time_point ratchetCreation_; - time_point ratchetExpiry_; - KeyGeneration oldestGeneration_{0}; - KeyGeneration newestGeneration_{0}; + time_point ratchetCreation_; + time_point ratchetExpiry_; + KeyGeneration oldestGeneration_{0}; + KeyGeneration newestGeneration_{0}; - std::optional newestProcessedNonce_; - std::deque missingNonces_; + std::optional newestProcessedNonce_; + std::deque missingNonces_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/decryptor.cpp b/src/dpp/dave/decryptor.cpp index da96b30100..a91a922b4b 100755 --- a/src/dpp/dave/decryptor.cpp +++ b/src/dpp/dave/decryptor.cpp @@ -41,203 +41,203 @@ constexpr auto kStatsInterval = 10s; void decryptor::transition_to_key_ratchet(std::unique_ptr keyRatchet, Duration transitionExpiry) { - DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() - << ", expiry: " << transitionExpiry.count(); + DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() + << ", expiry: " << transitionExpiry.count(); - // Update the expiry time for all existing cryptor managers + // Update the expiry time for all existing cryptor managers update_cryptor_manager_expiry(transitionExpiry); - if (keyRatchet) { - cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); - } + if (keyRatchet) { + cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); + } } void decryptor::transition_to_passthrough_mode(bool passthroughMode, Duration transitionExpiry) { - if (passthroughMode) { - allowPassThroughUntil_ = time_point::max(); - } - else { - // Update the pass through mode expiry - auto maxExpiry = clock_.now() + transitionExpiry; - allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); - } + if (passthroughMode) { + allowPassThroughUntil_ = time_point::max(); + } + else { + // Update the pass through mode expiry + auto maxExpiry = clock_.now() + transitionExpiry; + allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); + } } size_t decryptor::decrypt(media_type mediaType, array_view encryptedFrame, array_view frame) { - if (mediaType != media_audio && mediaType != media_video) { - DISCORD_LOG(LS_WARNING) << "decrypt failed, invalid media type: " - << static_cast(mediaType); - return 0; - } + if (mediaType != media_audio && mediaType != media_video) { + DISCORD_LOG(LS_WARNING) << "decrypt failed, invalid media type: " + << static_cast(mediaType); + return 0; + } - auto start = clock_.now(); + auto start = clock_.now(); - auto localFrame = get_or_create_frame_processor(); - ScopeExit cleanup([&] { return_frame_processor(std::move(localFrame)); }); + auto localFrame = get_or_create_frame_processor(); + ScopeExit cleanup([&] { return_frame_processor(std::move(localFrame)); }); - // Skip decrypting for silence frames - if (mediaType == media_audio && encryptedFrame.size() == OPUS_SILENCE_PACKET.size() && + // Skip decrypting for silence frames + if (mediaType == media_audio && encryptedFrame.size() == OPUS_SILENCE_PACKET.size() && std::memcmp(encryptedFrame.data(), OPUS_SILENCE_PACKET.data(), OPUS_SILENCE_PACKET.size()) == 0) { - DISCORD_LOG(LS_VERBOSE) << "decrypt skipping silence of size: " << encryptedFrame.size(); - if (encryptedFrame.data() != frame.data()) { + DISCORD_LOG(LS_VERBOSE) << "decrypt skipping silence of size: " << encryptedFrame.size(); + if (encryptedFrame.data() != frame.data()) { std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); - } - return encryptedFrame.size(); - } + } + return encryptedFrame.size(); + } - // Remove any expired cryptor manager + // Remove any expired cryptor manager cleanup_expired_cryptor_managers(); - // Process the incoming frame - // This will check whether it looks like a valid encrypted frame - // and if so it will parse it into its different components + // Process the incoming frame + // This will check whether it looks like a valid encrypted frame + // and if so it will parse it into its different components localFrame->parse_frame(encryptedFrame); - // If the frame is not encrypted and we can pass it through, do it - bool canUsePassThrough = allowPassThroughUntil_ > start; - if (!localFrame->is_encrypted() && canUsePassThrough) { - if (encryptedFrame.data() != frame.data()) { + // If the frame is not encrypted and we can pass it through, do it + bool canUsePassThrough = allowPassThroughUntil_ > start; + if (!localFrame->is_encrypted() && canUsePassThrough) { + if (encryptedFrame.data() != frame.data()) { std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); - } - stats_[mediaType].passthroughs++; - return encryptedFrame.size(); - } - - // If the frame is not encrypted and we can't pass it through, fail - if (!localFrame->is_encrypted()) { - DISCORD_LOG(LS_INFO) - << "decrypt failed, frame is not encrypted and pass through is disabled"; - stats_[mediaType].decrypt_failure++; - return 0; - } - - // Try and decrypt with each valid cryptor - // reverse iterate to try the newest cryptors first - bool success = false; - for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { - auto& cryptorManager = *it; - success = decrypt_impl(cryptorManager, mediaType, *localFrame, frame); - if (success) { - break; - } - } - - size_t bytesWritten = 0; - if (success) { - stats_[mediaType].decrypt_success++; - bytesWritten = localFrame->reconstruct_frame(frame); - } - else { - stats_[mediaType].decrypt_failure++; - DISCORD_LOG(LS_WARNING) << "decrypt failed, no valid cryptor found, type: " - << (mediaType ? "video" : "audio") - << ", encrypted frame size: " << encryptedFrame.size() - << ", plaintext frame size: " << frame.size() - << ", number of cryptor managers: " << cryptorManagers_.size() - << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); - } - - auto end = clock_.now(); - if (end > lastStatsTime_ + kStatsInterval) { - lastStatsTime_ = end; - DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[media_audio].decrypt_success - << ", video: " << stats_[media_video].decrypt_success - << ". Failed audio: " << stats_[media_audio].decrypt_failure - << ", video: " << stats_[media_video].decrypt_failure; - } - stats_[mediaType].decrypt_duration += - std::chrono::duration_cast(end - start).count(); - - return bytesWritten; + } + stats_[mediaType].passthroughs++; + return encryptedFrame.size(); + } + + // If the frame is not encrypted and we can't pass it through, fail + if (!localFrame->is_encrypted()) { + DISCORD_LOG(LS_INFO) + << "decrypt failed, frame is not encrypted and pass through is disabled"; + stats_[mediaType].decrypt_failure++; + return 0; + } + + // Try and decrypt with each valid cryptor + // reverse iterate to try the newest cryptors first + bool success = false; + for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { + auto& cryptorManager = *it; + success = decrypt_impl(cryptorManager, mediaType, *localFrame, frame); + if (success) { + break; + } + } + + size_t bytesWritten = 0; + if (success) { + stats_[mediaType].decrypt_success++; + bytesWritten = localFrame->reconstruct_frame(frame); + } + else { + stats_[mediaType].decrypt_failure++; + DISCORD_LOG(LS_WARNING) << "decrypt failed, no valid cryptor found, type: " + << (mediaType ? "video" : "audio") + << ", encrypted frame size: " << encryptedFrame.size() + << ", plaintext frame size: " << frame.size() + << ", number of cryptor managers: " << cryptorManagers_.size() + << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); + } + + auto end = clock_.now(); + if (end > lastStatsTime_ + kStatsInterval) { + lastStatsTime_ = end; + DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[media_audio].decrypt_success + << ", video: " << stats_[media_video].decrypt_success + << ". Failed audio: " << stats_[media_audio].decrypt_failure + << ", video: " << stats_[media_video].decrypt_failure; + } + stats_[mediaType].decrypt_duration += + std::chrono::duration_cast(end - start).count(); + + return bytesWritten; } bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, - media_type mediaType, - inbound_frame_processor& encryptedFrame, - array_view frame) + media_type mediaType, + inbound_frame_processor& encryptedFrame, + array_view frame) { - auto tag = encryptedFrame.get_tag(); - auto truncatedNonce = encryptedFrame.get_truncated_nonce(); + auto tag = encryptedFrame.get_tag(); + auto truncatedNonce = encryptedFrame.get_truncated_nonce(); - auto authenticatedData = encryptedFrame.get_authenticated_data(); - auto ciphertext = encryptedFrame.GetCiphertext(); - auto plaintext = encryptedFrame.get_plaintext(); + auto authenticatedData = encryptedFrame.get_authenticated_data(); + auto ciphertext = encryptedFrame.GetCiphertext(); + auto plaintext = encryptedFrame.get_plaintext(); - // expand the truncated nonce to the full sized one needed for decryption - auto nonceBuffer = std::array(); - memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, + // expand the truncated nonce to the full sized one needed for decryption + auto nonceBuffer = std::array(); + memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, &truncatedNonce, AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); - auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); + auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); - auto generation = - cipher_manager.compute_wrapped_generation(truncatedNonce >> RATCHET_GENERATION_SHIFT_BITS); + auto generation = + cipher_manager.compute_wrapped_generation(truncatedNonce >> RATCHET_GENERATION_SHIFT_BITS); - if (!cipher_manager.can_process_nonce(generation, truncatedNonce)) { - DISCORD_LOG(LS_INFO) << "decrypt failed, cannot process nonce: " << truncatedNonce; - return false; - } + if (!cipher_manager.can_process_nonce(generation, truncatedNonce)) { + DISCORD_LOG(LS_INFO) << "decrypt failed, cannot process nonce: " << truncatedNonce; + return false; + } - // Get the cryptor for this generation - cipher_interface* cipher = cipher_manager.get_cipher(generation); + // Get the cryptor for this generation + cipher_interface* cipher = cipher_manager.get_cipher(generation); - if (cipher == nullptr) { - DISCORD_LOG(LS_INFO) << "decrypt failed, no cryptor found for generation: " << generation; - return false; - } + if (cipher == nullptr) { + DISCORD_LOG(LS_INFO) << "decrypt failed, no cryptor found for generation: " << generation; + return false; + } - // perform the decryption - bool success = cipher->decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); - stats_[mediaType].decrypt_attempts++; + // perform the decryption + bool success = cipher->decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); + stats_[mediaType].decrypt_attempts++; - if (success) { - cipher_manager.report_cipher_success(generation, truncatedNonce); - } + if (success) { + cipher_manager.report_cipher_success(generation, truncatedNonce); + } - return success; + return success; } size_t decryptor::get_max_plaintext_byte_size(media_type mediaType, size_t encryptedFrameSize) { - return encryptedFrameSize; + return encryptedFrameSize; } void decryptor::update_cryptor_manager_expiry(Duration expiry) { - auto maxExpiryTime = clock_.now() + expiry; - for (auto& cryptorManager : cryptorManagers_) { - cryptorManager.update_expiry(maxExpiryTime); - } + auto maxExpiryTime = clock_.now() + expiry; + for (auto& cryptorManager : cryptorManagers_) { + cryptorManager.update_expiry(maxExpiryTime); + } } void decryptor::cleanup_expired_cryptor_managers() { - while (!cryptorManagers_.empty() && cryptorManagers_.front().is_expired()) { - DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; - cryptorManagers_.pop_front(); - } + while (!cryptorManagers_.empty() && cryptorManagers_.front().is_expired()) { + DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; + cryptorManagers_.pop_front(); + } } std::unique_ptr decryptor::get_or_create_frame_processor() { - std::lock_guard lock(frameProcessorsMutex_); - if (frameProcessors_.empty()) { - return std::make_unique(); - } - auto frameProcessor = std::move(frameProcessors_.back()); - frameProcessors_.pop_back(); - return frameProcessor; + std::lock_guard lock(frameProcessorsMutex_); + if (frameProcessors_.empty()) { + return std::make_unique(); + } + auto frameProcessor = std::move(frameProcessors_.back()); + frameProcessors_.pop_back(); + return frameProcessor; } void decryptor::return_frame_processor(std::unique_ptr frameProcessor) { - std::lock_guard lock(frameProcessorsMutex_); - frameProcessors_.push_back(std::move(frameProcessor)); + std::lock_guard lock(frameProcessorsMutex_); + frameProcessors_.push_back(std::move(frameProcessor)); } } // namespace dpp::dave diff --git a/src/dpp/dave/decryptor.h b/src/dpp/dave/decryptor.h index e4cd550e75..9e38e979dc 100755 --- a/src/dpp/dave/decryptor.h +++ b/src/dpp/dave/decryptor.h @@ -44,53 +44,53 @@ namespace dpp::dave { class IKeyRatchet; struct decryption_stats { - uint64_t passthroughs = 0; - uint64_t decrypt_success = 0; - uint64_t decrypt_failure = 0; - uint64_t decrypt_duration = 0; - uint64_t decrypt_attempts = 0; + uint64_t passthroughs = 0; + uint64_t decrypt_success = 0; + uint64_t decrypt_failure = 0; + uint64_t decrypt_duration = 0; + uint64_t decrypt_attempts = 0; }; class decryptor { public: - using Duration = std::chrono::seconds; + using Duration = std::chrono::seconds; - void transition_to_key_ratchet(std::unique_ptr keyRatchet, + void transition_to_key_ratchet(std::unique_ptr keyRatchet, Duration transitionExpiry = DEFAULT_TRANSITION_EXPIRY); - void transition_to_passthrough_mode(bool passthroughMode, + void transition_to_passthrough_mode(bool passthroughMode, Duration transitionExpiry = DEFAULT_TRANSITION_EXPIRY); - size_t decrypt(media_type mediaType, + size_t decrypt(media_type mediaType, array_view encryptedFrame, array_view frame); - size_t get_max_plaintext_byte_size(media_type mediaType, size_t encryptedFrameSize); - decryption_stats get_stats(media_type mediaType) const { return stats_[mediaType]; } + size_t get_max_plaintext_byte_size(media_type mediaType, size_t encryptedFrameSize); + decryption_stats get_stats(media_type mediaType) const { return stats_[mediaType]; } private: - using time_point = clock_interface::time_point; + using time_point = clock_interface::time_point; - bool decrypt_impl(aead_cipher_manager& cipher_manager, - media_type mediaType, - inbound_frame_processor& encryptedFrame, - array_view frame); + bool decrypt_impl(aead_cipher_manager& cipher_manager, + media_type mediaType, + inbound_frame_processor& encryptedFrame, + array_view frame); - void update_cryptor_manager_expiry(Duration expiry); - void cleanup_expired_cryptor_managers(); + void update_cryptor_manager_expiry(Duration expiry); + void cleanup_expired_cryptor_managers(); - std::unique_ptr get_or_create_frame_processor(); - void return_frame_processor(std::unique_ptr frameProcessor); + std::unique_ptr get_or_create_frame_processor(); + void return_frame_processor(std::unique_ptr frameProcessor); - clock clock_; - std::deque cryptorManagers_; + clock clock_; + std::deque cryptorManagers_; - std::mutex frameProcessorsMutex_; - std::vector> frameProcessors_; + std::mutex frameProcessorsMutex_; + std::vector> frameProcessors_; - time_point allowPassThroughUntil_{time_point::min()}; + time_point allowPassThroughUntil_{time_point::min()}; - time_point lastStatsTime_{time_point::min()}; - std::array stats_; + time_point lastStatsTime_{time_point::min()}; + std::array stats_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/encryptor.cpp b/src/dpp/dave/encryptor.cpp index 8937d0f7b7..d6d765edc7 100755 --- a/src/dpp/dave/encryptor.cpp +++ b/src/dpp/dave/encryptor.cpp @@ -42,273 +42,273 @@ constexpr auto kStatsInterval = 10s; void encryptor::set_key_ratchet(std::unique_ptr keyRatchet) { - std::lock_guard lock(keyGenMutex_); - keyRatchet_ = std::move(keyRatchet); - cryptor_ = nullptr; - currentKeyGeneration_ = 0; - truncatedNonce_ = 0; + std::lock_guard lock(keyGenMutex_); + keyRatchet_ = std::move(keyRatchet); + cryptor_ = nullptr; + currentKeyGeneration_ = 0; + truncatedNonce_ = 0; } void encryptor::set_passthrough_mode(bool passthroughMode) { - passthroughMode_ = passthroughMode; + passthroughMode_ = passthroughMode; update_current_protocol_version(passthroughMode ? 0 : MaxSupportedProtocolVersion()); } int encryptor::encrypt(media_type mediaType, - uint32_t ssrc, - array_view frame, - array_view encryptedFrame, - size_t* bytesWritten) + uint32_t ssrc, + array_view frame, + array_view encryptedFrame, + size_t* bytesWritten) { - if (mediaType != media_audio && mediaType != media_video) { - DISCORD_LOG(LS_WARNING) << "encrypt failed, invalid media type: " - << static_cast(mediaType); - return 0; - } - - if (passthroughMode_) { - // Pass frame through without encrypting + if (mediaType != media_audio && mediaType != media_video) { + DISCORD_LOG(LS_WARNING) << "encrypt failed, invalid media type: " + << static_cast(mediaType); + return 0; + } + + if (passthroughMode_) { + // Pass frame through without encrypting std::memcpy(encryptedFrame.data(), frame.data(), frame.size()); - *bytesWritten = frame.size(); - stats_[mediaType].passthroughs++; - return result_code::rc_success; - } + *bytesWritten = frame.size(); + stats_[mediaType].passthroughs++; + return result_code::rc_success; + } - { - std::lock_guard lock(keyGenMutex_); - if (!keyRatchet_) { - stats_[mediaType].encrypt_failure++; - return result_code::rc_encryption_failure; - } - } + { + std::lock_guard lock(keyGenMutex_); + if (!keyRatchet_) { + stats_[mediaType].encrypt_failure++; + return result_code::rc_encryption_failure; + } + } - auto start = std::chrono::steady_clock::now(); - auto result = result_code::rc_success; + auto start = std::chrono::steady_clock::now(); + auto result = result_code::rc_success; - // write the codec identifier - auto codec = codec_for_ssrc(ssrc); + // write the codec identifier + auto codec = codec_for_ssrc(ssrc); - auto frameProcessor = get_or_create_frame_processor(); - ScopeExit cleanup([&] { return_frame_processor(std::move(frameProcessor)); }); + auto frameProcessor = get_or_create_frame_processor(); + ScopeExit cleanup([&] { return_frame_processor(std::move(frameProcessor)); }); frameProcessor->process_frame(frame, codec); - const auto& unencryptedBytes = frameProcessor->get_unencrypted_bytes(); - const auto& encryptedBytes = frameProcessor->get_encrypted_bytes(); - auto& ciphertextBytes = frameProcessor->get_ciphertext_bytes(); + const auto& unencryptedBytes = frameProcessor->get_unencrypted_bytes(); + const auto& encryptedBytes = frameProcessor->get_encrypted_bytes(); + auto& ciphertextBytes = frameProcessor->get_ciphertext_bytes(); - const auto& unencryptedRanges = frameProcessor->get_unencrypted_ranges(); - auto unencryptedRangesSize = UnencryptedRangesSize(unencryptedRanges); + const auto& unencryptedRanges = frameProcessor->get_unencrypted_ranges(); + auto unencryptedRangesSize = UnencryptedRangesSize(unencryptedRanges); - auto additionalData = make_array_view(unencryptedBytes.data(), unencryptedBytes.size()); - auto plaintextBuffer = make_array_view(encryptedBytes.data(), encryptedBytes.size()); - auto ciphertextBuffer = make_array_view(ciphertextBytes.data(), ciphertextBytes.size()); + auto additionalData = make_array_view(unencryptedBytes.data(), unencryptedBytes.size()); + auto plaintextBuffer = make_array_view(encryptedBytes.data(), encryptedBytes.size()); + auto ciphertextBuffer = make_array_view(ciphertextBytes.data(), ciphertextBytes.size()); - auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); - auto tagBuffer = make_array_view(encryptedFrame.data() + frameSize, AES_GCM_127_TRUNCATED_TAG_BYTES); + auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); + auto tagBuffer = make_array_view(encryptedFrame.data() + frameSize, AES_GCM_127_TRUNCATED_TAG_BYTES); - auto nonceBuffer = std::array(); - auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); + auto nonceBuffer = std::array(); + auto nonceBufferView = make_array_view(nonceBuffer.data(), nonceBuffer.size()); - constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; + constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; - // some codecs (e.g. H26X) have packetizers that cannot handle specific byte sequences - // so we attempt up to MAX_CIPHERTEXT_VALIDATION_RETRIES to encrypt the frame - // calling into codec utils to validate the ciphertext + supplemental section - // and re-rolling the truncated nonce if it fails + // some codecs (e.g. H26X) have packetizers that cannot handle specific byte sequences + // so we attempt up to MAX_CIPHERTEXT_VALIDATION_RETRIES to encrypt the frame + // calling into codec utils to validate the ciphertext + supplemental section + // and re-rolling the truncated nonce if it fails - // the nonce increment will definitely change the ciphertext and the tag - // incrementing the nonce will also change the appropriate bytes - // in the tail end of the nonce - // which can remove start codes from the last 1 or 2 bytes of the nonce - // and the two bytes of the unencrypted header bytes - for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { - auto [cryptor, truncatedNonce] = get_next_cryptor_and_nonce(); + // the nonce increment will definitely change the ciphertext and the tag + // incrementing the nonce will also change the appropriate bytes + // in the tail end of the nonce + // which can remove start codes from the last 1 or 2 bytes of the nonce + // and the two bytes of the unencrypted header bytes + for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { + auto [cryptor, truncatedNonce] = get_next_cryptor_and_nonce(); - if (!cryptor) { - result = result_code::rc_encryption_failure; - break; - } + if (!cryptor) { + result = result_code::rc_encryption_failure; + break; + } - // write the truncated nonce to our temporary full nonce array - // (since the encryption call expects a full size nonce) - std::memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, - &truncatedNonce, - AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); + // write the truncated nonce to our temporary full nonce array + // (since the encryption call expects a full size nonce) + std::memcpy(nonceBuffer.data() + AES_GCM_128_TRUNCATED_SYNC_NONCE_OFFSET, + &truncatedNonce, + AES_GCM_128_TRUNCATED_SYNC_NONCE_BYTES); - // encrypt the plaintext, adding the unencrypted header to the tag - bool success = cryptor->encrypt( + // encrypt the plaintext, adding the unencrypted header to the tag + bool success = cryptor->encrypt( ciphertextBuffer, plaintextBuffer, nonceBufferView, additionalData, tagBuffer); - stats_[mediaType].encrypt_attempts++; - stats_[mediaType].encrypt_max_attempts = - std::max(stats_[mediaType].encrypt_max_attempts, (uint64_t)attempt); + stats_[mediaType].encrypt_attempts++; + stats_[mediaType].encrypt_max_attempts = + std::max(stats_[mediaType].encrypt_max_attempts, (uint64_t)attempt); - if (!success) { - assert(false && "Failed to encrypt frame"); - result = result_code::rc_encryption_failure; - break; - } + if (!success) { + assert(false && "Failed to encrypt frame"); + result = result_code::rc_encryption_failure; + break; + } - auto reconstructedFrameSize = frameProcessor->reconstruct_frame(encryptedFrame); - assert(reconstructedFrameSize == frameSize && "Failed to reconstruct frame"); + auto reconstructedFrameSize = frameProcessor->reconstruct_frame(encryptedFrame); + assert(reconstructedFrameSize == frameSize && "Failed to reconstruct frame"); - auto nonceSize = Leb128Size(truncatedNonce); + auto nonceSize = Leb128Size(truncatedNonce); - auto truncatedNonceBuffer = make_array_view(tagBuffer.end(), nonceSize); - auto unencryptedRangesBuffer = + auto truncatedNonceBuffer = make_array_view(tagBuffer.end(), nonceSize); + auto unencryptedRangesBuffer = make_array_view(truncatedNonceBuffer.end(), unencryptedRangesSize); - auto supplementalBytesBuffer = + auto supplementalBytesBuffer = make_array_view(unencryptedRangesBuffer.end(), sizeof(supplemental_bytes_size)); - auto markerBytesBuffer = make_array_view(supplementalBytesBuffer.end(), sizeof(magic_marker)); - - // write the nonce - auto res = WriteLeb128(truncatedNonce, truncatedNonceBuffer.begin()); - if (res != nonceSize) { - assert(false && "Failed to write truncated nonce"); - result = result_code::rc_encryption_failure; - break; - } - - // write the unencrypted ranges - res = SerializeUnencryptedRanges( - unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); - if (res != unencryptedRangesSize) { - assert(false && "Failed to write unencrypted ranges"); - result = result_code::rc_encryption_failure; - break; - } - - // write the supplemental bytes size - supplemental_bytes_size supplementalBytes = + auto markerBytesBuffer = make_array_view(supplementalBytesBuffer.end(), sizeof(magic_marker)); + + // write the nonce + auto res = WriteLeb128(truncatedNonce, truncatedNonceBuffer.begin()); + if (res != nonceSize) { + assert(false && "Failed to write truncated nonce"); + result = result_code::rc_encryption_failure; + break; + } + + // write the unencrypted ranges + res = SerializeUnencryptedRanges( + unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); + if (res != unencryptedRangesSize) { + assert(false && "Failed to write unencrypted ranges"); + result = result_code::rc_encryption_failure; + break; + } + + // write the supplemental bytes size + supplemental_bytes_size supplementalBytes = SUPPLEMENTAL_BYTES + nonceSize + unencryptedRangesSize; std::memcpy(supplementalBytesBuffer.data(), &supplementalBytes, sizeof(supplemental_bytes_size)); - // write the marker bytes, ends the frame + // write the marker bytes, ends the frame std::memcpy(markerBytesBuffer.data(), &MARKER_BYTES, sizeof(magic_marker)); - auto encryptedFrameBytes = reconstructedFrameSize + AES_GCM_127_TRUNCATED_TAG_BYTES + + auto encryptedFrameBytes = reconstructedFrameSize + AES_GCM_127_TRUNCATED_TAG_BYTES + nonceSize + unencryptedRangesSize + sizeof(supplemental_bytes_size) + sizeof(magic_marker); - if (codec_utils::validate_encrypted_frame( + if (codec_utils::validate_encrypted_frame( *frameProcessor, make_array_view(encryptedFrame.data(), encryptedFrameBytes))) { - *bytesWritten = encryptedFrameBytes; - break; - } - else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { - assert(false && "Failed to validate encrypted section for codec"); - result = result_code::rc_encryption_failure; - break; - } - } - - auto now = std::chrono::steady_clock::now(); - stats_[mediaType].encrypt_duration += - std::chrono::duration_cast(now - start).count(); - if (result == result_code::rc_success) { - stats_[mediaType].encrypt_success++; - } - else { - stats_[mediaType].encrypt_failure++; - } - - if (now > lastStatsTime_ + kStatsInterval) { - lastStatsTime_ = now; - DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[media_audio].encrypt_success - << ", video: " << stats_[media_video].encrypt_success - << ". Failed audio: " << stats_[media_audio].encrypt_failure - << ", video: " << stats_[media_video].encrypt_failure; - DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " - << (mediaType == media_audio ? "audio" : "video") << ", ssrc: " << ssrc - << ", size: " << frame.size(); - } - - return result; + *bytesWritten = encryptedFrameBytes; + break; + } + else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { + assert(false && "Failed to validate encrypted section for codec"); + result = result_code::rc_encryption_failure; + break; + } + } + + auto now = std::chrono::steady_clock::now(); + stats_[mediaType].encrypt_duration += + std::chrono::duration_cast(now - start).count(); + if (result == result_code::rc_success) { + stats_[mediaType].encrypt_success++; + } + else { + stats_[mediaType].encrypt_failure++; + } + + if (now > lastStatsTime_ + kStatsInterval) { + lastStatsTime_ = now; + DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[media_audio].encrypt_success + << ", video: " << stats_[media_video].encrypt_success + << ". Failed audio: " << stats_[media_audio].encrypt_failure + << ", video: " << stats_[media_video].encrypt_failure; + DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " + << (mediaType == media_audio ? "audio" : "video") << ", ssrc: " << ssrc + << ", size: " << frame.size(); + } + + return result; } size_t encryptor::get_max_ciphertext_byte_size(media_type mediaType, size_t frameSize) { - return frameSize + SUPPLEMENTAL_BYTES + TRANSFORM_PADDING_BYTES; + return frameSize + SUPPLEMENTAL_BYTES + TRANSFORM_PADDING_BYTES; } void encryptor::assign_ssrc_to_codec(uint32_t ssrc, Codec codecType) { - auto existingCodecIt = std::find_if( - ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { - return pair.first == ssrc; - }); - - if (existingCodecIt == ssrcCodecPairs_.end()) { - ssrcCodecPairs_.emplace_back(ssrc, codecType); - } - else { - existingCodecIt->second = codecType; - } + auto existingCodecIt = std::find_if( + ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { + return pair.first == ssrc; + }); + + if (existingCodecIt == ssrcCodecPairs_.end()) { + ssrcCodecPairs_.emplace_back(ssrc, codecType); + } + else { + existingCodecIt->second = codecType; + } } Codec encryptor::codec_for_ssrc(uint32_t ssrc) { - auto existingCodecIt = std::find_if( - ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { - return pair.first == ssrc; - }); - - if (existingCodecIt != ssrcCodecPairs_.end()) { - return existingCodecIt->second; - } - else { - return Codec::Opus; - } + auto existingCodecIt = std::find_if( + ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { + return pair.first == ssrc; + }); + + if (existingCodecIt != ssrcCodecPairs_.end()) { + return existingCodecIt->second; + } + else { + return Codec::Opus; + } } std::unique_ptr encryptor::get_or_create_frame_processor() { - std::lock_guard lock(frameProcessorsMutex_); - if (frameProcessors_.empty()) { - return std::make_unique(); - } - auto frameProcessor = std::move(frameProcessors_.back()); - frameProcessors_.pop_back(); - return frameProcessor; + std::lock_guard lock(frameProcessorsMutex_); + if (frameProcessors_.empty()) { + return std::make_unique(); + } + auto frameProcessor = std::move(frameProcessors_.back()); + frameProcessors_.pop_back(); + return frameProcessor; } void encryptor::return_frame_processor(std::unique_ptr frameProcessor) { - std::lock_guard lock(frameProcessorsMutex_); - frameProcessors_.push_back(std::move(frameProcessor)); + std::lock_guard lock(frameProcessorsMutex_); + frameProcessors_.push_back(std::move(frameProcessor)); } encryptor::cryptor_and_nonce encryptor::get_next_cryptor_and_nonce() { - std::lock_guard lock(keyGenMutex_); - if (!keyRatchet_) { - return {nullptr, 0}; - } + std::lock_guard lock(keyGenMutex_); + if (!keyRatchet_) { + return {nullptr, 0}; + } - auto generation = compute_wrapped_generation(currentKeyGeneration_, + auto generation = compute_wrapped_generation(currentKeyGeneration_, ++truncatedNonce_ >> RATCHET_GENERATION_SHIFT_BITS); - if (generation != currentKeyGeneration_ || !cryptor_) { - currentKeyGeneration_ = generation; + if (generation != currentKeyGeneration_ || !cryptor_) { + currentKeyGeneration_ = generation; - auto encryptionKey = keyRatchet_->GetKey(currentKeyGeneration_); - cryptor_ = create_cipher(encryptionKey); - } + auto encryptionKey = keyRatchet_->GetKey(currentKeyGeneration_); + cryptor_ = create_cipher(encryptionKey); + } - return {cryptor_, truncatedNonce_}; + return {cryptor_, truncatedNonce_}; } void encryptor::update_current_protocol_version(ProtocolVersion version) { - if (version == currentProtocolVersion_) { - return; - } - - currentProtocolVersion_ = version; - if (protocolVersionChangedCallback_) { - protocolVersionChangedCallback_(); - } + if (version == currentProtocolVersion_) { + return; + } + + currentProtocolVersion_ = version; + if (protocolVersionChangedCallback_) { + protocolVersionChangedCallback_(); + } } } // namespace dpp::dave diff --git a/src/dpp/dave/encryptor.h b/src/dpp/dave/encryptor.h index 0adc0617e4..ea83c63c65 100755 --- a/src/dpp/dave/encryptor.h +++ b/src/dpp/dave/encryptor.h @@ -41,75 +41,75 @@ namespace dpp::dave { struct encryption_stats { - uint64_t passthroughs = 0; - uint64_t encrypt_success = 0; - uint64_t encrypt_failure = 0; - uint64_t encrypt_duration = 0; - uint64_t encrypt_attempts = 0; - uint64_t encrypt_max_attempts = 0; + uint64_t passthroughs = 0; + uint64_t encrypt_success = 0; + uint64_t encrypt_failure = 0; + uint64_t encrypt_duration = 0; + uint64_t encrypt_attempts = 0; + uint64_t encrypt_max_attempts = 0; }; class encryptor { public: - void set_key_ratchet(std::unique_ptr keyRatchet); - void set_passthrough_mode(bool passthroughMode); + void set_key_ratchet(std::unique_ptr keyRatchet); + void set_passthrough_mode(bool passthroughMode); - bool has_key_ratchet() const { return keyRatchet_ != nullptr; } - bool is_passthrough_mode() const { return passthroughMode_; } + bool has_key_ratchet() const { return keyRatchet_ != nullptr; } + bool is_passthrough_mode() const { return passthroughMode_; } - void assign_ssrc_to_codec(uint32_t ssrc, Codec codecType); - Codec codec_for_ssrc(uint32_t ssrc); + void assign_ssrc_to_codec(uint32_t ssrc, Codec codecType); + Codec codec_for_ssrc(uint32_t ssrc); - int encrypt(media_type mediaType, + int encrypt(media_type mediaType, uint32_t ssrc, array_view frame, array_view encryptedFrame, size_t* bytesWritten); - size_t get_max_ciphertext_byte_size(media_type mediaType, size_t frameSize); - encryption_stats get_stats(media_type mediaType) const { return stats_[mediaType]; } + size_t get_max_ciphertext_byte_size(media_type mediaType, size_t frameSize); + encryption_stats get_stats(media_type mediaType) const { return stats_[mediaType]; } - using protocol_version_changed_callback = std::function; - void set_protocol_version_changed_callback(protocol_version_changed_callback callback) - { - protocolVersionChangedCallback_ = std::move(callback); - } - ProtocolVersion get_protocol_version() const { return currentProtocolVersion_; } + using protocol_version_changed_callback = std::function; + void set_protocol_version_changed_callback(protocol_version_changed_callback callback) + { + protocolVersionChangedCallback_ = std::move(callback); + } + ProtocolVersion get_protocol_version() const { return currentProtocolVersion_; } - enum result_code : uint8_t { - rc_success, - rc_encryption_failure, - }; + enum result_code : uint8_t { + rc_success, + rc_encryption_failure, + }; private: - std::unique_ptr get_or_create_frame_processor(); - void return_frame_processor(std::unique_ptr frameProcessor); + std::unique_ptr get_or_create_frame_processor(); + void return_frame_processor(std::unique_ptr frameProcessor); - using cryptor_and_nonce = std::pair, truncated_sync_nonce>; - cryptor_and_nonce get_next_cryptor_and_nonce(); + using cryptor_and_nonce = std::pair, truncated_sync_nonce>; + cryptor_and_nonce get_next_cryptor_and_nonce(); - void update_current_protocol_version(ProtocolVersion version); + void update_current_protocol_version(ProtocolVersion version); - std::atomic_bool passthroughMode_{false}; + std::atomic_bool passthroughMode_{false}; - std::mutex keyGenMutex_; - std::unique_ptr keyRatchet_; - std::shared_ptr cryptor_; - KeyGeneration currentKeyGeneration_{0}; - truncated_sync_nonce truncatedNonce_{0}; + std::mutex keyGenMutex_; + std::unique_ptr keyRatchet_; + std::shared_ptr cryptor_; + KeyGeneration currentKeyGeneration_{0}; + truncated_sync_nonce truncatedNonce_{0}; - std::mutex frameProcessorsMutex_; - std::vector> frameProcessors_; + std::mutex frameProcessorsMutex_; + std::vector> frameProcessors_; - using SsrcCodecPair = std::pair; - std::vector ssrcCodecPairs_; + using SsrcCodecPair = std::pair; + std::vector ssrcCodecPairs_; - using TimePoint = std::chrono::time_point; - TimePoint lastStatsTime_{TimePoint::min()}; - std::array stats_; + using TimePoint = std::chrono::time_point; + TimePoint lastStatsTime_{TimePoint::min()}; + std::array stats_; - protocol_version_changed_callback protocolVersionChangedCallback_; - ProtocolVersion currentProtocolVersion_{MaxSupportedProtocolVersion()}; + protocol_version_changed_callback protocolVersionChangedCallback_; + ProtocolVersion currentProtocolVersion_{MaxSupportedProtocolVersion()}; }; } // namespace dpp::dave diff --git a/src/dpp/dave/frame_processors.cpp b/src/dpp/dave/frame_processors.cpp index 9e313a2be3..177327232b 100755 --- a/src/dpp/dave/frame_processors.cpp +++ b/src/dpp/dave/frame_processors.cpp @@ -41,372 +41,372 @@ namespace dpp::dave { std::pair OverflowAdd(size_t a, size_t b) { - size_t res; + size_t res; #if defined(_MSC_VER) && defined(_M_X64) - bool didOverflow = _addcarry_u64(0, a, b, &res); + bool didOverflow = _addcarry_u64(0, a, b, &res); #elif defined(_MSC_VER) && defined(_M_IX86) - bool didOverflow = _addcarry_u32(0, a, b, &res); + bool didOverflow = _addcarry_u32(0, a, b, &res); #else - bool didOverflow = __builtin_add_overflow(a, b, &res); + bool didOverflow = __builtin_add_overflow(a, b, &res); #endif - return {didOverflow, res}; + return {didOverflow, res}; } uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges) { - size_t size = 0; - for (const auto& range : unencryptedRanges) { - size += Leb128Size(range.offset); - size += Leb128Size(range.size); - } - assert(size <= std::numeric_limits::max() && - "Unencrypted ranges size exceeds 255 bytes"); - return static_cast(size); + size_t size = 0; + for (const auto& range : unencryptedRanges) { + size += Leb128Size(range.offset); + size += Leb128Size(range.size); + } + assert(size <= std::numeric_limits::max() && + "Unencrypted ranges size exceeds 255 bytes"); + return static_cast(size); } uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, - uint8_t* buffer, - size_t bufferSize) + uint8_t* buffer, + size_t bufferSize) { - auto writeAt = buffer; - auto end = buffer + bufferSize; - for (const auto& range : unencryptedRanges) { - auto rangeSize = Leb128Size(range.offset) + Leb128Size(range.size); - if (rangeSize > static_cast(end - writeAt)) { - assert(false && "Buffer is too small to serialize unencrypted ranges"); - break; - } - - writeAt += WriteLeb128(range.offset, writeAt); - writeAt += WriteLeb128(range.size, writeAt); - } - return writeAt - buffer; + auto writeAt = buffer; + auto end = buffer + bufferSize; + for (const auto& range : unencryptedRanges) { + auto rangeSize = Leb128Size(range.offset) + Leb128Size(range.size); + if (rangeSize > static_cast(end - writeAt)) { + assert(false && "Buffer is too small to serialize unencrypted ranges"); + break; + } + + writeAt += WriteLeb128(range.offset, writeAt); + writeAt += WriteLeb128(range.size, writeAt); + } + return writeAt - buffer; } uint8_t DeserializeUnencryptedRanges(const uint8_t*& readAt, - const size_t bufferSize, - Ranges& unencryptedRanges) + const size_t bufferSize, + Ranges& unencryptedRanges) { - auto start = readAt; - auto end = readAt + bufferSize; - while (readAt < end) { - size_t offset = ReadLeb128(readAt, end); - if (readAt == nullptr) { - break; - } - - size_t size = ReadLeb128(readAt, end); - if (readAt == nullptr) { - break; - } - unencryptedRanges.push_back({offset, size}); - } - - if (readAt != end) { - DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; - unencryptedRanges.clear(); - readAt = nullptr; - return 0; - } - - return readAt - start; + auto start = readAt; + auto end = readAt + bufferSize; + while (readAt < end) { + size_t offset = ReadLeb128(readAt, end); + if (readAt == nullptr) { + break; + } + + size_t size = ReadLeb128(readAt, end); + if (readAt == nullptr) { + break; + } + unencryptedRanges.push_back({offset, size}); + } + + if (readAt != end) { + DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; + unencryptedRanges.clear(); + readAt = nullptr; + return 0; + } + + return readAt - start; } bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize) { - if (unencryptedRanges.empty()) { - return true; - } - - // validate that the ranges are in order and don't overlap - for (auto i = 0u; i < unencryptedRanges.size(); ++i) { - auto current = unencryptedRanges[i]; - // The current range should not overflow into the next range - // or if it is the last range, the end of the frame - auto maxEnd = - i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; - - auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); - if (didOverflow || currentEnd > maxEnd) { - DISCORD_LOG(LS_WARNING) - << "Unencrypted range may overlap or be out of order: current offset: " - << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd - << ", frame size: " << frameSize; - return false; - } - } - - return true; + if (unencryptedRanges.empty()) { + return true; + } + + // validate that the ranges are in order and don't overlap + for (auto i = 0u; i < unencryptedRanges.size(); ++i) { + auto current = unencryptedRanges[i]; + // The current range should not overflow into the next range + // or if it is the last range, the end of the frame + auto maxEnd = + i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; + + auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); + if (didOverflow || currentEnd > maxEnd) { + DISCORD_LOG(LS_WARNING) + << "Unencrypted range may overlap or be out of order: current offset: " + << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd + << ", frame size: " << frameSize; + return false; + } + } + + return true; } size_t Reconstruct(Ranges ranges, - const std::vector& rangeBytes, - const std::vector& otherBytes, - const array_view& output) + const std::vector& rangeBytes, + const std::vector& otherBytes, + const array_view& output) { - size_t frameIndex = 0; - size_t rangeBytesIndex = 0; - size_t otherBytesIndex = 0; + size_t frameIndex = 0; + size_t rangeBytesIndex = 0; + size_t otherBytesIndex = 0; - const auto CopyRangeBytes = [&](size_t size) { - assert(rangeBytesIndex + size <= rangeBytes.size()); - assert(frameIndex + size <= output.size()); + const auto CopyRangeBytes = [&](size_t size) { + assert(rangeBytesIndex + size <= rangeBytes.size()); + assert(frameIndex + size <= output.size()); std::memcpy(output.data() + frameIndex, rangeBytes.data() + rangeBytesIndex, size); - rangeBytesIndex += size; - frameIndex += size; - }; - - const auto CopyOtherBytes = [&](size_t size) { - assert(otherBytesIndex + size <= otherBytes.size()); - assert(frameIndex + size <= output.size()); - std::memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); - otherBytesIndex += size; - frameIndex += size; - }; - - for (const auto& range : ranges) { - if (range.offset > frameIndex) { - CopyOtherBytes(range.offset - frameIndex); - } - - CopyRangeBytes(range.size); - } - - if (otherBytesIndex < otherBytes.size()) { - CopyOtherBytes(otherBytes.size() - otherBytesIndex); - } - - assert(rangeBytesIndex == rangeBytes.size()); - assert(otherBytesIndex == otherBytes.size()); - assert(frameIndex <= output.size()); - - return frameIndex; + rangeBytesIndex += size; + frameIndex += size; + }; + + const auto CopyOtherBytes = [&](size_t size) { + assert(otherBytesIndex + size <= otherBytes.size()); + assert(frameIndex + size <= output.size()); + std::memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); + otherBytesIndex += size; + frameIndex += size; + }; + + for (const auto& range : ranges) { + if (range.offset > frameIndex) { + CopyOtherBytes(range.offset - frameIndex); + } + + CopyRangeBytes(range.size); + } + + if (otherBytesIndex < otherBytes.size()) { + CopyOtherBytes(otherBytes.size() - otherBytesIndex); + } + + assert(rangeBytesIndex == rangeBytes.size()); + assert(otherBytesIndex == otherBytes.size()); + assert(frameIndex <= output.size()); + + return frameIndex; } void inbound_frame_processor::clear() { - isEncrypted_ = false; - originalSize_ = 0; - truncatedNonce_ = std::numeric_limits::max(); - unencryptedRanges_.clear(); - authenticated_.clear(); - ciphertext_.clear(); - plaintext_.clear(); + isEncrypted_ = false; + originalSize_ = 0; + truncatedNonce_ = std::numeric_limits::max(); + unencryptedRanges_.clear(); + authenticated_.clear(); + ciphertext_.clear(); + plaintext_.clear(); } void inbound_frame_processor::parse_frame(array_view frame) { - clear(); - - constexpr auto MinSupplementalBytesSize = - AES_GCM_127_TRUNCATED_TAG_BYTES + sizeof(supplemental_bytes_size) + sizeof(magic_marker); - if (frame.size() < MinSupplementalBytesSize) { - DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; - return; - } - - // Check the frame ends with the magic marker - auto magicMarkerBuffer = frame.end() - sizeof(magic_marker); - if (memcmp(magicMarkerBuffer, &MARKER_BYTES, sizeof(magic_marker)) != 0) { - return; - } - - // Read the supplemental bytes size - supplemental_bytes_size supplementalBytesSize; - auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(supplemental_bytes_size); - assert(frame.begin() <= supplementalBytesSizeBuffer && - supplementalBytesSizeBuffer <= frame.end()); - memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(supplemental_bytes_size)); - - // Check the frame is large enough to contain the supplemental bytes - if (frame.size() < supplementalBytesSize) { - DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; - return; - } - - // Check that supplemental bytes size is large enough to contain the supplemental bytes - if (supplementalBytesSize < MinSupplementalBytesSize) { - DISCORD_LOG(LS_WARNING) - << "Supplemental bytes size is too small to contain supplemental bytes"; - return; - } - - auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; - assert(frame.begin() <= supplementalBytesBuffer && supplementalBytesBuffer <= frame.end()); - - // Read the tag - tag_ = make_array_view(supplementalBytesBuffer, AES_GCM_127_TRUNCATED_TAG_BYTES); - - // Read the nonce - auto nonceBuffer = supplementalBytesBuffer + AES_GCM_127_TRUNCATED_TAG_BYTES; - assert(frame.begin() <= nonceBuffer && nonceBuffer <= frame.end()); - auto readAt = nonceBuffer; - auto end = supplementalBytesSizeBuffer; - truncatedNonce_ = ReadLeb128(readAt, end); - if (readAt == nullptr) { - DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; - return; - } - - // Read the unencrypted ranges - assert(nonceBuffer <= readAt && readAt <= end); - auto unencryptedRangesSize = end - readAt; - DeserializeUnencryptedRanges(readAt, unencryptedRangesSize, unencryptedRanges_); - if (readAt == nullptr) { - DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; - return; - } - - if (!ValidateUnencryptedRanges(unencryptedRanges_, frame.size())) { - DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; - return; - } - - // This is overly aggressive but will keep reallocations to a minimum - authenticated_.reserve(frame.size()); - ciphertext_.reserve(frame.size()); - plaintext_.reserve(frame.size()); - - originalSize_ = frame.size(); - - // Split the frame into authenticated and ciphertext bytes - size_t frameIndex = 0; - for (const auto& range : unencryptedRanges_) { - auto encryptedBytes = range.offset - frameIndex; - if (encryptedBytes > 0) { - assert(frameIndex + encryptedBytes <= frame.size()); + clear(); + + constexpr auto MinSupplementalBytesSize = + AES_GCM_127_TRUNCATED_TAG_BYTES + sizeof(supplemental_bytes_size) + sizeof(magic_marker); + if (frame.size() < MinSupplementalBytesSize) { + DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; + return; + } + + // Check the frame ends with the magic marker + auto magicMarkerBuffer = frame.end() - sizeof(magic_marker); + if (memcmp(magicMarkerBuffer, &MARKER_BYTES, sizeof(magic_marker)) != 0) { + return; + } + + // Read the supplemental bytes size + supplemental_bytes_size supplementalBytesSize; + auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(supplemental_bytes_size); + assert(frame.begin() <= supplementalBytesSizeBuffer && + supplementalBytesSizeBuffer <= frame.end()); + memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(supplemental_bytes_size)); + + // Check the frame is large enough to contain the supplemental bytes + if (frame.size() < supplementalBytesSize) { + DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; + return; + } + + // Check that supplemental bytes size is large enough to contain the supplemental bytes + if (supplementalBytesSize < MinSupplementalBytesSize) { + DISCORD_LOG(LS_WARNING) + << "Supplemental bytes size is too small to contain supplemental bytes"; + return; + } + + auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; + assert(frame.begin() <= supplementalBytesBuffer && supplementalBytesBuffer <= frame.end()); + + // Read the tag + tag_ = make_array_view(supplementalBytesBuffer, AES_GCM_127_TRUNCATED_TAG_BYTES); + + // Read the nonce + auto nonceBuffer = supplementalBytesBuffer + AES_GCM_127_TRUNCATED_TAG_BYTES; + assert(frame.begin() <= nonceBuffer && nonceBuffer <= frame.end()); + auto readAt = nonceBuffer; + auto end = supplementalBytesSizeBuffer; + truncatedNonce_ = ReadLeb128(readAt, end); + if (readAt == nullptr) { + DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; + return; + } + + // Read the unencrypted ranges + assert(nonceBuffer <= readAt && readAt <= end); + auto unencryptedRangesSize = end - readAt; + DeserializeUnencryptedRanges(readAt, unencryptedRangesSize, unencryptedRanges_); + if (readAt == nullptr) { + DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; + return; + } + + if (!ValidateUnencryptedRanges(unencryptedRanges_, frame.size())) { + DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; + return; + } + + // This is overly aggressive but will keep reallocations to a minimum + authenticated_.reserve(frame.size()); + ciphertext_.reserve(frame.size()); + plaintext_.reserve(frame.size()); + + originalSize_ = frame.size(); + + // Split the frame into authenticated and ciphertext bytes + size_t frameIndex = 0; + for (const auto& range : unencryptedRanges_) { + auto encryptedBytes = range.offset - frameIndex; + if (encryptedBytes > 0) { + assert(frameIndex + encryptedBytes <= frame.size()); add_ciphertext_bytes(frame.data() + frameIndex, encryptedBytes); - } - - assert(range.offset + range.size <= frame.size()); - add_authenticated_bytes(frame.data() + range.offset, range.size); - frameIndex = range.offset + range.size; - } - auto actualFrameSize = frame.size() - supplementalBytesSize; - if (frameIndex < actualFrameSize) { - add_ciphertext_bytes(frame.data() + frameIndex, actualFrameSize - frameIndex); - } - - // Make sure the plaintext buffer is the same size as the ciphertext buffer - plaintext_.resize(ciphertext_.size()); - - // We've successfully parsed the frame - // Mark the frame as encrypted - isEncrypted_ = true; + } + + assert(range.offset + range.size <= frame.size()); + add_authenticated_bytes(frame.data() + range.offset, range.size); + frameIndex = range.offset + range.size; + } + auto actualFrameSize = frame.size() - supplementalBytesSize; + if (frameIndex < actualFrameSize) { + add_ciphertext_bytes(frame.data() + frameIndex, actualFrameSize - frameIndex); + } + + // Make sure the plaintext buffer is the same size as the ciphertext buffer + plaintext_.resize(ciphertext_.size()); + + // We've successfully parsed the frame + // Mark the frame as encrypted + isEncrypted_ = true; } size_t inbound_frame_processor::reconstruct_frame(array_view frame) const { - if (!isEncrypted_) { - DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; - return 0; - } + if (!isEncrypted_) { + DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; + return 0; + } - if (authenticated_.size() + plaintext_.size() > frame.size()) { - DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; - return 0; - } + if (authenticated_.size() + plaintext_.size() > frame.size()) { + DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; + return 0; + } - return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); + return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); } void inbound_frame_processor::add_authenticated_bytes(const uint8_t* data, size_t size) { - authenticated_.resize(authenticated_.size() + size); - memcpy(authenticated_.data() + authenticated_.size() - size, data, size); + authenticated_.resize(authenticated_.size() + size); + memcpy(authenticated_.data() + authenticated_.size() - size, data, size); } void inbound_frame_processor::add_ciphertext_bytes(const uint8_t* data, size_t size) { - ciphertext_.resize(ciphertext_.size() + size); - memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); + ciphertext_.resize(ciphertext_.size() + size); + memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); } void outbound_frame_processor::reset() { - codec_ = Codec::Unknown; - frameIndex_ = 0; - unencryptedBytes_.clear(); - encryptedBytes_.clear(); - unencryptedRanges_.clear(); + codec_ = Codec::Unknown; + frameIndex_ = 0; + unencryptedBytes_.clear(); + encryptedBytes_.clear(); + unencryptedRanges_.clear(); } void outbound_frame_processor::process_frame(array_view frame, Codec codec) { reset(); - codec_ = codec; - unencryptedBytes_.reserve(frame.size()); - encryptedBytes_.reserve(frame.size()); - - bool success = false; - switch (codec) { - case Codec::Opus: - success = codec_utils::process_frame_opus(*this, frame); - break; - case Codec::VP8: - success = codec_utils::process_frame_vp8(*this, frame); - break; - case Codec::VP9: - success = codec_utils::process_frame_vp9(*this, frame); - break; - case Codec::H264: - success = codec_utils::process_frame_h264(*this, frame); - break; - case Codec::H265: - success = codec_utils::process_frame_h265(*this, frame); - break; - case Codec::AV1: - success = codec_utils::process_frame_av1(*this, frame); - break; - default: - assert(false && "Unsupported codec for frame encryption"); - break; - } - - if (!success) { - frameIndex_ = 0; - unencryptedBytes_.clear(); - encryptedBytes_.clear(); - unencryptedRanges_.clear(); - add_encrypted_bytes(frame.data(), frame.size()); - } - - ciphertextBytes_.resize(encryptedBytes_.size()); + codec_ = codec; + unencryptedBytes_.reserve(frame.size()); + encryptedBytes_.reserve(frame.size()); + + bool success = false; + switch (codec) { + case Codec::Opus: + success = codec_utils::process_frame_opus(*this, frame); + break; + case Codec::VP8: + success = codec_utils::process_frame_vp8(*this, frame); + break; + case Codec::VP9: + success = codec_utils::process_frame_vp9(*this, frame); + break; + case Codec::H264: + success = codec_utils::process_frame_h264(*this, frame); + break; + case Codec::H265: + success = codec_utils::process_frame_h265(*this, frame); + break; + case Codec::AV1: + success = codec_utils::process_frame_av1(*this, frame); + break; + default: + assert(false && "Unsupported codec for frame encryption"); + break; + } + + if (!success) { + frameIndex_ = 0; + unencryptedBytes_.clear(); + encryptedBytes_.clear(); + unencryptedRanges_.clear(); + add_encrypted_bytes(frame.data(), frame.size()); + } + + ciphertextBytes_.resize(encryptedBytes_.size()); } size_t outbound_frame_processor::reconstruct_frame(array_view frame) { - if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { - DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; - return 0; - } + if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { + DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; + return 0; + } - return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); + return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); } void outbound_frame_processor::add_unencrypted_bytes(const uint8_t* bytes, size_t size) { - if (!unencryptedRanges_.empty() && - unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { - // extend the last range - unencryptedRanges_.back().size += size; - } - else { - // add a new range (offset, size) - unencryptedRanges_.push_back({frameIndex_, size}); - } - - unencryptedBytes_.resize(unencryptedBytes_.size() + size); - memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); - frameIndex_ += size; + if (!unencryptedRanges_.empty() && + unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { + // extend the last range + unencryptedRanges_.back().size += size; + } + else { + // add a new range (offset, size) + unencryptedRanges_.push_back({frameIndex_, size}); + } + + unencryptedBytes_.resize(unencryptedBytes_.size() + size); + memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); + frameIndex_ += size; } void outbound_frame_processor::add_encrypted_bytes(const uint8_t* bytes, size_t size) { - encryptedBytes_.resize(encryptedBytes_.size() + size); - memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); - frameIndex_ += size; + encryptedBytes_.resize(encryptedBytes_.size() + size); + memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); + frameIndex_ += size; } } // namespace dpp::dave diff --git a/src/dpp/dave/frame_processors.h b/src/dpp/dave/frame_processors.h index 9a5bb63d19..224bf5ab1e 100755 --- a/src/dpp/dave/frame_processors.h +++ b/src/dpp/dave/frame_processors.h @@ -34,77 +34,77 @@ namespace dpp::dave { struct Range { - size_t offset; - size_t size; + size_t offset; + size_t size; }; using Ranges = std::vector; uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges); uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, - uint8_t* buffer, - size_t bufferSize); + uint8_t* buffer, + size_t bufferSize); uint8_t DeserializeUnencryptedRanges(const uint8_t*& buffer, - const size_t bufferSize, - Ranges& unencryptedRanges); + const size_t bufferSize, + Ranges& unencryptedRanges); bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize); class inbound_frame_processor { public: - void parse_frame(array_view frame); - [[nodiscard]] size_t reconstruct_frame(array_view frame) const; + void parse_frame(array_view frame); + [[nodiscard]] size_t reconstruct_frame(array_view frame) const; - [[nodiscard]] bool is_encrypted() const { return isEncrypted_; } - [[nodiscard]] size_t size() const { return originalSize_; } - void clear(); + [[nodiscard]] bool is_encrypted() const { return isEncrypted_; } + [[nodiscard]] size_t size() const { return originalSize_; } + void clear(); - [[nodiscard]] array_view get_tag() const { return tag_; } - [[nodiscard]] truncated_sync_nonce get_truncated_nonce() const { return truncatedNonce_; } - array_view get_authenticated_data() const - { - return make_array_view(authenticated_.data(), authenticated_.size()); - } - array_view GetCiphertext() const - { - return make_array_view(ciphertext_.data(), ciphertext_.size()); - } - array_view get_plaintext() { return make_array_view(plaintext_); } + [[nodiscard]] array_view get_tag() const { return tag_; } + [[nodiscard]] truncated_sync_nonce get_truncated_nonce() const { return truncatedNonce_; } + array_view get_authenticated_data() const + { + return make_array_view(authenticated_.data(), authenticated_.size()); + } + array_view GetCiphertext() const + { + return make_array_view(ciphertext_.data(), ciphertext_.size()); + } + array_view get_plaintext() { return make_array_view(plaintext_); } private: - void add_authenticated_bytes(const uint8_t* data, size_t size); - void add_ciphertext_bytes(const uint8_t* data, size_t size); + void add_authenticated_bytes(const uint8_t* data, size_t size); + void add_ciphertext_bytes(const uint8_t* data, size_t size); - bool isEncrypted_{false}; - size_t originalSize_{0}; - array_view tag_; - truncated_sync_nonce truncatedNonce_; - Ranges unencryptedRanges_; - std::vector authenticated_; - std::vector ciphertext_; - std::vector plaintext_; + bool isEncrypted_{false}; + size_t originalSize_{0}; + array_view tag_; + truncated_sync_nonce truncatedNonce_; + Ranges unencryptedRanges_; + std::vector authenticated_; + std::vector ciphertext_; + std::vector plaintext_; }; class outbound_frame_processor { public: - void process_frame(array_view frame, Codec codec); - size_t reconstruct_frame(array_view frame); + void process_frame(array_view frame, Codec codec); + size_t reconstruct_frame(array_view frame); - Codec get_codec() const { return codec_; } - [[nodiscard]] const std::vector& get_unencrypted_bytes() const { return unencryptedBytes_; } - [[nodiscard]] const std::vector& get_encrypted_bytes() const { return encryptedBytes_; } - [[nodiscard]] std::vector& get_ciphertext_bytes() { return ciphertextBytes_; } - [[nodiscard]] const Ranges& get_unencrypted_ranges() const { return unencryptedRanges_; } + Codec get_codec() const { return codec_; } + [[nodiscard]] const std::vector& get_unencrypted_bytes() const { return unencryptedBytes_; } + [[nodiscard]] const std::vector& get_encrypted_bytes() const { return encryptedBytes_; } + [[nodiscard]] std::vector& get_ciphertext_bytes() { return ciphertextBytes_; } + [[nodiscard]] const Ranges& get_unencrypted_ranges() const { return unencryptedRanges_; } - void reset(); - void add_unencrypted_bytes(const uint8_t* bytes, size_t size); - void add_encrypted_bytes(const uint8_t* bytes, size_t size); + void reset(); + void add_unencrypted_bytes(const uint8_t* bytes, size_t size); + void add_encrypted_bytes(const uint8_t* bytes, size_t size); private: - Codec codec_{Codec::Unknown}; - size_t frameIndex_{0}; - std::vector unencryptedBytes_; - std::vector encryptedBytes_; - std::vector ciphertextBytes_; - Ranges unencryptedRanges_; + Codec codec_{Codec::Unknown}; + size_t frameIndex_{0}; + std::vector unencryptedBytes_; + std::vector encryptedBytes_; + std::vector ciphertextBytes_; + Ranges unencryptedRanges_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/key_ratchet.h b/src/dpp/dave/key_ratchet.h index 1fd84d7d1a..64df8bc24e 100755 --- a/src/dpp/dave/key_ratchet.h +++ b/src/dpp/dave/key_ratchet.h @@ -33,9 +33,9 @@ using KeyGeneration = uint32_t; class IKeyRatchet { public: - virtual ~IKeyRatchet() noexcept = default; - virtual encryption_key GetKey(KeyGeneration generation) noexcept = 0; - virtual void DeleteKey(KeyGeneration generation) noexcept = 0; + virtual ~IKeyRatchet() noexcept = default; + virtual encryption_key GetKey(KeyGeneration generation) noexcept = 0; + virtual void DeleteKey(KeyGeneration generation) noexcept = 0; }; } // namespace dpp::dave diff --git a/src/dpp/dave/leb128.cpp b/src/dpp/dave/leb128.cpp index 6f168ad6e2..7c86aff298 100755 --- a/src/dpp/dave/leb128.cpp +++ b/src/dpp/dave/leb128.cpp @@ -34,51 +34,51 @@ namespace dpp::dave { size_t Leb128Size(uint64_t value) { - int size = 0; - while (value >= 0x80) { - ++size; - value >>= 7; - } - return size + 1; + int size = 0; + while (value >= 0x80) { + ++size; + value >>= 7; + } + return size + 1; } uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end) { - uint64_t value = 0; - int fillBits = 0; - while (readAt != end && fillBits < 64 - 7) { - uint8_t leb128Byte = *readAt; - value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; - ++readAt; - fillBits += 7; - if ((leb128Byte & 0x80) == 0) { - return value; - } - } - // Read 9 bytes and didn't find the terminator byte. Check if 10th byte - // is that terminator, however to fit result into uint64_t it may carry only - // single bit. - if (readAt != end && *readAt <= 1) { - value |= uint64_t{*readAt} << fillBits; - ++readAt; - return value; - } - // Failed to find terminator leb128 byte. - readAt = nullptr; - return 0; + uint64_t value = 0; + int fillBits = 0; + while (readAt != end && fillBits < 64 - 7) { + uint8_t leb128Byte = *readAt; + value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; + ++readAt; + fillBits += 7; + if ((leb128Byte & 0x80) == 0) { + return value; + } + } + // Read 9 bytes and didn't find the terminator byte. Check if 10th byte + // is that terminator, however to fit result into uint64_t it may carry only + // single bit. + if (readAt != end && *readAt <= 1) { + value |= uint64_t{*readAt} << fillBits; + ++readAt; + return value; + } + // Failed to find terminator leb128 byte. + readAt = nullptr; + return 0; } size_t WriteLeb128(uint64_t value, uint8_t* buffer) { - int size = 0; - while (value >= 0x80) { - buffer[size] = 0x80 | (value & 0x7F); - ++size; - value >>= 7; - } - buffer[size] = value; - ++size; - return size; + int size = 0; + while (value >= 0x80) { + buffer[size] = 0x80 | (value & 0x7F); + ++size; + value >>= 7; + } + buffer[size] = value; + ++size; + return size; } } // namespace dpp::dave diff --git a/src/dpp/dave/logger.cpp b/src/dpp/dave/logger.cpp index a0e855a7ad..b847e4ecd9 100755 --- a/src/dpp/dave/logger.cpp +++ b/src/dpp/dave/logger.cpp @@ -34,7 +34,7 @@ std::atomic gLogSink = nullptr; void SetLogSink(LogSink sink) { - gLogSink = sink; + gLogSink = sink; } LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) @@ -46,32 +46,32 @@ LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) LogStreamer::~LogStreamer() { - std::string logLine = stream_.str(); - if (logLine.empty()) { - return; - } + std::string logLine = stream_.str(); + if (logLine.empty()) { + return; + } - auto sink = gLogSink.load(); - if (sink) { - sink(severity_, file_, line_, logLine); - return; - } + auto sink = gLogSink.load(); + if (sink) { + sink(severity_, file_, line_, logLine); + return; + } - switch (severity_) { - case LS_VERBOSE: - case LS_INFO: - case LS_WARNING: - case LS_ERROR: { - const char* file = file_; - if (auto separator = strrchr(file, '/')) { - file = separator + 1; - } - std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; - break; - } - case LS_NONE: - break; - } + switch (severity_) { + case LS_VERBOSE: + case LS_INFO: + case LS_WARNING: + case LS_ERROR: { + const char* file = file_; + if (auto separator = strrchr(file, '/')) { + file = separator + 1; + } + std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; + break; + } + case LS_NONE: + break; + } } } // namespace dpp::dave diff --git a/src/dpp/dave/logger.h b/src/dpp/dave/logger.h index 8e4b40fcd5..8b4dca8c14 100755 --- a/src/dpp/dave/logger.h +++ b/src/dpp/dave/logger.h @@ -34,36 +34,36 @@ namespace dpp::dave { enum LoggingSeverity { - LS_VERBOSE, - LS_INFO, - LS_WARNING, - LS_ERROR, - LS_NONE, + LS_VERBOSE, + LS_INFO, + LS_WARNING, + LS_ERROR, + LS_NONE, }; using LogSink = void (*)(LoggingSeverity severity, - const char* file, - int line, - const std::string& message); + const char* file, + int line, + const std::string& message); void SetLogSink(LogSink sink); class LogStreamer { public: - LogStreamer(LoggingSeverity severity, const char* file, int line); - ~LogStreamer(); + LogStreamer(LoggingSeverity severity, const char* file, int line); + ~LogStreamer(); - template - LogStreamer& operator<<(const T& value) - { - stream_ << value; - return *this; - } + template + LogStreamer& operator<<(const T& value) + { + stream_ << value; + return *this; + } private: - LoggingSeverity severity_; - const char* file_; - int line_; - std::ostringstream stream_; + LoggingSeverity severity_; + const char* file_; + int line_; + std::ostringstream stream_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/mls_key_ratchet.cpp b/src/dpp/dave/mls_key_ratchet.cpp index eb1d34b2b3..0a349a77e6 100755 --- a/src/dpp/dave/mls_key_ratchet.cpp +++ b/src/dpp/dave/mls_key_ratchet.cpp @@ -39,23 +39,23 @@ MlsKeyRatchet::~MlsKeyRatchet() noexcept = default; encryption_key MlsKeyRatchet::GetKey(KeyGeneration generation) noexcept { - DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; - - try { - auto keyAndNonce = hashRatchet_.get(generation); - assert(keyAndNonce.key.size() >= AES_GCM_128_KEY_BYTES); - return std::move(keyAndNonce.key.as_vec()); - } - catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " - << e.what(); - return {}; - } + DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; + + try { + auto keyAndNonce = hashRatchet_.get(generation); + assert(keyAndNonce.key.size() >= AES_GCM_128_KEY_BYTES); + return std::move(keyAndNonce.key.as_vec()); + } + catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " + << e.what(); + return {}; + } } void MlsKeyRatchet::DeleteKey(KeyGeneration generation) noexcept { - hashRatchet_.erase(generation); + hashRatchet_.erase(generation); } } // namespace dpp::dave diff --git a/src/dpp/dave/mls_key_ratchet.h b/src/dpp/dave/mls_key_ratchet.h index 74d8b4bde4..f6489e4ad1 100755 --- a/src/dpp/dave/mls_key_ratchet.h +++ b/src/dpp/dave/mls_key_ratchet.h @@ -31,14 +31,14 @@ namespace dpp::dave { class MlsKeyRatchet : public IKeyRatchet { public: - MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; - ~MlsKeyRatchet() noexcept override; + MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; + ~MlsKeyRatchet() noexcept override; - encryption_key GetKey(KeyGeneration generation) noexcept override; - void DeleteKey(KeyGeneration generation) noexcept override; + encryption_key GetKey(KeyGeneration generation) noexcept override; + void DeleteKey(KeyGeneration generation) noexcept override; private: - ::mlspp::HashRatchet hashRatchet_; + ::mlspp::HashRatchet hashRatchet_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/openssl_aead_cipher.cpp b/src/dpp/dave/openssl_aead_cipher.cpp index 545e04d7ee..90532ffba6 100755 --- a/src/dpp/dave/openssl_aead_cipher.cpp +++ b/src/dpp/dave/openssl_aead_cipher.cpp @@ -31,12 +31,12 @@ namespace dpp::dave { void PrintSSLErrors() { - ERR_print_errors_cb( - [](const char* str, size_t len, void* ctx) -> int { - DISCORD_LOG(LS_ERROR) << std::string(str, len); - return 1; - }, - nullptr); + ERR_print_errors_cb( + [](const char* str, size_t len, void* ctx) -> int { + DISCORD_LOG(LS_ERROR) << std::string(str, len); + return 1; + }, + nullptr); } openssl_aead_cipher::openssl_aead_cipher(const encryption_key& encryptionKey) : diff --git a/src/dpp/dave/parameters.cpp b/src/dpp/dave/parameters.cpp index b4ddba3909..fbc436e597 100755 --- a/src/dpp/dave/parameters.cpp +++ b/src/dpp/dave/parameters.cpp @@ -28,50 +28,50 @@ namespace dpp::dave::mls { ::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion(ProtocolVersion version) noexcept { - return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; + return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; } ::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept { - return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; + return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; } ::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion(SignatureVersion version) noexcept { - return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; + return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; } ::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept { - return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; + return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; } ::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept { - auto capabilities = ::mlspp::Capabilities::create_default(); + auto capabilities = ::mlspp::Capabilities::create_default(); - capabilities.cipher_suites = {CiphersuiteIDForProtocolVersion(version)}; - capabilities.credentials = {::mlspp::CredentialType::basic}; + capabilities.cipher_suites = {CiphersuiteIDForProtocolVersion(version)}; + capabilities.credentials = {::mlspp::CredentialType::basic}; - return capabilities; + return capabilities; } ::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion(ProtocolVersion version) noexcept { - return ::mlspp::ExtensionList{}; + return ::mlspp::ExtensionList{}; } ::mlspp::ExtensionList GroupExtensionsForProtocolVersion( ProtocolVersion version, const ::mlspp::ExternalSender& externalSender) noexcept { - auto extensionList = ::mlspp::ExtensionList{}; + auto extensionList = ::mlspp::ExtensionList{}; - extensionList.add(::mlspp::ExternalSendersExtension{{ - {externalSender.signature_key, externalSender.credential}, - }}); + extensionList.add(::mlspp::ExternalSendersExtension{{ + {externalSender.signature_key, externalSender.credential}, + }}); - return extensionList; + return extensionList; } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/persisted_key_pair.cpp b/src/dpp/dave/persisted_key_pair.cpp index 23f56bd84c..014f8ae890 100755 --- a/src/dpp/dave/persisted_key_pair.cpp +++ b/src/dpp/dave/persisted_key_pair.cpp @@ -43,8 +43,8 @@ static const std::string SelfSignatureLabel = "DiscordSelfSignature"; static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) { - return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + - std::to_string(dpp::dave::mls::KeyVersion); + return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + + std::to_string(dpp::dave::mls::KeyVersion); } static std::mutex mtx; @@ -57,64 +57,64 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( const std::string& sessionID, ::mlspp::CipherSuite suite) { - std::lock_guard lk(mtx); + std::lock_guard lk(mtx); - std::string id = MakeKeyID(sessionID, suite); + std::string id = MakeKeyID(sessionID, suite); - if (auto it = map.find(id); it != map.end()) { - return it->second; - } + if (auto it = map.find(id); it != map.end()) { + return it->second; + } - std::shared_ptr<::mlspp::SignaturePrivateKey> ret = ::dpp::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite); + std::shared_ptr<::mlspp::SignaturePrivateKey> ret = ::dpp::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite); - if (!ret) { - DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair"; - return nullptr; - } + if (!ret) { + DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair"; + return nullptr; + } - map.emplace(id, ret); + map.emplace(id, ret); - return ret; + return ret; } std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, - const std::string& sessionID, - ProtocolVersion version) + const std::string& sessionID, + ProtocolVersion version) { - return GetPersistedKeyPair(ctx, sessionID, CiphersuiteForProtocolVersion(version)); + return GetPersistedKeyPair(ctx, sessionID, CiphersuiteForProtocolVersion(version)); } KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, - const std::string& sessionID, - SignatureVersion version) + const std::string& sessionID, + SignatureVersion version) { - auto suite = CiphersuiteForSignatureVersion(version); + auto suite = CiphersuiteForSignatureVersion(version); - auto pair = GetPersistedKeyPair(ctx, sessionID, suite); + auto pair = GetPersistedKeyPair(ctx, sessionID, suite); - if (!pair) { - return {}; - } + if (!pair) { + return {}; + } - bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; + bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; - return { - pair->public_key.data.as_vec(), - std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), - }; + return { + pair->public_key.data.as_vec(), + std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), + }; } bool DeletePersistedKeyPair(KeyPairContextType ctx, - const std::string& sessionID, - SignatureVersion version) + const std::string& sessionID, + SignatureVersion version) { - std::string id = MakeKeyID(sessionID, CiphersuiteForSignatureVersion(version)); + std::string id = MakeKeyID(sessionID, CiphersuiteForSignatureVersion(version)); - std::lock_guard lk(mtx); + std::lock_guard lk(mtx); - map.erase(id); + map.erase(id); - return ::dpp::dave::mls::detail::DeleteGenericPersistedKeyPair(ctx, id); + return ::dpp::dave::mls::detail::DeleteGenericPersistedKeyPair(ctx, id); } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/persisted_key_pair.h b/src/dpp/dave/persisted_key_pair.h index 65bb93ba30..82af384a7b 100755 --- a/src/dpp/dave/persisted_key_pair.h +++ b/src/dpp/dave/persisted_key_pair.h @@ -47,21 +47,21 @@ namespace dpp::dave::mls { using KeyPairContextType = const char *; std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, - const std::string& sessionID, - ProtocolVersion version); + const std::string& sessionID, + ProtocolVersion version); struct KeyAndSelfSignature { - std::vector key; - std::vector signature; + std::vector key; + std::vector signature; }; KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, - const std::string& sessionID, - SignatureVersion version); + const std::string& sessionID, + SignatureVersion version); bool DeletePersistedKeyPair(KeyPairContextType ctx, - const std::string& sessionID, - SignatureVersion version); + const std::string& sessionID, + SignatureVersion version); constexpr unsigned KeyVersion = 1; diff --git a/src/dpp/dave/persisted_key_pair_generic.cpp b/src/dpp/dave/persisted_key_pair_generic.cpp index 7bd3411f1c..255375318b 100755 --- a/src/dpp/dave/persisted_key_pair_generic.cpp +++ b/src/dpp/dave/persisted_key_pair_generic.cpp @@ -49,148 +49,148 @@ static const std::string_view KeyStorageDir = "Discord Key Storage"; static std::filesystem::path GetKeyStorageDirectory() { - std::filesystem::path dir; + std::filesystem::path dir; #if defined(__ANDROID__) - dir = std::filesystem::path("/data/data"); - - { - std::ifstream idFile("/proc/self/cmdline", std::ios_base::in); - std::string appId; - std::getline(idFile, appId, '\0'); - dir /= appId; - } + dir = std::filesystem::path("/data/data"); + + { + std::ifstream idFile("/proc/self/cmdline", std::ios_base::in); + std::string appId; + std::getline(idFile, appId, '\0'); + dir /= appId; + } #else // __ANDROID__ #if defined(_WIN32) - if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { - dir = std::filesystem::path(appdata); - } + if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { + dir = std::filesystem::path(appdata); + } #else // _WIN32 - if (const char* xdg = getenv("XDG_CONFIG_HOME")) { - dir = std::filesystem::path(xdg); - } - else if (const char* home = getenv("HOME")) { - dir = std::filesystem::path(home); - dir /= ".config"; - } + if (const char* xdg = getenv("XDG_CONFIG_HOME")) { + dir = std::filesystem::path(xdg); + } + else if (const char* home = getenv("HOME")) { + dir = std::filesystem::path(home); + dir /= ".config"; + } #endif // !_WIN32 - else { - return dir; - } + else { + return dir; + } #endif // !__ANDROID__ - return dir / KeyStorageDir; + return dir / KeyStorageDir; } namespace dpp::dave::mls::detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPairContextType ctx, - const std::string& id, - ::mlspp::CipherSuite suite) + const std::string& id, + ::mlspp::CipherSuite suite) { - ::mlspp::SignaturePrivateKey ret; - std::string curstr; - std::filesystem::path dir = GetKeyStorageDirectory(); - - if (dir.empty()) { - DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; - return nullptr; - } - - std::error_code errc; - std::filesystem::create_directories(dir, errc); - if (errc) { - DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in GetPersistedKeyPair: " - << errc; - return nullptr; - } - - std::filesystem::path file = dir / (id + ".key"); - - if (std::filesystem::exists(file)) { - std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); - if (!ifs) { - DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair"; - return nullptr; - } + ::mlspp::SignaturePrivateKey ret; + std::string curstr; + std::filesystem::path dir = GetKeyStorageDirectory(); + + if (dir.empty()) { + DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; + return nullptr; + } + + std::error_code errc; + std::filesystem::create_directories(dir, errc); + if (errc) { + DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in GetPersistedKeyPair: " + << errc; + return nullptr; + } + + std::filesystem::path file = dir / (id + ".key"); + + if (std::filesystem::exists(file)) { + std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); + if (!ifs) { + DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair"; + return nullptr; + } std::stringstream s; s << ifs.rdbuf(); curstr = s.str(); - if (!ifs) { - DISCORD_LOG(LS_ERROR) << "Failed to read key in GetPersistedKeyPair"; - return nullptr; - } - - try { - ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); - } - catch (std::exception& ex) { - DISCORD_LOG(LS_ERROR) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); - return nullptr; - } - } - else { - ret = ::mlspp::SignaturePrivateKey::generate(suite); - - std::string newstr = ret.to_jwk(suite); - - std::filesystem::path tmpfile = file; - tmpfile += ".tmp"; + if (!ifs) { + DISCORD_LOG(LS_ERROR) << "Failed to read key in GetPersistedKeyPair"; + return nullptr; + } + + try { + ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); + } + catch (std::exception& ex) { + DISCORD_LOG(LS_ERROR) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); + return nullptr; + } + } + else { + ret = ::mlspp::SignaturePrivateKey::generate(suite); + + std::string newstr = ret.to_jwk(suite); + + std::filesystem::path tmpfile = file; + tmpfile += ".tmp"; #ifdef _WIN32 - int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); + int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); #else - int fd = open(tmpfile.c_str(), - O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, - S_IRUSR | S_IWUSR); + int fd = open(tmpfile.c_str(), + O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, + S_IRUSR | S_IWUSR); #endif - if (fd < 0) { - DISCORD_LOG(LS_ERROR) << "Failed to open output file in GetPersistedKeyPair: " << errno - << "(" << tmpfile << ")"; - return nullptr; - } + if (fd < 0) { + DISCORD_LOG(LS_ERROR) << "Failed to open output file in GetPersistedKeyPair: " << errno + << "(" << tmpfile << ")"; + return nullptr; + } #ifdef _WIN32 - int wret = _write(fd, newstr.c_str(), newstr.size()); - _close(fd); + int wret = _write(fd, newstr.c_str(), newstr.size()); + _close(fd); #else - ssize_t wret = write(fd, newstr.c_str(), newstr.size()); - close(fd); + ssize_t wret = write(fd, newstr.c_str(), newstr.size()); + close(fd); #endif - if (wret < 0 || (size_t)wret != newstr.size()) { - DISCORD_LOG(LS_ERROR) << "Failed to write output file in GetPersistedKeyPair: " - << errno; - return nullptr; - } - - std::filesystem::rename(tmpfile, file, errc); - if (errc) { - DISCORD_LOG(LS_ERROR) << "Failed to rename output file in GetPersistedKeyPair: " - << errc; - return nullptr; - } - } - - if (!ret.public_key.data.empty()) { - return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); - } - return nullptr; + if (wret < 0 || (size_t)wret != newstr.size()) { + DISCORD_LOG(LS_ERROR) << "Failed to write output file in GetPersistedKeyPair: " + << errno; + return nullptr; + } + + std::filesystem::rename(tmpfile, file, errc); + if (errc) { + DISCORD_LOG(LS_ERROR) << "Failed to rename output file in GetPersistedKeyPair: " + << errc; + return nullptr; + } + } + + if (!ret.public_key.data.empty()) { + return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); + } + return nullptr; } bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& id) { - std::error_code errc; - std::filesystem::path dir = GetKeyStorageDirectory(); - if (dir.empty()) { - DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; - return false; - } + std::error_code errc; + std::filesystem::path dir = GetKeyStorageDirectory(); + if (dir.empty()) { + DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; + return false; + } - std::filesystem::path file = dir / (id + ".key"); + std::filesystem::path file = dir / (id + ".key"); - return std::filesystem::remove(file, errc); + return std::filesystem::remove(file, errc); } } // namespace dpp::dave::mls::detail diff --git a/src/dpp/dave/scope_exit.h b/src/dpp/dave/scope_exit.h index eca6b7a583..219c779135 100755 --- a/src/dpp/dave/scope_exit.h +++ b/src/dpp/dave/scope_exit.h @@ -32,39 +32,39 @@ namespace dpp::dave { class [[nodiscard]] ScopeExit final { public: - template - explicit ScopeExit(Cleanup&& cleanup) - : cleanup_{std::forward(cleanup)} - { - } + template + explicit ScopeExit(Cleanup&& cleanup) + : cleanup_{std::forward(cleanup)} + { + } - ScopeExit(ScopeExit&& rhs) - : cleanup_{std::move(rhs.cleanup_)} - { - rhs.cleanup_ = nullptr; - } + ScopeExit(ScopeExit&& rhs) + : cleanup_{std::move(rhs.cleanup_)} + { + rhs.cleanup_ = nullptr; + } - ~ScopeExit() - { - if (cleanup_) { - cleanup_(); - } - } + ~ScopeExit() + { + if (cleanup_) { + cleanup_(); + } + } - ScopeExit& operator=(ScopeExit&& rhs) - { - cleanup_ = std::move(rhs.cleanup_); - rhs.cleanup_ = nullptr; - return *this; - } + ScopeExit& operator=(ScopeExit&& rhs) + { + cleanup_ = std::move(rhs.cleanup_); + rhs.cleanup_ = nullptr; + return *this; + } - void Dismiss() { cleanup_ = nullptr; } + void Dismiss() { cleanup_ = nullptr; } private: - ScopeExit(ScopeExit const&) = delete; - ScopeExit& operator=(ScopeExit const&) = delete; + ScopeExit(ScopeExit const&) = delete; + ScopeExit& operator=(ScopeExit const&) = delete; - std::function cleanup_; + std::function cleanup_; }; } // namespace dpp::dave diff --git a/src/dpp/dave/session.cpp b/src/dpp/dave/session.cpp index e52a5ac16b..56e8bcce54 100755 --- a/src/dpp/dave/session.cpp +++ b/src/dpp/dave/session.cpp @@ -44,763 +44,763 @@ #include "openssl/evp.h" -#define TRACK_MLS_ERROR(reason) \ - if (onMLSFailureCallback_) { \ - onMLSFailureCallback_(__FUNCTION__, reason); \ - } +#define TRACK_MLS_ERROR(reason) \ + if (onMLSFailureCallback_) { \ + onMLSFailureCallback_(__FUNCTION__, reason); \ + } namespace dpp::dave::mls { struct QueuedProposal { - ::mlspp::ValidatedContent content; - ::mlspp::bytes_ns::bytes ref; + ::mlspp::ValidatedContent content; + ::mlspp::bytes_ns::bytes ref; }; Session::Session(KeyPairContextType context, - const std::string& authSessionId, - MLSFailureCallback callback) noexcept + const std::string& authSessionId, + MLSFailureCallback callback) noexcept : signingKeyId_(authSessionId) , keyPairContext_(context) , onMLSFailureCallback_(std::move(callback)) { - DISCORD_LOG(LS_INFO) << "Creating a new MLS session"; + DISCORD_LOG(LS_INFO) << "Creating a new MLS session"; } Session::~Session() noexcept = default; void Session::Init(ProtocolVersion protocolVersion, - uint64_t groupId, - std::string const& selfUserId, - std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept + uint64_t groupId, + std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept { - Reset(); + Reset(); - selfUserId_ = selfUserId; + selfUserId_ = selfUserId; - DISCORD_LOG(LS_INFO) << "Initializing MLS session with protocol version " << protocolVersion - << " and group ID " << groupId; - protocolVersion_ = protocolVersion; - groupId_ = std::move(BigEndianBytesFrom(groupId).as_vec()); + DISCORD_LOG(LS_INFO) << "Initializing MLS session with protocol version " << protocolVersion + << " and group ID " << groupId; + protocolVersion_ = protocolVersion; + groupId_ = std::move(BigEndianBytesFrom(groupId).as_vec()); - InitLeafNode(selfUserId, transientKey); + InitLeafNode(selfUserId, transientKey); - CreatePendingGroup(); + CreatePendingGroup(); } void Session::Reset() noexcept { - DISCORD_LOG(LS_INFO) << "Resetting MLS session"; + DISCORD_LOG(LS_INFO) << "Resetting MLS session"; - ClearPendingState(); + ClearPendingState(); - currentState_.reset(); - outboundCachedGroupState_.reset(); + currentState_.reset(); + outboundCachedGroupState_.reset(); - protocolVersion_ = 0; - groupId_.clear(); + protocolVersion_ = 0; + groupId_.clear(); } void Session::SetProtocolVersion(ProtocolVersion version) noexcept { - if (version != protocolVersion_) { - // when we need to retain backwards compatibility - // there may be some changes to the MLS objects required here - // until then we can just update the stored version - protocolVersion_ = version; - } + if (version != protocolVersion_) { + // when we need to retain backwards compatibility + // there may be some changes to the MLS objects required here + // until then we can just update the stored version + protocolVersion_ = version; + } } std::vector Session::GetLastEpochAuthenticator() const noexcept { - if (!currentState_) { - DISCORD_LOG(LS_ERROR) << "Cannot get epoch authenticator without an established MLS group"; - return {}; - } + if (!currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot get epoch authenticator without an established MLS group"; + return {}; + } - return std::move(currentState_->epoch_authenticator().as_vec()); + return std::move(currentState_->epoch_authenticator().as_vec()); } void Session::SetExternalSender(const std::vector& marshalledExternalSender) noexcept try { - if (currentState_) { - DISCORD_LOG(LS_ERROR) << "Cannot set external sender after joining/creating an MLS group"; - return; - } + if (currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot set external sender after joining/creating an MLS group"; + return; + } - DISCORD_LOG(LS_INFO) << "Unmarshalling MLS external sender"; + DISCORD_LOG(LS_INFO) << "Unmarshalling MLS external sender"; - DISCORD_LOG(LS_INFO) << "Sender: " << ::mlspp::bytes_ns::bytes(marshalledExternalSender); + DISCORD_LOG(LS_INFO) << "Sender: " << ::mlspp::bytes_ns::bytes(marshalledExternalSender); - externalSender_ = std::make_unique<::mlspp::ExternalSender>( - ::mlspp::tls::get<::mlspp::ExternalSender>(marshalledExternalSender)); + externalSender_ = std::make_unique<::mlspp::ExternalSender>( + ::mlspp::tls::get<::mlspp::ExternalSender>(marshalledExternalSender)); - if (!groupId_.empty()) { - CreatePendingGroup(); - } + if (!groupId_.empty()) { + CreatePendingGroup(); + } } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to unmarshal external sender: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return; + DISCORD_LOG(LS_ERROR) << "Failed to unmarshal external sender: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return; } std::optional> Session::ProcessProposals( std::vector proposals, std::set const& recognizedUserIDs) noexcept try { - if (!pendingGroupState_ && !currentState_) { - DISCORD_LOG(LS_ERROR) - << "Cannot process proposals without any pending or established MLS group state"; - return std::nullopt; - } - - if (!stateWithProposals_) { - stateWithProposals_ = std::make_unique<::mlspp::State>( - pendingGroupState_ ? *pendingGroupState_ : *currentState_); - } + if (!pendingGroupState_ && !currentState_) { + DISCORD_LOG(LS_ERROR) + << "Cannot process proposals without any pending or established MLS group state"; + return std::nullopt; + } + + if (!stateWithProposals_) { + stateWithProposals_ = std::make_unique<::mlspp::State>( + pendingGroupState_ ? *pendingGroupState_ : *currentState_); + } - DISCORD_LOG(LS_INFO) << "Processing MLS proposals message of " << proposals.size() << " bytes"; - - DISCORD_LOG(LS_INFO) << "Proposals: " << ::mlspp::bytes_ns::bytes(proposals); - - ::mlspp::tls::istream inStream(proposals); - - bool isRevoke = false; - inStream >> isRevoke; - - DISCORD_LOG(LS_INFO) << "Revoking: " << isRevoke; - - const auto suite = stateWithProposals_->cipher_suite(); - - if (isRevoke) { - std::vector<::mlspp::bytes_ns::bytes> refs; - inStream >> refs; - - for (const auto& ref : refs) { - bool found = false; - for (auto it = proposalQueue_.begin(); it != proposalQueue_.end(); it++) { - if (it->ref == ref) { - found = true; - proposalQueue_.erase(it); - break; - } - } - - if (!found) { - DISCORD_LOG(LS_ERROR) << "Cannot revoke unrecognized proposal ref"; - TRACK_MLS_ERROR("Unrecognized proposal revocation"); - return std::nullopt; - } - } - - stateWithProposals_ = std::make_unique<::mlspp::State>( - pendingGroupState_ ? *pendingGroupState_ : *currentState_); - - for (auto& prop : proposalQueue_) { - // success will queue the proposal, failure will throw - stateWithProposals_->handle(prop.content); - } - } - else { - std::vector<::mlspp::MLSMessage> messages; - inStream >> messages; - - for (const auto& proposalMessage : messages) { - auto validatedMessage = stateWithProposals_->unwrap(proposalMessage); - - if (!ValidateProposalMessage(validatedMessage.authenticated_content(), - *stateWithProposals_, - recognizedUserIDs)) { - return std::nullopt; - } - - // success will queue the proposal, failure will throw - stateWithProposals_->handle(validatedMessage); - - auto ref = suite.ref(validatedMessage.authenticated_content()); - - proposalQueue_.push_back({ - std::move(validatedMessage), - std::move(ref), - }); - } - } - - // generate a commit - auto commitSecret = ::mlspp::hpke::random_bytes(suite.secret_size()); - - auto commitOpts = ::mlspp::CommitOpts{ - {}, // no extra proposals - true, // inline tree in welcome - false, // do not force path - {} // default leaf node options - }; - - auto [commitMessage, welcomeMessage, newState] = - stateWithProposals_->commit(commitSecret, commitOpts, {}); - - DISCORD_LOG(LS_INFO) - << "Prepared commit/welcome/next state for MLS group from received proposals"; - - // combine the commit and welcome messages into a single buffer - auto outStream = ::mlspp::tls::ostream(); - outStream << commitMessage; - - // keep a copy of the commit, we can check incoming pending group commit later for a match - pendingGroupCommit_ = std::make_unique<::mlspp::MLSMessage>(std::move(commitMessage)); - - // if there were any add proposals in this commit, then we also include the welcome message - if (welcomeMessage.secrets.size() > 0) { - outStream << welcomeMessage; - } + DISCORD_LOG(LS_INFO) << "Processing MLS proposals message of " << proposals.size() << " bytes"; + + DISCORD_LOG(LS_INFO) << "Proposals: " << ::mlspp::bytes_ns::bytes(proposals); + + ::mlspp::tls::istream inStream(proposals); + + bool isRevoke = false; + inStream >> isRevoke; + + DISCORD_LOG(LS_INFO) << "Revoking: " << isRevoke; + + const auto suite = stateWithProposals_->cipher_suite(); + + if (isRevoke) { + std::vector<::mlspp::bytes_ns::bytes> refs; + inStream >> refs; + + for (const auto& ref : refs) { + bool found = false; + for (auto it = proposalQueue_.begin(); it != proposalQueue_.end(); it++) { + if (it->ref == ref) { + found = true; + proposalQueue_.erase(it); + break; + } + } + + if (!found) { + DISCORD_LOG(LS_ERROR) << "Cannot revoke unrecognized proposal ref"; + TRACK_MLS_ERROR("Unrecognized proposal revocation"); + return std::nullopt; + } + } + + stateWithProposals_ = std::make_unique<::mlspp::State>( + pendingGroupState_ ? *pendingGroupState_ : *currentState_); + + for (auto& prop : proposalQueue_) { + // success will queue the proposal, failure will throw + stateWithProposals_->handle(prop.content); + } + } + else { + std::vector<::mlspp::MLSMessage> messages; + inStream >> messages; + + for (const auto& proposalMessage : messages) { + auto validatedMessage = stateWithProposals_->unwrap(proposalMessage); + + if (!ValidateProposalMessage(validatedMessage.authenticated_content(), + *stateWithProposals_, + recognizedUserIDs)) { + return std::nullopt; + } + + // success will queue the proposal, failure will throw + stateWithProposals_->handle(validatedMessage); + + auto ref = suite.ref(validatedMessage.authenticated_content()); + + proposalQueue_.push_back({ + std::move(validatedMessage), + std::move(ref), + }); + } + } + + // generate a commit + auto commitSecret = ::mlspp::hpke::random_bytes(suite.secret_size()); + + auto commitOpts = ::mlspp::CommitOpts{ + {}, // no extra proposals + true, // inline tree in welcome + false, // do not force path + {} // default leaf node options + }; + + auto [commitMessage, welcomeMessage, newState] = + stateWithProposals_->commit(commitSecret, commitOpts, {}); + + DISCORD_LOG(LS_INFO) + << "Prepared commit/welcome/next state for MLS group from received proposals"; + + // combine the commit and welcome messages into a single buffer + auto outStream = ::mlspp::tls::ostream(); + outStream << commitMessage; + + // keep a copy of the commit, we can check incoming pending group commit later for a match + pendingGroupCommit_ = std::make_unique<::mlspp::MLSMessage>(std::move(commitMessage)); + + // if there were any add proposals in this commit, then we also include the welcome message + if (welcomeMessage.secrets.size() > 0) { + outStream << welcomeMessage; + } - // cache the outbound state in case we're the winning sender - outboundCachedGroupState_ = std::make_unique<::mlspp::State>(std::move(newState)); + // cache the outbound state in case we're the winning sender + outboundCachedGroupState_ = std::make_unique<::mlspp::State>(std::move(newState)); - DISCORD_LOG(LS_INFO) << "Output: " << ::mlspp::bytes_ns::bytes(outStream.bytes()); + DISCORD_LOG(LS_INFO) << "Output: " << ::mlspp::bytes_ns::bytes(outStream.bytes()); - return outStream.bytes(); + return outStream.bytes(); } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to parse MLS proposals: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return std::nullopt; + DISCORD_LOG(LS_ERROR) << "Failed to parse MLS proposals: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return std::nullopt; } bool Session::IsRecognizedUserID(const ::mlspp::Credential& cred, - std::set const& recognizedUserIDs) const + std::set const& recognizedUserIDs) const { - std::string uid = UserCredentialToString(cred, protocolVersion_); - if (uid.empty()) { - DISCORD_LOG(LS_ERROR) << "Attempted to verify credential of unexpected type"; - return false; - } - - if (recognizedUserIDs.find(uid) == recognizedUserIDs.end()) { - DISCORD_LOG(LS_ERROR) << "Attempted to verify credential for unrecognized user ID: " << uid; - return false; - } - - return true; + std::string uid = UserCredentialToString(cred, protocolVersion_); + if (uid.empty()) { + DISCORD_LOG(LS_ERROR) << "Attempted to verify credential of unexpected type"; + return false; + } + + if (recognizedUserIDs.find(uid) == recognizedUserIDs.end()) { + DISCORD_LOG(LS_ERROR) << "Attempted to verify credential for unrecognized user ID: " << uid; + return false; + } + + return true; } bool Session::ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, - ::mlspp::State const& targetState, - std::set const& recognizedUserIDs) const + ::mlspp::State const& targetState, + std::set const& recognizedUserIDs) const { - if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { - DISCORD_LOG(LS_ERROR) << "MLS proposal message must be PublicMessage"; - TRACK_MLS_ERROR("Invalid proposal wire format"); - return false; - } - - if (message.content.epoch != targetState.epoch()) { - DISCORD_LOG(LS_ERROR) << "MLS proposal message must be for current epoch (" - << message.content.epoch << " != " << targetState.epoch() << ")"; - TRACK_MLS_ERROR("Proposal epoch mismatch"); - return false; - } - - if (message.content.content_type() != ::mlspp::ContentType::proposal) { - DISCORD_LOG(LS_ERROR) << "ProcessProposals called with non-proposal message"; - TRACK_MLS_ERROR("Unexpected message type"); - return false; - } - - if (message.content.sender.sender_type() != ::mlspp::SenderType::external) { - DISCORD_LOG(LS_ERROR) << "MLS proposal must be from external sender"; - TRACK_MLS_ERROR("Unexpected proposal sender type"); - return false; - } - - const auto& proposal = ::mlspp::tls::var::get<::mlspp::Proposal>(message.content.content); - switch (proposal.proposal_type()) { - case ::mlspp::ProposalType::add: { - const auto& credential = - ::mlspp::tls::var::get<::mlspp::Add>(proposal.content).key_package.leaf_node.credential; - if (!IsRecognizedUserID(credential, recognizedUserIDs)) { - DISCORD_LOG(LS_ERROR) << "MLS add proposal must be for recognized user"; - TRACK_MLS_ERROR("Unexpected user ID in add proposal"); - return false; - } - break; - } - case ::mlspp::ProposalType::remove: - // Remove proposals are always allowed (mlspp will validate that it's a recognized user) - break; - default: - DISCORD_LOG(LS_ERROR) << "MLS proposal must be add or remove"; - TRACK_MLS_ERROR("Unexpected proposal type"); - return false; - } - - return true; + if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { + DISCORD_LOG(LS_ERROR) << "MLS proposal message must be PublicMessage"; + TRACK_MLS_ERROR("Invalid proposal wire format"); + return false; + } + + if (message.content.epoch != targetState.epoch()) { + DISCORD_LOG(LS_ERROR) << "MLS proposal message must be for current epoch (" + << message.content.epoch << " != " << targetState.epoch() << ")"; + TRACK_MLS_ERROR("Proposal epoch mismatch"); + return false; + } + + if (message.content.content_type() != ::mlspp::ContentType::proposal) { + DISCORD_LOG(LS_ERROR) << "ProcessProposals called with non-proposal message"; + TRACK_MLS_ERROR("Unexpected message type"); + return false; + } + + if (message.content.sender.sender_type() != ::mlspp::SenderType::external) { + DISCORD_LOG(LS_ERROR) << "MLS proposal must be from external sender"; + TRACK_MLS_ERROR("Unexpected proposal sender type"); + return false; + } + + const auto& proposal = ::mlspp::tls::var::get<::mlspp::Proposal>(message.content.content); + switch (proposal.proposal_type()) { + case ::mlspp::ProposalType::add: { + const auto& credential = + ::mlspp::tls::var::get<::mlspp::Add>(proposal.content).key_package.leaf_node.credential; + if (!IsRecognizedUserID(credential, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "MLS add proposal must be for recognized user"; + TRACK_MLS_ERROR("Unexpected user ID in add proposal"); + return false; + } + break; + } + case ::mlspp::ProposalType::remove: + // Remove proposals are always allowed (mlspp will validate that it's a recognized user) + break; + default: + DISCORD_LOG(LS_ERROR) << "MLS proposal must be add or remove"; + TRACK_MLS_ERROR("Unexpected proposal type"); + return false; + } + + return true; } bool Session::CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept { - if (!stateWithProposals_) { - return false; - } + if (!stateWithProposals_) { + return false; + } - if (commit.group_id() != groupId_) { - DISCORD_LOG(LS_ERROR) << "MLS commit message was for unexpected group"; - return false; - } + if (commit.group_id() != groupId_) { + DISCORD_LOG(LS_ERROR) << "MLS commit message was for unexpected group"; + return false; + } - return true; + return true; } roster_variant Session::ProcessCommit(std::vector commit) noexcept try { - DISCORD_LOG(LS_INFO) << "Processing commit"; - DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit); + DISCORD_LOG(LS_INFO) << "Processing commit"; + DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit); - auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); + auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); - if (!CanProcessCommit(commitMessage)) { - DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit"; - return ignored_t{}; - } + if (!CanProcessCommit(commitMessage)) { + DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit"; + return ignored_t{}; + } - // in case we're the sender of this commit - // we need to pull the cached state from our outbound cache - std::optional<::mlspp::State> optionalCachedState = std::nullopt; - if (outboundCachedGroupState_) { - optionalCachedState = *(outboundCachedGroupState_.get()); - } + // in case we're the sender of this commit + // we need to pull the cached state from our outbound cache + std::optional<::mlspp::State> optionalCachedState = std::nullopt; + if (outboundCachedGroupState_) { + optionalCachedState = *(outboundCachedGroupState_.get()); + } - auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState); + auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState); - if (!newState) { - DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state"; - return failed_t{}; - } + if (!newState) { + DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state"; + return failed_t{}; + } - DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is " - << newState->index().val << "; current epoch is " << newState->epoch(); + DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is " + << newState->index().val << "; current epoch is " << newState->epoch(); - roster_map ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState))); + roster_map ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState))); - // reset the outbound cached group since we handled the commit for this epoch - outboundCachedGroupState_.reset(); + // reset the outbound cached group since we handled the commit for this epoch + outboundCachedGroupState_.reset(); - ClearPendingState(); + ClearPendingState(); - return ret; + return ret; } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to process MLS commit: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return failed_t{}; + DISCORD_LOG(LS_ERROR) << "Failed to process MLS commit: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return failed_t{}; } std::optional Session::ProcessWelcome( std::vector welcome, std::set const& recognizedUserIDs) noexcept try { - if (!HasCryptographicStateForWelcome()) { - DISCORD_LOG(LS_ERROR) << "Missing local crypto state necessary to process MLS welcome"; - return std::nullopt; - } + if (!HasCryptographicStateForWelcome()) { + DISCORD_LOG(LS_ERROR) << "Missing local crypto state necessary to process MLS welcome"; + return std::nullopt; + } - if (!externalSender_) { - DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome without an external sender"; - return std::nullopt; - } + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome without an external sender"; + return std::nullopt; + } - if (currentState_) { - DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome after joining/creating an MLS group"; - return std::nullopt; - } + if (currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome after joining/creating an MLS group"; + return std::nullopt; + } - DISCORD_LOG(LS_INFO) << "Processing welcome: " << ::mlspp::bytes_ns::bytes(welcome); + DISCORD_LOG(LS_INFO) << "Processing welcome: " << ::mlspp::bytes_ns::bytes(welcome); - // unmarshal the incoming welcome - auto unmarshalledWelcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); + // unmarshal the incoming welcome + auto unmarshalledWelcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); - // construct the state from the unmarshalled welcome - auto newState = std::make_unique<::mlspp::State>( - *joinInitPrivateKey_, - *selfHPKEPrivateKey_, - *selfSigPrivateKey_, - *joinKeyPackage_, - unmarshalledWelcome, - std::nullopt, - std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); + // construct the state from the unmarshalled welcome + auto newState = std::make_unique<::mlspp::State>( + *joinInitPrivateKey_, + *selfHPKEPrivateKey_, + *selfSigPrivateKey_, + *joinKeyPackage_, + unmarshalledWelcome, + std::nullopt, + std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); - // perform application-level verification of the new state - if (!VerifyWelcomeState(*newState, recognizedUserIDs)) { - DISCORD_LOG(LS_ERROR) << "Group received in MLS welcome is not valid"; + // perform application-level verification of the new state + if (!VerifyWelcomeState(*newState, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "Group received in MLS welcome is not valid"; - return std::nullopt; - } + return std::nullopt; + } - DISCORD_LOG(LS_INFO) << "Successfully welcomed to MLS Group, our leaf index is " - << newState->index().val << "; current epoch is " << newState->epoch(); + DISCORD_LOG(LS_INFO) << "Successfully welcomed to MLS Group, our leaf index is " + << newState->index().val << "; current epoch is " << newState->epoch(); - // make the verified state our new (and only) state - roster_map ret = ReplaceState(std::move(newState)); + // make the verified state our new (and only) state + roster_map ret = ReplaceState(std::move(newState)); - // clear out any pending state for creating/joining a group - ClearPendingState(); + // clear out any pending state for creating/joining a group + ClearPendingState(); - return ret; + return ret; } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to create group state from MLS welcome: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return std::nullopt; + DISCORD_LOG(LS_ERROR) << "Failed to create group state from MLS welcome: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return std::nullopt; } roster_map Session::ReplaceState(std::unique_ptr<::mlspp::State>&& state) { - roster_map newRoster; - for (const ::mlspp::LeafNode& node : state->roster()) { - if (node.credential.type() != ::mlspp::CredentialType::basic) { - continue; - } + roster_map newRoster; + for (const ::mlspp::LeafNode& node : state->roster()) { + if (node.credential.type() != ::mlspp::CredentialType::basic) { + continue; + } - const auto& cred = node.credential.template get<::mlspp::BasicCredential>(); + const auto& cred = node.credential.template get<::mlspp::BasicCredential>(); - newRoster[FromBigEndianBytes(cred.identity)] = node.signature_key.data.as_vec(); - } + newRoster[FromBigEndianBytes(cred.identity)] = node.signature_key.data.as_vec(); + } - roster_map changeMap; + roster_map changeMap; - std::set_difference(newRoster.begin(), - newRoster.end(), - roster_.begin(), - roster_.end(), - std::inserter(changeMap, changeMap.end())); + std::set_difference(newRoster.begin(), + newRoster.end(), + roster_.begin(), + roster_.end(), + std::inserter(changeMap, changeMap.end())); - struct MissingItemWrapper { - roster_map& changeMap_; + struct MissingItemWrapper { + roster_map& changeMap_; - using iterator = roster_map::iterator; - using const_iterator = roster_map::const_iterator; - using value_type = roster_map::value_type; + using iterator = roster_map::iterator; + using const_iterator = roster_map::const_iterator; + using value_type = roster_map::value_type; - iterator insert(const_iterator it, const value_type& value) - { - return changeMap_.try_emplace(it, value.first, std::vector{}); - } + iterator insert(const_iterator it, const value_type& value) + { + return changeMap_.try_emplace(it, value.first, std::vector{}); + } - iterator begin() { return changeMap_.begin(); } + iterator begin() { return changeMap_.begin(); } - iterator end() { return changeMap_.end(); } - }; + iterator end() { return changeMap_.end(); } + }; - MissingItemWrapper wrapper{changeMap}; + MissingItemWrapper wrapper{changeMap}; - std::set_difference(roster_.begin(), - roster_.end(), - newRoster.begin(), - newRoster.end(), - std::inserter(wrapper, wrapper.end())); + std::set_difference(roster_.begin(), + roster_.end(), + newRoster.begin(), + newRoster.end(), + std::inserter(wrapper, wrapper.end())); - roster_ = std::move(newRoster); - currentState_ = std::move(state); + roster_ = std::move(newRoster); + currentState_ = std::move(state); - return changeMap; + return changeMap; } bool Session::HasCryptographicStateForWelcome() const noexcept { - return joinKeyPackage_ && joinInitPrivateKey_ && selfSigPrivateKey_ && selfHPKEPrivateKey_; + return joinKeyPackage_ && joinInitPrivateKey_ && selfSigPrivateKey_ && selfHPKEPrivateKey_; } bool Session::VerifyWelcomeState(::mlspp::State const& state, - std::set const& recognizedUserIDs) const + std::set const& recognizedUserIDs) const { - if (!externalSender_) { - DISCORD_LOG(LS_ERROR) << "Cannot verify MLS welcome without an external sender"; - TRACK_MLS_ERROR("Missing external sender when processing Welcome"); - return false; - } - - auto ext = state.extensions().template find(); - if (!ext) { - DISCORD_LOG(LS_ERROR) << "MLS welcome missing external senders extension"; - TRACK_MLS_ERROR("Welcome message missing external sender extension"); - return false; - } - - if (ext->senders.size() != 1) { - DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected number of external senders: " - << ext->senders.size(); - TRACK_MLS_ERROR("Welcome message lists unexpected external sender count"); - return false; - } - - if (ext->senders.front() != *externalSender_) { - DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected external sender"; - TRACK_MLS_ERROR("Welcome message lists unexpected external sender"); - return false; - } - - // TODO: Until we leverage revocation in the protocol - // if we re-enable this change we will refuse welcome messages - // because someone was previously supposed to be added but disconnected - // before all in-flight proposals were handled. - - for (const auto& leaf : state.roster()) { - if (!IsRecognizedUserID(leaf.credential, recognizedUserIDs)) { - DISCORD_LOG(LS_ERROR) << "MLS welcome lists unrecognized user ID"; - // TRACK_MLS_ERROR("Welcome message lists unrecognized user ID"); - // return false; - } - } - - return true; + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot verify MLS welcome without an external sender"; + TRACK_MLS_ERROR("Missing external sender when processing Welcome"); + return false; + } + + auto ext = state.extensions().template find(); + if (!ext) { + DISCORD_LOG(LS_ERROR) << "MLS welcome missing external senders extension"; + TRACK_MLS_ERROR("Welcome message missing external sender extension"); + return false; + } + + if (ext->senders.size() != 1) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected number of external senders: " + << ext->senders.size(); + TRACK_MLS_ERROR("Welcome message lists unexpected external sender count"); + return false; + } + + if (ext->senders.front() != *externalSender_) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected external sender"; + TRACK_MLS_ERROR("Welcome message lists unexpected external sender"); + return false; + } + + // TODO: Until we leverage revocation in the protocol + // if we re-enable this change we will refuse welcome messages + // because someone was previously supposed to be added but disconnected + // before all in-flight proposals were handled. + + for (const auto& leaf : state.roster()) { + if (!IsRecognizedUserID(leaf.credential, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unrecognized user ID"; + // TRACK_MLS_ERROR("Welcome message lists unrecognized user ID"); + // return false; + } + } + + return true; } void Session::InitLeafNode(std::string const& selfUserId, - std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept try { - auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); - - if (!transientKey) { - if (!signingKeyId_.empty()) { - transientKey = GetPersistedKeyPair(keyPairContext_, signingKeyId_, protocolVersion_); - if (!transientKey) { - DISCORD_LOG(LS_ERROR) << "Did not receive MLS signature private key from " - "GetPersistedKeyPair; aborting"; - return; - } - } - else { - transientKey = std::make_shared<::mlspp::SignaturePrivateKey>( - ::mlspp::SignaturePrivateKey::generate(ciphersuite)); - } - } - - selfSigPrivateKey_ = transientKey; - - auto selfCredential = CreateUserCredential(selfUserId, protocolVersion_); - - selfHPKEPrivateKey_ = - std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); - - selfLeafNode_ = - std::make_unique<::mlspp::LeafNode>(ciphersuite, - selfHPKEPrivateKey_->public_key, - selfSigPrivateKey_->public_key, - std::move(selfCredential), - LeafNodeCapabilitiesForProtocolVersion(protocolVersion_), - ::mlspp::Lifetime::create_default(), - LeafNodeExtensionsForProtocolVersion(protocolVersion_), - *selfSigPrivateKey_); - - DISCORD_LOG(LS_INFO) << "Created MLS leaf node"; + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + + if (!transientKey) { + if (!signingKeyId_.empty()) { + transientKey = GetPersistedKeyPair(keyPairContext_, signingKeyId_, protocolVersion_); + if (!transientKey) { + DISCORD_LOG(LS_ERROR) << "Did not receive MLS signature private key from " + "GetPersistedKeyPair; aborting"; + return; + } + } + else { + transientKey = std::make_shared<::mlspp::SignaturePrivateKey>( + ::mlspp::SignaturePrivateKey::generate(ciphersuite)); + } + } + + selfSigPrivateKey_ = transientKey; + + auto selfCredential = CreateUserCredential(selfUserId, protocolVersion_); + + selfHPKEPrivateKey_ = + std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + + selfLeafNode_ = + std::make_unique<::mlspp::LeafNode>(ciphersuite, + selfHPKEPrivateKey_->public_key, + selfSigPrivateKey_->public_key, + std::move(selfCredential), + LeafNodeCapabilitiesForProtocolVersion(protocolVersion_), + ::mlspp::Lifetime::create_default(), + LeafNodeExtensionsForProtocolVersion(protocolVersion_), + *selfSigPrivateKey_); + + DISCORD_LOG(LS_INFO) << "Created MLS leaf node"; } catch (const std::exception& e) { - DISCORD_LOG(LS_INFO) << "Failed to initialize MLS leaf node: " << e.what(); - TRACK_MLS_ERROR(e.what()); + DISCORD_LOG(LS_INFO) << "Failed to initialize MLS leaf node: " << e.what(); + TRACK_MLS_ERROR(e.what()); } void Session::ResetJoinKeyPackage() noexcept try { - if (!selfLeafNode_) { - DISCORD_LOG(LS_ERROR) << "Cannot initialize join key package without a leaf node"; - return; - } + if (!selfLeafNode_) { + DISCORD_LOG(LS_ERROR) << "Cannot initialize join key package without a leaf node"; + return; + } - auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); - joinInitPrivateKey_ = - std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + joinInitPrivateKey_ = + std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); - joinKeyPackage_ = - std::make_unique<::mlspp::KeyPackage>(ciphersuite, - joinInitPrivateKey_->public_key, - *selfLeafNode_, - LeafNodeExtensionsForProtocolVersion(protocolVersion_), - *selfSigPrivateKey_); + joinKeyPackage_ = + std::make_unique<::mlspp::KeyPackage>(ciphersuite, + joinInitPrivateKey_->public_key, + *selfLeafNode_, + LeafNodeExtensionsForProtocolVersion(protocolVersion_), + *selfSigPrivateKey_); - DISCORD_LOG(LS_INFO) << "Generated key package: " - << ::mlspp::bytes_ns::bytes(::mlspp::tls::marshal(*joinKeyPackage_)); + DISCORD_LOG(LS_INFO) << "Generated key package: " + << ::mlspp::bytes_ns::bytes(::mlspp::tls::marshal(*joinKeyPackage_)); } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to initialize join key package: " << e.what(); - TRACK_MLS_ERROR(e.what()); + DISCORD_LOG(LS_ERROR) << "Failed to initialize join key package: " << e.what(); + TRACK_MLS_ERROR(e.what()); } void Session::CreatePendingGroup() noexcept try { - if (groupId_.empty()) { - DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without a group ID"; - return; - } - - if (!externalSender_) { - DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without ExternalSender"; - return; - } - - if (!selfLeafNode_) { - DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without self leaf node"; - return; - } - - DISCORD_LOG(LS_INFO) << "Creating a pending MLS group"; - - auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); - - pendingGroupState_ = std::make_unique<::mlspp::State>( - groupId_, - ciphersuite, - *selfHPKEPrivateKey_, - *selfSigPrivateKey_, - *selfLeafNode_, - GroupExtensionsForProtocolVersion(protocolVersion_, *externalSender_)); - - DISCORD_LOG(LS_INFO) << "Created a pending MLS group"; + if (groupId_.empty()) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without a group ID"; + return; + } + + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without ExternalSender"; + return; + } + + if (!selfLeafNode_) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without self leaf node"; + return; + } + + DISCORD_LOG(LS_INFO) << "Creating a pending MLS group"; + + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + + pendingGroupState_ = std::make_unique<::mlspp::State>( + groupId_, + ciphersuite, + *selfHPKEPrivateKey_, + *selfSigPrivateKey_, + *selfLeafNode_, + GroupExtensionsForProtocolVersion(protocolVersion_, *externalSender_)); + + DISCORD_LOG(LS_INFO) << "Created a pending MLS group"; } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to create MLS group: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return; + DISCORD_LOG(LS_ERROR) << "Failed to create MLS group: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return; } std::vector Session::GetMarshalledKeyPackage() noexcept try { - // key packages are not meant to be re-used - // so every time the client asks for a key package we create a new one - ResetJoinKeyPackage(); + // key packages are not meant to be re-used + // so every time the client asks for a key package we create a new one + ResetJoinKeyPackage(); - if (!joinKeyPackage_) { - DISCORD_LOG(LS_ERROR) << "Cannot marshal an uninitialized key package"; - return {}; - } + if (!joinKeyPackage_) { + DISCORD_LOG(LS_ERROR) << "Cannot marshal an uninitialized key package"; + return {}; + } - return ::mlspp::tls::marshal(*joinKeyPackage_); + return ::mlspp::tls::marshal(*joinKeyPackage_); } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to marshal join key package: " << e.what(); - TRACK_MLS_ERROR(e.what()); - return {}; + DISCORD_LOG(LS_ERROR) << "Failed to marshal join key package: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return {}; } std::unique_ptr Session::GetKeyRatchet(std::string const& userId) const noexcept { - if (!currentState_) { - DISCORD_LOG(LS_ERROR) << "Cannot get key ratchet without an established MLS group"; - return nullptr; - } - - // change the string user ID to a little endian 64 bit user ID - auto u64userId = strtoull(userId.c_str(), nullptr, 10); - auto userIdBytes = ::mlspp::bytes_ns::bytes(sizeof(u64userId)); - memcpy(userIdBytes.data(), &u64userId, sizeof(u64userId)); - - // generate the base secret for the hash ratchet - auto baseSecret = - currentState_->do_export(Session::USER_MEDIA_KEY_BASE_LABEL, userIdBytes, AES_GCM_128_KEY_BYTES); - - // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key - // would need to be updated to a different ciphersuite if there's a future mismatch - return std::make_unique(currentState_->cipher_suite(), std::move(baseSecret)); + if (!currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot get key ratchet without an established MLS group"; + return nullptr; + } + + // change the string user ID to a little endian 64 bit user ID + auto u64userId = strtoull(userId.c_str(), nullptr, 10); + auto userIdBytes = ::mlspp::bytes_ns::bytes(sizeof(u64userId)); + memcpy(userIdBytes.data(), &u64userId, sizeof(u64userId)); + + // generate the base secret for the hash ratchet + auto baseSecret = + currentState_->do_export(Session::USER_MEDIA_KEY_BASE_LABEL, userIdBytes, AES_GCM_128_KEY_BYTES); + + // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key + // would need to be updated to a different ciphersuite if there's a future mismatch + return std::make_unique(currentState_->cipher_suite(), std::move(baseSecret)); } void Session::GetPairwiseFingerprint(uint16_t version, - std::string const& userId, - PairwiseFingerprintCallback callback) const noexcept + std::string const& userId, + PairwiseFingerprintCallback callback) const noexcept try { - if (!currentState_ || !selfSigPrivateKey_) { - throw std::invalid_argument("No established MLS group"); - } - - uint64_t u64RemoteUserId = strtoull(userId.c_str(), nullptr, 10); - uint64_t u64SelfUserId = strtoull(selfUserId_.c_str(), nullptr, 10); - - auto it = roster_.find(u64RemoteUserId); - if (it == roster_.end()) { - throw std::invalid_argument("Unknown user ID: " + userId); - } - - ::mlspp::tls::ostream toHash1; - ::mlspp::tls::ostream toHash2; - - toHash1 << version; - toHash1.write_raw(it->second); - toHash1 << u64RemoteUserId; - - toHash2 << version; - toHash2.write_raw(selfSigPrivateKey_->public_key.data); - toHash2 << u64SelfUserId; - - std::vector> keyData = { - toHash1.bytes(), - toHash2.bytes(), - }; - - std::sort(keyData.begin(), keyData.end()); - - std::thread([callback = std::move(callback), - data = ::mlspp::bytes_ns::bytes(std::move(keyData[0])) + keyData[1]] { - static constexpr uint8_t salt[] = { - 0x24, - 0xca, - 0xb1, - 0x7a, - 0x7a, - 0xf8, - 0xec, - 0x2b, - 0x82, - 0xb4, - 0x12, - 0xb9, - 0x2d, - 0xab, - 0x19, - 0x2e, - }; - - constexpr uint64_t N = 16384, r = 8, p = 2, max_mem = 32 * 1024 * 1024; - constexpr size_t hash_len = 64; - - std::vector out(hash_len); - - int ret = EVP_PBE_scrypt((const char*)data.data(), - data.size(), - salt, - sizeof(salt), - N, - r, - p, - max_mem, - out.data(), - out.size()); - - if (ret == 1) { - callback(out); - } - else { - callback({}); - } - }).detach(); + if (!currentState_ || !selfSigPrivateKey_) { + throw std::invalid_argument("No established MLS group"); + } + + uint64_t u64RemoteUserId = strtoull(userId.c_str(), nullptr, 10); + uint64_t u64SelfUserId = strtoull(selfUserId_.c_str(), nullptr, 10); + + auto it = roster_.find(u64RemoteUserId); + if (it == roster_.end()) { + throw std::invalid_argument("Unknown user ID: " + userId); + } + + ::mlspp::tls::ostream toHash1; + ::mlspp::tls::ostream toHash2; + + toHash1 << version; + toHash1.write_raw(it->second); + toHash1 << u64RemoteUserId; + + toHash2 << version; + toHash2.write_raw(selfSigPrivateKey_->public_key.data); + toHash2 << u64SelfUserId; + + std::vector> keyData = { + toHash1.bytes(), + toHash2.bytes(), + }; + + std::sort(keyData.begin(), keyData.end()); + + std::thread([callback = std::move(callback), + data = ::mlspp::bytes_ns::bytes(std::move(keyData[0])) + keyData[1]] { + static constexpr uint8_t salt[] = { + 0x24, + 0xca, + 0xb1, + 0x7a, + 0x7a, + 0xf8, + 0xec, + 0x2b, + 0x82, + 0xb4, + 0x12, + 0xb9, + 0x2d, + 0xab, + 0x19, + 0x2e, + }; + + constexpr uint64_t N = 16384, r = 8, p = 2, max_mem = 32 * 1024 * 1024; + constexpr size_t hash_len = 64; + + std::vector out(hash_len); + + int ret = EVP_PBE_scrypt((const char*)data.data(), + data.size(), + salt, + sizeof(salt), + N, + r, + p, + max_mem, + out.data(), + out.size()); + + if (ret == 1) { + callback(out); + } + else { + callback({}); + } + }).detach(); } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to generate pairwise fingerprint: " << e.what(); - callback({}); + DISCORD_LOG(LS_ERROR) << "Failed to generate pairwise fingerprint: " << e.what(); + callback({}); } void Session::ClearPendingState() { - pendingGroupState_.reset(); - pendingGroupCommit_.reset(); + pendingGroupState_.reset(); + pendingGroupCommit_.reset(); - joinInitPrivateKey_.reset(); - joinKeyPackage_.reset(); + joinInitPrivateKey_.reset(); + joinKeyPackage_.reset(); - selfHPKEPrivateKey_.reset(); + selfHPKEPrivateKey_.reset(); - selfLeafNode_.reset(); + selfLeafNode_.reset(); - stateWithProposals_.reset(); - proposalQueue_.clear(); + stateWithProposals_.reset(); + proposalQueue_.clear(); } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/session.h b/src/dpp/dave/session.h index 32c7410ac5..7c728726b4 100755 --- a/src/dpp/dave/session.h +++ b/src/dpp/dave/session.h @@ -55,99 +55,99 @@ struct QueuedProposal; class Session { public: - using MLSFailureCallback = std::function; + using MLSFailureCallback = std::function; - Session(KeyPairContextType context, - const std::string& authSessionId, - MLSFailureCallback callback) noexcept; + Session(KeyPairContextType context, + const std::string& authSessionId, + MLSFailureCallback callback) noexcept; - ~Session() noexcept; + ~Session() noexcept; - void Init(ProtocolVersion version, - uint64_t groupId, - std::string const& selfUserId, - std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; - void Reset() noexcept; + void Init(ProtocolVersion version, + uint64_t groupId, + std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void Reset() noexcept; - void SetProtocolVersion(ProtocolVersion version) noexcept; - ProtocolVersion GetProtocolVersion() const noexcept { return protocolVersion_; } + void SetProtocolVersion(ProtocolVersion version) noexcept; + ProtocolVersion GetProtocolVersion() const noexcept { return protocolVersion_; } - std::vector GetLastEpochAuthenticator() const noexcept; + std::vector GetLastEpochAuthenticator() const noexcept; - void SetExternalSender(std::vector const& externalSenderPackage) noexcept; + void SetExternalSender(std::vector const& externalSenderPackage) noexcept; - std::optional> ProcessProposals( - std::vector proposals, - std::set const& recognizedUserIDs) noexcept; + std::optional> ProcessProposals( + std::vector proposals, + std::set const& recognizedUserIDs) noexcept; - roster_variant ProcessCommit(std::vector commit) noexcept; + roster_variant ProcessCommit(std::vector commit) noexcept; - std::optional ProcessWelcome( - std::vector welcome, - std::set const& recognizedUserIDs) noexcept; + std::optional ProcessWelcome( + std::vector welcome, + std::set const& recognizedUserIDs) noexcept; - std::vector GetMarshalledKeyPackage() noexcept; + std::vector GetMarshalledKeyPackage() noexcept; - std::unique_ptr GetKeyRatchet(std::string const& userId) const noexcept; + std::unique_ptr GetKeyRatchet(std::string const& userId) const noexcept; - using PairwiseFingerprintCallback = std::function const&)>; + using PairwiseFingerprintCallback = std::function const&)>; - void GetPairwiseFingerprint(uint16_t version, - std::string const& userId, - PairwiseFingerprintCallback callback) const noexcept; + void GetPairwiseFingerprint(uint16_t version, + std::string const& userId, + PairwiseFingerprintCallback callback) const noexcept; private: - void InitLeafNode(std::string const& selfUserId, - std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; - void ResetJoinKeyPackage() noexcept; + void InitLeafNode(std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void ResetJoinKeyPackage() noexcept; - void CreatePendingGroup() noexcept; + void CreatePendingGroup() noexcept; - bool HasCryptographicStateForWelcome() const noexcept; + bool HasCryptographicStateForWelcome() const noexcept; - bool IsRecognizedUserID(const ::mlspp::Credential& cred, - std::set const& recognizedUserIDs) const; - bool ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, - ::mlspp::State const& targetState, - std::set const& recognizedUserIDs) const; - bool VerifyWelcomeState(::mlspp::State const& state, - std::set const& recognizedUserIDs) const; + bool IsRecognizedUserID(const ::mlspp::Credential& cred, + std::set const& recognizedUserIDs) const; + bool ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, + ::mlspp::State const& targetState, + std::set const& recognizedUserIDs) const; + bool VerifyWelcomeState(::mlspp::State const& state, + std::set const& recognizedUserIDs) const; - bool CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept; + bool CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept; - roster_map ReplaceState(std::unique_ptr<::mlspp::State>&& state); + roster_map ReplaceState(std::unique_ptr<::mlspp::State>&& state); - void ClearPendingState(); + void ClearPendingState(); - inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; + inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; - ProtocolVersion protocolVersion_; - std::vector groupId_; - std::string signingKeyId_; - std::string selfUserId_; - KeyPairContextType keyPairContext_{nullptr}; + ProtocolVersion protocolVersion_; + std::vector groupId_; + std::string signingKeyId_; + std::string selfUserId_; + KeyPairContextType keyPairContext_{nullptr}; - std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; - std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; - std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; + std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; + std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; + std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; - std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; - std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; + std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; + std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; - std::unique_ptr<::mlspp::ExternalSender> externalSender_; + std::unique_ptr<::mlspp::ExternalSender> externalSender_; - std::unique_ptr<::mlspp::State> pendingGroupState_; - std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; + std::unique_ptr<::mlspp::State> pendingGroupState_; + std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; - std::unique_ptr<::mlspp::State> outboundCachedGroupState_; + std::unique_ptr<::mlspp::State> outboundCachedGroupState_; - std::unique_ptr<::mlspp::State> currentState_; - roster_map roster_; + std::unique_ptr<::mlspp::State> currentState_; + roster_map roster_; - std::unique_ptr<::mlspp::State> stateWithProposals_; - std::list proposalQueue_; + std::unique_ptr<::mlspp::State> stateWithProposals_; + std::list proposalQueue_; - MLSFailureCallback onMLSFailureCallback_{}; + MLSFailureCallback onMLSFailureCallback_{}; }; } // namespace dpp::dave::mls diff --git a/src/dpp/dave/user_credential.cpp b/src/dpp/dave/user_credential.cpp index 72a9bae5f4..38d3236f61 100755 --- a/src/dpp/dave/user_credential.cpp +++ b/src/dpp/dave/user_credential.cpp @@ -30,24 +30,24 @@ namespace dpp::dave::mls { ::mlspp::Credential CreateUserCredential(const std::string& userId, ProtocolVersion version) { - // convert the string user ID to a big endian uint64_t - auto userID = std::stoull(userId); - auto credentialBytes = BigEndianBytesFrom(userID); + // convert the string user ID to a big endian uint64_t + auto userID = std::stoull(userId); + auto credentialBytes = BigEndianBytesFrom(userID); - return ::mlspp::Credential::basic(credentialBytes); + return ::mlspp::Credential::basic(credentialBytes); } std::string UserCredentialToString(const ::mlspp::Credential& cred, ProtocolVersion version) { - if (cred.type() != ::mlspp::CredentialType::basic) { - return ""; - } + if (cred.type() != ::mlspp::CredentialType::basic) { + return ""; + } - const auto& basic = cred.template get<::mlspp::BasicCredential>(); + const auto& basic = cred.template get<::mlspp::BasicCredential>(); - auto uidVal = FromBigEndianBytes(basic.identity); + auto uidVal = FromBigEndianBytes(basic.identity); - return std::to_string(uidVal); + return std::to_string(uidVal); } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/util.cpp b/src/dpp/dave/util.cpp index f91a03eec0..ebb4aa7a58 100755 --- a/src/dpp/dave/util.cpp +++ b/src/dpp/dave/util.cpp @@ -28,27 +28,27 @@ namespace dpp::dave::mls { ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept { - auto buffer = ::mlspp::bytes_ns::bytes(); - buffer.reserve(sizeof(value)); + auto buffer = ::mlspp::bytes_ns::bytes(); + buffer.reserve(sizeof(value)); - for (int i = sizeof(value) - 1; i >= 0; --i) { - buffer.push_back(static_cast(value >> (i * 8))); - } + for (int i = sizeof(value) - 1; i >= 0; --i) { + buffer.push_back(static_cast(value >> (i * 8))); + } - return buffer; + return buffer; } uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept { - uint64_t val = 0; + uint64_t val = 0; - if (buffer.size() <= sizeof(val)) { - for (uint8_t byte : buffer) { - val = (val << 8) | byte; - } - } + if (buffer.size() <= sizeof(val)) { + for (uint8_t byte : buffer) { + val = (val << 8) | byte; + } + } - return val; + return val; } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/version.cpp b/src/dpp/dave/version.cpp index 780d58fab7..32e9a9138d 100755 --- a/src/dpp/dave/version.cpp +++ b/src/dpp/dave/version.cpp @@ -30,7 +30,7 @@ constexpr ProtocolVersion CurrentDaveProtocolVersion = 1; ProtocolVersion MaxSupportedProtocolVersion() { - return CurrentDaveProtocolVersion; + return CurrentDaveProtocolVersion; } } // namespace dpp::dave