diff --git a/include/dpp/discordvoiceclient.h b/include/dpp/discordvoiceclient.h index 88bd5f7241..00b75d54cf 100644 --- a/include/dpp/discordvoiceclient.h +++ b/include/dpp/discordvoiceclient.h @@ -166,7 +166,7 @@ enum voice_websocket_opcode_t : uint8_t { voice_client_dave_mls_key_package = 26, voice_client_dave_mls_proposals = 27, voice_client_dave_mls_commit_message = 28, - voice_client_dave_announce_commit_transaction = 29, + voice_client_dave_announce_commit_transition = 29, voice_client_dave_mls_welcome = 30, voice_client_dave_mls_invalid_commit_welcome = 31, }; @@ -206,7 +206,7 @@ struct dave_binary_header_t { [[nodiscard]] std::vector get_data() const; /** - * Get transition ID for process_welcome + * Get transition ID for process_commit and process_welcome * * @return Transition ID */ @@ -215,7 +215,7 @@ struct dave_binary_header_t { private: /** * @brief Transition id, only valid when the opcode is - * welcome state. Use get_transition_id() to obtain value. + * commit and welcome state. Use get_transition_id() to obtain value. */ uint16_t transition_id; }; @@ -482,7 +482,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ std::set dave_mls_new_user_list; - /** + /** * @brief File descriptor for UDP connection */ dpp::socket fd; @@ -583,6 +583,11 @@ class DPP_EXPORT discord_voice_client : public websocket_client */ dave_version_t dave_version; + /** + * @brief Current dave session key generation. + */ + uint32_t dave_current_epoch; + /** * @brief Send data to UDP socket immediately. * @@ -1200,6 +1205,38 @@ class DPP_EXPORT discord_voice_client : public websocket_client * which internally uses scrypt. */ void get_user_privacy_code(const dpp::snowflake user, privacy_code_callback_t callback) const; + + /* + * @brief Notify gateway ready for a DAVE transition. + * + * Fires Voice Ready event when appropriate. + * + * https://daveprotocol.com/#commit-handling + */ + void ready_for_transition(const std::string &data); + + /** + * @brief Reset dave session, send voice_client_dave_mls_invalid_commit_welcome + * payload with current transition Id and our new key package to gateway. + * + * https://daveprotocol.com/#recovery-from-invalid-commit-or-welcome + */ + void recover_from_invalid_commit_welcome(); + + /** + * @brief Execute pending protocol upgrade/downgrade to/from dave. + */ + void execute_pending_upgrade_downgrade(); + + /** + * @brief Reset dave session and prepare initial session group. + */ + void reinit_dave_mls_group(); + + /** + * @brief Process roster map from commit/welcome. + */ + void process_mls_group_rosters(const std::map> &rmap); }; } // namespace dpp diff --git a/src/dpp/discordvoiceclient.cpp b/src/dpp/discordvoiceclient.cpp index ca44b19c06..dc548984e3 100644 --- a/src/dpp/discordvoiceclient.cpp +++ b/src/dpp/discordvoiceclient.cpp @@ -85,8 +85,9 @@ bool discord_voice_client::is_playing() { } uint16_t dave_binary_header_t::get_transition_id() const { - if (opcode != voice_client_dave_mls_welcome) { - throw dpp::logic_exception("Can't get transition ID from buffer that is not of type voice_client_dave_mls_welcome(30)"); + bool has_transition_id = opcode == voice_client_dave_mls_welcome || opcode == voice_client_dave_announce_commit_transition; + if (!has_transition_id) { + throw dpp::logic_exception("Can't get transition ID from buffer that is not of type voice_client_dave_announce_commit_transition(29) or voice_client_dave_mls_welcome(30)"); } return transition_id; } @@ -98,7 +99,9 @@ dave_binary_header_t::dave_binary_header_t(const std::string& buffer) { seq = (buffer[0] << 8) | buffer[1]; opcode = buffer[2]; transition_id = (buffer[3] << 8) | buffer[4]; - package.assign(buffer.begin() + (opcode == voice_client_dave_mls_welcome ? 5 : 3), buffer.end()); + + bool has_transition_id = opcode == voice_client_dave_mls_welcome || opcode == voice_client_dave_announce_commit_transition; + package.assign(buffer.begin() + (has_transition_id ? 5 : 3), buffer.end()); } std::vector dave_binary_header_t::get_data() const { diff --git a/src/dpp/voice/enabled/constructor.cpp b/src/dpp/voice/enabled/constructor.cpp index 1e0a20b8f3..45ab5eae89 100644 --- a/src/dpp/voice/enabled/constructor.cpp +++ b/src/dpp/voice/enabled/constructor.cpp @@ -55,6 +55,7 @@ discord_voice_client::discord_voice_client(dpp::cluster* _cluster, snowflake _ch sending(false), tracks(0), dave_version(enable_dave ? dave_version_1 : dave_version_none), + dave_current_epoch(0), creator(_cluster), terminating(false), heartbeat_interval(0), diff --git a/src/dpp/voice/enabled/enabled.h b/src/dpp/voice/enabled/enabled.h index 996b249492..65025e2e16 100644 --- a/src/dpp/voice/enabled/enabled.h +++ b/src/dpp/voice/enabled/enabled.h @@ -95,10 +95,6 @@ struct dave_state { * @brief Current transition ID */ uint64_t transition_id{0}; - /** - * @brief True when we have a sending ratchet - */ - bool have_sending_ratchet{false}; /** * @brief Have sent ready event to listeners */ @@ -133,6 +129,11 @@ struct dave_state { * MLS group is not established. */ std::string privacy_code; + + /** + * @brief Cached roster map. + */ + dave::roster_map cached_roster_map; }; /** diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index 0bcc049862..10d21eac5f 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -25,7 +25,6 @@ #include #include #include -#include #include "../../dave/encryptor.h" #include "enabled.h" @@ -66,10 +65,10 @@ void discord_voice_client::update_ratchets(bool force) { decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY); } /* No expiry on sender! It's up to the receiver to decide when to discard their old keys */ - if (mls_state->encryptor && !mls_state->have_sending_ratchet) { + if (mls_state->encryptor // && !mls_state->encryptor->has_key_ratchet() + ) { log(ll_debug, "Setting key ratchet for sending audio..."); mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str())); - mls_state->have_sending_ratchet = true; } /** @@ -96,66 +95,67 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod dave_binary_header_t dave_header(data); + /* These binaries also contains sequence number we need to save */ + receive_sequence = dave_header.seq; + switch (dave_header.opcode) { case voice_client_dave_mls_external_sender: { log(ll_debug, "voice_client_dave_mls_external_sender"); mls_state->dave_session->set_external_sender(dave_header.get_data()); - mls_state->encryptor = std::make_unique(*creator); - mls_state->decryptors.clear(); } break; case voice_client_dave_mls_proposals: { log(ll_debug, "voice_client_dave_mls_proposals"); - std::optional> response = mls_state->dave_session->process_proposals(dave_header.get_data(), dave_mls_new_user_list); + std::optional> response = mls_state->dave_session->process_proposals(dave_header.get_data(), dave_mls_user_list); if (response.has_value()) { auto r = response.value(); - mls_state->cached_commit = r; r.insert(r.begin(), voice_client_dave_mls_commit_message); this->write(std::string_view(reinterpret_cast(r.data()), r.size()), OP_BINARY); } } break; - case voice_client_dave_announce_commit_transaction: { - log(ll_debug, "voice_client_dave_announce_commit_transaction"); - auto r = mls_state->dave_session->process_commit(mls_state->cached_commit); - update_ratchets(!mls_state->have_sending_ratchet); - if (!creator->on_voice_ready.empty() && !mls_state->done_ready) { - voice_ready_t rdy(nullptr, data); - rdy.voice_client = this; - rdy.voice_channel_id = this->channel_id; - creator->on_voice_ready.call(rdy); - mls_state->done_ready = true; + case voice_client_dave_announce_commit_transition: { + this->mls_state->transition_id = dave_header.get_transition_id(); + log(ll_debug, "voice_client_dave_announce_commit_transition"); + auto r = mls_state->dave_session->process_commit(dave_header.get_data()); + + /* + * We need to do recovery here when we failed processing the message + */ + if (!std::holds_alternative(r)) { + log(ll_debug, "Unable to process commit in transition " + std::to_string(this->mls_state->transition_id)); + + this->recover_from_invalid_commit_welcome(); + break; } + + auto rmap = std::get(r); + this->process_mls_group_rosters(rmap); + + this->ready_for_transition(data); } break; case voice_client_dave_mls_welcome: { this->mls_state->transition_id = dave_header.get_transition_id(); log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id)); - dave_mls_user_list.erase(creator->me.id.str()); + // dave_mls_user_list.erase(creator->me.id.str()); auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list); - update_ratchets(); - log(ll_debug, "Ready to execute transition " + std::to_string(this->mls_state->transition_id)); - json obj = { - { "op", voice_client_dave_transition_ready }, - { - "d", - { - { "transition_id", this->mls_state->transition_id }, - } - } - }; - this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); - if (!mls_state->done_ready) { - if (!creator->on_voice_ready.empty()) { - voice_ready_t rdy(nullptr, data); - rdy.voice_client = this; - rdy.voice_channel_id = this->channel_id; - creator->on_voice_ready.call(rdy); - mls_state->done_ready = true; - } + + /* + * We need to do recovery here when we failed processing the message + */ + if (!r) { + log(ll_debug, "Unable to process welcome in transition " + std::to_string(this->mls_state->transition_id)); + + this->recover_from_invalid_commit_welcome(); + break; } + + this->process_mls_group_rosters(r.value()); + + this->ready_for_transition(data); } break; default: @@ -222,53 +222,45 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->mls_state->transition_id = j["d"]["transition_id"]; if (this->mls_state->pending_transition.is_pending) { - if (this->mls_state->transition_id != this->mls_state->pending_transition.id) { - log(ll_debug, "voice_client_dave_execute_transition unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->pending_transition.id)); - } else { - dave_version = this->mls_state->pending_transition.protocol_version == 1 ? dave_version_1 : dave_version_none; - - if (this->mls_state->pending_transition.protocol_version != 0 && dave_version == dave_version_none) { - log(ll_debug, "voice_client_dave_execute_transition unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->pending_transition.id)); - } - - this->mls_state->privacy_code.clear(); - this->dave_mls_user_list.clear(); - - this->mls_state->pending_transition.is_pending = false; - } + this->execute_pending_upgrade_downgrade(); + break; } + + /* + * Execute transition from a commit/welcome message. + */ + update_ratchets(); } break; /* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */ case voice_client_dave_prepare_transition: { - uint64_t transition_id = j["d"]["transition_id"]; + this->mls_state->transition_id = j["d"]["transition_id"]; uint64_t protocol_version = j["d"]["protocol_version"]; - this->mls_state->pending_transition = {transition_id, protocol_version, true}; - log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id)); + this->mls_state->pending_transition = {this->mls_state->transition_id, protocol_version, true}; + log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(this->mls_state->transition_id)); - json obj = { - { "op", voice_client_dave_transition_ready }, - { - "d", + if (this->mls_state->transition_id == 0) { + this->execute_pending_upgrade_downgrade(); + } else { + json obj = { + { "op", voice_client_dave_transition_ready }, { - { "transition_id", this->mls_state->transition_id }, + "d", + { + { "transition_id", this->mls_state->transition_id }, + } } - } - }; - this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); + }; + this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); + } } break; case voice_client_dave_prepare_epoch: { uint64_t protocol_version = j["d"]["protocol_version"]; - uint64_t epoch = j["d"]["epoch"]; - log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch)); - if (epoch == 1) { - mls_state->dave_session->reset(); - mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); - auto key_response = mls_state->dave_session->get_marshalled_key_package(); - key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); - this->write(std::string_view(reinterpret_cast(key_response.data()), key_response.size()), OP_BINARY); - mls_state->decryptors.clear(); + dave_current_epoch = j["d"]["epoch"]; + log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(dave_current_epoch)); + if (dave_current_epoch == 1) { + this->reinit_dave_mls_group(); } } break; @@ -285,14 +277,6 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod ssrc_map.erase(it); } - auto it_dave = dave_mls_user_list.find(j["d"]["user_id"]); - if (it_dave != dave_mls_user_list.end()) { - /* Ongoing MLS session, update everyone's ratchets and VC privacy code */ - dave_mls_user_list.erase(it_dave); - } - /* Remove this user's key ratchet */ - mls_state->decryptors.erase(u_id); - if (!creator->on_voice_client_disconnect.empty()) { voice_client_disconnect_t vcd(nullptr, data); vcd.voice_client = this; @@ -331,8 +315,9 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod this->heartbeat_interval = j["d"]["heartbeat_interval"].get(); } - /* Reset receive_sequence on HELLO */ + /* Reset receive_sequence and dave_current_epoch on HELLO */ receive_sequence = -1; + dave_current_epoch = 0; if (!modes.empty()) { log(dpp::ll_debug, "Resuming voice session " + this->sessionid + "..."); @@ -402,11 +387,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod nullptr, "", [this](std::string const& s1, std::string const& s2) { log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2); }); - mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); } - auto key_response = mls_state->dave_session->get_marshalled_key_package(); - key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); - this->write(std::string_view(reinterpret_cast(key_response.data()), key_response.size()), OP_BINARY); + this->reinit_dave_mls_group(); } if (ready_now) { @@ -491,5 +473,113 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod return true; } +void discord_voice_client::ready_for_transition(const std::string &data) { + log(ll_debug, "Ready to execute transition " + std::to_string(this->mls_state->transition_id)); + json obj = { + { "op", voice_client_dave_transition_ready }, + { + "d", + { + { "transition_id", this->mls_state->transition_id }, + } + } + }; + this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); + + /* When the included transition ID is 0, the transition is for (re)initialization and it can be executed immediately. */ + if (this->mls_state->transition_id == 0) { + /* Mark state ready and update rachets the first time */ + update_ratchets(); + } + + if (!this->mls_state->done_ready) { + this->mls_state->done_ready = true; + + if (!creator->on_voice_ready.empty()) { + voice_ready_t rdy(nullptr, data); + rdy.voice_client = this; + rdy.voice_channel_id = this->channel_id; + creator->on_voice_ready.call(rdy); + } + } +} + +void discord_voice_client::recover_from_invalid_commit_welcome() { + mls_state->dave_session->reset(); + + json obj = { + {"op", voice_client_dave_mls_invalid_commit_welcome}, + { + "d", { + "transition_id", this->mls_state->transition_id + } + } + }; + this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT); + this->reinit_dave_mls_group(); +} + +void discord_voice_client::execute_pending_upgrade_downgrade() { + if (this->mls_state->transition_id != this->mls_state->pending_transition.id) { + log(ll_debug, "voice_client_dave_execute_transition unexpected transition_id, we never received voice_client_dave_prepare_transition event with this id: " + std::to_string(this->mls_state->transition_id)); + } else if (dave_version != this->mls_state->pending_transition.protocol_version) { + dave_version = this->mls_state->pending_transition.protocol_version == 1 ? dave_version_1 : dave_version_none; + + if (this->mls_state->pending_transition.protocol_version != 0 && dave_version == dave_version_none) { + log(ll_debug, "voice_client_dave_execute_transition unexpected protocol version: " + std::to_string(this->mls_state->pending_transition.protocol_version)+ " in transition " + std::to_string(this->mls_state->transition_id)); + } + + this->mls_state->privacy_code.clear(); + this->dave_mls_user_list.clear(); + } + this->mls_state->pending_transition.is_pending = false; } + +void discord_voice_client::reinit_dave_mls_group() { + mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); + + auto key_response = mls_state->dave_session->get_marshalled_key_package(); + key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); + this->write(std::string_view(reinterpret_cast(key_response.data()), key_response.size()), OP_BINARY); + mls_state->encryptor = std::make_unique(*creator); + mls_state->decryptors.clear(); +} + +void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rmap) { + for (const auto &[k, v] : rmap) { + bool user_has_key = !v.empty(); + + /* Debug log for changed and added keys */ + auto cached_user = mls_state->cached_roster_map.find(k); + if (cached_user == mls_state->cached_roster_map.end()) { + log(ll_debug, "Added user to MLS Group: " + std::to_string(k)); + } else if (user_has_key && cached_user->second != v) { + log(ll_debug, "Changed user key in MLS Group: " + std::to_string(k)); + } + + /* + * Remove user from recognized list. + * Do not remove user with non-empty key. + */ + if (user_has_key) { + continue; + } + + dpp::snowflake u_id(k); + log(ll_debug, "Removed user from MLS Group: " + u_id.str()); + + auto it_dave = dave_mls_user_list.find(u_id.str()); + if (it_dave != dave_mls_user_list.end()) { + dave_mls_user_list.erase(it_dave); + } + /* Remove this user's key ratchet */ + mls_state->decryptors.erase(u_id); + } + + mls_state->cached_roster_map = rmap; +} + +} + +// vim: sw=4 noet diff --git a/src/dpp/voice/enabled/opus.cpp b/src/dpp/voice/enabled/opus.cpp index e57fcb69f4..2834624b77 100644 --- a/src/dpp/voice/enabled/opus.cpp +++ b/src/dpp/voice/enabled/opus.cpp @@ -106,7 +106,7 @@ discord_voice_client& discord_voice_client::send_audio_opus(uint8_t* opus_packet log(ll_warning, "DAVE Encryption failure: " + std::to_string(result)); } else { encoded_audio = encrypted_buffer; - encoded_audio_length = encoded_audio.size(); + encoded_audio_length = encrypted_buffer.size(); } }