Skip to content

Commit

Permalink
fix: cleanup and fix upgrade not working
Browse files Browse the repository at this point in the history
  • Loading branch information
Neko-Life committed Oct 14, 2024
1 parent 91ee46c commit bc7dc97
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 49 deletions.
19 changes: 8 additions & 11 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,10 @@ class DPP_EXPORT discord_voice_client : public websocket_client
std::set<std::string> dave_mls_user_list;

/**
* @brief The list of users that just joined for DAVE
* @brief The list of users that have left the voice channel but
* not yet removed from MLS group.
*/
std::set<std::string> dave_mls_new_user_list;
std::set<std::string> dave_mls_pending_remove_list;

/**
* @brief File descriptor for UDP connection
Expand Down Expand Up @@ -583,11 +584,6 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
dave_version_t dave_version;

/**
* @brief Current dave session key generation.
*/
uint32_t dave_current_epoch;

/**
* @brief Send data to UDP socket immediately.
*
Expand Down Expand Up @@ -1206,7 +1202,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
void get_user_privacy_code(const dpp::snowflake user, privacy_code_callback_t callback) const;

/*
/**
* @brief Notify gateway ready for a DAVE transition.
*
* Fires Voice Ready event when appropriate.
Expand All @@ -1225,18 +1221,19 @@ class DPP_EXPORT discord_voice_client : public websocket_client

/**
* @brief Execute pending protocol upgrade/downgrade to/from dave.
* @return true if did an upgrade/downgrade
*/
void execute_pending_upgrade_downgrade();
bool execute_pending_upgrade_downgrade();

/**
* @brief Reset dave session and prepare initial session group.
*/
void reinit_dave_mls_group();
void reinit_dave_mls_group();

/**
* @brief Process roster map from commit/welcome.
*/
void process_mls_group_rosters(const std::map<unsigned long, std::vector<unsigned char>> &rmap);
void process_mls_group_rosters(const std::map<unsigned long, std::vector<unsigned char>> &rmap);
};

} // namespace dpp
Expand Down
8 changes: 6 additions & 2 deletions src/dpp/discordvoiceclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,17 @@ void discord_voice_client::get_user_privacy_code(const dpp::snowflake user, priv

bool discord_voice_client::is_end_to_end_encrypted() const {
#ifdef HAVE_VOICE
if (mls_state == nullptr) {
if (mls_state == nullptr || mls_state->encryptor == nullptr) {
return false;
}

bool has_pending_downgrade = mls_state->pending_transition.is_pending && mls_state->pending_transition.protocol_version != dave_version_1;

return !has_pending_downgrade && !mls_state->privacy_code.empty();
/*
* A dave_version 0 should be enough to know we're in non-e2ee session, we should also check for pending downgrade and
* whether session encryptor actually has key rachet set to encrypt opus packets.
*/
return !has_pending_downgrade && dave_version != dave_version_none && mls_state->encryptor->has_key_ratchet();
#else
return false;
#endif
Expand Down
1 change: 0 additions & 1 deletion src/dpp/voice/enabled/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ discord_voice_client::discord_voice_client(dpp::cluster* _cluster, snowflake _ch
sending(false),
tracks(0),
dave_version(enable_dave ? dave_version_1 : dave_version_none),
dave_current_epoch(0),
creator(_cluster),
terminating(false),
heartbeat_interval(0),
Expand Down
8 changes: 2 additions & 6 deletions src/dpp/voice/enabled/enabled.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ struct dave_state {
* @brief Our key package
*/
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
/**
* @brief Cached commit package for use in welcome
*/
std::vector<uint8_t> cached_commit;
/**
* @brief Current transition ID
*/
Expand Down Expand Up @@ -131,9 +127,9 @@ struct dave_state {
std::string privacy_code;

/**
* @brief Cached roster map.
* @brief Cached roster map to track rosters changes.
*/
dave::roster_map cached_roster_map;
dave::roster_map cached_roster_map;
};

/**
Expand Down
101 changes: 72 additions & 29 deletions src/dpp/voice/enabled/handle_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ void discord_voice_client::update_ratchets(bool force) {
}
decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY);
}
/* No expiry on sender! It's up to the receiver to decide when to discard their old keys */
if (mls_state->encryptor // && !mls_state->encryptor->has_key_ratchet()
) {

/*
* Encryptor should always be present on execute transition.
* Should we throw error if it's missing here?
*/
if (mls_state->encryptor) {
/* Updating key rachet should always be done on execute transition. Generally after group member add/remove. */
log(ll_debug, "Setting key ratchet for sending audio...");
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str()));
}
Expand Down Expand Up @@ -140,21 +144,23 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
case voice_client_dave_mls_welcome: {
this->mls_state->transition_id = dave_header.get_transition_id();
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));
// dave_mls_user_list.erase(creator->me.id.str());

/* We should always recognize our own selves, but do we? */
dave_mls_user_list.insert(this->creator->me.id.str());

auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list);

/*
* We need to do recovery here when we failed processing the message
*/
if (!r) {
if (!r.has_value()) {
log(ll_debug, "Unable to process welcome in transition " + std::to_string(this->mls_state->transition_id));

this->recover_from_invalid_commit_welcome();
break;
}

this->process_mls_group_rosters(r.value());

this->ready_for_transition(data);
}
break;
Expand Down Expand Up @@ -207,9 +213,19 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
}
break;
case voice_opcode_multiple_clients_connect: {
dave_mls_new_user_list = j["d"]["user_ids"];
dave_mls_user_list.insert(dave_mls_new_user_list.begin(), dave_mls_new_user_list.end());
log(ll_debug, "New of clients in voice channel: " + std::to_string(dave_mls_new_user_list.size()) + " total is " + std::to_string(dave_mls_user_list.size()));
/**
* @brief The list of users that just joined for DAVE
*/
std::set<std::string> joining_dave_users = j["d"]["user_ids"];

dave_mls_user_list.insert(joining_dave_users.begin(), joining_dave_users.end());

/* Remove this user from pending remove list if exist */
for (const auto &user : joining_dave_users) {
dave_mls_pending_remove_list.erase(user);
}

log(ll_debug, "New of clients in voice channel: " + std::to_string(joining_dave_users.size()) + " total is " + std::to_string(dave_mls_user_list.size()));
}
break;
case voice_client_dave_mls_invalid_commit_welcome: {
Expand All @@ -221,8 +237,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
log(ll_debug, "voice_client_dave_execute_transition");
this->mls_state->transition_id = j["d"]["transition_id"];

if (this->mls_state->pending_transition.is_pending) {
this->execute_pending_upgrade_downgrade();
if (this->mls_state->pending_transition.is_pending && this->execute_pending_upgrade_downgrade()) {
break;
}

Expand All @@ -236,7 +251,9 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
case voice_client_dave_prepare_transition: {
this->mls_state->transition_id = j["d"]["transition_id"];
uint64_t protocol_version = j["d"]["protocol_version"];

this->mls_state->pending_transition = {this->mls_state->transition_id, protocol_version, true};

log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(this->mls_state->transition_id));

if (this->mls_state->transition_id == 0) {
Expand All @@ -257,9 +274,12 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
break;
case voice_client_dave_prepare_epoch: {
uint64_t protocol_version = j["d"]["protocol_version"];
dave_current_epoch = j["d"]["epoch"];
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(dave_current_epoch));
if (dave_current_epoch == 1) {
uint32_t epoch = j["d"]["epoch"];
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch));
if (epoch == 1) {
/* An epoch 1 is the start of new dave session, update dave_version */
dave_version = protocol_version == 1 ? dave_version_1 : dave_version_none;

this->reinit_dave_mls_group();
}
}
Expand All @@ -277,6 +297,9 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
ssrc_map.erase(it);
}

/* Mark this user for remove on immediate upgrade */
dave_mls_pending_remove_list.insert(u_id.str());

if (!creator->on_voice_client_disconnect.empty()) {
voice_client_disconnect_t vcd(nullptr, data);
vcd.voice_client = this;
Expand Down Expand Up @@ -315,9 +338,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
this->heartbeat_interval = j["d"]["heartbeat_interval"].get<uint32_t>();
}

/* Reset receive_sequence and dave_current_epoch on HELLO */
/* Reset receive_sequence on HELLO */
receive_sequence = -1;
dave_current_epoch = 0;

if (!modes.empty()) {
log(dpp::ll_debug, "Resuming voice session " + this->sessionid + "...");
Expand Down Expand Up @@ -473,6 +495,11 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
return true;
}


/*
* Handle DAVE frame utilities.
*/

void discord_voice_client::ready_for_transition(const std::string &data) {
log(ll_debug, "Ready to execute transition " + std::to_string(this->mls_state->transition_id));
json obj = {
Expand All @@ -485,6 +512,7 @@ void discord_voice_client::ready_for_transition(const std::string &data) {
}
};
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
this->mls_state->pending_transition.id = this->mls_state->transition_id;

/* When the included transition ID is 0, the transition is for (re)initialization and it can be executed immediately. */
if (this->mls_state->transition_id == 0) {
Expand All @@ -505,8 +533,6 @@ void discord_voice_client::ready_for_transition(const std::string &data) {
}

void discord_voice_client::recover_from_invalid_commit_welcome() {
mls_state->dave_session->reset();

json obj = {
{"op", voice_client_dave_mls_invalid_commit_welcome},
{
Expand All @@ -519,21 +545,24 @@ void discord_voice_client::recover_from_invalid_commit_welcome() {
this->reinit_dave_mls_group();
}

void discord_voice_client::execute_pending_upgrade_downgrade() {
bool discord_voice_client::execute_pending_upgrade_downgrade() {
bool did_upgrade_downgrade = false;

if (this->mls_state->transition_id != this->mls_state->pending_transition.id) {
log(ll_debug, "voice_client_dave_execute_transition unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->transition_id));
log(ll_debug, "execute_pending_upgrade_downgrade unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->transition_id));
} else if (dave_version != this->mls_state->pending_transition.protocol_version) {
dave_version = this->mls_state->pending_transition.protocol_version == 1 ? dave_version_1 : dave_version_none;

if (this->mls_state->pending_transition.protocol_version != 0 && dave_version == dave_version_none) {
log(ll_debug, "voice_client_dave_execute_transition unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->transition_id));
log(ll_debug, "execute_pending_upgrade_downgrade unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->transition_id));
} else {
log(ll_debug, "execute_pending_upgrade_downgrade upgrade/downgrade successful");
did_upgrade_downgrade = true;
}

this->mls_state->privacy_code.clear();
this->dave_mls_user_list.clear();
}

this->mls_state->pending_transition.is_pending = false;
return did_upgrade_downgrade;
}

void discord_voice_client::reinit_dave_mls_group() {
Expand All @@ -542,11 +571,24 @@ void discord_voice_client::reinit_dave_mls_group() {
auto key_response = mls_state->dave_session->get_marshalled_key_package();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);

mls_state->encryptor = std::make_unique<dave::encryptor>(*creator);
mls_state->decryptors.clear();

mls_state->cached_roster_map.clear();

mls_state->privacy_code.clear();

/* Remove any user in pending remove from MLS member list */
for (const auto &user : dave_mls_pending_remove_list) {
dave_mls_user_list.erase(user);
}
dave_mls_pending_remove_list.clear();
}

void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rmap) {
log(ll_debug, "process_mls_group_rosters of size: " + std::to_string(rmap.size()));

for (const auto &[k, v] : rmap) {
bool user_has_key = !v.empty();

Expand All @@ -567,12 +609,13 @@ void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rma
}

dpp::snowflake u_id(k);
log(ll_debug, "Removed user from MLS Group: " + u_id.str());
auto u_id_str = u_id.str();

log(ll_debug, "Removed user from MLS Group: " + u_id_str);

dave_mls_user_list.erase(u_id_str);
dave_mls_pending_remove_list.erase(u_id_str);

auto it_dave = dave_mls_user_list.find(u_id.str());
if (it_dave != dave_mls_user_list.end()) {
dave_mls_user_list.erase(it_dave);
}
/* Remove this user's key ratchet */
mls_state->decryptors.erase(u_id);
}
Expand Down

0 comments on commit bc7dc97

Please sign in to comment.