Skip to content

Commit

Permalink
add the ut for the change
Browse files Browse the repository at this point in the history
  • Loading branch information
Hang Zheng committed Aug 27, 2024
1 parent 4f66938 commit bb2e024
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 41 deletions.
9 changes: 9 additions & 0 deletions cpp/src/arrow/filesystem/filesystem_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ TEST(FileInfo, BaseName) {
ASSERT_EQ(info.base_name(), "baz.qux");
}

TEST(CalculateSSECKeyMD5, Sanity) {
std::string lResult;
ASSERT_FALSE(CalculateSSECKeyMD5("", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("%^H", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("INVALID", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("MTIzNDU2Nzg5", lResult)); // not, match 32 bytes
ASSERT_TRUE(CalculateSSECKeyMD5("1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw=", lResult)); // valid case
}

TEST(PathUtil, SplitAbstractPath) {
std::vector<std::string> parts;

Expand Down
42 changes: 1 addition & 41 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <cctype>

#ifdef _WIN32
// Undefine preprocessor macros that interfere with AWS function / method names
Expand Down Expand Up @@ -78,8 +77,6 @@
#include <aws/s3/model/PutObjectRequest.h>
#include <aws/s3/model/PutObjectResult.h>
#include <aws/s3/model/UploadPartRequest.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/base64/Base64.h>

// AWS_SDK_VERSION_{MAJOR,MINOR,PATCH} are available since 1.9.7.
#if defined(AWS_SDK_VERSION_MAJOR) && defined(AWS_SDK_VERSION_MINOR) && \
Expand Down Expand Up @@ -451,51 +448,14 @@ bool S3Options::Equals(const S3Options& other) const {

namespace {

bool ComputeMD5Base64(const std::string& base64EncodedKey,
std::string& base64DecodedResult) {

if (base64EncodedKey.size() < 2) {
return false;
}
// Check if the string contains only valid Base64 characters
for (char c : base64EncodedKey) {
if (!std::isalnum(c) && c != '+' && c != '/' && c != '=') {
return false;
}
}

// Decode the Base64-encoded key to get the raw binary key
Aws::Utils::ByteBuffer rawKey =
Aws::Utils::HashingUtils::Base64Decode(base64EncodedKey);

// the key needs to be // 256 bits(32 bytes)according to https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption
if (rawKey.GetLength() != 32) {
return false;
}

// Convert the raw binary key to an Aws::String
Aws::String rawKeyStr(reinterpret_cast<const char*>(rawKey.GetUnderlyingData()),
rawKey.GetLength());

// Compute the MD5 hash of the raw binary key
Aws::Utils::ByteBuffer md5Hash = Aws::Utils::HashingUtils::CalculateMD5(rawKeyStr);

// Base64-encode the MD5 hash
Aws::String awsEncodedHash = Aws::Utils::HashingUtils::Base64Encode(md5Hash);

// Return the Base64-encoded MD5 hash as a std::string
base64DecodedResult = std::string(awsEncodedHash.begin(), awsEncodedHash.end());
return true;
}

template <typename S3RequestType>
Status SetSSECustomerKey(S3RequestType& request,
const std::string& sse_customer_key) {
if (sse_customer_key.empty()) {
return Status::OK(); // do nothing if the sse_customer_key is not configured
}
std::string sse_customer_key_md5;
if (ComputeMD5Base64(sse_customer_key, sse_customer_key_md5)) {
if (internal::CalculateSSECKeyMD5(sse_customer_key, sse_customer_key_md5)) {
request.SetSSECustomerKeyMD5(sse_customer_key_md5);
request.SetSSECustomerKey(sse_customer_key);
request.SetSSECustomerAlgorithm("AES256");
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/filesystem/s3fs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ class TestS3FS : public S3TestMixin {
// Most tests will create buckets
options_.allow_bucket_creation = true;
options_.allow_bucket_deletion = true;
options_.sse_customer_key = "1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw=";
MakeFileSystem();
// Set up test bucket
{
Expand Down
41 changes: 41 additions & 0 deletions cpp/src/arrow/filesystem/util_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include <algorithm>
#include <cerrno>
#include <cctype>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/base64/Base64.h>

#include "arrow/buffer.h"
#include "arrow/filesystem/path_util.h"
Expand Down Expand Up @@ -260,6 +263,44 @@ Result<FileInfoVector> GlobFiles(const std::shared_ptr<FileSystem>& filesystem,
return out;
}

bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result,
int expect_input_key_size = 32) {
if (base64_encoded_key.size() < 2) {
return false;
}
// Check if the string contains only valid Base64 characters
for (char c : base64_encoded_key) {
if (!std::isalnum(c) && c != '+' && c != '/' && c != '=') {
return false;
}
}

// Decode the Base64-encoded key to get the raw binary key
Aws::Utils::ByteBuffer rawKey =
Aws::Utils::HashingUtils::Base64Decode(base64_encoded_key);

// the key needs to be // 256 bits(32 bytes)according to
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption
if (rawKey.GetLength() != expect_input_key_size) {
return false;
}

// Convert the raw binary key to an Aws::String
Aws::String rawKeyStr(reinterpret_cast<const char*>(rawKey.GetUnderlyingData()),
rawKey.GetLength());

// Compute the MD5 hash of the raw binary key
Aws::Utils::ByteBuffer md5Hash = Aws::Utils::HashingUtils::CalculateMD5(rawKeyStr);

// Base64-encode the MD5 hash
Aws::String awsEncodedHash = Aws::Utils::HashingUtils::Base64Encode(md5Hash);

// Return the Base64-encoded MD5 hash as a std::string
md5_result = std::string(awsEncodedHash.begin(), awsEncodedHash.end());
return true;
}


FileSystemGlobalOptions global_options;

} // namespace internal
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/filesystem/util_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ ARROW_EXPORT
Result<FileInfoVector> GlobFiles(const std::shared_ptr<FileSystem>& filesystem,
const std::string& glob);


bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result,
int expect_input_key_size = 32);


extern FileSystemGlobalOptions global_options;

} // namespace internal
Expand Down

0 comments on commit bb2e024

Please sign in to comment.