Skip to content

Commit

Permalink
remove cout based logger and replace with a dpp cluster log
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 7, 2024
1 parent ac755a0 commit 89d549a
Show file tree
Hide file tree
Showing 25 changed files with 250 additions and 404 deletions.
5 changes: 3 additions & 2 deletions src/dpp/dave/cipher_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
************************************************************************************/
#include "cipher_interface.h"
#include "openssl_aead_cipher.h"
#include <dpp/cluster.h>

namespace dpp::dave {

std::unique_ptr<cipher_interface> create_cipher(const encryption_key& encryptionKey)
std::unique_ptr<cipher_interface> create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey)
{
auto cipher = std::make_unique<openssl_aead_cipher>(encryptionKey);
auto cipher = std::make_unique<openssl_aead_cipher>(cl, encryptionKey);
return cipher->is_valid() ? std::move(cipher) : nullptr;
}

Expand Down
19 changes: 18 additions & 1 deletion src/dpp/dave/cipher_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include "common.h"
#include "array_view.h"

namespace dpp {
class cluster;
}

namespace dpp::dave {

/**
Expand All @@ -47,6 +51,12 @@ using byte_view = array_view<uint8_t>;
*/
class cipher_interface { // NOLINT
public:
/**
* @brief Create cipher interface
* @param _creator Creating cluster
*/
cipher_interface(dpp::cluster& _creator) : creator(_creator) { };

/**
* @brief Default destructor
*/
Expand All @@ -73,14 +83,21 @@ class cipher_interface { // NOLINT
* @return true if decryption succeeded, false if it failed
*/
virtual bool decrypt(byte_view plaintextBufferOut, const_byte_view ciphertextBuffer, const_byte_view tagBuffer, const_byte_view nonceBuffer, const_byte_view additionalData) = 0;

protected:

/**
* @brief DPP Cluster, used for logging
*/
dpp::cluster& creator;
};

/**
* @brief Factory function to create new cipher interface of the best supported type for DAVE
* @param encryptionKey encryption key
* @return an instance of a class derived from cipher_interface
*/
std::unique_ptr<cipher_interface> create_cipher(const encryption_key& encryptionKey);
std::unique_ptr<cipher_interface> create_cipher(dpp::cluster& cl, const encryption_key& encryptionKey);

} // namespace dpp::dave

46 changes: 15 additions & 31 deletions src/dpp/dave/codec_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
#include <limits>
#include <optional>
#include <dpp/exception.h>

#include "logger.h"
#include "leb128.h"

namespace dpp::dave::codec_utils {

unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload,
const uint64_t sizeRemaining)
unencrypted_frame_header_size BytesCoveringH264PPS(const uint8_t* payload, const uint64_t sizeRemaining)
{
// the payload starts with three exponential golomb encoded values
// (first_mb_in_slice, sps_id, pps_id)
Expand Down Expand Up @@ -212,29 +209,23 @@ bool process_frame_h264(outbound_frame_processor& processor, array_view<const ui
// always write a long start code and then the NAL unit
processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode));

auto nextNaluIndexPair =
FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex);
auto nextNaluStart = nextNaluIndexPair.has_value()
? nextNaluIndexPair->first - nextNaluIndexPair->second
: frame.size();
auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex);
auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size();

if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) {
// once we've hit a slice or an IDR
// we just need to cover getting to the PPS ID
auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize;
auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart,
frame.size() - nalUnitPayloadStart);
auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, frame.size() - nalUnitPayloadStart);

processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex,
kH264NalUnitHeaderSize + nalUnitPPSBytes);
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH264NalUnitHeaderSize + nalUnitPPSBytes);
processor.add_encrypted_bytes(
frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes,
nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes);
}
else {
// copy the whole NAL unit
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex,
nextNaluStart - nalUnitStartIndex);
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex);
}

naluIndexPair = nextNaluIndexPair;
Expand Down Expand Up @@ -274,22 +265,16 @@ bool process_frame_h265(outbound_frame_processor& processor, array_view<const ui
// always write a long start code and then the NAL unit
processor.add_unencrypted_bytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode));

auto nextNaluIndexPair =
FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex);
auto nextNaluStart = nextNaluIndexPair.has_value()
? nextNaluIndexPair->first - nextNaluIndexPair->second
: frame.size();
auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex);
auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size();

if (nalType < kH265NalTypeVclCutoff) {
// found a VCL NAL, encrypt the payload only
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize);
processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize,
nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize);
}
else {
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize);
processor.add_encrypted_bytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize);
} else {
// copy the whole NAL unit
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex,
nextNaluStart - nalUnitStartIndex);
processor.add_unencrypted_bytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex);
}

naluIndexPair = nextNaluIndexPair;
Expand Down Expand Up @@ -372,8 +357,7 @@ bool process_frame_av1(outbound_frame_processor& processor, array_view<const uin
processor.add_unencrypted_bytes(&obuHeader, sizeof(obuHeader));
if (obuHasExtension) {
// write the extension byte unencrypted
processor.add_unencrypted_bytes(frame.data() + obuHeaderIndex + sizeof(obuHeader),
kObuExtensionSizeBytes);
processor.add_unencrypted_bytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), kObuExtensionSizeBytes);
}

// write the OBU payload size unencrypted if it was present and we didn't rewrite
Expand All @@ -384,11 +368,11 @@ bool process_frame_av1(outbound_frame_processor& processor, array_view<const uin
// we sanitize the size by re-writing it ourselves
uint8_t leb128Buffer[LEB128_MAX_SIZE];
size_t additionalBytesToWrite = write_leb128(obuPayloadSize, leb128Buffer);
processor.add_unencrypted_bytes(leb128Buffer, additionalBytesToWrite);
processor.add_unencrypted_bytes(leb128Buffer, additionalBytesToWrite);
}

// add the OBU payload, encrypted
processor.add_encrypted_bytes(frame.data() + obuPayloadIndex, obuPayloadSize);
processor.add_encrypted_bytes(frame.data() + obuPayloadIndex, obuPayloadSize);
}
}

Expand Down
36 changes: 14 additions & 22 deletions src/dpp/dave/cryptor_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@
*
************************************************************************************/
#include "cryptor_manager.h"

#include <limits>

#include "key_ratchet.h"
#include "logger.h"

#include <bytes/bytes.h>
#include <dpp/cluster.h>

using namespace std::chrono_literals;

Expand All @@ -52,11 +49,12 @@ big_nonce compute_wrapped_big_nonce(key_generation generation, truncated_sync_no
return static_cast<big_nonce>(generation) << RATCHET_GENERATION_SHIFT_BITS | maskedNonce;
}

aead_cipher_manager::aead_cipher_manager(const clock_interface& clock, std::unique_ptr<key_ratchet_interface> keyRatchet)
aead_cipher_manager::aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr<key_ratchet_interface> keyRatchet)
: clock_(clock)
, keyRatchet_(std::move(keyRatchet))
, ratchetCreation_(clock.now())
, ratchetExpiry_(time_point::max())
, creator(cl)
{
}

Expand All @@ -76,14 +74,12 @@ cipher_interface* aead_cipher_manager::get_cipher(key_generation generation)
cleanup_expired_ciphers();

if (generation < oldestGeneration_) {
DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation
<< ", oldest generation: " << oldestGeneration_;
creator.log(dpp::ll_trace, "Received frame with old generation: " + std::to_string(generation) + ", oldest generation: " + std::to_string(oldestGeneration_));
return nullptr;
}

if (generation > newestGeneration_ + MAX_GENERATION_GAP) {
DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation
<< ", newest generation: " << newestGeneration_;
creator.log(dpp::ll_trace, "Received frame with future generation: " + std::to_string(generation) + ", newest generation: " + std::to_string(newestGeneration_));
return nullptr;
}

Expand All @@ -92,10 +88,7 @@ cipher_interface* aead_cipher_manager::get_cipher(key_generation generation)
auto maxLifetimeFrames = MAX_FRAMES_PER_SECOND * ratchetLifetimeSec;
auto maxLifetimeGenerations = maxLifetimeFrames >> RATCHET_GENERATION_SHIFT_BITS;
if (generation > maxLifetimeGenerations) {
DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation
<< " beyond ratchet max lifetime generations: "
<< maxLifetimeGenerations
<< ", ratchet lifetime: " << ratchetLifetimeSec << "s";
creator.log(dpp::ll_debug, "Received frame with generation " + std::to_string(generation) + " beyond ratchet max lifetime generations: " + std::to_string(maxLifetimeGenerations) + ", ratchet lifetime: " + std::to_string(ratchetLifetimeSec) + "s");
return nullptr;
}

Expand Down Expand Up @@ -144,14 +137,14 @@ void aead_cipher_manager::report_cipher_success(key_generation generation, trunc
if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) {
return;
}
DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation;
creator.log(dpp::ll_trace, "Reporting cryptor success, generation: " + std::to_string(generation));
newestGeneration_ = generation;

// Update the expiry time for all old cryptors
const auto expiryTime = clock_.now() + CIPHER_EXPIRY;
for (auto& [gen, cryptor] : cryptors_) {
if (gen < newestGeneration_) {
DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen;
creator.log(dpp::ll_trace, "Updating expiry for cryptor, generation: " + std::to_string(gen));
cryptor.expiry = std::min(cryptor.expiry, expiryTime);
}
}
Expand All @@ -172,14 +165,14 @@ aead_cipher_manager::expiring_cipher aead_cipher_manager::make_expiring_cipher(k
// In that case, create it with a non-infinite expiry time as we have already transitioned
// to a newer generation
if (generation < newestGeneration_) {
DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation;
creator.log(dpp::ll_debug, "Creating cryptor for old generation: " + std::to_string(generation));
expiryTime = clock_.now() + CIPHER_EXPIRY;
}
else {
DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation;
creator.log(dpp::ll_debug, "Creating cryptor for new generation: " + std::to_string(generation));
}

return {create_cipher(encryptionKey), expiryTime};
return {create_cipher(creator, encryptionKey), expiryTime};
}

void aead_cipher_manager::cleanup_expired_ciphers()
Expand All @@ -189,15 +182,14 @@ void aead_cipher_manager::cleanup_expired_ciphers()

bool expired = cryptor.expiry < clock_.now();
if (expired) {
DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation;
creator.log(dpp::ll_trace, "Removing expired cryptor, generation: " + std::to_string(generation));
}

it = expired ? cryptors_.erase(it) : ++it;
}

while (oldestGeneration_ < newestGeneration_ &&
cryptors_.find(oldestGeneration_) == cryptors_.end()) {
DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_;
while (oldestGeneration_ < newestGeneration_ && cryptors_.find(oldestGeneration_) == cryptors_.end()) {
creator.log(dpp::ll_trace, "Deleting key for old generation: " + std::to_string(oldestGeneration_));
keyRatchet_->delete_key(oldestGeneration_);
++oldestGeneration_;
}
Expand Down
13 changes: 12 additions & 1 deletion src/dpp/dave/cryptor_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
#include "common.h"
#include "clock.h"

namespace dpp {
class cluster;
}

namespace dpp::dave {

/**
Expand Down Expand Up @@ -71,10 +75,11 @@ class aead_cipher_manager {

/**
* @brief Constructor
* @param cl Creating cluster
* @param clock chrono clock
* @param keyRatchet key ratchet for cipher
*/
aead_cipher_manager(const clock_interface& clock, std::unique_ptr<key_ratchet_interface> keyRatchet);
aead_cipher_manager(dpp::cluster& cl, const clock_interface& clock, std::unique_ptr<key_ratchet_interface> keyRatchet);

/**
* @brief Update cipher expiry
Expand Down Expand Up @@ -155,6 +160,12 @@ class aead_cipher_manager {

std::optional<big_nonce> newestProcessedNonce_;
std::deque<big_nonce> missingNonces_;

/**
* @brief DPP Cluster, used for logging
*/
dpp::cluster& creator;

};

} // namespace dpp::dave
Expand Down
Loading

0 comments on commit 89d549a

Please sign in to comment.