Skip to content

Commit

Permalink
fix: DAVE protocol initiation for BOB
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 1, 2024
1 parent 2e8000e commit ca0a82d
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 29 deletions.
15 changes: 12 additions & 3 deletions include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ inline constexpr size_t send_audio_raw_max_length = 11520;
inline constexpr size_t secret_key_size = 32;

struct dave_transient_key;
struct dave_encryptors;

/*
* @brief For holding a moving average of the number of current voice users, for applying a smooth gain ramp.
Expand Down Expand Up @@ -387,6 +388,12 @@ class DPP_EXPORT discord_voice_client : public websocket_client
*/
OpusRepacketizer* repacketizer;

std::unique_ptr<dave::mls::Session> dave_session{};

std::unique_ptr<dave_transient_key> transient_key{};

std::unique_ptr<dave_encryptors> encryptors{};

#else
/**
* @brief libopus encoder
Expand All @@ -398,11 +405,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client
* (merges frames into one packet)
*/
void* repacketizer;
#endif

std::unique_ptr<dave::mls::Session> dave_session{};
std::unique_ptr<int> dave_session{};

std::unique_ptr<dave_transient_key> transient_key{};
std::unique_ptr<int> transient_key{};

std::unique_ptr<int> encryptors{};
#endif

std::set<std::string> dave_mls_user_list;

Expand Down
6 changes: 2 additions & 4 deletions mlspp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ option(TESTING "Build tests" OFF)
option(CLANG_TIDY "Perform linting with clang-tidy" OFF)
option(SANITIZERS "Enable sanitizers" OFF)
option(MLS_NAMESPACE_SUFFIX "Namespace Suffix for CXX and CMake Export")
option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" OFF)
option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" ON)
option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF)

if(MLS_NAMESPACE_SUFFIX)
Expand Down Expand Up @@ -54,9 +54,7 @@ if("$ENV{MACOSX_DEPLOYMENT_TARGET}" STREQUAL "10.11")
add_compile_options(-DVARIANT_COMPAT)
endif()

if (DISABLE_GREASE)
add_compile_options(-DDISABLE_GREASE)
endif ()
add_compile_options(-DDISABLE_GREASE)

###
### Dependencies
Expand Down
8 changes: 3 additions & 5 deletions src/dpp/dave/persisted_key_pair.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>
#include <mutex>
#include <functional>
#include <iostream>

#include <bytes/bytes.h>
#include <mls/crypto.h>
Expand Down Expand Up @@ -37,14 +38,11 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(
std::string id = MakeKeyID(sessionID, suite);

if (auto it = map.find(id); it != map.end()) {
std::cout << "5\n";
return it->second;
}

std::shared_ptr<::mlspp::SignaturePrivateKey> ret;

if (!ret) {
ret = ::dpp::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite);
}
std::shared_ptr<::mlspp::SignaturePrivateKey> ret = ::dpp::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite);

if (!ret) {
DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair";
Expand Down
15 changes: 12 additions & 3 deletions src/dpp/dave/persisted_key_pair_generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <mutex>
#include <functional>
#include <iostream>

#ifdef _WIN32
#include <io.h>
Expand Down Expand Up @@ -67,6 +68,8 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
std::string curstr;
std::filesystem::path dir = GetKeyStorageDirectory();

std::cout << "KSD: " << dir << "\n";

if (dir.empty()) {
DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair";
return nullptr;
Expand All @@ -82,6 +85,8 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair

std::filesystem::path file = dir / (id + ".key");

std::cout << "FILE: " << file << "\n";

if (std::filesystem::exists(file)) {
std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary);
if (!ifs) {
Expand All @@ -97,6 +102,8 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
return nullptr;
}

std::cout << "CURSTR: " << curstr << "\n";

try {
ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr);
}
Expand All @@ -106,13 +113,16 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
}
}
else {
std::cout << "GEN NEW\n";
ret = ::mlspp::SignaturePrivateKey::generate(suite);

std::string newstr = ret.to_jwk(suite);

std::filesystem::path tmpfile = file;
tmpfile += ".tmp";

std::cout << "TMPFILE " << tmpfile << "\n";

#ifdef _WIN32
int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE);
#else
Expand Down Expand Up @@ -150,9 +160,8 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
if (!ret.public_key.data.empty()) {
return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret));
}
else {
return nullptr;
}
return nullptr;

}

bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& id)
Expand Down
24 changes: 24 additions & 0 deletions src/dpp/dave/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <thread>
#include <vector>
#include <cstring>
#include <iostream>

#include <hpke/random.h>
#include <hpke/signature.h>
Expand Down Expand Up @@ -331,37 +332,57 @@ try {
DISCORD_LOG(LS_INFO) << "Processing commit";
DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit);

std::cout << "1\n";

auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit);

std::cout << "2\n";

if (!CanProcessCommit(commitMessage)) {
DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit";
return ignored_t{};
}

std::cout << "3\n";

// in case we're the sender of this commit
// we need to pull the cached state from our outbound cache
std::optional<::mlspp::State> optionalCachedState = std::nullopt;
if (outboundCachedGroupState_) {
optionalCachedState = *(outboundCachedGroupState_.get());
}

std::cout << "4\n";

auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState);

std::cout << "5\n";

if (!newState) {
DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state";
return failed_t{};
}

std::cout << "6\n";

DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is "
<< newState->index().val << "; current epoch is " << newState->epoch();

std::cout << "7\n";

RosterMap ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState)));

std::cout << "8\n";

// reset the outbound cached group since we handled the commit for this epoch
outboundCachedGroupState_.reset();

std::cout << "9\n";

ClearPendingState();

std::cout << "10\n";

return ret;
}
catch (const std::exception& e) {
Expand Down Expand Up @@ -536,7 +557,9 @@ try {
auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_);

if (!transientKey) {
std::cout << "2\n";
if (!signingKeyId_.empty()) {
std::cout << "3\n";
transientKey = GetPersistedKeyPair(keyPairContext_, signingKeyId_, protocolVersion_);
if (!transientKey) {
DISCORD_LOG(LS_ERROR) << "Did not receive MLS signature private key from "
Expand All @@ -545,6 +568,7 @@ try {
}
}
else {
std::cout << "1\n";
transientKey = std::make_shared<::mlspp::SignaturePrivateKey>(
::mlspp::SignaturePrivateKey::generate(ciphersuite));
}
Expand Down
60 changes: 47 additions & 13 deletions src/dpp/discordvoiceclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
#include <sodium.h>
#include <opus/opus.h>
#include "dave/session.h"
#include "dave/decryptor.h"
#include "dave/encryptor.h"
#else
struct OpusDecoder {};
struct OpusEncoder {};
Expand Down Expand Up @@ -74,6 +76,12 @@ static std::string external_ip;

struct dave_transient_key {
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
std::vector<uint8_t> cached_commit;
};

struct dave_encryptors {
std::unique_ptr<dave::Encryptor> encryptor;
std::unique_ptr<dave::Decryptor> decryptor;
};

/**
Expand Down Expand Up @@ -500,32 +508,69 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
/**
* MLS frames come in as type OP_BINARY, we can also reply to them as type OP_BINARY.
*/
if (opcode == OP_BINARY && data.size() >= sizeof(dave_binary_header_t) && dave_session && transient_key && transient_key->mls_key) {
if (opcode == OP_BINARY && data.size() >= sizeof(dave_binary_header_t)) {

auto* dave_header = reinterpret_cast<const dave_binary_header_t*>(data.data());

switch (dave_header->opcode) {
case voice_client_dave_mls_external_sender: {
log(ll_debug, "voice_client_dave_mls_external_sender");


dave_session = std::make_unique<dave::mls::Session>(
nullptr, sessionid, [this](std::string const& s1, std::string const& s2) {
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
});

dave_session->SetExternalSender(dave_header->get_data(data.length()));

transient_key = std::make_unique<dave_transient_key>();
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), transient_key->mls_key);

encryptors = std::make_unique<dave_encryptors>();
encryptors->encryptor = std::make_unique<dave::Encryptor>();
/**
* TODO: There should be one of these per user but only one of the encryptor, above
*/
encryptors->decryptor = std::make_unique<dave::Decryptor>();

auto epoch = dave_session->GetLastEpochAuthenticator();

auto key_response = dave_session->GetMarshalledKeyPackage();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);

encryptors->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
}
break;
case voice_client_dave_mls_proposals: {
log(ll_debug, "voice_client_dave_mls_proposals");

std::optional<std::vector<uint8_t>> response = dave_session->ProcessProposals(dave_header->get_data(data.length()), dave_mls_user_list);
if (response.has_value()) {
auto r = response.value();
transient_key->cached_commit = r;
r.insert(r.begin(), voice_client_dave_mls_commit_message);
this->write(std::string_view(reinterpret_cast<const char*>(r.data()), r.size()), OP_BINARY);
}
}
break;
case voice_client_dave_announce_commit_transaction: {
log(ll_debug, "voice_client_dave_announce_commit_transaction");
auto r = dave_session->ProcessCommit(transient_key->cached_commit);
for (const auto& user : dave_mls_user_list) {
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
encryptors->decryptor->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
}
}
break;
case voice_client_dave_mls_welcome: {
log(ll_debug, "voice_client_dave_mls_welcome");
auto r = dave_session->ProcessWelcome(dave_header->get_data(data.length()), dave_mls_user_list);
}
break;
case voice_client_dave_mls_invalid_commit_welcome: {
log(ll_debug, "voice_client_dave_mls_invalid_commit_welcome");
}
break;
default:
Expand Down Expand Up @@ -638,7 +683,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod

if (!modes.empty()) {
log(dpp::ll_debug, "Resuming voice session " + this->sessionid + "...");
json obj = {
json obj = {
{ "op", voice_opcode_connection_resume },
{
"d",
Expand Down Expand Up @@ -688,17 +733,6 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
log(ll_error, "We requested DAVE E2EE but didn't receive it from the server, downgrading...");
dave_version = dave_version_none;
send_silence(20);
} else {
dave_session = std::make_unique<dave::mls::Session>(
nullptr, sessionid, [this](std::string const& s1, std::string const& s2) {
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
});
transient_key = std::make_unique<dave_transient_key>();
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), transient_key->mls_key);

auto key_response = dave_session->GetMarshalledKeyPackage();
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
}
} else {
/* This is needed to start voice receiving and make sure that the start of sending isn't cut off */
Expand Down
2 changes: 1 addition & 1 deletion src/dpp/utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ std::string debug_dump(uint8_t* data, size_t length) {
size_t extra = addr % 16;
if (extra != 0) {
addr -= extra;
out << to_hex(addr);
out << "\n[" << to_hex(addr) << "] : ";
}
for (size_t n = 0; n < extra; ++n) {
out << "-- ";
Expand Down

0 comments on commit ca0a82d

Please sign in to comment.