Skip to content

Commit

Permalink
apacheGH-43535: [C++] support the AWS S3 SSE-C encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
Hang Zheng committed Sep 27, 2024
1 parent 96d61c7 commit 3ae52f0
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 15 deletions.
40 changes: 40 additions & 0 deletions cpp/src/arrow/filesystem/s3_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <cctype>
#include <optional>
#include <sstream>
#include <string>
Expand All @@ -29,11 +30,13 @@
#include <aws/core/client/RetryStrategy.h>
#include <aws/core/http/HttpTypes.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/StringUtils.h>

#include "arrow/filesystem/filesystem.h"
#include "arrow/filesystem/s3fs.h"
#include "arrow/status.h"
#include "arrow/util/base64.h"
#include "arrow/util/logging.h"
#include "arrow/util/print.h"
#include "arrow/util/string.h"
Expand Down Expand Up @@ -291,6 +294,43 @@ class ConnectRetryStrategy : public Aws::Client::RetryStrategy {
int32_t max_retry_duration_;
};

/// \brief calculate the MD5 of the input sse-c key (raw key, not base64 encoded)
/// \param sse_customer_key is the input sse key
/// \return the base64 encoded MD5 for the input key
inline Result<std::string> CalculateSSECustomerKeyMD5(
const std::string& sse_customer_key) {
// the key needs to be 256 bits (32 bytes) according to
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption
if (sse_customer_key.length() != 32) {
return Status::Invalid("32 bytes sse-c key is expected");
}

// Convert the raw binary key to an Aws::String
Aws::String sse_customer_key_aws_string(sse_customer_key.data(),
sse_customer_key.length());

// Compute the MD5 hash of the raw binary key
Aws::Utils::ByteBuffer sse_customer_key_md5 =
Aws::Utils::HashingUtils::CalculateMD5(sse_customer_key_aws_string);

// Base64-encode the MD5 hash
return arrow::util::base64_encode(std::string_view(
reinterpret_cast<const char*>(sse_customer_key_md5.GetUnderlyingData()),
sse_customer_key_md5.GetLength()));
}

template <typename S3RequestType>
Status SetSSECustomerKey(S3RequestType& request, const std::string& sse_customer_key) {
if (sse_customer_key.empty()) {
return Status::OK(); // do nothing if the sse_customer_key is not configured
}
ARROW_ASSIGN_OR_RAISE(auto md5, internal::CalculateSSECustomerKeyMD5(sse_customer_key));
request.SetSSECustomerKeyMD5(md5);
request.SetSSECustomerKey(arrow::util::base64_encode(sse_customer_key));
request.SetSSECustomerAlgorithm("AES256");
return Status::OK();
}

} // namespace internal
} // namespace fs
} // namespace arrow
40 changes: 28 additions & 12 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ using internal::IsNotFound;
using internal::OutcomeToResult;
using internal::OutcomeToStatus;
using internal::S3Backend;
using internal::SetSSECustomerKey;
using internal::ToAwsString;
using internal::ToURLEncodedAwsString;

Expand Down Expand Up @@ -439,7 +440,8 @@ bool S3Options::Equals(const S3Options& other) const {
background_writes == other.background_writes &&
allow_bucket_creation == other.allow_bucket_creation &&
allow_bucket_deletion == other.allow_bucket_deletion &&
default_metadata_equals && GetAccessKey() == other.GetAccessKey() &&
sse_customer_key == other.sse_customer_key && default_metadata_equals &&
GetAccessKey() == other.GetAccessKey() &&
GetSecretKey() == other.GetSecretKey() &&
GetSessionToken() == other.GetSessionToken());
}
Expand Down Expand Up @@ -1292,11 +1294,14 @@ Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) {
}

Result<S3Model::GetObjectResult> GetObjectRange(Aws::S3::S3Client* client,
const S3Path& path, int64_t start,
int64_t length, void* out) {
const S3Path& path,
const std::string& sse_customer_key,
int64_t start, int64_t length,
void* out) {
S3Model::GetObjectRequest req;
req.SetBucket(ToAwsString(path.bucket));
req.SetKey(ToAwsString(path.key));
RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key));
req.SetRange(ToAwsString(FormatRange(start, length)));
req.SetResponseStreamFactory(AwsWriteableStreamFactory(out, length));
return OutcomeToResult("GetObject", client->GetObject(req));
Expand Down Expand Up @@ -1433,11 +1438,13 @@ bool IsDirectory(std::string_view key, const S3Model::HeadObjectResult& result)
class ObjectInputFile final : public io::RandomAccessFile {
public:
ObjectInputFile(std::shared_ptr<S3ClientHolder> holder, const io::IOContext& io_context,
const S3Path& path, int64_t size = kNoSize)
const S3Path& path, int64_t size = kNoSize,
const std::string& sse_customer_key = "")
: holder_(std::move(holder)),
io_context_(io_context),
path_(path),
content_length_(size) {}
content_length_(size),
sse_customer_key_(sse_customer_key) {}

Status Init() {
// Issue a HEAD Object to get the content-length and ensure any
Expand All @@ -1450,6 +1457,7 @@ class ObjectInputFile final : public io::RandomAccessFile {
S3Model::HeadObjectRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_));

ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock());
auto outcome = client_lock.Move()->HeadObject(req);
Expand Down Expand Up @@ -1534,9 +1542,9 @@ class ObjectInputFile final : public io::RandomAccessFile {

// Read the desired range of bytes
ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock());
ARROW_ASSIGN_OR_RAISE(
S3Model::GetObjectResult result,
GetObjectRange(client_lock.get(), path_, position, nbytes, out));
ARROW_ASSIGN_OR_RAISE(S3Model::GetObjectResult result,
GetObjectRange(client_lock.get(), path_, sse_customer_key_,
position, nbytes, out));

auto& stream = result.GetBody();
stream.ignore(nbytes);
Expand Down Expand Up @@ -1584,6 +1592,7 @@ class ObjectInputFile final : public io::RandomAccessFile {
int64_t pos_ = 0;
int64_t content_length_ = kNoSize;
std::shared_ptr<const KeyValueMetadata> metadata_;
std::string sse_customer_key_;
};

// Upload size per part. While AWS and Minio support different sizes for each
Expand Down Expand Up @@ -1620,7 +1629,8 @@ class ObjectOutputStream final : public io::OutputStream {
metadata_(metadata),
default_metadata_(options.default_metadata),
background_writes_(options.background_writes),
allow_delayed_open_(options.allow_delayed_open) {}
allow_delayed_open_(options.allow_delayed_open),
sse_customer_key_(options.sse_customer_key) {}

~ObjectOutputStream() override {
// For compliance with the rest of the IO stack, Close rather than Abort,
Expand Down Expand Up @@ -1668,6 +1678,7 @@ class ObjectOutputStream final : public io::OutputStream {
S3Model::CreateMultipartUploadRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_));
RETURN_NOT_OK(SetMetadataInRequest(&req));

auto outcome = client_lock.Move()->CreateMultipartUpload(req);
Expand Down Expand Up @@ -1769,6 +1780,7 @@ class ObjectOutputStream final : public io::OutputStream {
S3Model::CompleteMultipartUploadRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_));
req.SetUploadId(multipart_upload_id_);
req.SetMultipartUpload(std::move(completed_upload));

Expand Down Expand Up @@ -1950,6 +1962,7 @@ class ObjectOutputStream final : public io::OutputStream {
req.SetKey(ToAwsString(path_.key));
req.SetBody(std::make_shared<StringViewStream>(data, nbytes));
req.SetContentLength(nbytes);
RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_));

if (!background_writes_) {
req.SetBody(std::make_shared<StringViewStream>(data, nbytes));
Expand Down Expand Up @@ -2171,6 +2184,7 @@ class ObjectOutputStream final : public io::OutputStream {
Future<> pending_uploads_completed = Future<>::MakeFinished(Status::OK());
};
std::shared_ptr<UploadState> upload_state_;
std::string sse_customer_key_;
};

// This function assumes info->path() is already set
Expand Down Expand Up @@ -2321,6 +2335,7 @@ class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Imp

S3Model::CopyObjectRequest req;
req.SetBucket(ToAwsString(dest_path.bucket));
RETURN_NOT_OK(SetSSECustomerKey(req, options().sse_customer_key));
req.SetKey(ToAwsString(dest_path.key));
// ARROW-13048: Copy source "Must be URL-encoded" according to AWS SDK docs.
// However at least in 1.8 and 1.9 the SDK URL-encodes the path for you
Expand Down Expand Up @@ -2972,7 +2987,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Imp

RETURN_NOT_OK(CheckS3Initialized());

auto ptr = std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path);
auto ptr = std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path, kNoSize,
fs->options().sse_customer_key);
RETURN_NOT_OK(ptr->Init());
return ptr;
}
Expand All @@ -2992,8 +3008,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Imp

RETURN_NOT_OK(CheckS3Initialized());

auto ptr =
std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path, info.size());
auto ptr = std::make_shared<ObjectInputFile>(
holder_, fs->io_context(), path, info.size(), fs->options().sse_customer_key);
RETURN_NOT_OK(ptr->Init());
return ptr;
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/filesystem/s3fs.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ struct ARROW_EXPORT S3Options {
/// delay between retries.
std::shared_ptr<S3RetryStrategy> retry_strategy;

/// the SSE-C customized key (raw 32 bytes key).
std::string sse_customer_key;

S3Options();

/// Configure with the default AWS credentials provider chain.
Expand Down
55 changes: 52 additions & 3 deletions cpp/src/arrow/filesystem/s3fs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ using ::arrow::internal::ToChars;
using ::arrow::internal::Zip;
using ::arrow::util::UriEscape;

using ::arrow::fs::internal::CalculateSSECustomerKeyMD5;
using ::arrow::fs::internal::ConnectRetryStrategy;
using ::arrow::fs::internal::ErrorToStatus;
using ::arrow::fs::internal::OutcomeToStatus;
Expand Down Expand Up @@ -530,17 +531,19 @@ class TestS3FS : public S3TestMixin {
}

Result<std::shared_ptr<S3FileSystem>> MakeNewFileSystem(
io::IOContext io_context = io::default_io_context()) {
io::IOContext io_context = io::default_io_context(), bool use_https = false) {
options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key());
options_.scheme = "http";
options_.scheme = use_https ? "https" : "http";
options_.endpoint_override = minio_->connect_string();
if (!options_.retry_strategy) {
options_.retry_strategy = std::make_shared<ShortRetryStrategy>();
}
return S3FileSystem::Make(options_, io_context);
}

void MakeFileSystem() { ASSERT_OK_AND_ASSIGN(fs_, MakeNewFileSystem()); }
void MakeFileSystem(bool use_https = false) {
ASSERT_OK_AND_ASSIGN(fs_, MakeNewFileSystem(io::default_io_context(), use_https));
}

template <typename Matcher>
void AssertMetadataRoundtrip(const std::string& path,
Expand Down Expand Up @@ -1288,6 +1291,36 @@ TEST_F(TestS3FS, OpenInputFile) {
ASSERT_RAISES(IOError, file->Seek(10));
}

TEST_F(TestS3FS, SSECustomerKeyMatch) {
// normal write/read with correct SSEC key
std::shared_ptr<io::OutputStream> stream;
options_.sse_customer_key = "12345678123456781234567812345678";
MakeFileSystem(true); // need to use https, otherwise get 'InvalidRequest Message:
// Requests specifying Server Side Encryption with Customer
// provided keys must be made over a secure connection.'
ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c"));
ASSERT_OK(stream->Write("some"));
ASSERT_OK(stream->Close());
ASSERT_OK_AND_ASSIGN(auto file, fs_->OpenInputFile("bucket/newfile_with_sse_c"));
ASSERT_OK_AND_ASSIGN(auto buf, file->Read(4));
AssertBufferEqual(*buf, "some");
ASSERT_OK(RestoreTestBucket());
}

TEST_F(TestS3FS, SSECustomerKeyMismatch) {
std::shared_ptr<io::OutputStream> stream;
options_.sse_customer_key = "12345678123456781234567812345678";
MakeFileSystem(true);
ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c"));
ASSERT_OK(stream->Write("some"));
ASSERT_OK(stream->Close());

options_.sse_customer_key = "87654321876543218765432187654321";
MakeFileSystem(true);
ASSERT_RAISES(IOError, fs_->OpenInputFile("bucket/newfile_with_sse_c"));
ASSERT_OK(RestoreTestBucket());
}

struct S3OptionsTestParameters {
bool background_writes{false};
bool allow_delayed_open{false};
Expand Down Expand Up @@ -1579,5 +1612,21 @@ TEST(S3GlobalOptions, DefaultsLogLevel) {
}
}

TEST(CalculateSSECustomerKeyMD5, Sanity) {
ASSERT_RAISES(Invalid, CalculateSSECustomerKeyMD5("")); // invalid length
ASSERT_RAISES(Invalid,
CalculateSSECustomerKeyMD5(
"1234567890123456789012345678901234567890")); // invalid length
// valid case, with some non-ASCII character and a null byte in the sse_customer_key
char sse_customer_key[32] = {};
sse_customer_key[0] = '\x40'; // '@' character
sse_customer_key[1] = '\0'; // null byte
sse_customer_key[2] = '\xFF'; // non-ASCII
sse_customer_key[31] = '\xFA'; // non-ASCII
std::string sse_customer_key_string(sse_customer_key, sizeof(sse_customer_key));
ASSERT_OK_AND_ASSIGN(auto md5, CalculateSSECustomerKeyMD5(sse_customer_key_string))
ASSERT_STREQ(md5.c_str(), "97FTa6lj0hE7lshKdBy61g=="); // valid case
}

} // namespace fs
} // namespace arrow

0 comments on commit 3ae52f0

Please sign in to comment.