Skip to content

Commit

Permalink
HPCC-30214 Use openssl aes encrypt/decrypt functions
Browse files Browse the repository at this point in the history
Changes following review, and additional tests

Signed-off-by: Richard Chapman <[email protected]>
  • Loading branch information
richardkchapman committed Nov 22, 2023
1 parent f20334b commit 9521d7a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 30 deletions.
4 changes: 1 addition & 3 deletions roxie/udplib/udpmsgpk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,8 @@ class PackageSequencer : public CInterface, implements IInterface
// MORE - could argue that we would prefer to wait even longer - until we know consumer wants it - but that might be complex
if (encrypted)
{
// MORE - This is decrypting in-place. Is that ok?? Seems to be with the code we currently use, but if that changed
// might need to rethink this
const MemoryAttr &udpkey = getSecretUdpKey(true);
size_t decryptedSize = aesDecrypt(udpkey.get(), udpkey.length(), pktHdr+1, pktHdr->length-sizeof(UdpPacketHeader), pktHdr+1, DATA_PAYLOAD-sizeof(UdpPacketHeader));
size_t decryptedSize = aesDecryptInPlace(udpkey.get(), udpkey.length(), pktHdr+1, pktHdr->length-sizeof(UdpPacketHeader));
pktHdr->length = decryptedSize + sizeof(UdpPacketHeader);
}

Expand Down
27 changes: 16 additions & 11 deletions system/jlib/jencrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,14 +1818,14 @@ MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *input, size
return output;
}

size_t aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, void *output, size_t outlen)
size_t aesDecryptInPlace(const void *key, size_t keylen, void *data, size_t inlen)
{
Rijndael rin;
Rijndael::KeyLength keyType = getAesKeyType(keylen);

rin.init(Rijndael::CBC, Rijndael::Decrypt, (const UINT8 *)key, keyType);
size32_t truncInLen = (size32_t)inlen;
int len = rin.padDecrypt((const UINT8 *)input, truncInLen, (UINT8 *) output, outlen);
int len = rin.padDecrypt((const UINT8 *)data, truncInLen, (UINT8 *) data, inlen);
if(len < 0)
throw MakeStringException(-1,"AES Decryption error: %d, %s", len, getAesErrorText(len));
return len;
Expand All @@ -1844,6 +1844,8 @@ static void encryptError(const char *what)

MemoryBuffer &aesEncrypt(const void *key, size_t keylen, const void *plaintext, size_t plaintext_len, MemoryBuffer &output)
{
if (!plaintext || !plaintext_len)
return output;
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (!ctx)
encryptError("Failed to create context");
Expand Down Expand Up @@ -1896,6 +1898,8 @@ static void decryptError(const char *what)

MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *ciphertext, size_t ciphertext_len, MemoryBuffer &output)
{
if (!ciphertext || !ciphertext_len)
return output;
EVP_CIPHER_CTX *ctx;

int thislen = 0;
Expand Down Expand Up @@ -1924,14 +1928,15 @@ MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *ciphertext,
decryptError("Unsupported key length");
break;
}
byte *plaintext = (byte *) output.reserve(ciphertext_len + 100);
byte *plaintext = (byte *) output.reserve(ciphertext_len);
if(1 != EVP_DecryptUpdate(ctx, plaintext, &thislen, (const unsigned char *) ciphertext, ciphertext_len))
decryptError("Error in EVP_DecryptUpdate");
plaintext_len += thislen;

if(1 != EVP_DecryptFinal_ex(ctx, plaintext + plaintext_len, &thislen))
decryptError("Error in EVP_DecryptFinal_ex");
plaintext_len += thislen;
output.setLength(plaintext_len);
EVP_CIPHER_CTX_free(ctx);
return output;
}
Expand All @@ -1942,8 +1947,10 @@ MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *ciphertext,
}
}

size_t aesDecrypt(const void *key, size_t keylen, const void *ciphertext, size_t ciphertext_len, void *output, size_t outlen)
size_t aesDecryptInPlace(const void *key, size_t keylen, void *ciphertext, size_t ciphertext_len)
{
if (!ciphertext || !ciphertext_len)
return 0;
EVP_CIPHER_CTX *ctx;

int thislen = 0;
Expand All @@ -1954,8 +1961,6 @@ size_t aesDecrypt(const void *key, size_t keylen, const void *ciphertext, size_t

try
{
if (outlen < ciphertext_len)
decryptError("output length too small"); // MORE - not sure this is actually true?
unsigned char iv[16] = { 0 };
switch (keylen)
{
Expand All @@ -1975,15 +1980,15 @@ size_t aesDecrypt(const void *key, size_t keylen, const void *ciphertext, size_t
decryptError("Unsupported key length");
break;
}
byte *plaintext = (byte *) output;
byte *plaintext = (byte *) ciphertext;
if(1 != EVP_DecryptUpdate(ctx, plaintext, &thislen, (const unsigned char *) ciphertext, ciphertext_len))
decryptError("Error in EVP_DecryptUpdate");
plaintext_len += thislen;

if(1 != EVP_DecryptFinal_ex(ctx, plaintext + plaintext_len, &thislen))
decryptError("Error in EVP_DecryptFinal_ex");
plaintext_len += thislen;
assertex(plaintext_len <= outlen);
assertex(plaintext_len <= ciphertext_len);
EVP_CIPHER_CTX_free(ctx);
return plaintext_len;
}
Expand Down Expand Up @@ -2015,12 +2020,12 @@ MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *input, size
#endif
}

size_t aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, void *output, size_t outlen)
size_t aesDecryptInPlace(const void *key, size_t keylen, void *data, size_t inlen)
{
#ifdef _USE_OPENSSL
return openssl::aesDecrypt(key, keylen, input, inlen, output, outlen);
return openssl::aesDecryptInPlace(key, keylen, data, inlen);
#else
return jlib::aesDecrypt(key, keylen, input, inlen, output, outlen);
return jlib::aesDecryptInPlace(key, keylen, data, inlen);
#endif
}

Expand Down
6 changes: 3 additions & 3 deletions system/jlib/jencrypt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ namespace jlib
{
extern jlib_decl MemoryBuffer &aesEncrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl size_t aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, void *output, size_t outlen);
extern jlib_decl size_t aesDecryptInPlace(const void *key, size_t keylen, void *data, size_t inlen);
} // end of namespace jlib;

#ifdef _USE_OPENSSL
namespace openssl
{
extern jlib_decl MemoryBuffer &aesEncrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl size_t aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, void *output, size_t outlen);
extern jlib_decl size_t aesDecryptInPlace(const void *key, size_t keylen, void *data, size_t inlen);
} // end of namespace openssl;
#endif

// NB: these are wrappers to either the openssl versions (if USE_OPENSSL) or the jlib version.

extern jlib_decl MemoryBuffer &aesEncrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl MemoryBuffer &aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, MemoryBuffer &output);
extern jlib_decl size_t aesDecrypt(const void *key, size_t keylen, const void *input, size_t inlen, void *output, size_t outlen);
extern jlib_decl size_t aesDecryptInPlace(const void *key, size_t keylen, void *data, size_t inlen);


#define encrypt _LogProcessError12
Expand Down
49 changes: 36 additions & 13 deletions testing/unittests/jlibtests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3641,7 +3641,7 @@ class JLibOpensslAESTest : public CppUnit::TestFixture

protected:

void test()
void testOne(unsigned len, const char *intext)
{
/* A 256 bit key */
unsigned char key[] = { 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
Expand All @@ -3650,25 +3650,48 @@ class JLibOpensslAESTest : public CppUnit::TestFixture
0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, 0x31
};

/* Message to be encrypted */
unsigned char *plaintext = (unsigned char *)"The quick brown fox jumps over the lazy dog";

MemoryBuffer ciphertext1, ciphertext2, decrypted;
MemoryBuffer ciphertext1, ciphertext2, decrypted1, decrypted2;

openssl::aesEncrypt(key, 32, plaintext, strlen ((char *)plaintext), ciphertext1);
jlib::aesEncrypt(key, 32, plaintext, strlen ((char *)plaintext), ciphertext2);
openssl::aesEncrypt(key, 32, intext, len, ciphertext1);
jlib::aesEncrypt(key, 32, intext, len, ciphertext2);

CPPUNIT_ASSERT(ciphertext1.length()==ciphertext2.length());
CPPUNIT_ASSERT(memcmp(ciphertext1.bytes(), ciphertext2.bytes(), ciphertext1.length()) == 0);

/* Decrypt the ciphertext */
openssl::aesDecrypt(key, 32, ciphertext1.bytes(), ciphertext1.length(), decrypted);

/* Add a NULL terminator. We are expecting printable text */
decrypted.append('\0');
openssl::aesDecrypt(key, 32, ciphertext1.bytes(), ciphertext1.length(), decrypted1);
assert(decrypted1.length() == len);
CPPUNIT_ASSERT(decrypted1.length() == len);
CPPUNIT_ASSERT(memcmp(decrypted1.bytes(), intext, len) == 0);
CPPUNIT_ASSERT(memcmp(ciphertext1.bytes(), ciphertext2.bytes(), ciphertext1.length()) == 0); // check input unchanged

jlib::aesDecrypt(key, 32, ciphertext2.bytes(), ciphertext2.length(), decrypted2);
CPPUNIT_ASSERT(decrypted2.length() == len);
CPPUNIT_ASSERT(memcmp(decrypted2.bytes(), intext, len) == 0);
CPPUNIT_ASSERT(memcmp(ciphertext1.bytes(), ciphertext2.bytes(), ciphertext1.length()) == 0); // check input unchanged

// Now test in-place decrypt
unsigned cipherlen = ciphertext1.length();
ciphertext1.append(4, "XXXX"); // Marker
unsigned decryptedlen = openssl::aesDecryptInPlace(key, 32, (void *) ciphertext1.bytes(), cipherlen);
CPPUNIT_ASSERT(decryptedlen == len);
CPPUNIT_ASSERT(memcmp(ciphertext1.bytes(), intext, len) == 0);
CPPUNIT_ASSERT(memcmp(ciphertext1.bytes()+cipherlen, "XXXX", 4) == 0);

cipherlen = ciphertext2.length();
ciphertext2.append(4, "XXXX"); // Marker
decryptedlen = jlib::aesDecryptInPlace(key, 32, (void *) ciphertext2.bytes(), cipherlen);
CPPUNIT_ASSERT(decryptedlen == len);
CPPUNIT_ASSERT(memcmp(ciphertext2.bytes(), intext, len) == 0);
CPPUNIT_ASSERT(memcmp(ciphertext2.bytes()+cipherlen, "XXXX", 4) == 0);
}

/* Show the decrypted text */
DBGLOG("Decrypted text is: %s", (const char *) decrypted.bytes());
void test()
{
/* Message to be encrypted */
const char *plaintext = "The quick brown fox jumps over the lazy dog";
for (unsigned l = 0; l < strlen(plaintext); l++)
testOne(l, plaintext);
}

};
Expand Down

0 comments on commit 9521d7a

Please sign in to comment.