Skip to content

Commit

Permalink
refine according to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Hang Zheng committed Aug 30, 2024
1 parent 4f813a7 commit de05328
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
14 changes: 8 additions & 6 deletions cpp/src/arrow/filesystem/filesystem_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ TEST(FileInfo, BaseName) {
}

TEST(CalculateSSECKeyMD5, Sanity) {
std::string lResult;
ASSERT_FALSE(CalculateSSECKeyMD5("", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("%^H", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("INVALID", lResult)); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("MTIzNDU2Nzg5", lResult)); // invalid, the input key size not match
ASSERT_TRUE(CalculateSSECKeyMD5("1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw=", lResult)); // valid case
ASSERT_FALSE(CalculateSSECKeyMD5("").ok()); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("%^H").ok()); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("INVALID").ok()); // invalid base64
ASSERT_FALSE(CalculateSSECKeyMD5("MTIzNDU2Nzg5").ok()); // invalid, the input key size not match
// valid case
auto result = CalculateSSECKeyMD5("1WH9aTJ0+Tn0NLbTMHZn9aCW3Li3ViAdBsoIldPCREw=");
ASSERT_TRUE(result.ok())); // valid case
ASSERT_STREQ(result->c_str(), "3HYIM58NCLwrIOdPpWnYwQ=="); // valid case
}

TEST(PathUtil, SplitAbstractPath) {
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,14 @@ Status SetSSECustomerKey(S3RequestType& request,
if (sse_customer_key.empty()) {
return Status::OK(); // do nothing if the sse_customer_key is not configured
}
std::string sse_customer_key_md5;
if (internal::CalculateSSECKeyMD5(sse_customer_key, sse_customer_key_md5)) {
request.SetSSECustomerKeyMD5(sse_customer_key_md5);
auto result = internal::CalculateSSECKeyMD5(sse_customer_key);
if (result.ok()) {
request.SetSSECustomerKeyMD5(*result);
request.SetSSECustomerKey(sse_customer_key);
request.SetSSECustomerAlgorithm("AES256");
return Status::OK();
} else {
return Status::Invalid("sse_customer_key is not a vaild 256-bit base64-encoded encryption key");
return result.status();
}
}

Expand Down
12 changes: 5 additions & 7 deletions cpp/src/arrow/filesystem/util_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ Result<FileInfoVector> GlobFiles(const std::shared_ptr<FileSystem>& filesystem,
return out;
}

bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result,
Result<std::string> CalculateSSECKeyMD5(const std::string& base64_encoded_key,
int expect_input_key_size) {
if (base64_encoded_key.size() < 2) {
return false;
return Status::Invalid("At least 2 bytes needed for the base64 encoded string");
}
// Check if the string contains only valid Base64 characters
for (char c : base64_encoded_key) {
if (!std::isalnum(c) && c != '+' && c != '/' && c != '=') {
return false;
return Status::Invalid("Invalid character found in the base64 encoded string");
}
}

Expand All @@ -282,7 +282,7 @@ bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5
// the key needs to be // 256 bits(32 bytes)according to
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption
if (rawKey.GetLength() != expect_input_key_size) {
return false;
return Status::Invalid("Invalid Length for the key");
}

// Convert the raw binary key to an Aws::String
Expand All @@ -295,9 +295,7 @@ bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5
// Base64-encode the MD5 hash
Aws::String awsEncodedHash = Aws::Utils::HashingUtils::Base64Encode(md5Hash);

// Return the Base64-encoded MD5 hash as a std::string
md5_result = std::string(awsEncodedHash.begin(), awsEncodedHash.end());
return true;
return std::string(awsEncodedHash.begin(), awsEncodedHash.end());
}


Expand Down
5 changes: 2 additions & 3 deletions cpp/src/arrow/filesystem/util_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ Result<FileInfoVector> GlobFiles(const std::shared_ptr<FileSystem>& filesystem,

/// \brief Decode the Input SSE key,calculate the MD5
/// \param base64_encoded_key is the input base64 encoded sse key
/// \param md5_result, output resut
/// \param expect_input_key_size, default 32
/// \return true if the decode and calculate MD5 success, otherwise return false
/// \return the base64 encoded MD5 for the input key
ARROW_EXPORT
bool CalculateSSECKeyMD5(const std::string& base64_encoded_key, std::string& md5_result,
Result<std::string> CalculateSSECKeyMD5(const std::string& base64_encoded_key,
int expect_input_key_size = 32);


Expand Down

0 comments on commit de05328

Please sign in to comment.