Skip to content

Commit

Permalink
discord unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Oct 4, 2024
1 parent d92d46b commit 0ad6a45
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 68 deletions.
187 changes: 185 additions & 2 deletions src/davetest/dave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,58 @@
************************************************************************************/
#include <dpp/dpp.h>
#include <iostream>

#include <cstring>
#include <string>
#include <algorithm>

#include "../../mlspp/lib/bytes/include/bytes/bytes.h"
#include "../../src/dpp/dave/array_view.h"
#include "../../src/dpp/dave/openssl_aead_cipher.h"
#include "../../src/dpp/dave/key_ratchet.h"
#include "../../src/dpp/dave/common.h"

dpp::dave::EncryptionKey MakeStaticSenderKey(const std::string& userID);
dpp::dave::EncryptionKey MakeStaticSenderKey(uint64_t u64userID);

class StaticKeyRatchet : public dpp::dave::IKeyRatchet {
public:
StaticKeyRatchet(const std::string& userId) noexcept;
~StaticKeyRatchet() noexcept override = default;

dpp::dave::EncryptionKey GetKey(dpp::dave::KeyGeneration generation) noexcept override;
void DeleteKey(dpp::dave::KeyGeneration generation) noexcept override;

private:
uint64_t u64userID_;
};

dpp::dave::EncryptionKey MakeStaticSenderKey(const std::string& userID)
{
auto u64userID = strtoull(userID.c_str(), nullptr, 10);
return MakeStaticSenderKey(u64userID);
}

dpp::dave::EncryptionKey MakeStaticSenderKey(uint64_t u64userID)
{
static_assert(dpp::dave::kAesGcm128KeyBytes == 2 * sizeof(u64userID));
dpp::dave::EncryptionKey senderKey(dpp::dave::kAesGcm128KeyBytes);
const uint8_t* bytePtr = reinterpret_cast<const uint8_t*>(&u64userID);
std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin());
std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin() + sizeof(u64userID));
return senderKey;
}

StaticKeyRatchet::StaticKeyRatchet(const std::string& userId) noexcept
: u64userID_(strtoull(userId.c_str(), nullptr, 10)) {
}

dpp::dave::EncryptionKey StaticKeyRatchet::GetKey(dpp::dave::KeyGeneration generation) noexcept
{
return MakeStaticSenderKey(u64userID_);
}

void StaticKeyRatchet::DeleteKey([[maybe_unused]] dpp::dave::KeyGeneration generation) noexcept {
}

std::string get_testdata_dir() {
char *env_var = getenv("TEST_DATA_DIR");
Expand All @@ -47,7 +97,138 @@ std::vector<uint8_t> load_test_audio() {
return testaudio;
}

/*#define EXPECT_TRUE(expr) if (!(expr)) { std::cout << "Failed\n"; exit(1); }
#define EXPECT_FALSE(expr) if (expr) { std::cout << "Failed\n"; exit(1); }
void encryptor_unit_test() {
constexpr size_t PLAINTEXT_SIZE = 1024;
auto plaintextBufferIn = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto additionalDataBuffer = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto plaintextBufferOut = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto ciphertextBuffer = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto nonceBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128NonceBytes, 0);
auto tagBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128TruncatedTagBytes, 0);
auto plaintextIn =
dpp::dave::make_array_view<const uint8_t>(plaintextBufferIn.data(), plaintextBufferIn.size());
auto additionalData =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer.data(), additionalDataBuffer.size());
auto plaintextOut =
dpp::dave::make_array_view<uint8_t>(plaintextBufferOut.data(), plaintextBufferOut.size());
auto ciphertextOut = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer.data(), ciphertextBuffer.size());
auto ciphertextIn =
dpp::dave::make_array_view<const uint8_t>(ciphertextBuffer.data(), ciphertextBuffer.size());
auto nonce = dpp::dave::make_array_view<const uint8_t>(nonceBuffer.data(), nonceBuffer.size());
auto tagOut = dpp::dave::make_array_view<uint8_t>(tagBuffer.data(), tagBuffer.size());
auto tagIn = dpp::dave::make_array_view<const uint8_t>(tagBuffer.data(), tagBuffer.size());
dpp::dave::openssl_aead_cipher cryptor(MakeStaticSenderKey("12345678901234567890"));
EXPECT_TRUE(cryptor.encrypt(ciphertextOut, plaintextIn, nonce, additionalData, tagOut));
// The ciphertext should not be the same as the plaintext
EXPECT_FALSE(memcmp(plaintextBufferIn.data(), ciphertextBuffer.data(), PLAINTEXT_SIZE) == 0);
EXPECT_TRUE(cryptor.decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData));
// The plaintext should be the same as the original plaintext
EXPECT_TRUE(memcmp(plaintextBufferIn.data(), plaintextBufferOut.data(), PLAINTEXT_SIZE) == 0);
plaintextBufferIn = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
additionalDataBuffer = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
plaintextBufferOut = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
ciphertextBuffer = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
nonceBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128NonceBytes, 0);
tagBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128TruncatedTagBytes, 0);
plaintextIn =
dpp::dave::make_array_view<const uint8_t>(plaintextBufferIn.data(), plaintextBufferIn.size());
auto additionalData1 =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer.data(), additionalDataBuffer.size());
plaintextOut =
dpp::dave::make_array_view<uint8_t>(plaintextBufferOut.data(), plaintextBufferOut.size());
ciphertextOut = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer.data(), ciphertextBuffer.size());
ciphertextIn =
dpp::dave::make_array_view<const uint8_t>(ciphertextBuffer.data(), ciphertextBuffer.size());
nonce = dpp::dave::make_array_view<const uint8_t>(nonceBuffer.data(), nonceBuffer.size());
tagOut = dpp::dave::make_array_view<uint8_t>(tagBuffer.data(), tagBuffer.size());
tagIn = dpp::dave::make_array_view<const uint8_t>(tagBuffer.data(), tagBuffer.size());
dpp::dave::openssl_aead_cipher cryptor4(MakeStaticSenderKey("12345678901234567890"));
EXPECT_TRUE(cryptor4.encrypt(ciphertextOut, plaintextIn, nonce, additionalData1, tagOut));
// We modify the additional data before decryption
additionalDataBuffer[0] = 1;
EXPECT_FALSE(cryptor4.decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData1));
auto plaintextBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto additionalDataBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto plaintextBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto additionalDataBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto ciphertextBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto ciphertextBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
nonceBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128NonceBytes, 0);
tagBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128TruncatedTagBytes, 0);
auto plaintext1 =
dpp::dave::make_array_view<const uint8_t>(plaintextBuffer1.data(), plaintextBuffer1.size());
additionalData1 =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer1.data(), additionalDataBuffer1.size());
auto plaintext2 =
dpp::dave::make_array_view<const uint8_t>(plaintextBuffer2.data(), plaintextBuffer2.size());
auto additionalData2 =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer2.data(), additionalDataBuffer2.size());
auto ciphertext1 = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer1.data(), ciphertextBuffer1.size());
auto ciphertext2 = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer2.data(), ciphertextBuffer2.size());
nonce = dpp::dave::make_array_view<const uint8_t>(nonceBuffer.data(), nonceBuffer.size());
auto tag = dpp::dave::make_array_view<uint8_t>(tagBuffer.data(), tagBuffer.size());
dpp::dave::openssl_aead_cipher cryptor1(MakeStaticSenderKey("12345678901234567890"));
dpp::dave::openssl_aead_cipher cryptor2(MakeStaticSenderKey("09876543210987654321"));
EXPECT_TRUE(cryptor1.encrypt(ciphertext1, plaintext1, nonce, additionalData1, tag));
EXPECT_TRUE(cryptor2.encrypt(ciphertext2, plaintext2, nonce, additionalData2, tag));
EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0);
plaintextBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
additionalDataBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
plaintextBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
additionalDataBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
ciphertextBuffer1 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
ciphertextBuffer2 = std::vector<uint8_t>(PLAINTEXT_SIZE, 0);
auto nonceBuffer1 = std::vector<uint8_t>(dpp::dave::kAesGcm128NonceBytes, 0);
auto nonceBuffer2 = std::vector<uint8_t>(dpp::dave::kAesGcm128NonceBytes, 1);
tagBuffer = std::vector<uint8_t>(dpp::dave::kAesGcm128TruncatedTagBytes, 0);
plaintext1 =
dpp::dave::make_array_view<const uint8_t>(plaintextBuffer1.data(), plaintextBuffer1.size());
additionalData1 =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer1.data(), additionalDataBuffer1.size());
plaintext2 =
dpp::dave::make_array_view<const uint8_t>(plaintextBuffer2.data(), plaintextBuffer2.size());
additionalData2 =
dpp::dave::make_array_view<const uint8_t>(additionalDataBuffer2.data(), additionalDataBuffer2.size());
ciphertext1 = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer1.data(), ciphertextBuffer1.size());
ciphertext2 = dpp::dave::make_array_view<uint8_t>(ciphertextBuffer2.data(), ciphertextBuffer2.size());
auto nonce1 = dpp::dave::make_array_view<const uint8_t>(nonceBuffer1.data(), nonceBuffer1.size());
auto nonce2 = dpp::dave::make_array_view<const uint8_t>(nonceBuffer2.data(), nonceBuffer2.size());
tag = dpp::dave::make_array_view<uint8_t>(tagBuffer.data(), tagBuffer.size());
dpp::dave::openssl_aead_cipher cryptor3(MakeStaticSenderKey("12345678901234567890"));
EXPECT_TRUE(cryptor3.encrypt(ciphertext1, plaintext1, nonce1, additionalData1, tag));
EXPECT_TRUE(cryptor3.encrypt(ciphertext2, plaintext2, nonce2, additionalData2, tag));
EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0);
}*/

int main() {

//encryptor_unit_test();

using namespace std::chrono_literals;
char* t = getenv("DPP_UNIT_TEST_TOKEN");
if (t == nullptr || getenv("TEST_GUILD_ID") == nullptr || getenv("TEST_VC_ID") == nullptr) {
Expand All @@ -66,6 +247,7 @@ int main() {
std::vector<uint8_t> testaudio = load_test_audio();

dave_test.on_voice_ready([&](const dpp::voice_ready_t & event) {
dave_test.log(dpp::ll_info, "Voice channel ready, sending audio...");
dpp::discord_voice_client* v = event.voice_client;
if (v && v->is_ready()) {
v->send_audio_raw((uint16_t*)testaudio.data(), testaudio.size());
Expand All @@ -76,7 +258,8 @@ int main() {
dave_test.on_guild_create([&](const dpp::guild_create_t & event) {
if (event.created->id == TEST_GUILD_ID) {
dpp::discord_client* s = dave_test.get_shard(0);
s->connect_voice(TEST_GUILD_ID, TEST_VC_ID, false, false, true);
bool muted = false, deaf = false, enable_dave = true;
s->connect_voice(TEST_GUILD_ID, TEST_VC_ID, muted, deaf, enable_dave);
}
});
dave_test.start(false);
Expand Down
2 changes: 1 addition & 1 deletion src/dpp/dave/encryptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Codec Encryptor::CodecForSsrc(uint32_t ssrc)
return existingCodecIt->second;
}
else {
return Codec::Unknown;
return Codec::Opus;
}
}

Expand Down
18 changes: 9 additions & 9 deletions src/dpp/dave/encryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,6 @@ class Encryptor {
}
ProtocolVersion GetProtocolVersion() const { return currentProtocolVersion_; }

private:
std::unique_ptr<OutboundFrameProcessor> GetOrCreateFrameProcessor();
void ReturnFrameProcessor(std::unique_ptr<OutboundFrameProcessor> frameProcessor);

using CryptorAndNonce = std::pair<std::shared_ptr<cipher_interface>, TruncatedSyncNonce>;
CryptorAndNonce GetNextCryptorAndNonce();

void UpdateCurrentProtocolVersion(ProtocolVersion version);

enum ResultCode {
Success,
UninitializedContext,
Expand All @@ -95,6 +86,15 @@ class Encryptor {
TagAppendFailure
};

private:
std::unique_ptr<OutboundFrameProcessor> GetOrCreateFrameProcessor();
void ReturnFrameProcessor(std::unique_ptr<OutboundFrameProcessor> frameProcessor);

using CryptorAndNonce = std::pair<std::shared_ptr<cipher_interface>, TruncatedSyncNonce>;
CryptorAndNonce GetNextCryptorAndNonce();

void UpdateCurrentProtocolVersion(ProtocolVersion version);

std::atomic_bool passthroughMode_{false};

std::mutex keyGenMutex_;
Expand Down
1 change: 1 addition & 0 deletions src/dpp/dave/openssl_aead_cipher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <openssl/err.h>
#include <openssl/evp.h>
#include <bytes/bytes.h>
#include <iostream>
#include "common.h"
#include "logger.h"

Expand Down
Loading

0 comments on commit 0ad6a45

Please sign in to comment.