From 83a86f645a5eb74564b6f026eee8a5de960bffba Mon Sep 17 00:00:00 2001 From: "Craig Edwards (Brain)" Date: Wed, 23 Oct 2024 07:23:38 +0100 Subject: [PATCH] refactor: dont use strings and uint64_t in dave, use snowflake type (#1307) --- include/dpp/discordvoiceclient.h | 4 +- src/dpp/dave/session.cpp | 51 +++++++++++--------------- src/dpp/dave/session.h | 26 +++++++------ src/dpp/voice/enabled/handle_frame.cpp | 28 +++++++------- 4 files changed, 50 insertions(+), 59 deletions(-) diff --git a/include/dpp/discordvoiceclient.h b/include/dpp/discordvoiceclient.h index ceb9d92829..ad0c5e93fa 100644 --- a/include/dpp/discordvoiceclient.h +++ b/include/dpp/discordvoiceclient.h @@ -487,13 +487,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client * @brief The list of users that have E2EE potentially enabled for * DAVE protocol. */ - std::set dave_mls_user_list; + std::set dave_mls_user_list; /** * @brief The list of users that have left the voice channel but * not yet removed from MLS group. */ - std::set dave_mls_pending_remove_list; + std::set dave_mls_pending_remove_list; /** * @brief File descriptor for UDP connection diff --git a/src/dpp/dave/session.cpp b/src/dpp/dave/session.cpp index 0bb7f10568..8fa34409c8 100755 --- a/src/dpp/dave/session.cpp +++ b/src/dpp/dave/session.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include #include "mls_key_ratchet.h" @@ -50,7 +52,7 @@ struct queued_proposal { ::mlspp::bytes_ns::bytes ref; }; -session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept +session::session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept : signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster) { creator.log(dpp::ll_debug, "Creating a new MLS session"); @@ -58,12 +60,12 @@ session::session(dpp::cluster& cluster, key_pair_context_type context, const std session::~session() noexcept = default; -void session::init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept { +void session::init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept { reset(); bot_user_id = self_user_id; - creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(group_id)); + creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + group_id.str()); session_protocol_version = version; session_group_id = std::move(big_endian_bytes_from(group_id).as_vec()); @@ -123,7 +125,7 @@ catch (const std::exception& e) { return; } -std::optional> session::process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept +std::optional> session::process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept try { if (!pending_group_state && !current_state) { creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state"); @@ -183,9 +185,7 @@ try { for (const auto& proposal_message : messages) { auto validated_content = state_with_proposals->unwrap(proposal_message); - if (!validate_proposal_message(validated_content.authenticated_content(), - *state_with_proposals, - recognised_user_ids)) { + if (!validate_proposal_message(validated_content.authenticated_content(), *state_with_proposals, recognised_user_ids)) { return std::nullopt; } @@ -238,9 +238,9 @@ catch (const std::exception& e) { return std::nullopt; } -bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const +bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const { - std::string uid = user_credential_to_string(cred, session_protocol_version); + dpp::snowflake uid(user_credential_to_string(cred, session_protocol_version)); if (uid.empty()) { creator.log(dpp::ll_warning, "Attempted to verify credential of unexpected type"); return false; @@ -254,7 +254,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const { +bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set const& recognised_user_ids) const { if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage"); TRACK_MLS_ERROR("Invalid proposal wire format"); @@ -357,7 +357,7 @@ catch (const std::exception& e) { return failed_t{}; } -std::optional session::process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept +std::optional session::process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept try { if (!has_cryptographic_state_for_welcome()) { creator.log(dpp::ll_warning, "Missing local crypto state necessary to process MLS welcome"); @@ -461,7 +461,7 @@ bool session::has_cryptographic_state_for_welcome() const noexcept return join_key_package && join_init_private_key && signature_private_key && hpke_private_key; } -bool session::verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const +bool session::verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const { if (!mls_external_sender) { creator.log(dpp::ll_warning, "Cannot verify MLS welcome without an external sender"); @@ -502,13 +502,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, std::set& transient_key) noexcept +void session::init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept try { auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version); if (!transient_key) { if (!signing_key_id.empty()) { - transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id, session_protocol_version); + transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id.str(), session_protocol_version); if (!transient_key) { creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting"); return; @@ -522,7 +522,7 @@ try { signature_private_key = transient_key; - auto self_credential = create_user_credential(self_user_id, session_protocol_version); + auto self_credential = create_user_credential(self_user_id.str(), session_protocol_version); hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); self_leaf_node = std::make_unique<::mlspp::LeafNode>( ciphersuite, hpke_private_key->public_key, signature_private_key->public_key, std::move(self_credential), @@ -608,7 +608,7 @@ catch (const std::exception& e) { return {}; } -std::unique_ptr session::get_key_ratchet(std::string const& user_id) const noexcept +std::unique_ptr session::get_key_ratchet(dpp::snowflake user_id) const noexcept { if (!current_state) { creator.log(dpp::ll_warning, "Cannot get key ratchet without an established MLS group"); @@ -617,7 +617,7 @@ std::unique_ptr session::get_key_ratchet(std::string cons // change the string user ID to a little endian 64 bit user ID // TODO: Make this use dpp::snowflake - auto u64_user_id = strtoull(user_id.c_str(), nullptr, 10); + uint64_t u64_user_id = user_id; auto user_id_bytes = ::mlspp::bytes_ns::bytes(sizeof(u64_user_id)); memcpy(user_id_bytes.data(), &u64_user_id, sizeof(u64_user_id)); @@ -629,14 +629,14 @@ std::unique_ptr session::get_key_ratchet(std::string cons return std::make_unique(creator, current_state->cipher_suite(), std::move(secret)); } -void session::get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept +void session::get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept try { if (!current_state || !signature_private_key) { throw std::invalid_argument("No established MLS group"); } - uint64_t remote_user_id = strtoull(user_id.c_str(), nullptr, 10); - uint64_t self_user_id = strtoull(bot_user_id.c_str(), nullptr, 10); + uint64_t remote_user_id = user_id; + uint64_t self_user_id = bot_user_id; auto it = roster.find(remote_user_id); if (it == roster.end()) { @@ -687,16 +687,7 @@ try { std::vector out(hash_len); - int ret = EVP_PBE_scrypt((const char*)data.data(), - data.size(), - salt, - sizeof(salt), - N, - r, - p, - max_mem, - out.data(), - out.size()); + int ret = EVP_PBE_scrypt((const char*)data.data(), data.size(), salt, sizeof(salt), N, r, p, max_mem, out.data(), out.size()); if (ret == 1) { callback(out); diff --git a/src/dpp/dave/session.h b/src/dpp/dave/session.h index 36e5dabb17..bb94067d05 100755 --- a/src/dpp/dave/session.h +++ b/src/dpp/dave/session.h @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include "persisted_key_pair.h" #include "key_ratchet.h" #include "version.h" @@ -73,7 +75,7 @@ class session { // NOLINT * @param auth_session_id auth session id (set to empty string to use a transient key pair) * @param callback callback for failure */ - session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept; + session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept; /** * @brief Destructor @@ -90,7 +92,7 @@ class session { // NOLINT * @param self_user_id bot's user id * @param transient_key transient private key */ - void init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; + void init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; /** * @brief Reset the session to defaults @@ -129,7 +131,7 @@ class session { // NOLINT * @param recognised_user_ids list of recognised user IDs * @return optional vector to send in reply as commit welcome */ - std::optional> process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept; + std::optional> process_proposals(std::vector proposals, std::set const& recognised_user_ids) noexcept; /** * @brief Process commit message from discord websocket @@ -144,7 +146,7 @@ class session { // NOLINT * @param recognised_user_ids Recognised user ID list * @return roster list of people in the vc */ - std::optional process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept; + std::optional process_welcome(std::vector welcome, std::set const& recognised_user_ids) noexcept; /** * @brief Get the bot user's key package for sending to websocket @@ -157,7 +159,7 @@ class session { // NOLINT * @param user_id User id to get ratchet for * @return The user's key ratchet for use in an encryptor or decryptor */ - [[nodiscard]] std::unique_ptr get_key_ratchet(std::string const& user_id) const noexcept; + [[nodiscard]] std::unique_ptr get_key_ratchet(dpp::snowflake user_id) const noexcept; /** * @brief callback for completion of pairwise fingerprint @@ -172,7 +174,7 @@ class session { // NOLINT * @param user_id User ID to get fingerprint for * @param callback Callback for completion */ - void get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept; + void get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept; private: /** @@ -180,7 +182,7 @@ class session { // NOLINT * @param self_user_id Bot user id * @param transient_key Transient key */ - void init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; + void init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept; /** * @brief Reset join key @@ -204,7 +206,7 @@ class session { // NOLINT * @param recognised_user_ids list of recognised user IDs * @return */ - [[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const; + [[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set const& recognised_user_ids) const; /** * @brief Validate proposals message @@ -213,7 +215,7 @@ class session { // NOLINT * @param recognised_user_ids recognised list of user IDs * @return true if validated */ - [[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set const& recognised_user_ids) const; + [[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set const& recognised_user_ids) const; /** * @brief Verify that welcome state is valid @@ -221,7 +223,7 @@ class session { // NOLINT * @param recognised_user_ids list of recognised user IDs * @return */ - [[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const; + [[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set const& recognised_user_ids) const; /** * @brief Check if can process a commit now @@ -260,12 +262,12 @@ class session { // NOLINT /** * @brief Signing key id */ - std::string signing_key_id; + dpp::snowflake signing_key_id; /** * @brief The bot's user snowflake ID */ - std::string bot_user_id; + dpp::snowflake bot_user_id; /** * @brief The bot's key pair context diff --git a/src/dpp/voice/enabled/handle_frame.cpp b/src/dpp/voice/enabled/handle_frame.cpp index 27d260be9a..2a8faca93c 100644 --- a/src/dpp/voice/enabled/handle_frame.cpp +++ b/src/dpp/voice/enabled/handle_frame.cpp @@ -50,16 +50,15 @@ void discord_voice_client::update_ratchets(bool force) { */ log(ll_debug, "Updating MLS ratchets for " + std::to_string(dave_mls_user_list.size() + 1) + " user(s)"); for (const auto& user : dave_mls_user_list) { - dpp::snowflake u{user}; - if (u == creator->me.id) { + if (user == creator->me.id) { continue; } decryptor_list::iterator decryptor; /* New user join/old user leave - insert new ratchets if they don't exist */ - decryptor = mls_state->decryptors.find(u); + decryptor = mls_state->decryptors.find(user.str()); if (decryptor == mls_state->decryptors.end()) { - log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version())); - auto [iter, inserted] = mls_state->decryptors.emplace(u, std::make_unique(*creator)); + log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user.str() + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version())); + auto [iter, inserted] = mls_state->decryptors.emplace(user.str(), std::make_unique(*creator)); decryptor = iter; } decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY); @@ -72,7 +71,7 @@ void discord_voice_client::update_ratchets(bool force) { if (mls_state->encryptor) { /* Updating key rachet should always be done on execute transition. Generally after group member add/remove. */ log(ll_debug, "Setting key ratchet for sending audio..."); - mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str())); + mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id)); } /** @@ -146,7 +145,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id)); /* We should always recognize our own selves, but do we? */ - dave_mls_user_list.insert(this->creator->me.id.str()); + dave_mls_user_list.insert(this->creator->me.id); auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list); @@ -222,7 +221,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod /* Remove this user from pending remove list if exist */ for (const auto &user : joining_dave_users) { - dave_mls_pending_remove_list.erase(user); + dave_mls_pending_remove_list.erase(dpp::snowflake(user)); } log(ll_debug, "New of clients in voice channel: " + std::to_string(joining_dave_users.size()) + " total is " + std::to_string(dave_mls_user_list.size())); @@ -298,7 +297,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod } /* Mark this user for remove on immediate upgrade */ - dave_mls_pending_remove_list.insert(u_id.str()); + dave_mls_pending_remove_list.insert(u_id); if (!creator->on_voice_client_disconnect.empty()) { voice_client_disconnect_t vcd(nullptr, data); @@ -575,12 +574,12 @@ void discord_voice_client::reinit_dave_mls_group() { if (mls_state->dave_session == nullptr) { mls_state->dave_session = std::make_unique( *creator, - nullptr, "", [this](std::string const &s1, std::string const &s2) { + nullptr, snowflake(), [this](std::string const &s1, std::string const &s2) { log(ll_debug, "DAVE: " + s1 + ", " + s2); }); } - mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key); + mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id, mls_state->mls_key); auto key_response = mls_state->dave_session->get_marshalled_key_package(); key_response.insert(key_response.begin(), voice_client_dave_mls_key_package); @@ -630,12 +629,11 @@ void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rma } dpp::snowflake u_id(k); - auto u_id_str = u_id.str(); - log(ll_debug, "Removed user from MLS Group: " + u_id_str); + log(ll_debug, "Removed user from MLS Group: " + u_id.str()); - dave_mls_user_list.erase(u_id_str); - dave_mls_pending_remove_list.erase(u_id_str); + dave_mls_user_list.erase(u_id); + dave_mls_pending_remove_list.erase(u_id); /* Remove this user's key ratchet */ mls_state->decryptors.erase(u_id);