diff --git a/cpp/src/arrow/filesystem/s3_internal.h b/cpp/src/arrow/filesystem/s3_internal.h index 54da3d5987e8a..dcffdc5739884 100644 --- a/cpp/src/arrow/filesystem/s3_internal.h +++ b/cpp/src/arrow/filesystem/s3_internal.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -29,11 +30,13 @@ #include #include #include +#include #include #include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/s3fs.h" #include "arrow/status.h" +#include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/print.h" #include "arrow/util/string.h" @@ -291,6 +294,43 @@ class ConnectRetryStrategy : public Aws::Client::RetryStrategy { int32_t max_retry_duration_; }; +/// \brief calculate the MD5 of the input sse-c key (raw key, not base64 encoded) +/// \param sse_customer_key is the input sse key +/// \return the base64 encoded MD5 for the input key +inline Result CalculateSSECustomerKeyMD5( + const std::string& sse_customer_key) { + // the key needs to be 256 bits (32 bytes) according to + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption + if (sse_customer_key.length() != 32) { + return Status::Invalid("32 bytes sse-c key is expected"); + } + + // Convert the raw binary key to an Aws::String + Aws::String sse_customer_key_aws_string(sse_customer_key.data(), + sse_customer_key.length()); + + // Compute the MD5 hash of the raw binary key + Aws::Utils::ByteBuffer sse_customer_key_md5 = + Aws::Utils::HashingUtils::CalculateMD5(sse_customer_key_aws_string); + + // Base64-encode the MD5 hash + return arrow::util::base64_encode(std::string_view( + reinterpret_cast(sse_customer_key_md5.GetUnderlyingData()), + sse_customer_key_md5.GetLength())); +} + +template +Status SetSSECustomerKey(S3RequestType& request, const std::string& sse_customer_key) { + if (sse_customer_key.empty()) { + return Status::OK(); // do nothing if the sse_customer_key is not configured + } + ARROW_ASSIGN_OR_RAISE(auto md5, internal::CalculateSSECustomerKeyMD5(sse_customer_key)); + request.SetSSECustomerKeyMD5(md5); + request.SetSSECustomerKey(arrow::util::base64_encode(sse_customer_key)); + request.SetSSECustomerAlgorithm("AES256"); + return Status::OK(); +} + } // namespace internal } // namespace fs } // namespace arrow diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 77b111f61bf4c..ceaa9fff69b6b 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -160,6 +160,7 @@ using internal::IsNotFound; using internal::OutcomeToResult; using internal::OutcomeToStatus; using internal::S3Backend; +using internal::SetSSECustomerKey; using internal::ToAwsString; using internal::ToURLEncodedAwsString; @@ -439,7 +440,8 @@ bool S3Options::Equals(const S3Options& other) const { background_writes == other.background_writes && allow_bucket_creation == other.allow_bucket_creation && allow_bucket_deletion == other.allow_bucket_deletion && - default_metadata_equals && GetAccessKey() == other.GetAccessKey() && + sse_customer_key == other.sse_customer_key && default_metadata_equals && + GetAccessKey() == other.GetAccessKey() && GetSecretKey() == other.GetSecretKey() && GetSessionToken() == other.GetSessionToken()); } @@ -1292,11 +1294,14 @@ Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) { } Result GetObjectRange(Aws::S3::S3Client* client, - const S3Path& path, int64_t start, - int64_t length, void* out) { + const S3Path& path, + const std::string& sse_customer_key, + int64_t start, int64_t length, + void* out) { S3Model::GetObjectRequest req; req.SetBucket(ToAwsString(path.bucket)); req.SetKey(ToAwsString(path.key)); + RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key)); req.SetRange(ToAwsString(FormatRange(start, length))); req.SetResponseStreamFactory(AwsWriteableStreamFactory(out, length)); return OutcomeToResult("GetObject", client->GetObject(req)); @@ -1433,11 +1438,13 @@ bool IsDirectory(std::string_view key, const S3Model::HeadObjectResult& result) class ObjectInputFile final : public io::RandomAccessFile { public: ObjectInputFile(std::shared_ptr holder, const io::IOContext& io_context, - const S3Path& path, int64_t size = kNoSize) + const S3Path& path, int64_t size = kNoSize, + const std::string& sse_customer_key = "") : holder_(std::move(holder)), io_context_(io_context), path_(path), - content_length_(size) {} + content_length_(size), + sse_customer_key_(sse_customer_key) {} Status Init() { // Issue a HEAD Object to get the content-length and ensure any @@ -1450,6 +1457,7 @@ class ObjectInputFile final : public io::RandomAccessFile { S3Model::HeadObjectRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); + RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_)); ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); auto outcome = client_lock.Move()->HeadObject(req); @@ -1534,9 +1542,9 @@ class ObjectInputFile final : public io::RandomAccessFile { // Read the desired range of bytes ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); - ARROW_ASSIGN_OR_RAISE( - S3Model::GetObjectResult result, - GetObjectRange(client_lock.get(), path_, position, nbytes, out)); + ARROW_ASSIGN_OR_RAISE(S3Model::GetObjectResult result, + GetObjectRange(client_lock.get(), path_, sse_customer_key_, + position, nbytes, out)); auto& stream = result.GetBody(); stream.ignore(nbytes); @@ -1584,6 +1592,7 @@ class ObjectInputFile final : public io::RandomAccessFile { int64_t pos_ = 0; int64_t content_length_ = kNoSize; std::shared_ptr metadata_; + std::string sse_customer_key_; }; // Upload size per part. While AWS and Minio support different sizes for each @@ -1620,7 +1629,8 @@ class ObjectOutputStream final : public io::OutputStream { metadata_(metadata), default_metadata_(options.default_metadata), background_writes_(options.background_writes), - allow_delayed_open_(options.allow_delayed_open) {} + allow_delayed_open_(options.allow_delayed_open), + sse_customer_key_(options.sse_customer_key) {} ~ObjectOutputStream() override { // For compliance with the rest of the IO stack, Close rather than Abort, @@ -1668,6 +1678,7 @@ class ObjectOutputStream final : public io::OutputStream { S3Model::CreateMultipartUploadRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); + RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_)); RETURN_NOT_OK(SetMetadataInRequest(&req)); auto outcome = client_lock.Move()->CreateMultipartUpload(req); @@ -1769,6 +1780,7 @@ class ObjectOutputStream final : public io::OutputStream { S3Model::CompleteMultipartUploadRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); + RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_)); req.SetUploadId(multipart_upload_id_); req.SetMultipartUpload(std::move(completed_upload)); @@ -1950,6 +1962,7 @@ class ObjectOutputStream final : public io::OutputStream { req.SetKey(ToAwsString(path_.key)); req.SetBody(std::make_shared(data, nbytes)); req.SetContentLength(nbytes); + RETURN_NOT_OK(SetSSECustomerKey(req, sse_customer_key_)); if (!background_writes_) { req.SetBody(std::make_shared(data, nbytes)); @@ -2171,6 +2184,7 @@ class ObjectOutputStream final : public io::OutputStream { Future<> pending_uploads_completed = Future<>::MakeFinished(Status::OK()); }; std::shared_ptr upload_state_; + std::string sse_customer_key_; }; // This function assumes info->path() is already set @@ -2321,6 +2335,7 @@ class S3FileSystem::Impl : public std::enable_shared_from_this(holder_, fs->io_context(), path); + auto ptr = std::make_shared(holder_, fs->io_context(), path, kNoSize, + fs->options().sse_customer_key); RETURN_NOT_OK(ptr->Init()); return ptr; } @@ -2992,8 +3008,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this(holder_, fs->io_context(), path, info.size()); + auto ptr = std::make_shared( + holder_, fs->io_context(), path, info.size(), fs->options().sse_customer_key); RETURN_NOT_OK(ptr->Init()); return ptr; } diff --git a/cpp/src/arrow/filesystem/s3fs.h b/cpp/src/arrow/filesystem/s3fs.h index 85d5ff8fed553..46aa5430ed914 100644 --- a/cpp/src/arrow/filesystem/s3fs.h +++ b/cpp/src/arrow/filesystem/s3fs.h @@ -196,6 +196,9 @@ struct ARROW_EXPORT S3Options { /// delay between retries. std::shared_ptr retry_strategy; + /// the SSE-C customized key (raw 32 bytes key). + std::string sse_customer_key; + S3Options(); /// Configure with the default AWS credentials provider chain. diff --git a/cpp/src/arrow/filesystem/s3fs_test.cc b/cpp/src/arrow/filesystem/s3fs_test.cc index 82a7d6e546ef3..d7f1bf8c89b19 100644 --- a/cpp/src/arrow/filesystem/s3fs_test.cc +++ b/cpp/src/arrow/filesystem/s3fs_test.cc @@ -80,6 +80,7 @@ using ::arrow::internal::ToChars; using ::arrow::internal::Zip; using ::arrow::util::UriEscape; +using ::arrow::fs::internal::CalculateSSECustomerKeyMD5; using ::arrow::fs::internal::ConnectRetryStrategy; using ::arrow::fs::internal::ErrorToStatus; using ::arrow::fs::internal::OutcomeToStatus; @@ -530,9 +531,9 @@ class TestS3FS : public S3TestMixin { } Result> MakeNewFileSystem( - io::IOContext io_context = io::default_io_context()) { + io::IOContext io_context = io::default_io_context(), bool use_https = false) { options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key()); - options_.scheme = "http"; + options_.scheme = use_https ? "https" : "http"; options_.endpoint_override = minio_->connect_string(); if (!options_.retry_strategy) { options_.retry_strategy = std::make_shared(); @@ -540,7 +541,9 @@ class TestS3FS : public S3TestMixin { return S3FileSystem::Make(options_, io_context); } - void MakeFileSystem() { ASSERT_OK_AND_ASSIGN(fs_, MakeNewFileSystem()); } + void MakeFileSystem(bool use_https = false) { + ASSERT_OK_AND_ASSIGN(fs_, MakeNewFileSystem(io::default_io_context(), use_https)); + } template void AssertMetadataRoundtrip(const std::string& path, @@ -1288,6 +1291,36 @@ TEST_F(TestS3FS, OpenInputFile) { ASSERT_RAISES(IOError, file->Seek(10)); } +TEST_F(TestS3FS, SSECustomerKeyMatch) { + // normal write/read with correct SSEC key + std::shared_ptr stream; + options_.sse_customer_key = "12345678123456781234567812345678"; + MakeFileSystem(true); // need to use https, otherwise get 'InvalidRequest Message: + // Requests specifying Server Side Encryption with Customer + // provided keys must be made over a secure connection.' + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + ASSERT_OK_AND_ASSIGN(auto file, fs_->OpenInputFile("bucket/newfile_with_sse_c")); + ASSERT_OK_AND_ASSIGN(auto buf, file->Read(4)); + AssertBufferEqual(*buf, "some"); + ASSERT_OK(RestoreTestBucket()); +} + +TEST_F(TestS3FS, SSECustomerKeyMismatch) { + std::shared_ptr stream; + options_.sse_customer_key = "12345678123456781234567812345678"; + MakeFileSystem(true); + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + + options_.sse_customer_key = "87654321876543218765432187654321"; + MakeFileSystem(true); + ASSERT_RAISES(IOError, fs_->OpenInputFile("bucket/newfile_with_sse_c")); + ASSERT_OK(RestoreTestBucket()); +} + struct S3OptionsTestParameters { bool background_writes{false}; bool allow_delayed_open{false}; @@ -1579,5 +1612,21 @@ TEST(S3GlobalOptions, DefaultsLogLevel) { } } +TEST(CalculateSSECustomerKeyMD5, Sanity) { + ASSERT_RAISES(Invalid, CalculateSSECustomerKeyMD5("")); // invalid length + ASSERT_RAISES(Invalid, + CalculateSSECustomerKeyMD5( + "1234567890123456789012345678901234567890")); // invalid length + // valid case, with some non-ASCII character and a null byte in the sse_customer_key + char sse_customer_key[32] = {}; + sse_customer_key[0] = '\x40'; // '@' character + sse_customer_key[1] = '\0'; // null byte + sse_customer_key[2] = '\xFF'; // non-ASCII + sse_customer_key[31] = '\xFA'; // non-ASCII + std::string sse_customer_key_string(sse_customer_key, sizeof(sse_customer_key)); + ASSERT_OK_AND_ASSIGN(auto md5, CalculateSSECustomerKeyMD5(sse_customer_key_string)) + ASSERT_STREQ(md5.c_str(), "97FTa6lj0hE7lshKdBy61g=="); // valid case +} + } // namespace fs } // namespace arrow