Skip to content

Commit

Permalink
Feat/dave downgrade-upgrade (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis authored Oct 6, 2024
2 parents e7cb773 + 7aeda79 commit 856ee3b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/dpp/discordvoiceclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ void discord_voice_client::get_user_privacy_code(const dpp::snowflake user, priv

bool discord_voice_client::is_end_to_end_encrypted() const {
#ifdef HAVE_VOICE
return mls_state && !mls_state->privacy_code.empty();
if (mls_state == nullptr) {
return false;
}

bool has_pending_downgrade = mls_state->pending_transition.is_pending && mls_state->pending_transition.protocol_version != dave_version_1;

return !has_pending_downgrade && !mls_state->privacy_code.empty();
#else
return false;
#endif
Expand Down
5 changes: 5 additions & 0 deletions src/dpp/voice/enabled/enabled.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ struct dave_state {
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
std::vector<uint8_t> cached_commit;
uint64_t transition_id{0};
struct {
uint64_t id{0};
uint64_t protocol_version{0};
bool is_pending{false};
} pending_transition;
std::map<dpp::snowflake, std::unique_ptr<dave::decryptor>> decryptors;
std::unique_ptr<dave::encryptor> encryptor;
std::string privacy_code;
Expand Down
82 changes: 60 additions & 22 deletions src/dpp/voice/enabled/handle_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->get_last_epoch_authenticator());
log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code);

if (!creator->on_voice_ready.empty()) {
voice_ready_t rdy(nullptr, data);
rdy.voice_client = this;
rdy.voice_channel_id = this->channel_id;
creator->on_voice_ready.call(rdy);
}
json obj = {
{ "op", voice_client_dave_transition_ready },
{
"d",
{
{ "transition_id", this->mls_state->transition_id },
}
}
};
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
}
break;
default:
Expand Down Expand Up @@ -176,6 +180,32 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
case voice_client_dave_execute_transition: {
log(ll_debug, "voice_client_dave_execute_transition");
this->mls_state->transition_id = j["d"]["transition_id"];

if (this->mls_state->pending_transition.is_pending) {
if (this->mls_state->transition_id != this->mls_state->pending_transition.id) {
log(ll_debug, "voice_client_dave_execute_transition unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->pending_transition.id));
} else {
dave_version = this->mls_state->pending_transition.protocol_version == 1 ? dave_version_1 : dave_version_none;

if (this->mls_state->pending_transition.protocol_version != 0 && dave_version == dave_version_none) {
log(ll_debug, "voice_client_dave_execute_transition unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->pending_transition.id));
}

this->mls_state->privacy_code.clear();
this->dave_mls_user_list.clear();

this->mls_state->pending_transition.is_pending = false;
}
}
}
break;
/* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */
case voice_client_dave_prepare_transition: {
uint64_t transition_id = j["d"]["transition_id"];
uint64_t protocol_version = j["d"]["protocol_version"];
this->mls_state->pending_transition = {transition_id, protocol_version, true};
log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id));

json obj = {
{ "op", voice_client_dave_transition_ready },
{
Expand All @@ -188,20 +218,16 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
}
break;
/* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */
case voice_client_dave_prepare_transition: {
uint64_t transition_id = j["d"]["transition_id"];
uint64_t protocol_version = j["d"]["protocol_version"];
log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id));
}
break;
case voice_client_dave_prepare_epoch: {
uint64_t protocol_version = j["d"]["protocol_version"];
uint64_t epoch = j["d"]["epoch"];
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch));
if (epoch == 1) {
mls_state->dave_session->reset();
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
auto key_response = mls_state->dave_session->get_marshalled_key_package();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
}
}
break;
Expand All @@ -216,6 +242,11 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
ssrc_map.erase(it);
}

auto it_dave = dave_mls_user_list.find(j["d"]["user_id"]);
if (it_dave != dave_mls_user_list.end()) {
dave_mls_user_list.erase(it_dave);
}

if (!creator->on_voice_client_disconnect.empty()) {
voice_client_disconnect_t vcd(nullptr, data);
vcd.voice_client = this;
Expand Down Expand Up @@ -309,24 +340,31 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
/* Reset packet_nonce */
packet_nonce = 1;

bool ready_now = false;

if (dave_version != dave_version_none) {
if (j["d"]["dave_protocol_version"] != static_cast<uint32_t>(dave_version)) {
log(ll_error, "We requested DAVE E2EE but didn't receive it from the server, downgrading...");
dave_version = dave_version_none;
send_silence(20);
ready_now = true;
}

mls_state = std::make_unique<dave_state>();
mls_state->dave_session = std::make_unique<dave::mls::session>(
nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) {
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
});
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
if (mls_state == nullptr) {
mls_state = std::make_unique<dave_state>();
}
if (mls_state->dave_session == nullptr) {
mls_state->dave_session = std::make_unique<dave::mls::session>(
nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) {
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
});
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
}
auto key_response = mls_state->dave_session->get_marshalled_key_package();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
}

} else {
if (ready_now) {
/* This is needed to start voice receiving and make sure that the start of sending isn't cut off */
send_silence(20);
/* Fire on_voice_ready */
Expand Down Expand Up @@ -409,4 +447,4 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
}


}
}

0 comments on commit 856ee3b

Please sign in to comment.