diff --git a/src/dpp/dave/cipher_interface.cpp b/src/dpp/dave/cipher_interface.cpp index 38be00a61d..3dae2807cc 100755 --- a/src/dpp/dave/cipher_interface.cpp +++ b/src/dpp/dave/cipher_interface.cpp @@ -24,12 +24,13 @@ ************************************************************************************/ #include "cipher_interface.h" #include "openssl_aead_cipher.h" +#include namespace dpp::dave { -std::unique_ptr create_cipher(const encryption_key& encryptionKey) +std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey) { - auto cipher = std::make_unique(encryptionKey); + auto cipher = std::make_unique(cl, encryptionKey); return cipher->is_valid() ? std::move(cipher) : nullptr; } diff --git a/src/dpp/dave/cipher_interface.h b/src/dpp/dave/cipher_interface.h index 679f18dddb..35b849c279 100755 --- a/src/dpp/dave/cipher_interface.h +++ b/src/dpp/dave/cipher_interface.h @@ -29,6 +29,10 @@ #include "common.h" #include "array_view.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -47,6 +51,12 @@ using byte_view = array_view; */ class cipher_interface { // NOLINT public: + /** + * @brief Create cipher interface + * @param _creator Creating cluster + */ + cipher_interface(dpp::cluster& _creator) : creator(_creator) { }; + /** * @brief Default destructor */ @@ -73,6 +83,13 @@ class cipher_interface { // NOLINT * @return true if decryption succeeded, false if it failed */ virtual bool decrypt(byte_view plaintextBufferOut, const_byte_view ciphertextBuffer, const_byte_view tagBuffer, const_byte_view nonceBuffer, const_byte_view additionalData) = 0; + +protected: + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; /** @@ -80,7 +97,7 @@ class cipher_interface { // NOLINT * @param encryptionKey encryption key * @return an instance of a class derived from cipher_interface */ -std::unique_ptr create_cipher(const encryption_key& encryptionKey); +std::unique_ptr create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey); } // namespace dpp::dave diff --git a/src/dpp/dave/codec_utils.cpp b/src/dpp/dave/codec_utils.cpp index 05d5314e88..ea4269c050 100755 --- a/src/dpp/dave/codec_utils.cpp +++ b/src/dpp/dave/codec_utils.cpp @@ -27,14 +27,11 @@ #include #include #include - -#include "logger.h" #include "leb128.h" namespace dpp::dave::codec_utils { -unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload, - const uint64_t sizeRemaining) +unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload, const uint64_t sizeRemaining) { // the payload starts with three exponential golomb encoded values // (first_mb_in_slice, sps_id, pps_id) @@ -212,29 +209,23 @@ bool process_frame_h264(outbound_frame_processor& processor, array_viewfirst - nextNaluIndexPair->second - : frame.size(); + auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { // once we've hit a slice or an IDR // we just need to cover getting to the PPS ID auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; - auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, - frame.size() - nalUnitPayloadStart); + auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, frame.size() - nalUnitPayloadStart); - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, - kH264NalUnitHeaderSize + nalUnitPPSBytes); + processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH264NalUnitHeaderSize + nalUnitPPSBytes); processor.add_encrypted_bytes( frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes, nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes); } else { // copy the whole NAL unit - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, - nextNaluStart - nalUnitStartIndex); + processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); } naluIndexPair = nextNaluIndexPair; @@ -274,22 +265,16 @@ bool process_frame_h265(outbound_frame_processor& processor, array_viewfirst - nextNaluIndexPair->second - : frame.size(); + auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); if (nalType < kH265NalTypeVclCutoff) { // found a VCL NAL, encrypt the payload only - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); - processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, - nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); - } - else { + processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); + processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); + } else { // copy the whole NAL unit - processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, - nextNaluStart - nalUnitStartIndex); + processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); } naluIndexPair = nextNaluIndexPair; @@ -372,8 +357,7 @@ bool process_frame_av1(outbound_frame_processor& processor, array_view - #include "key_ratchet.h" -#include "logger.h" - #include +#include using namespace std::chrono_literals; @@ -52,11 +49,12 @@ big_nonce compute_wrapped_big_nonce(key_generation generation, truncated_sync_no return static_cast(generation) << RATCHET_GENERATION_SHIFT_BITS | maskedNonce; } -aead_cipher_manager::aead_cipher_manager(const clock_interface& clock, std::unique_ptr keyRatchet) +aead_cipher_manager::aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr keyRatchet) : clock_(clock) , keyRatchet_(std::move(keyRatchet)) , ratchetCreation_(clock.now()) , ratchetExpiry_(time_point::max()) + , creator(cl) { } @@ -76,14 +74,12 @@ cipher_interface* aead_cipher_manager::get_cipher(key_generation generation) cleanup_expired_ciphers(); if (generation < oldestGeneration_) { - DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation - << ", oldest generation: " << oldestGeneration_; + creator.log(dpp::ll_trace, "Received frame with old generation: " + std::to_string(generation) + ", oldest generation: " + std::to_string(oldestGeneration_)); return nullptr; } if (generation > newestGeneration_ + MAX_GENERATION_GAP) { - DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation - << ", newest generation: " << newestGeneration_; + creator.log(dpp::ll_trace, "Received frame with future generation: " + std::to_string(generation) + ", newest generation: " + std::to_string(newestGeneration_)); return nullptr; } @@ -92,10 +88,7 @@ cipher_interface* aead_cipher_manager::get_cipher(key_generation generation) auto maxLifetimeFrames = MAX_FRAMES_PER_SECOND * ratchetLifetimeSec; auto maxLifetimeGenerations = maxLifetimeFrames >> RATCHET_GENERATION_SHIFT_BITS; if (generation > maxLifetimeGenerations) { - DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation - << " beyond ratchet max lifetime generations: " - << maxLifetimeGenerations - << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; + creator.log(dpp::ll_debug, "Received frame with generation " + std::to_string(generation) + " beyond ratchet max lifetime generations: " + std::to_string(maxLifetimeGenerations) + ", ratchet lifetime: " + std::to_string(ratchetLifetimeSec) + "s"); return nullptr; } @@ -144,14 +137,14 @@ void aead_cipher_manager::report_cipher_success(key_generation generation, trunc if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { return; } - DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; + creator.log(dpp::ll_trace, "Reporting cryptor success, generation: " + std::to_string(generation)); newestGeneration_ = generation; // Update the expiry time for all old cryptors const auto expiryTime = clock_.now() + CIPHER_EXPIRY; for (auto& [gen, cryptor] : cryptors_) { if (gen < newestGeneration_) { - DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; + creator.log(dpp::ll_trace, "Updating expiry for cryptor, generation: " + std::to_string(gen)); cryptor.expiry = std::min(cryptor.expiry, expiryTime); } } @@ -172,14 +165,14 @@ aead_cipher_manager::expiring_cipher aead_cipher_manager::make_expiring_cipher(k // In that case, create it with a non-infinite expiry time as we have already transitioned // to a newer generation if (generation < newestGeneration_) { - DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; + creator.log(dpp::ll_debug, "Creating cryptor for old generation: " + std::to_string(generation)); expiryTime = clock_.now() + CIPHER_EXPIRY; } else { - DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; + creator.log(dpp::ll_debug, "Creating cryptor for new generation: " + std::to_string(generation)); } - return {create_cipher(encryptionKey), expiryTime}; + return {create_cipher(creator, encryptionKey), expiryTime}; } void aead_cipher_manager::cleanup_expired_ciphers() @@ -189,15 +182,14 @@ void aead_cipher_manager::cleanup_expired_ciphers() bool expired = cryptor.expiry < clock_.now(); if (expired) { - DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; + creator.log(dpp::ll_trace, "Removing expired cryptor, generation: " + std::to_string(generation)); } it = expired ? cryptors_.erase(it) : ++it; } - while (oldestGeneration_ < newestGeneration_ && - cryptors_.find(oldestGeneration_) == cryptors_.end()) { - DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; + while (oldestGeneration_ < newestGeneration_ && cryptors_.find(oldestGeneration_) == cryptors_.end()) { + creator.log(dpp::ll_trace, "Deleting key for old generation: " + std::to_string(oldestGeneration_)); keyRatchet_->delete_key(oldestGeneration_); ++oldestGeneration_; } diff --git a/src/dpp/dave/cryptor_manager.h b/src/dpp/dave/cryptor_manager.h index 6a8d6f2cfc..774e87f23e 100755 --- a/src/dpp/dave/cryptor_manager.h +++ b/src/dpp/dave/cryptor_manager.h @@ -34,6 +34,10 @@ #include "common.h" #include "clock.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -71,10 +75,11 @@ class aead_cipher_manager { /** * @brief Constructor + * @param cl Creating cluster * @param clock chrono clock * @param keyRatchet key ratchet for cipher */ - aead_cipher_manager(const clock_interface& clock, std::unique_ptr keyRatchet); + aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr keyRatchet); /** * @brief Update cipher expiry @@ -155,6 +160,12 @@ class aead_cipher_manager { std::optional newestProcessedNonce_; std::deque missingNonces_; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; + }; } // namespace dpp::dave diff --git a/src/dpp/dave/decryptor.cpp b/src/dpp/dave/decryptor.cpp index 9dd0160f63..e2fa7272be 100755 --- a/src/dpp/dave/decryptor.cpp +++ b/src/dpp/dave/decryptor.cpp @@ -23,12 +23,10 @@ * ************************************************************************************/ #include "decryptor.h" - #include - +#include #include #include "common.h" -#include "logger.h" #include "leb128.h" #include "scope_exit.h" @@ -41,14 +39,15 @@ constexpr auto kStatsInterval = 10s; void decryptor::transition_to_key_ratchet(std::unique_ptr keyRatchet, duration transitionExpiry) { - DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() - << ", expiry: " << transitionExpiry.count(); + if (keyRatchet) { + creator.log(dpp::ll_trace, "Transitioning to new key ratchet, expiry: " + std::to_string(transitionExpiry.count())); + } // Update the expiry time for all existing cryptor managers update_cryptor_manager_expiry(transitionExpiry); if (keyRatchet) { - cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); + cryptorManagers_.emplace_back(creator, clock_, std::move(keyRatchet)); } } @@ -69,8 +68,7 @@ size_t decryptor::decrypt(media_type mediaType, array_view frame) { if (mediaType != media_audio && mediaType != media_video) { - DISCORD_LOG(LS_WARNING) << "decrypt failed, invalid media type: " - << static_cast(mediaType); + creator.log(dpp::ll_trace, "decrypt failed, invalid media type: " + std::to_string(static_cast(mediaType))); return 0; } @@ -82,9 +80,9 @@ size_t decryptor::decrypt(media_type mediaType, // Skip decrypting for silence frames if (mediaType == media_audio && encryptedFrame.size() == OPUS_SILENCE_PACKET.size() && std::memcmp(encryptedFrame.data(), OPUS_SILENCE_PACKET.data(), OPUS_SILENCE_PACKET.size()) == 0) { - DISCORD_LOG(LS_VERBOSE) << "decrypt skipping silence of size: " << encryptedFrame.size(); + creator.log(dpp::ll_trace, "decrypt skipping silence of size: " + std::to_string(encryptedFrame.size())); if (encryptedFrame.data() != frame.data()) { - std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); + std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); } return encryptedFrame.size(); } @@ -109,8 +107,7 @@ size_t decryptor::decrypt(media_type mediaType, // If the frame is not encrypted and we can't pass it through, fail if (!localFrame->is_encrypted()) { - DISCORD_LOG(LS_INFO) - << "decrypt failed, frame is not encrypted and pass through is disabled"; + creator.log(dpp::ll_warning, "decrypt failed, frame is not encrypted and pass through is disabled"); stats_[mediaType].decrypt_failure++; return 0; } @@ -133,24 +130,16 @@ size_t decryptor::decrypt(media_type mediaType, } else { stats_[mediaType].decrypt_failure++; - DISCORD_LOG(LS_WARNING) << "decrypt failed, no valid cryptor found, type: " - << (mediaType ? "video" : "audio") - << ", encrypted frame size: " << encryptedFrame.size() - << ", plaintext frame size: " << frame.size() - << ", number of cryptor managers: " << cryptorManagers_.size() - << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); + creator.log(dpp::ll_warning, "decrypt failed, no valid cryptor found, type: " + std::string(mediaType ? "video" : "audio") + + ", encrypted frame size: " + std::to_string(encryptedFrame.size()) + + ", plaintext frame size: " + std::to_string(frame.size()) + + ", number of cryptor managers: " + std::to_string(cryptorManagers_.size()) + + ", pass through enabled: " + std::string(canUsePassThrough ? "yes" : "no") + ); } auto end = clock_.now(); - if (end > lastStatsTime_ + kStatsInterval) { - lastStatsTime_ = end; - DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[media_audio].decrypt_success - << ", video: " << stats_[media_video].decrypt_success - << ". Failed audio: " << stats_[media_audio].decrypt_failure - << ", video: " << stats_[media_video].decrypt_failure; - } - stats_[mediaType].decrypt_duration += - std::chrono::duration_cast(end - start).count(); + stats_[mediaType].decrypt_duration += std::chrono::duration_cast(end - start).count(); return bytesWritten; } @@ -179,7 +168,7 @@ bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, cipher_manager.compute_wrapped_generation(truncatedNonce >> RATCHET_GENERATION_SHIFT_BITS); if (!cipher_manager.can_process_nonce(generation, truncatedNonce)) { - DISCORD_LOG(LS_INFO) << "decrypt failed, cannot process nonce: " << truncatedNonce; + creator.log(dpp::ll_trace, "decrypt failed, cannot process nonce"); return false; } @@ -187,7 +176,7 @@ bool decryptor::decrypt_impl(aead_cipher_manager& cipher_manager, cipher_interface* cipher = cipher_manager.get_cipher(generation); if (cipher == nullptr) { - DISCORD_LOG(LS_INFO) << "decrypt failed, no cryptor found for generation: " << generation; + creator.log(dpp::ll_warning, "decrypt failed, no cryptor found for generation: " + std::to_string(generation)); return false; } @@ -218,7 +207,7 @@ void decryptor::update_cryptor_manager_expiry(duration expiry) void decryptor::cleanup_expired_cryptor_managers() { while (!cryptorManagers_.empty() && cryptorManagers_.front().is_expired()) { - DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; + creator.log(dpp::ll_trace, "Removing expired cryptor manager"); cryptorManagers_.pop_front(); } } @@ -227,7 +216,7 @@ std::unique_ptr decryptor::get_or_create_frame_processo { std::lock_guard lock(frameProcessorsMutex_); if (frameProcessors_.empty()) { - return std::make_unique(); + return std::make_unique(creator); } auto frameProcessor = std::move(frameProcessors_.back()); frameProcessors_.pop_back(); diff --git a/src/dpp/dave/decryptor.h b/src/dpp/dave/decryptor.h index 9f48ac6cde..804c41c826 100755 --- a/src/dpp/dave/decryptor.h +++ b/src/dpp/dave/decryptor.h @@ -39,6 +39,10 @@ #include "version.h" #include "clock.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { class key_ratchet_interface; @@ -74,6 +78,12 @@ struct decryption_stats { */ class decryptor { public: + /** + * @brief Constructor + * @param cl Creator + */ + decryptor(dpp::cluster& cl) : creator(cl) { }; + /** * @brief Chrono duration */ @@ -179,6 +189,11 @@ class decryptor { time_point lastStatsTime_{time_point::min()}; std::array stats_; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; } // namespace dpp::dave diff --git a/src/dpp/dave/encryptor.cpp b/src/dpp/dave/encryptor.cpp index de5ffee028..b169662754 100755 --- a/src/dpp/dave/encryptor.cpp +++ b/src/dpp/dave/encryptor.cpp @@ -27,9 +27,9 @@ #include #include #include +#include #include "common.h" #include "cryptor_manager.h" -#include "logger.h" #include "codec_utils.h" #include "array_view.h" #include "leb128.h" @@ -63,8 +63,7 @@ encryptor::result_code encryptor::encrypt(media_type mediaType, size_t* bytesWritten) { if (mediaType != media_audio && mediaType != media_video) { - DISCORD_LOG(LS_WARNING) << "encrypt failed, invalid media type: " - << static_cast(mediaType); + creator.log(dpp::ll_warning, "encrypt failed, invalid media type: " + std::to_string(static_cast(mediaType))); return result_code::rc_encryption_failure; } @@ -209,17 +208,6 @@ encryptor::result_code encryptor::encrypt(media_type mediaType, stats_[mediaType].encrypt_failure++; } - if (now > lastStatsTime_ + kStatsInterval) { - lastStatsTime_ = now; - DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[media_audio].encrypt_success - << ", video: " << stats_[media_video].encrypt_success - << ". Failed audio: " << stats_[media_audio].encrypt_failure - << ", video: " << stats_[media_video].encrypt_failure; - DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " - << (mediaType == media_audio ? "audio" : "video") << ", ssrc: " << ssrc - << ", size: " << frame.size(); - } - return result; } @@ -262,7 +250,7 @@ std::unique_ptr encryptor::get_or_create_frame_process { std::lock_guard lock(frameProcessorsMutex_); if (frameProcessors_.empty()) { - return std::make_unique(); + return std::make_unique(creator); } auto frameProcessor = std::move(frameProcessors_.back()); frameProcessors_.pop_back(); @@ -289,7 +277,7 @@ encryptor::cryptor_and_nonce encryptor::get_next_cryptor_and_nonce() currentKeyGeneration_ = generation; auto encryptionKey = keyRatchet_->get_key(currentKeyGeneration_); - cryptor_ = create_cipher(encryptionKey); + cryptor_ = create_cipher(creator, encryptionKey); } return {cryptor_, truncatedNonce_}; diff --git a/src/dpp/dave/encryptor.h b/src/dpp/dave/encryptor.h index 1b2b8c798f..978238d744 100755 --- a/src/dpp/dave/encryptor.h +++ b/src/dpp/dave/encryptor.h @@ -38,6 +38,10 @@ #include "frame_processors.h" #include "version.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -73,6 +77,12 @@ struct encryption_stats { class encryptor { public: /** + * @brief Constructor + * @param cl Creator + */ + encryptor(dpp::cluster& cl) : creator(cl) { }; + + /** * @brief Return codes for encryptor::encrypt */ enum result_code : uint8_t { @@ -233,6 +243,11 @@ class encryptor { protocol_version_changed_callback protocolVersionChangedCallback_; protocol_version currentProtocolVersion_{max_protocol_version()}; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; } // namespace dpp::dave diff --git a/src/dpp/dave/frame_processors.cpp b/src/dpp/dave/frame_processors.cpp index 673c990a53..9cd8b8c4a6 100755 --- a/src/dpp/dave/frame_processors.cpp +++ b/src/dpp/dave/frame_processors.cpp @@ -28,8 +28,8 @@ #include #include #include +#include #include "codec_utils.h" -#include "logger.h" #include "array_view.h" #include "leb128.h" @@ -100,7 +100,6 @@ uint8_t deserialize_unencrypted_ranges(const uint8_t*& readAt, } if (readAt != end) { - DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; unencryptedRanges.clear(); readAt = nullptr; return 0; @@ -125,10 +124,6 @@ bool validate_unencrypted_ranges(const ranges& unencryptedRanges, size_t frameSi auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); if (didOverflow || currentEnd > maxEnd) { - DISCORD_LOG(LS_WARNING) - << "Unencrypted range may overlap or be out of order: current offset: " - << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd - << ", frame size: " << frameSize; return false; } } @@ -190,7 +185,7 @@ void inbound_frame_processor::parse_frame(array_view frame) constexpr auto MinSupplementalBytesSize = AES_GCM_127_TRUNCATED_TAG_BYTES + sizeof(supplemental_bytes_size) + sizeof(magic_marker); if (frame.size() < MinSupplementalBytesSize) { - DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; + creator.log(dpp::ll_warning, "Encrypted frame is too small to contain min supplemental bytes"); return; } @@ -207,14 +202,13 @@ void inbound_frame_processor::parse_frame(array_view frame) // Check the frame is large enough to contain the supplemental bytes if (frame.size() < supplementalBytesSize) { - DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; + creator.log(dpp::ll_warning, "Encrypted frame is too small to contain supplemental bytes"); return; } // Check that supplemental bytes size is large enough to contain the supplemental bytes if (supplementalBytesSize < MinSupplementalBytesSize) { - DISCORD_LOG(LS_WARNING) - << "Supplemental bytes size is too small to contain supplemental bytes"; + creator.log(dpp::ll_warning, "Supplemental bytes size is too small to contain supplemental bytes"); return; } @@ -229,7 +223,7 @@ void inbound_frame_processor::parse_frame(array_view frame) auto end = supplementalBytesSizeBuffer; truncatedNonce_ = read_leb128(readAt, end); if (readAt == nullptr) { - DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; + creator.log(dpp::ll_warning, "Failed to read truncated nonce"); return; } @@ -237,12 +231,12 @@ void inbound_frame_processor::parse_frame(array_view frame) auto unencryptedRangesSize = end - readAt; deserialize_unencrypted_ranges(readAt, unencryptedRangesSize, unencryptedRanges_); if (readAt == nullptr) { - DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; + creator.log(dpp::ll_warning, "Failed to read unencrypted ranges"); return; } if (!validate_unencrypted_ranges(unencryptedRanges_, frame.size())) { - DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; + creator.log(dpp::ll_warning, "Invalid unencrypted ranges"); return; } @@ -280,12 +274,12 @@ void inbound_frame_processor::parse_frame(array_view frame) size_t inbound_frame_processor::reconstruct_frame(array_view frame) const { if (!isEncrypted_) { - DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; + creator.log(dpp::ll_warning, "Cannot reconstruct an invalid encrypted frame"); return 0; } if (authenticated_.size() + plaintext_.size() > frame.size()) { - DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; + creator.log(dpp::ll_warning, "Frame is too small to contain the decrypted frame"); return 0; } @@ -359,7 +353,7 @@ void outbound_frame_processor::process_frame(array_view frame, co size_t outbound_frame_processor::reconstruct_frame(array_view frame) { if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { - DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; + creator.log(dpp::ll_warning, "Frame is too small to contain the encrypted frame"); return 0; } diff --git a/src/dpp/dave/frame_processors.h b/src/dpp/dave/frame_processors.h index 106ce01122..26e58d0b42 100755 --- a/src/dpp/dave/frame_processors.h +++ b/src/dpp/dave/frame_processors.h @@ -31,6 +31,10 @@ #include "common.h" #include "array_view.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -84,6 +88,12 @@ bool validate_unencrypted_ranges(const ranges& unencryptedRanges, size_t frameSi */ class inbound_frame_processor { public: + /** + * @brief Create inbound frame processor + * @param _creator creating cluster + */ + inbound_frame_processor(dpp::cluster& _creator) : creator(_creator) { }; + /** * @brief Parse inbound frame * @param frame frame bytes @@ -179,6 +189,11 @@ class inbound_frame_processor { std::vector authenticated_; std::vector ciphertext_; std::vector plaintext_; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; /** @@ -186,6 +201,12 @@ class inbound_frame_processor { */ class outbound_frame_processor { public: + /** + * @brief Create outbound frame processor + * @param _creator creating cluster + */ + outbound_frame_processor(dpp::cluster& _creator) : creator(_creator) { }; + /** * @brief Process outbound frame * @param frame frame data @@ -266,6 +287,11 @@ class outbound_frame_processor { std::vector encryptedBytes_; std::vector ciphertextBytes_; ranges unencryptedRanges_; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; } // namespace dpp::dave diff --git a/src/dpp/dave/logger.cpp b/src/dpp/dave/logger.cpp deleted file mode 100755 index b847e4ecd9..0000000000 --- a/src/dpp/dave/logger.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/************************************************************************************ - * - * D++, A Lightweight C++ library for Discord - * - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2021 Craig Edwards and D++ contributors - * (https://github.com/brainboxdotcc/DPP/graphs/contributors) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * This folder is a modified fork of libdave, https://github.com/discord/libdave - * Copyright (c) 2024 Discord, Licensed under MIT - * - ************************************************************************************/ -#include "logger.h" - -#include -#include -#include - -namespace dpp::dave { - -std::atomic gLogSink = nullptr; - -void SetLogSink(LogSink sink) -{ - gLogSink = sink; -} - -LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) - : severity_(severity) - , file_(file) - , line_(line) -{ -} - -LogStreamer::~LogStreamer() -{ - std::string logLine = stream_.str(); - if (logLine.empty()) { - return; - } - - auto sink = gLogSink.load(); - if (sink) { - sink(severity_, file_, line_, logLine); - return; - } - - switch (severity_) { - case LS_VERBOSE: - case LS_INFO: - case LS_WARNING: - case LS_ERROR: { - const char* file = file_; - if (auto separator = strrchr(file, '/')) { - file = separator + 1; - } - std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; - break; - } - case LS_NONE: - break; - } -} - -} // namespace dpp::dave - diff --git a/src/dpp/dave/logger.h b/src/dpp/dave/logger.h deleted file mode 100755 index 8b4dca8c14..0000000000 --- a/src/dpp/dave/logger.h +++ /dev/null @@ -1,70 +0,0 @@ -/************************************************************************************ - * - * D++, A Lightweight C++ library for Discord - * - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2021 Craig Edwards and D++ contributors - * (https://github.com/brainboxdotcc/DPP/graphs/contributors) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * This folder is a modified fork of libdave, https://github.com/discord/libdave - * Copyright (c) 2024 Discord, Licensed under MIT - * - ************************************************************************************/ -#pragma once - -#include - -#if !defined(DISCORD_LOG) -#define DISCORD_LOG_FILE_LINE(sev, file, line) ::dpp::dave::LogStreamer(sev, file, line) -#define DISCORD_LOG(sev) DISCORD_LOG_FILE_LINE(::dpp::dave::sev, __FILE__, __LINE__) -#endif - -namespace dpp::dave { - -enum LoggingSeverity { - LS_VERBOSE, - LS_INFO, - LS_WARNING, - LS_ERROR, - LS_NONE, -}; - -using LogSink = void (*)(LoggingSeverity severity, - const char* file, - int line, - const std::string& message); -void SetLogSink(LogSink sink); - -class LogStreamer { -public: - LogStreamer(LoggingSeverity severity, const char* file, int line); - ~LogStreamer(); - - template - LogStreamer& operator<<(const T& value) - { - stream_ << value; - return *this; - } - -private: - LoggingSeverity severity_; - const char* file_; - int line_; - std::ostringstream stream_; -}; - -} // namespace dpp::dave - diff --git a/src/dpp/dave/mls_key_ratchet.cpp b/src/dpp/dave/mls_key_ratchet.cpp index a30b46cefe..1f8b1149dc 100755 --- a/src/dpp/dave/mls_key_ratchet.cpp +++ b/src/dpp/dave/mls_key_ratchet.cpp @@ -23,12 +23,12 @@ * ************************************************************************************/ #include "mls_key_ratchet.h" -#include "logger.h" +#include namespace dpp::dave { -mls_key_ratchet::mls_key_ratchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept - : hashRatchet_(suite, std::move(baseSecret)) +mls_key_ratchet::mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes baseSecret) noexcept + : hashRatchet_(suite, std::move(baseSecret)), creator(cl) { } @@ -36,15 +36,13 @@ mls_key_ratchet::~mls_key_ratchet() noexcept = default; encryption_key mls_key_ratchet::get_key(key_generation generation) noexcept { - DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; - + creator.log(dpp::ll_debug, "Retrieving key for generation " + std::to_string(generation) + " from HashRatchet"); try { auto keyAndNonce = hashRatchet_.get(generation); return std::move(keyAndNonce.key.as_vec()); } catch (const std::exception& e) { - DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " - << e.what(); + creator.log(dpp::ll_warning, "Failed to retrieve key for generation " + std::to_string(generation) + ": " + std::string(e.what())); return {}; } } diff --git a/src/dpp/dave/mls_key_ratchet.h b/src/dpp/dave/mls_key_ratchet.h index 65f4f382b0..0a4d053819 100755 --- a/src/dpp/dave/mls_key_ratchet.h +++ b/src/dpp/dave/mls_key_ratchet.h @@ -27,6 +27,10 @@ #include #include "key_ratchet.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -39,7 +43,7 @@ class mls_key_ratchet : public key_ratchet_interface { // NOLINT * @param suite MLS ciphersuite to use * @param baseSecret base secret */ - mls_key_ratchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; + mls_key_ratchet(dpp::cluster& cl, ::mlspp::CipherSuite suite, bytes baseSecret) noexcept; /** * @brief Destructor @@ -64,6 +68,11 @@ class mls_key_ratchet : public key_ratchet_interface { // NOLINT * @brief MLS hash ratchet */ ::mlspp::HashRatchet hashRatchet_; + + /** + * @brief DPP Cluster, used for logging + */ + dpp::cluster& creator; }; } // namespace dpp::dave diff --git a/src/dpp/dave/openssl_aead_cipher.cpp b/src/dpp/dave/openssl_aead_cipher.cpp index 90532ffba6..fcb2cad352 100755 --- a/src/dpp/dave/openssl_aead_cipher.cpp +++ b/src/dpp/dave/openssl_aead_cipher.cpp @@ -23,23 +23,14 @@ #include #include #include +#include #include #include "common.h" -#include "logger.h" namespace dpp::dave { -void PrintSSLErrors() -{ - ERR_print_errors_cb( - [](const char* str, size_t len, void* ctx) -> int { - DISCORD_LOG(LS_ERROR) << std::string(str, len); - return 1; - }, - nullptr); -} - -openssl_aead_cipher::openssl_aead_cipher(const encryption_key& encryptionKey) : +openssl_aead_cipher::openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& encryptionKey) : + cipher_interface(_creator), cipherCtx_(EVP_CIPHER_CTX_new()), key_(std::vector(encryptionKey.data(), encryptionKey.data() + encryptionKey.size())) { } @@ -53,7 +44,7 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view int len{}; if (EVP_EncryptInit_ex(cipherCtx_, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -61,13 +52,13 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * Set IV length */ if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Initialise key and IV */ if (EVP_EncryptInit_ex(cipherCtx_, nullptr, nullptr, key_.data(), nonceBuffer.data()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -76,7 +67,7 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * required */ if (EVP_EncryptUpdate(cipherCtx_, nullptr, &len, additionalData.data(), (int)additionalData.size()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -85,7 +76,7 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * EVP_EncryptUpdate can be called multiple times if necessary */ if (EVP_EncryptUpdate(cipherCtx_, ciphertextBufferOut.data(), &len, plaintextBuffer.data(), (int)plaintextBuffer.size()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -94,13 +85,13 @@ bool openssl_aead_cipher::encrypt(byte_view ciphertextBufferOut, const_byte_view * this stage, but this does not occur in GCM mode */ if (EVP_EncryptFinal_ex(cipherCtx_, ciphertextBufferOut.data() + len, &len) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Get the tag */ if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_GET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, tagBufferOut.data()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -113,19 +104,19 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view /* Initialise the decryption operation. */ if (EVP_DecryptInit_ex(cipherCtx_, EVP_aes_128_gcm(), nullptr, nullptr, nullptr) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Set IV length. Not necessary if this is 12 bytes (96 bits) */ if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_128_NONCE_BYTES, nullptr) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Initialise key and IV */ if (EVP_DecryptInit_ex(cipherCtx_, nullptr, nullptr, key_.data(), nonceBuffer.data()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -134,7 +125,7 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * required */ if (EVP_DecryptUpdate(cipherCtx_, nullptr, &len, additionalData.data(), (int)additionalData.size()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -143,13 +134,13 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * EVP_DecryptUpdate can be called multiple times if necessary */ if (EVP_DecryptUpdate(cipherCtx_, plaintextBufferOut.data(), &len, ciphertextBuffer.data(), (int)ciphertextBuffer.size()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } /* Set expected tag value. Works in OpenSSL 1.0.1d and later */ if (EVP_CIPHER_CTX_ctrl(cipherCtx_, EVP_CTRL_GCM_SET_TAG, AES_GCM_127_TRUNCATED_TAG_BYTES, (void*)tagBuffer.data()) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } @@ -158,7 +149,7 @@ bool openssl_aead_cipher::decrypt(byte_view plaintextBufferOut, const_byte_view * anything else is a failure - the plaintext is not trustworthy. */ if (EVP_DecryptFinal_ex(cipherCtx_, plaintextBufferOut.data() + len, &len) == 0) { - PrintSSLErrors(); + creator.log(dpp::ll_warning, "SSL Error: " + std::to_string(ERR_get_error())); return false; } diff --git a/src/dpp/dave/openssl_aead_cipher.h b/src/dpp/dave/openssl_aead_cipher.h index 26070a91c0..b69ae2c322 100755 --- a/src/dpp/dave/openssl_aead_cipher.h +++ b/src/dpp/dave/openssl_aead_cipher.h @@ -26,6 +26,10 @@ #include #include "cipher_interface.h" +namespace dpp { + class cluster; +} + namespace dpp::dave { /** @@ -36,11 +40,13 @@ namespace dpp::dave { */ class openssl_aead_cipher : public cipher_interface { // NOLINT public: + /** * @brief constructor + * @param _creator Creator * @param encryptionKey encryption key */ - openssl_aead_cipher(const encryption_key& encryptionKey); + openssl_aead_cipher(dpp::cluster& _creator, const encryption_key& encryptionKey); /** * @brief Destructor diff --git a/src/dpp/dave/persisted_key_pair.cpp b/src/dpp/dave/persisted_key_pair.cpp index 12d2ee9b74..77ea4f415e 100755 --- a/src/dpp/dave/persisted_key_pair.cpp +++ b/src/dpp/dave/persisted_key_pair.cpp @@ -31,19 +31,16 @@ #include #include #include - +#include #include #include - #include "parameters.h" -#include "logger.h" static const std::string SelfSignatureLabel = "DiscordSelfSignature"; static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) { - return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + - std::to_string(dpp::dave::mls::KeyVersion); + return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + std::to_string(dpp::dave::mls::KeyVersion); } static std::mutex mtx; @@ -51,10 +48,7 @@ static std::map> map; namespace dpp::dave::mls { -static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( - key_pair_context_type ctx, - const std::string& sessionID, - ::mlspp::CipherSuite suite) +static std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, ::mlspp::CipherSuite suite) { std::lock_guard lk(mtx); @@ -64,10 +58,10 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( return it->second; } - std::shared_ptr<::mlspp::SignaturePrivateKey> ret = ::dpp::dave::mls::detail::get_generic_persisted_key_pair(ctx, id, suite); + std::shared_ptr<::mlspp::SignaturePrivateKey> ret = ::dpp::dave::mls::detail::get_generic_persisted_key_pair(creator, ctx, id, suite); if (!ret) { - DISCORD_LOG(LS_ERROR) << "Failed to get key in get_persisted_key_pair"; + creator.log(dpp::ll_warning, "Failed to get key in get_persisted_key_pair"); return nullptr; } @@ -76,20 +70,15 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( return ret; } -std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(key_pair_context_type ctx, - const std::string& sessionID, - protocol_version version) +std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, protocol_version version) { - return GetPersistedKeyPair(ctx, sessionID, ciphersuite_for_protocol_version(version)); + return get_persisted_key_pair(creator, ctx, sessionID, ciphersuite_for_protocol_version(version)); } -KeyAndSelfSignature get_persisted_public_key(key_pair_context_type ctx, - const std::string& sessionID, - signature_version version) +KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version) { auto suite = ciphersuite_for_signature_version(version); - - auto pair = GetPersistedKeyPair(ctx, sessionID, suite); + auto pair = get_persisted_key_pair(creator, ctx, sessionID, suite); if (!pair) { return {}; @@ -98,22 +87,17 @@ KeyAndSelfSignature get_persisted_public_key(key_pair_context_type ctx, bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; return { - pair->public_key.data.as_vec(), - std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), + pair->public_key.data.as_vec(), + std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), }; } -bool delete_persisted_key_pair(key_pair_context_type ctx, - const std::string& sessionID, - signature_version version) +bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version) { std::string id = MakeKeyID(sessionID, ciphersuite_for_signature_version(version)); - std::lock_guard lk(mtx); - map.erase(id); - - return ::dpp::dave::mls::detail::delete_generic_persisted_key_pair(ctx, id); + return ::dpp::dave::mls::detail::delete_generic_persisted_key_pair(creator, ctx, id); } } // namespace dpp::dave::mls diff --git a/src/dpp/dave/persisted_key_pair.h b/src/dpp/dave/persisted_key_pair.h index 803b4f65de..7c13c6578b 100755 --- a/src/dpp/dave/persisted_key_pair.h +++ b/src/dpp/dave/persisted_key_pair.h @@ -37,6 +37,10 @@ #include "parameters.h" #include "version.h" +namespace dpp { + class cluster; +} + namespace mlspp { struct SignaturePrivateKey; }; @@ -55,7 +59,7 @@ using key_pair_context_type = const char *; * @param version Protocol version * @return MLS signature private key */ -std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(key_pair_context_type ctx, const std::string& sessionID, protocol_version version); +std::shared_ptr<::mlspp::SignaturePrivateKey> get_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, protocol_version version); /** * @brief self signed signature and key @@ -78,7 +82,7 @@ struct KeyAndSelfSignature { * @param version protocol version * @return Key and self signature */ -KeyAndSelfSignature get_persisted_public_key(key_pair_context_type ctx, const std::string& sessionID, signature_version version); +KeyAndSelfSignature get_persisted_public_key(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version); /** * @brief Delete persisted key pair @@ -87,7 +91,7 @@ KeyAndSelfSignature get_persisted_public_key(key_pair_context_type ctx, const st * @param version protocol version * @return true if deleted */ -bool delete_persisted_key_pair(key_pair_context_type ctx, const std::string& sessionID, signature_version version); +bool delete_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& sessionID, signature_version version); /** * @brief Key version for DAVE @@ -102,7 +106,7 @@ namespace detail { * @param suite ciphersuite * @return signature and private key */ - std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key_pair_context_type ctx, const std::string& id, ::mlspp::CipherSuite suite); + std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id, ::mlspp::CipherSuite suite); /** * Delete generic persisted key pair @@ -110,7 +114,7 @@ namespace detail { * @param id id * @return true if deleted */ - bool delete_generic_persisted_key_pair(key_pair_context_type ctx, const std::string& id); + bool delete_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id); } // namespace dpp::dave::mls::detail } // namespace dpp::dave::mls diff --git a/src/dpp/dave/persisted_key_pair_generic.cpp b/src/dpp/dave/persisted_key_pair_generic.cpp index dc97f965ed..c0a473acc7 100755 --- a/src/dpp/dave/persisted_key_pair_generic.cpp +++ b/src/dpp/dave/persisted_key_pair_generic.cpp @@ -29,19 +29,16 @@ #include #include #include - #ifdef _WIN32 #include #else #include #endif #include - #include #include - +#include #include "parameters.h" -#include "logger.h" #include "persisted_key_pair.h" static const std::string_view KeyStorageDir = "Discord Key Storage"; @@ -83,24 +80,21 @@ static std::filesystem::path GetKeyStorageDirectory() namespace dpp::dave::mls::detail { -std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key_pair_context_type ctx, - const std::string& id, - ::mlspp::CipherSuite suite) +std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id, ::mlspp::CipherSuite suite) { ::mlspp::SignaturePrivateKey ret; std::string curstr; std::filesystem::path dir = GetKeyStorageDirectory(); if (dir.empty()) { - DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in get_persisted_key_pair"; + creator.log(dpp::ll_warning, "Failed to determine key storage directory in get_persisted_key_pair"); return nullptr; } std::error_code errc; std::filesystem::create_directories(dir, errc); if (errc) { - DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in get_persisted_key_pair: " - << errc; + creator.log(dpp::ll_warning, "Failed to create key storage directory in get_persisted_key_pair: " + std::to_string(errc.value())); return nullptr; } @@ -109,7 +103,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key if (std::filesystem::exists(file)) { std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); if (!ifs) { - DISCORD_LOG(LS_ERROR) << "Failed to open key in get_persisted_key_pair"; + creator.log(dpp::ll_warning, "Failed to open key in get_persisted_key_pair"); return nullptr; } @@ -117,7 +111,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key s << ifs.rdbuf(); curstr = s.str(); if (!ifs) { - DISCORD_LOG(LS_ERROR) << "Failed to read key in get_persisted_key_pair"; + creator.log(dpp::ll_warning, "Failed to read key in get_persisted_key_pair"); return nullptr; } @@ -125,7 +119,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); } catch (std::exception& ex) { - DISCORD_LOG(LS_ERROR) << "Failed to parse key in get_persisted_key_pair: " << ex.what(); + creator.log(dpp::ll_warning, "Failed to parse key in get_persisted_key_pair: " + std::string(ex.what())); return nullptr; } } @@ -145,8 +139,7 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key S_IRUSR | S_IWUSR); #endif if (fd < 0) { - DISCORD_LOG(LS_ERROR) << "Failed to open output file in get_persisted_key_pair: " << errno - << "(" << tmpfile << ")"; + creator.log(dpp::ll_warning, "Failed to open output file in get_persisted_key_pair: " + std::to_string(errno) + " (" + tmpfile.generic_string() + ")"); return nullptr; } @@ -158,15 +151,13 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key close(fd); #endif if (wret < 0 || (size_t)wret != newstr.size()) { - DISCORD_LOG(LS_ERROR) << "Failed to write output file in get_persisted_key_pair: " - << errno; + creator.log(dpp::ll_warning, "Failed to write output file in get_persisted_key_pair: " + std::to_string(errno)); return nullptr; } std::filesystem::rename(tmpfile, file, errc); if (errc) { - DISCORD_LOG(LS_ERROR) << "Failed to rename output file in get_persisted_key_pair: " - << errc; + creator.log(dpp::ll_warning, "Failed to rename output file in get_persisted_key_pair: " + std::to_string(errc.value())); return nullptr; } } @@ -178,17 +169,16 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> get_generic_persisted_key_pair(key } -bool delete_generic_persisted_key_pair(key_pair_context_type ctx, const std::string& id) +bool delete_generic_persisted_key_pair(dpp::cluster& creator, key_pair_context_type ctx, const std::string& id) { std::error_code errc; std::filesystem::path dir = GetKeyStorageDirectory(); if (dir.empty()) { - DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in get_persisted_key_pair"; + creator.log(dpp::ll_warning, "Failed to determine key storage directory in get_persisted_key_pair"); return false; } std::filesystem::path file = dir / (id + ".key"); - return std::filesystem::remove(file, errc); } diff --git a/src/dpp/dave/session.cpp b/src/dpp/dave/session.cpp index bdfb63024e..b619a2913a 100755 --- a/src/dpp/dave/session.cpp +++ b/src/dpp/dave/session.cpp @@ -41,9 +41,9 @@ #include "openssl/evp.h" -#define TRACK_MLS_ERROR(reason) \ - if (onMLSFailureCallback_) { \ - onMLSFailureCallback_(__FUNCTION__, reason); \ +#define TRACK_MLS_ERROR(reason) \ + if (onMLSFailureCallback_) { \ + onMLSFailureCallback_(__FUNCTION__, reason); \ } namespace dpp::dave::mls { @@ -61,8 +61,7 @@ session::session(dpp::cluster& cluster, key_pair_context_type context, const std session::~session() noexcept = default; -void session::init(protocol_version version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept -{ +void session::init(protocol_version version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept { reset(); selfUserId_ = selfUserId; @@ -76,8 +75,7 @@ void session::init(protocol_version version, uint64_t groupId, std::string const create_pending_group(); } -void session::reset() noexcept -{ +void session::reset() noexcept { creator.log(dpp::ll_debug, "Resetting MLS session"); clear_pending_state(); @@ -89,8 +87,7 @@ void session::reset() noexcept groupId_.clear(); } -void session::set_protocol_version(protocol_version version) noexcept -{ +void session::set_protocol_version(protocol_version version) noexcept { if (version != protocolVersion_) { // when we need to retain backwards compatibility // there may be some changes to the MLS objects required here @@ -99,13 +96,11 @@ void session::set_protocol_version(protocol_version version) noexcept } } -std::vector session::get_last_epoch_authenticator() const noexcept -{ +std::vector session::get_last_epoch_authenticator() const noexcept { if (!currentState_) { creator.log(dpp::ll_debug, "Cannot get epoch authenticator without an established MLS group"); return {}; } - return std::move(currentState_->epoch_authenticator().as_vec()); } @@ -131,9 +126,7 @@ catch (const std::exception& e) { return; } -std::optional> session::process_proposals( - std::vector proposals, - std::set const& recognizedUserIDs) noexcept +std::optional> session::process_proposals(std::vector proposals, std::set const& recognizedUserIDs) noexcept try { if (!pendingGroupState_ && !currentState_) { creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state"); @@ -267,8 +260,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const -{ +bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& targetState, std::set const& recognizedUserIDs) const { if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage"); TRACK_MLS_ERROR("Invalid proposal wire format"); @@ -526,14 +518,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, return true; } -void session::init_leaf_node(std::string const& selfUserId, - std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept +void session::init_leaf_node(std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept try { auto ciphersuite = ciphersuite_for_protocol_version(protocolVersion_); if (!transientKey) { if (!signingKeyId_.empty()) { - transientKey = get_persisted_key_pair(keyPairContext_, signingKeyId_, protocolVersion_); + transientKey = get_persisted_key_pair(creator, keyPairContext_, signingKeyId_, protocolVersion_); if (!transientKey) { creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting"); return; @@ -549,18 +540,11 @@ try { auto selfCredential = create_user_credential(selfUserId, protocolVersion_); - selfHPKEPrivateKey_ = - std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + selfHPKEPrivateKey_ = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); - selfLeafNode_ = - std::make_unique<::mlspp::LeafNode>(ciphersuite, - selfHPKEPrivateKey_->public_key, - selfSigPrivateKey_->public_key, - std::move(selfCredential), - leaf_node_capabilities_for_protocol_version(protocolVersion_), - ::mlspp::Lifetime::create_default(), - leaf_node_extensions_for_protocol_version(protocolVersion_), - *selfSigPrivateKey_); + selfLeafNode_ = std::make_unique<::mlspp::LeafNode>(ciphersuite, selfHPKEPrivateKey_->public_key, selfSigPrivateKey_->public_key, std::move(selfCredential), + leaf_node_capabilities_for_protocol_version(protocolVersion_), ::mlspp::Lifetime::create_default(), + leaf_node_extensions_for_protocol_version(protocolVersion_), *selfSigPrivateKey_); creator.log(dpp::ll_debug, "Created MLS leaf node"); } @@ -667,7 +651,7 @@ std::unique_ptr session::get_key_ratchet(std::string cons // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key // would need to be updated to a different ciphersuite if there's a future mismatch - return std::make_unique(currentState_->cipher_suite(), std::move(baseSecret)); + return std::make_unique(creator, currentState_->cipher_suite(), std::move(baseSecret)); } void session::get_pairwise_fingerprint(uint16_t version, diff --git a/src/dpp/dave/user_credential.cpp b/src/dpp/dave/user_credential.cpp index 592cfcc004..f5fd9a6493 100755 --- a/src/dpp/dave/user_credential.cpp +++ b/src/dpp/dave/user_credential.cpp @@ -28,8 +28,7 @@ namespace dpp::dave::mls { -::mlspp::Credential create_user_credential(const std::string& userId, protocol_version version) -{ +::mlspp::Credential create_user_credential(const std::string& userId, protocol_version version) { // convert the string user ID to a big endian uint64_t auto userID = std::stoull(userId); auto credentialBytes = big_endian_bytes_from(userID); @@ -37,8 +36,7 @@ ::mlspp::Credential create_user_credential(const std::string& userId, protocol_v return ::mlspp::Credential::basic(credentialBytes); } -std::string user_credential_to_string(const ::mlspp::Credential& cred, protocol_version version) -{ +std::string user_credential_to_string(const ::mlspp::Credential& cred, protocol_version version) { if (cred.type() != ::mlspp::CredentialType::basic) { return ""; } diff --git a/src/dpp/dave/util.cpp b/src/dpp/dave/util.cpp index 8708d56f9d..1abdcc2faf 100755 --- a/src/dpp/dave/util.cpp +++ b/src/dpp/dave/util.cpp @@ -26,8 +26,7 @@ namespace dpp::dave::mls { -::mlspp::bytes_ns::bytes big_endian_bytes_from(uint64_t value) noexcept -{ +::mlspp::bytes_ns::bytes big_endian_bytes_from(uint64_t value) noexcept { auto buffer = ::mlspp::bytes_ns::bytes(); buffer.reserve(sizeof(value)); @@ -38,8 +37,7 @@ ::mlspp::bytes_ns::bytes big_endian_bytes_from(uint64_t value) noexcept return buffer; } -uint64_t from_big_endian_bytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept -{ +uint64_t from_big_endian_bytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept { uint64_t val = 0; if (buffer.size() <= sizeof(val)) { diff --git a/src/dpp/dave/version.cpp b/src/dpp/dave/version.cpp index b0f5b68928..34a5ac8269 100755 --- a/src/dpp/dave/version.cpp +++ b/src/dpp/dave/version.cpp @@ -26,11 +26,11 @@ namespace dpp::dave { -constexpr protocol_version CurrentDaveProtocolVersion = 1; +constexpr protocol_version current_dave_protocol_version = 1; protocol_version max_protocol_version() { - return CurrentDaveProtocolVersion; + return current_dave_protocol_version; } } // namespace dpp::dave diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index c2900dfae9..5ade5ac4f0 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -47,7 +47,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod mls_state->dave_session->set_external_sender(dave_header.get_data()); - mls_state->encryptor = std::make_unique(); + mls_state->encryptor = std::make_unique(*creator); mls_state->decryptors.clear(); } break; @@ -69,7 +69,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod for (const auto& user : dave_mls_user_list) { log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version())); dpp::snowflake u{user}; - mls_state->decryptors.emplace(u, std::make_unique()); + mls_state->decryptors.emplace(u, std::make_unique(*creator)); mls_state->decryptors.find(u)->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user)); } mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str())); @@ -99,7 +99,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod for (const auto& user : dave_mls_user_list) { log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version())); dpp::snowflake u{user}; - mls_state->decryptors.emplace(u, std::make_unique()); + mls_state->decryptors.emplace(u, std::make_unique(*creator)); mls_state->decryptors.find(u)->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user)); } mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str()));