Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: dont use strings and uint64_t in dave, use snowflake type #1307

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client
* @brief The list of users that have E2EE potentially enabled for
* DAVE protocol.
*/
std::set<std::string> dave_mls_user_list;
std::set<dpp::snowflake> dave_mls_user_list;

/**
* @brief The list of users that have left the voice channel but
* not yet removed from MLS group.
*/
std::set<std::string> dave_mls_pending_remove_list;
std::set<dpp::snowflake> dave_mls_pending_remove_list;

/**
* @brief File descriptor for UDP connection
Expand Down
51 changes: 21 additions & 30 deletions src/dpp/dave/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <iostream>
#include <mls/crypto.h>
#include <mls/messages.h>
#include <dpp/export.h>
#include <dpp/snowflake.h>
#include <mls/state.h>
#include <dpp/cluster.h>
#include "mls_key_ratchet.h"
Expand All @@ -50,20 +52,20 @@ struct queued_proposal {
::mlspp::bytes_ns::bytes ref;
};

session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept
session::session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept
: signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster)
{
creator.log(dpp::ll_debug, "Creating a new MLS session");
}

session::~session() noexcept = default;

void session::init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
void session::init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
reset();

bot_user_id = self_user_id;

creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(group_id));
creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + group_id.str());
session_protocol_version = version;
session_group_id = std::move(big_endian_bytes_from(group_id).as_vec());

Expand Down Expand Up @@ -123,7 +125,7 @@ catch (const std::exception& e) {
return;
}

std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept
std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
try {
if (!pending_group_state && !current_state) {
creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state");
Expand Down Expand Up @@ -183,9 +185,7 @@ try {
for (const auto& proposal_message : messages) {
auto validated_content = state_with_proposals->unwrap(proposal_message);

if (!validate_proposal_message(validated_content.authenticated_content(),
*state_with_proposals,
recognised_user_ids)) {
if (!validate_proposal_message(validated_content.authenticated_content(), *state_with_proposals, recognised_user_ids)) {
return std::nullopt;
}

Expand Down Expand Up @@ -238,9 +238,9 @@ catch (const std::exception& e) {
return std::nullopt;
}

bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const
bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const
{
std::string uid = user_credential_to_string(cred, session_protocol_version);
dpp::snowflake uid(user_credential_to_string(cred, session_protocol_version));
if (uid.empty()) {
creator.log(dpp::ll_warning, "Attempted to verify credential of unexpected type");
return false;
Expand All @@ -254,7 +254,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<st
return true;
}

bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const {
bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const {
if (message.wire_format != ::mlspp::WireFormat::mls_public_message) {
creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage");
TRACK_MLS_ERROR("Invalid proposal wire format");
Expand Down Expand Up @@ -357,7 +357,7 @@ catch (const std::exception& e) {
return failed_t{};
}

std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept
std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
try {
if (!has_cryptographic_state_for_welcome()) {
creator.log(dpp::ll_warning, "Missing local crypto state necessary to process MLS welcome");
Expand Down Expand Up @@ -461,7 +461,7 @@ bool session::has_cryptographic_state_for_welcome() const noexcept
return join_key_package && join_init_private_key && signature_private_key && hpke_private_key;
}

bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const
bool session::verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const
{
if (!mls_external_sender) {
creator.log(dpp::ll_warning, "Cannot verify MLS welcome without an external sender");
Expand Down Expand Up @@ -502,13 +502,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::st
return true;
}

void session::init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
void session::init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
try {
auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version);

if (!transient_key) {
if (!signing_key_id.empty()) {
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id, session_protocol_version);
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id.str(), session_protocol_version);
if (!transient_key) {
creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting");
return;
Expand All @@ -522,7 +522,7 @@ try {

signature_private_key = transient_key;

auto self_credential = create_user_credential(self_user_id, session_protocol_version);
auto self_credential = create_user_credential(self_user_id.str(), session_protocol_version);
hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite));
self_leaf_node = std::make_unique<::mlspp::LeafNode>(
ciphersuite, hpke_private_key->public_key, signature_private_key->public_key, std::move(self_credential),
Expand Down Expand Up @@ -608,7 +608,7 @@ catch (const std::exception& e) {
return {};
}

std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string const& user_id) const noexcept
std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(dpp::snowflake user_id) const noexcept
{
if (!current_state) {
creator.log(dpp::ll_warning, "Cannot get key ratchet without an established MLS group");
Expand All @@ -617,7 +617,7 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons

// change the string user ID to a little endian 64 bit user ID
// TODO: Make this use dpp::snowflake
auto u64_user_id = strtoull(user_id.c_str(), nullptr, 10);
uint64_t u64_user_id = user_id;
auto user_id_bytes = ::mlspp::bytes_ns::bytes(sizeof(u64_user_id));
memcpy(user_id_bytes.data(), &u64_user_id, sizeof(u64_user_id));

Expand All @@ -629,14 +629,14 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
return std::make_unique<mls_key_ratchet>(creator, current_state->cipher_suite(), std::move(secret));
}

void session::get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept
void session::get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept
try {
if (!current_state || !signature_private_key) {
throw std::invalid_argument("No established MLS group");
}

uint64_t remote_user_id = strtoull(user_id.c_str(), nullptr, 10);
uint64_t self_user_id = strtoull(bot_user_id.c_str(), nullptr, 10);
uint64_t remote_user_id = user_id;
uint64_t self_user_id = bot_user_id;

auto it = roster.find(remote_user_id);
if (it == roster.end()) {
Expand Down Expand Up @@ -687,16 +687,7 @@ try {

std::vector<uint8_t> out(hash_len);

int ret = EVP_PBE_scrypt((const char*)data.data(),
data.size(),
salt,
sizeof(salt),
N,
r,
p,
max_mem,
out.data(),
out.size());
int ret = EVP_PBE_scrypt((const char*)data.data(), data.size(), salt, sizeof(salt), N, r, p, max_mem, out.data(), out.size());

if (ret == 1) {
callback(out);
Expand Down
26 changes: 14 additions & 12 deletions src/dpp/dave/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <vector>
#include <map>
#include <set>
#include <dpp/export.h>
#include <dpp/snowflake.h>
#include "persisted_key_pair.h"
#include "key_ratchet.h"
#include "version.h"
Expand Down Expand Up @@ -73,7 +75,7 @@ class session { // NOLINT
* @param auth_session_id auth session id (set to empty string to use a transient key pair)
* @param callback callback for failure
*/
session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept;
session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept;

/**
* @brief Destructor
Expand All @@ -90,7 +92,7 @@ class session { // NOLINT
* @param self_user_id bot's user id
* @param transient_key transient private key
*/
void init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
void init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;

/**
* @brief Reset the session to defaults
Expand Down Expand Up @@ -129,7 +131,7 @@ class session { // NOLINT
* @param recognised_user_ids list of recognised user IDs
* @return optional vector to send in reply as commit welcome
*/
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept;
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;

/**
* @brief Process commit message from discord websocket
Expand All @@ -144,7 +146,7 @@ class session { // NOLINT
* @param recognised_user_ids Recognised user ID list
* @return roster list of people in the vc
*/
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept;
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;

/**
* @brief Get the bot user's key package for sending to websocket
Expand All @@ -157,7 +159,7 @@ class session { // NOLINT
* @param user_id User id to get ratchet for
* @return The user's key ratchet for use in an encryptor or decryptor
*/
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(std::string const& user_id) const noexcept;
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(dpp::snowflake user_id) const noexcept;

/**
* @brief callback for completion of pairwise fingerprint
Expand All @@ -172,15 +174,15 @@ class session { // NOLINT
* @param user_id User ID to get fingerprint for
* @param callback Callback for completion
*/
void get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept;
void get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept;

private:
/**
* @brief Initialise leaf node
* @param self_user_id Bot user id
* @param transient_key Transient key
*/
void init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
void init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;

/**
* @brief Reset join key
Expand All @@ -204,7 +206,7 @@ class session { // NOLINT
* @param recognised_user_ids list of recognised user IDs
* @return
*/
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Validate proposals message
Expand All @@ -213,15 +215,15 @@ class session { // NOLINT
* @param recognised_user_ids recognised list of user IDs
* @return true if validated
*/
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Verify that welcome state is valid
* @param state current state
* @param recognised_user_ids list of recognised user IDs
* @return
*/
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const;
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const;

/**
* @brief Check if can process a commit now
Expand Down Expand Up @@ -260,12 +262,12 @@ class session { // NOLINT
/**
* @brief Signing key id
*/
std::string signing_key_id;
dpp::snowflake signing_key_id;

/**
* @brief The bot's user snowflake ID
*/
std::string bot_user_id;
dpp::snowflake bot_user_id;

/**
* @brief The bot's key pair context
Expand Down
28 changes: 13 additions & 15 deletions src/dpp/voice/enabled/handle_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ void discord_voice_client::update_ratchets(bool force) {
*/
log(ll_debug, "Updating MLS ratchets for " + std::to_string(dave_mls_user_list.size() + 1) + " user(s)");
for (const auto& user : dave_mls_user_list) {
dpp::snowflake u{user};
if (u == creator->me.id) {
if (user == creator->me.id) {
continue;
}
decryptor_list::iterator decryptor;
/* New user join/old user leave - insert new ratchets if they don't exist */
decryptor = mls_state->decryptors.find(u);
decryptor = mls_state->decryptors.find(user.str());
if (decryptor == mls_state->decryptors.end()) {
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
auto [iter, inserted] = mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::decryptor>(*creator));
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user.str() + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
auto [iter, inserted] = mls_state->decryptors.emplace(user.str(), std::make_unique<dpp::dave::decryptor>(*creator));
decryptor = iter;
}
decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY);
Expand All @@ -72,7 +71,7 @@ void discord_voice_client::update_ratchets(bool force) {
if (mls_state->encryptor) {
/* Updating key rachet should always be done on execute transition. Generally after group member add/remove. */
log(ll_debug, "Setting key ratchet for sending audio...");
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str()));
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id));
}

/**
Expand Down Expand Up @@ -146,7 +145,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));

/* We should always recognize our own selves, but do we? */
dave_mls_user_list.insert(this->creator->me.id.str());
dave_mls_user_list.insert(this->creator->me.id);

auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list);

Expand Down Expand Up @@ -222,7 +221,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod

/* Remove this user from pending remove list if exist */
for (const auto &user : joining_dave_users) {
dave_mls_pending_remove_list.erase(user);
dave_mls_pending_remove_list.erase(dpp::snowflake(user));
}

log(ll_debug, "New of clients in voice channel: " + std::to_string(joining_dave_users.size()) + " total is " + std::to_string(dave_mls_user_list.size()));
Expand Down Expand Up @@ -298,7 +297,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
}

/* Mark this user for remove on immediate upgrade */
dave_mls_pending_remove_list.insert(u_id.str());
dave_mls_pending_remove_list.insert(u_id);

if (!creator->on_voice_client_disconnect.empty()) {
voice_client_disconnect_t vcd(nullptr, data);
Expand Down Expand Up @@ -575,12 +574,12 @@ void discord_voice_client::reinit_dave_mls_group() {
if (mls_state->dave_session == nullptr) {
mls_state->dave_session = std::make_unique<dave::mls::session>(
*creator,
nullptr, "", [this](std::string const &s1, std::string const &s2) {
nullptr, snowflake(), [this](std::string const &s1, std::string const &s2) {
log(ll_debug, "DAVE: " + s1 + ", " + s2);
});
}

mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id, mls_state->mls_key);

auto key_response = mls_state->dave_session->get_marshalled_key_package();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
Expand Down Expand Up @@ -630,12 +629,11 @@ void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rma
}

dpp::snowflake u_id(k);
auto u_id_str = u_id.str();

log(ll_debug, "Removed user from MLS Group: " + u_id_str);
log(ll_debug, "Removed user from MLS Group: " + u_id.str());

dave_mls_user_list.erase(u_id_str);
dave_mls_pending_remove_list.erase(u_id_str);
dave_mls_user_list.erase(u_id);
dave_mls_pending_remove_list.erase(u_id);

/* Remove this user's key ratchet */
mls_state->decryptors.erase(u_id);
Expand Down
Loading