Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom headers in s3 #4400

Merged
merged 5 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/src/unit-cppapi-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*
* The MIT License
*
* @copyright Copyright (c) 2017-2021 TileDB, Inc.
* @copyright Copyright (c) 2017-2023 TileDB, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
23 changes: 23 additions & 0 deletions test/src/unit-s3-no-multipart.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,27 @@ TEST_CASE_METHOD(
auto badbuffer = (char*)malloc(11000000);
CHECK(!(s3_.write(URI(badfile), badbuffer, 11000000).ok()));
}

TEST_CASE_METHOD(
S3DirectFx, "Validate vfs.s3.custom_headers.*", "[s3][custom-headers]") {
Config cfg = set_config_params();

// Check the edge case of a key matching the ConfigIter prefix.
REQUIRE(cfg.set("vfs.s3.custom_headers.", "").ok());

// Set an unexpected value for Content-MD5, which minio should reject
REQUIRE(cfg.set("vfs.s3.custom_headers.Content-MD5", "unexpected").ok());

// Recreate a new S3 client because config is not dynamic
tiledb::sm::S3 s3{&g_helper_stats, &thread_pool_, cfg};
auto uri = URI(TEST_DIR + "writefailure");

// This is a buffered write, which is why it returns ok.
auto st = s3.write(uri, "Validate s3 custom headers", 26);
REQUIRE(st.ok());

auto matcher = Catch::Matchers::ContainsSubstring(
"The Content-Md5 you specified is not valid.");
REQUIRE_THROWS_WITH(s3.flush_object(uri), matcher);
}
#endif
4 changes: 4 additions & 0 deletions tiledb/api/c_api/config/config_api_external.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ TILEDB_EXPORT void tiledb_config_free(tiledb_config_t** config) TILEDB_NOEXCEPT;
* The scale factor for exponential backoff when connecting to S3.
* Any `long` value is acceptable. <br>
* **Default**: 25
* - `vfs.s3.custom_headers.*` <br>
* (Optional) Prefix for custom headers on s3 requests. For each custom
* header, use "vfs.s3.custom_headers.header_key" = "header_value" <br>
* **Optional. No Default**
* - `vfs.s3.logging_level` <br>
* The AWS SDK logging level. This is a process-global setting. The
* configuration of the most recently constructed context will set
Expand Down
4 changes: 4 additions & 0 deletions tiledb/sm/cpp_api/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,10 @@ class Config {
* The scale factor for exponential backofff when connecting to S3.
* Any `long` value is acceptable. <br>
* **Default**: 25
* - `vfs.s3.custom_headers.*` <br>
* (Optional) Prefix for custom headers on s3 requests. For each custom
* header, use "vfs.s3.custom_headers.header_key" = "header_value" <br>
* **Optional. No Default**
* - `vfs.s3.logging_level` <br>
* The AWS SDK logging level. This is a process-global setting. The
* configuration of the most recently constructed context will set
Expand Down
73 changes: 70 additions & 3 deletions tiledb/sm/filesystem/s3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "tiledb/common/logger.h"
#include "tiledb/common/unique_rwlock.h"
#include "tiledb/platform/platform.h"
#include "tiledb/sm/config/config_iter.h"
#include "tiledb/sm/global_state/unit_test_config.h"
#include "tiledb/sm/misc/tdb_math.h"
#include "tiledb/sm/misc/utils.h"
Expand Down Expand Up @@ -222,6 +223,71 @@ class S3Exception : public StatusException {
}
};

S3Parameters::Headers S3Parameters::load_headers(const Config& cfg) {
Headers ret;
auto iter = ConfigIter(cfg, constants::s3_header_prefix);
for (; !iter.end(); iter.next()) {
auto key = iter.param();
if (key.size() == 0) {
continue;
}
ret[key] = iter.value();
}
return ret;
}

/**
* Helper class which overrides Aws::S3::S3Client to set headers from
* vfs.s3.custom_headers.*
*
* @note The AWS SDK does not have a common base class, so there's no
* straightforward way to add a header to a unique request before submitting
* it. This class exists solely to override the S3Client, adding custom headers
* upon building the Http Request.
*/
class TileDBS3Client : public Aws::S3::S3Client {
public:
TileDBS3Client(
const S3Parameters& s3_params,
const Aws::Client::ClientConfiguration& client_config,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads,
bool use_virtual_addressing)
: Aws::S3::S3Client(client_config, sign_payloads, use_virtual_addressing)
, params_(s3_params) {
}

TileDBS3Client(
const S3Parameters& s3_params,
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& creds,
const Aws::Client::ClientConfiguration& client_config,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads,
bool use_virtual_addressing)
: Aws::S3::S3Client(
creds, client_config, sign_payloads, use_virtual_addressing)
, params_(s3_params) {
}

virtual void BuildHttpRequest(
const Aws::AmazonWebServiceRequest& request,
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest)
const override {
S3Client::BuildHttpRequest(request, httpRequest);

// Set header from S3Parameters custom headers
for (auto& [key, val] : params_.custom_headers_) {
httpRequest->SetHeaderValue(key, val);
}
}

protected:
/**
* A reference to the S3 configuration parameters, which stores the header.
*
* @note Until the removal of init_client(), this must be const-qualified.
*/
const S3Parameters& params_;
};

/* ********************************* */
/* GLOBAL VARIABLES */
/* ********************************* */
Expand Down Expand Up @@ -1490,16 +1556,17 @@ Status S3::init_client() const {
static std::mutex static_client_init_mtx;
{
std::lock_guard<std::mutex> static_lck(static_client_init_mtx);

if (credentials_provider_ == nullptr) {
client_ = make_shared<Aws::S3::S3Client>(
client_ = make_shared<TileDBS3Client>(
HERE(),
s3_params_,
*client_config_,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
s3_params_.use_virtual_addressing_);
} else {
client_ = make_shared<Aws::S3::S3Client>(
client_ = make_shared<TileDBS3Client>(
HERE(),
s3_params_,
credentials_provider_,
*client_config_,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
Expand Down
15 changes: 14 additions & 1 deletion tiledb/sm/filesystem/s3.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "tiledb/sm/stats/stats.h"
#include "uri.h"

#undef GetObject
#include <aws/core/Aws.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/client/ClientConfiguration.h>
Expand Down Expand Up @@ -85,6 +86,8 @@ using tiledb::common::filesystem::directory_entry;

namespace tiledb::sm {

class TileDBS3Client;

/**
* The s3-specific configuration parameters.
*
Expand All @@ -93,6 +96,9 @@ namespace tiledb::sm {
* @note Not all vfs.s3 config parameters are present in this struct.
*/
struct S3Parameters {
/** Stores parsed custom headers from config. */
using Headers = std::unordered_map<std::string, std::string>;

S3Parameters() = delete;

S3Parameters(const Config& config)
Expand Down Expand Up @@ -129,6 +135,7 @@ struct S3Parameters {
config.get<int64_t>("vfs.s3.connect_max_tries", Config::must_find))
, connect_scale_factor_(config.get<int64_t>(
"vfs.s3.connect_scale_factor", Config::must_find))
, custom_headers_(load_headers(config))
, logging_level_(
config.get<std::string>("vfs.s3.logging_level", Config::must_find))
, request_timeout_ms_(
Expand Down Expand Up @@ -161,6 +168,9 @@ struct S3Parameters {

~S3Parameters() = default;

/** Load all custom headers from the given config. */
static Headers load_headers(const Config& cfg);

/** The AWS region. */
std::string region_;

Expand Down Expand Up @@ -215,6 +225,9 @@ struct S3Parameters {
/** The scale factor for exponential backoff when connecting to S3. */
int64_t connect_scale_factor_;

/** Custom headers to add to all s3 requests. */
Headers custom_headers_;

/** Process-global AWS SDK logging level. */
std::string logging_level_;

Expand Down Expand Up @@ -861,7 +874,7 @@ class S3 {
* The lazily-initialized S3 client. This is mutable so that nominally const
* functions can call init_client().
*/
mutable shared_ptr<Aws::S3::S3Client> client_;
mutable shared_ptr<TileDBS3Client> client_;

/** The AWS credetial provider. */
mutable shared_ptr<Aws::Auth::AWSCredentialsProvider> credentials_provider_;
Expand Down
5 changes: 4 additions & 1 deletion tiledb/sm/misc/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*
* The MIT License
*
* @copyright Copyright (c) 2017-2022 TileDB, Inc.
* @copyright Copyright (c) 2017-2023 TileDB, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -731,6 +731,9 @@ const unsigned int gcs_attempt_sleep_ms = 1000;
/** An allocation tag used for logging. */
const std::string s3_allocation_tag = "TileDB";

/** The config key prefix for S3 custom headers. */
const std::string s3_header_prefix = "vfs.s3.custom_headers.";

/** Prefix indicating a special name reserved by TileDB. */
const std::string special_name_prefix = "__";

Expand Down
5 changes: 4 additions & 1 deletion tiledb/sm/misc/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*
* The MIT License
*
* @copyright Copyright (c) 2017-2022 TileDB, Inc.
* @copyright Copyright (c) 2017-2023 TileDB, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -699,6 +699,9 @@ extern const unsigned int gcs_attempt_sleep_ms;
/** An allocation tag used for logging. */
extern const std::string s3_allocation_tag;

/** The S3 custom headers config key prefix. */
extern const std::string s3_header_prefix;

/** Prefix indicating a special name reserved by TileDB. */
extern const std::string special_name_prefix;

Expand Down