Skip to content

Commit

Permalink
move dave session into mls state
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 3, 2024
1 parent 89e412e commit 5afa97c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
7 changes: 1 addition & 6 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ inline constexpr size_t send_audio_raw_max_length = 11520;
inline constexpr size_t secret_key_size = 32;

struct dave_state;
struct dave_encryptors;

/*
* @brief For holding a moving average of the number of current voice users, for applying a smooth gain ramp.
Expand Down Expand Up @@ -413,9 +412,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
OpusRepacketizer* repacketizer;

std::unique_ptr<dave::mls::Session> dave_session{};

std::unique_ptr<dave_state> mls_state{};
std::unique_ptr<dave_state> mls_state;

#else
/**
Expand All @@ -429,8 +426,6 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
void* repacketizer;

std::unique_ptr<int> dave_session{};

std::unique_ptr<int> mls_state{};
#endif

Expand Down
41 changes: 21 additions & 20 deletions src/dpp/discordvoiceclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ constexpr uint8_t voice_protocol_version = 8;
static std::string external_ip;

struct dave_state {
std::unique_ptr<dave::mls::Session> dave_session{};
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
std::vector<uint8_t> cached_commit;
uint64_t transition_id{0};
Expand Down Expand Up @@ -538,14 +539,14 @@ void discord_voice_client::get_user_privacy_code(const dpp::snowflake user, priv
callback("");
return;
}
dave_session->GetPairwiseFingerprint(0x0000, user.str(), [callback](const std::vector<uint8_t>& data) {
mls_state->dave_session->GetPairwiseFingerprint(0x0000, user.str(), [callback](const std::vector<uint8_t>& data) {
std::cout << dpp::utility::debug_dump((uint8_t*)data.data(), data.size());
callback(data.size() == 64 ? generate_displayable_code(data, 45) : "");
});
}

bool discord_voice_client::is_end_to_end_encrypted() const {
return dave_session && mls_state && !mls_state->privacy_code.empty();
return mls_state && !mls_state->privacy_code.empty();
}

bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcode) {
Expand All @@ -562,7 +563,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
case voice_client_dave_mls_external_sender: {
log(ll_debug, "voice_client_dave_mls_external_sender");

dave_session->SetExternalSender(dave_header->get_data(data.length()));
mls_state->dave_session->SetExternalSender(dave_header->get_data(data.length()));

mls_state->encryptor = std::make_unique<dave::Encryptor>();
mls_state->decryptors.clear();
Expand All @@ -571,7 +572,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
case voice_client_dave_mls_proposals: {
log(ll_debug, "voice_client_dave_mls_proposals");

std::optional<std::vector<uint8_t>> response = dave_session->ProcessProposals(dave_header->get_data(data.length()), dave_mls_user_list);
std::optional<std::vector<uint8_t>> response = mls_state->dave_session->ProcessProposals(dave_header->get_data(data.length()), dave_mls_user_list);
if (response.has_value()) {
auto r = response.value();
mls_state->cached_commit = r;
Expand All @@ -582,39 +583,39 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
break;
case voice_client_dave_announce_commit_transaction: {
log(ll_debug, "voice_client_dave_announce_commit_transaction");
auto r = dave_session->ProcessCommit(mls_state->cached_commit);
auto r = mls_state->dave_session->ProcessCommit(mls_state->cached_commit);
for (const auto& user : dave_mls_user_list) {
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->GetProtocolVersion()));
dpp::snowflake u{user};
mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::Decryptor>());
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(mls_state->dave_session->GetKeyRatchet(user));
}
mls_state->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
mls_state->encryptor->SetKeyRatchet(mls_state->dave_session->GetKeyRatchet(creator->me.id.str()));

/**
* https://www.ietf.org/archive/id/draft-ietf-mls-protocol-14.html#name-epoch-authenticators
* 9.7. Epoch Authenticators
* The main MLS key schedule provides a per-epoch epoch_authenticator. If one member of the group is being impersonated by an active attacker,
* the epoch_authenticator computed by their client will differ from those computed by the other group members.
*/
mls_state->privacy_code = generate_displayable_code(dave_session->GetLastEpochAuthenticator());
mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->GetLastEpochAuthenticator());
log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code);
}
break;
case voice_client_dave_mls_welcome: {
this->mls_state->transition_id = dave_header->get_welcome_transition_id();
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));
auto r = dave_session->ProcessWelcome(dave_header->get_welcome_data(data.length()), dave_mls_user_list);
auto r = mls_state->dave_session->ProcessWelcome(dave_header->get_welcome_data(data.length()), dave_mls_user_list);
if (r.has_value()) {
for (const auto& user : dave_mls_user_list) {
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->GetProtocolVersion()));
dpp::snowflake u{user};
mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::Decryptor>());
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(mls_state->dave_session->GetKeyRatchet(user));
}
mls_state->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
mls_state->encryptor->SetKeyRatchet(mls_state->dave_session->GetKeyRatchet(creator->me.id.str()));
}
mls_state->privacy_code = generate_displayable_code(dave_session->GetLastEpochAuthenticator());
mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->GetLastEpochAuthenticator());
log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code);
}
break;
Expand Down Expand Up @@ -704,8 +705,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
uint64_t epoch = j["d"]["epoch"];
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch));
if (epoch == 1) {
dave_session->Reset();
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
mls_state->dave_session->Reset();
mls_state->dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
}
}
break;
Expand Down Expand Up @@ -817,13 +818,13 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
send_silence(20);
}

dave_session = std::make_unique<dave::mls::Session>(
mls_state = std::make_unique<dave_state>();
mls_state->dave_session = std::make_unique<dave::mls::Session>(
nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) {
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
});
mls_state = std::make_unique<dave_state>();
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
auto key_response = dave_session->GetMarshalledKeyPackage();
mls_state->dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
auto key_response = mls_state->dave_session->GetMarshalledKeyPackage();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);

Expand Down

0 comments on commit 5afa97c

Please sign in to comment.