diff --git a/src/dpp/discordvoiceclient.cpp b/src/dpp/discordvoiceclient.cpp index 046c8c784c..16f7e05f3f 100644 --- a/src/dpp/discordvoiceclient.cpp +++ b/src/dpp/discordvoiceclient.cpp @@ -125,7 +125,13 @@ void discord_voice_client::get_user_privacy_code(const dpp::snowflake user, priv bool discord_voice_client::is_end_to_end_encrypted() const { #ifdef HAVE_VOICE - return mls_state && !mls_state->privacy_code.empty(); + if (mls_state == nullptr) { + return false; + } + + bool has_pending_downgrade = mls_state->pending_transition.is_pending && mls_state->pending_transition.protocol_version != dave_version_1; + + return !has_pending_downgrade && !mls_state->privacy_code.empty(); #else return false; #endif diff --git a/src/dpp/voice/enabled/enabled.h b/src/dpp/voice/enabled/enabled.h index 9daf86dc55..78b55958d5 100644 --- a/src/dpp/voice/enabled/enabled.h +++ b/src/dpp/voice/enabled/enabled.h @@ -75,6 +75,11 @@ struct dave_state { std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key; std::vector cached_commit; uint64_t transition_id{0}; + struct { + uint64_t id{0}; + uint64_t protocol_version{0}; + bool is_pending{false}; + } pending_transition; std::map> decryptors; std::unique_ptr encryptor; std::string privacy_code; diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index a0683ece2e..86ae682cc5 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -107,12 +107,16 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->get_last_epoch_authenticator()); log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code); - if (!creator->on_voice_ready.empty()) { - voice_ready_t rdy(nullptr, data); - rdy.voice_client = this; - rdy.voice_channel_id = this->channel_id; - creator->on_voice_ready.call(rdy); - } + json obj = { + { "op", voice_client_dave_transition_ready }, + { + "d", + { + { "transition_id", this->mls_state->transition_id }, + } + } + }; + this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); } break; default: @@ -176,6 +180,32 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod case voice_client_dave_execute_transition: { log(ll_debug, "voice_client_dave_execute_transition"); this->mls_state->transition_id = j["d"]["transition_id"]; + + if (this->mls_state->pending_transition.is_pending) { + if (this->mls_state->transition_id != this->mls_state->pending_transition.id) { + log(ll_debug, "voice_client_dave_execute_transition unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->pending_transition.id)); + } else { + dave_version = this->mls_state->pending_transition.protocol_version == 1 ? dave_version_1 : dave_version_none; + + if (this->mls_state->pending_transition.protocol_version != 0 && dave_version == dave_version_none) { + log(ll_debug, "voice_client_dave_execute_transition unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->pending_transition.id)); + } + + this->mls_state->privacy_code.clear(); + this->dave_mls_user_list.clear(); + + this->mls_state->pending_transition.is_pending = false; + } + } + } + break; + /* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */ + case voice_client_dave_prepare_transition: { + uint64_t transition_id = j["d"]["transition_id"]; + uint64_t protocol_version = j["d"]["protocol_version"]; + this->mls_state->pending_transition = {transition_id, protocol_version, true}; + log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id)); + json obj = { { "op", voice_client_dave_transition_ready }, { @@ -188,13 +218,6 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); } break; - /* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */ - case voice_client_dave_prepare_transition: { - uint64_t transition_id = j["d"]["transition_id"]; - uint64_t protocol_version = j["d"]["protocol_version"]; - log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id)); - } - break; case voice_client_dave_prepare_epoch: { uint64_t protocol_version = j["d"]["protocol_version"]; uint64_t epoch = j["d"]["epoch"]; @@ -202,6 +225,9 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod if (epoch == 1) { mls_state->dave_session->reset(); mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); + auto key_response = mls_state->dave_session->get_marshalled_key_package(); + key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); + this->write(std::string_view(reinterpret_cast(key_response.data()), key_response.size()), OP_BINARY); } } break; @@ -216,6 +242,11 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod ssrc_map.erase(it); } + auto it_dave = dave_mls_user_list.find(j["d"]["user_id"]); + if (it_dave != dave_mls_user_list.end()) { + dave_mls_user_list.erase(it_dave); + } + if (!creator->on_voice_client_disconnect.empty()) { voice_client_disconnect_t vcd(nullptr, data); vcd.voice_client = this; @@ -309,24 +340,31 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod /* Reset packet_nonce */ packet_nonce = 1; + bool ready_now = false; + if (dave_version != dave_version_none) { if (j["d"]["dave_protocol_version"] != static_cast(dave_version)) { log(ll_error, "We requested DAVE E2EE but didn't receive it from the server, downgrading..."); dave_version = dave_version_none; - send_silence(20); + ready_now = true; } - mls_state = std::make_unique(); - mls_state->dave_session = std::make_unique( - nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) { - log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2); - }); - mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); + if (mls_state == nullptr) { + mls_state = std::make_unique(); + } + if (mls_state->dave_session == nullptr) { + mls_state->dave_session = std::make_unique( + nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) { + log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2); + }); + mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); + } auto key_response = mls_state->dave_session->get_marshalled_key_package(); key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); this->write(std::string_view(reinterpret_cast(key_response.data()), key_response.size()), OP_BINARY); + } - } else { + if (ready_now) { /* This is needed to start voice receiving and make sure that the start of sending isn't cut off */ send_silence(20); /* Fire on_voice_ready */ @@ -409,4 +447,4 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod } -} \ No newline at end of file +}