Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Sigv4 to Role based Access #390

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions osbenchmark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self, hosts, client_options):
self.aws_log_in_dict = self.parse_aws_log_in_params()
masked_client_options["aws_access_key_id"] = "*****"
masked_client_options["aws_secret_access_key"] = "*****"
# session_token is optional and used only for role based access
if self.aws_log_in_dict.get("aws_session_token", None):
masked_client_options["aws_session_token"] = "*****"
self.logger.info("Creating OpenSearch client connected to %s with options [%s]", hosts, masked_client_options)

# we're using an SSL context now and it is not allowed to have use_ssl present in client options anymore
Expand Down Expand Up @@ -206,7 +209,7 @@ def __init__(self, hosts, client_options):
self.logger.info("HTTP basic authentication: off")

if self._is_set(self.client_options, "compressed"):
console.warn("You set the deprecated client option 'compressed. Please use 'http_compress' instead.", logger=self.logger)
console.warn("You set the deprecated client option 'compressed'. Please use 'http_compress' instead.", logger=self.logger)
self.client_options["http_compress"] = self.client_options.pop("compressed")

if self._is_set(self.client_options, "http_compress"):
Expand Down Expand Up @@ -251,12 +254,16 @@ def parse_aws_log_in_params(self):
aws_log_in_dict["aws_secret_access_key"] = os.environ.get("OSB_AWS_SECRET_ACCESS_KEY")
aws_log_in_dict["region"] = os.environ.get("OSB_REGION")
aws_log_in_dict["service"] = os.environ.get("OSB_SERVICE")
# optional: applicable only for role-based access
aws_log_in_dict["aws_session_token"] = os.environ.get("OSB_AWS_SESSION_TOKEN")
# aws log in : option 2) parameters are passed in from command line
elif self.client_options["amazon_aws_log_in"] == "client_option":
aws_log_in_dict["aws_access_key_id"] = self.client_options.get("aws_access_key_id")
aws_log_in_dict["aws_secret_access_key"] = self.client_options.get("aws_secret_access_key")
aws_log_in_dict["region"] = self.client_options.get("region")
aws_log_in_dict["service"] = self.client_options.get("service")
# optional: applicable only for role-based access
aws_log_in_dict["aws_session_token"] = self.client_options.get("aws_session_token")
if (not aws_log_in_dict["aws_access_key_id"] or not aws_log_in_dict["aws_secret_access_key"]
or not aws_log_in_dict["service"] or not aws_log_in_dict["region"]):
self.logger.error("Invalid amazon aws log in parameters, required input aws_access_key_id, "
Expand All @@ -282,7 +289,8 @@ def create(self):
return opensearchpy.OpenSearch(hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options)

credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"],
secret_key=self.aws_log_in_dict["aws_secret_access_key"])
secret_key=self.aws_log_in_dict["aws_secret_access_key"],
token=self.aws_log_in_dict["aws_session_token"])
aws_auth = opensearchpy.AWSV4SignerAuth(credentials, self.aws_log_in_dict["region"],
self.aws_log_in_dict["service"])
return opensearchpy.OpenSearch(hosts=self.hosts, use_ssl=True, verify_certs=True, http_auth=aws_auth,
Expand Down Expand Up @@ -332,7 +340,8 @@ class BenchmarkAsyncOpenSearch(opensearchpy.AsyncOpenSearch, RequestContextHolde
**self.client_options)

credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"],
secret_key=self.aws_log_in_dict["aws_secret_access_key"])
secret_key=self.aws_log_in_dict["aws_secret_access_key"],
token=self.aws_log_in_dict["aws_session_token"])
aws_auth = opensearchpy.AWSV4SignerAsyncAuth(credentials, self.aws_log_in_dict["region"],
self.aws_log_in_dict["service"])
return BenchmarkAsyncOpenSearch(hosts=self.hosts,
Expand Down
10 changes: 9 additions & 1 deletion osbenchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(self, cfg):
default_value=None, mandatory=False)
metrics_aws_access_key_id = None
metrics_aws_secret_access_key = None
metrics_aws_session_token = None
metrics_aws_region = None
metrics_aws_service = None

Expand All @@ -196,13 +197,16 @@ def __init__(self, cfg):
default_value=None, mandatory=False)
metrics_aws_secret_access_key = self._config.opts("results_publishing", "datastore.aws_secret_access_key",
default_value=None, mandatory=False)
metrics_aws_session_token = self._config.opts("results_publishing", "datastore.aws_session_token",
default_value=None, mandatory=False)
metrics_aws_region = self._config.opts("results_publishing", "datastore.region",
default_value=None, mandatory=False)
metrics_aws_service = self._config.opts("results_publishing", "datastore.service",
default_value=None, mandatory=False)
elif metrics_amazon_aws_log_in == 'environment':
metrics_aws_access_key_id = os.getenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", default=None)
metrics_aws_secret_access_key = os.getenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", default=None)
metrics_aws_session_token = os.getenv("OSB_DATASTORE_AWS_SESSION_TOKEN", default=None)
metrics_aws_region = os.getenv("OSB_DATASTORE_REGION", default=None)
metrics_aws_service = os.getenv("OSB_DATASTORE_SERVICE", default=None)

Expand Down Expand Up @@ -254,14 +258,18 @@ def __init__(self, cfg):
client_options["basic_auth_user"] = user
client_options["basic_auth_password"] = password

#add options for aws user login: pass in aws access key id, aws secret access key, service and region on command
# add options for aws user login:
# pass in aws access key id, aws secret access key, aws session token, service and region on command
if metrics_amazon_aws_log_in is not None:
client_options["amazon_aws_log_in"] = 'client_option'
client_options["aws_access_key_id"] = metrics_aws_access_key_id
client_options["aws_secret_access_key"] = metrics_aws_secret_access_key
client_options["service"] = metrics_aws_service
client_options["region"] = metrics_aws_region

if metrics_aws_session_token:
client_options["aws_session_token"] = metrics_aws_session_token

factory = client.OsClientFactory(hosts=[{"host": host, "port": port}], client_options=client_options)
self._client = factory.create()

Expand Down
62 changes: 39 additions & 23 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_create_https_connection_unverified_certificate(self, mocked_load_cert_c
@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain):
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
user_based_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
Expand All @@ -259,34 +259,50 @@ def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain):
"region": "us-east-1",
"verify_certs": True
}
# make a copy so we can verify later that the factory did not modify it
original_client_options = dict(client_options)

role_based_client_options = dict(user_based_client_options)
role_based_client_options["aws_session_token"] = "dummy_token"

client_options_list = [
user_based_client_options,
role_based_client_options
]

logger = logging.getLogger("osbenchmark.client")
with mock.patch.object(logger, "info") as mocked_info_logger:
f = client.OsClientFactory(hosts, client_options)
mocked_info_logger.assert_has_calls([
mock.call("SSL support: on"),
mock.call("SSL certificate verification: on"),
mock.call("SSL client authentication: off")
])

assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \
"client certs"
for client_options in client_options_list:
# make a copy so we can verify later that the factory did not modify it
original_client_options = dict(client_options)

self.assertEqual(hosts, f.hosts)
self.assertTrue(f.ssl_context.check_hostname)
self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode)
with mock.patch.object(logger, "info") as mocked_info_logger:
f = client.OsClientFactory(hosts, client_options)

self.assertEqual("https", f.client_options["scheme"])
self.assertIn("timeout", f.client_options)
self.assertIn("aws_access_key_id", f.client_options)
self.assertIn("aws_secret_access_key", f.client_options)
self.assertIn("amazon_aws_log_in", f.client_options)
self.assertIn("service", f.client_options)
self.assertIn("region", f.client_options)
mocked_info_logger.assert_has_calls([
mock.call("SSL support: on"),
mock.call("SSL certificate verification: on"),
mock.call("SSL client authentication: off")
])

assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \
"client certs"

self.assertEqual(hosts, f.hosts)
self.assertTrue(f.ssl_context.check_hostname)
self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode)

self.assertEqual("https", f.client_options["scheme"])
self.assertIn("timeout", f.client_options)
self.assertIn("aws_access_key_id", f.client_options)
self.assertIn("aws_secret_access_key", f.client_options)
self.assertIn("amazon_aws_log_in", f.client_options)
self.assertIn("service", f.client_options)
self.assertIn("region", f.client_options)

if "aws_session_token" in original_client_options:
self.assertIn("aws_session_token", f.client_options)

self.assertDictEqual(original_client_options, client_options)

self.assertDictEqual(original_client_options, client_options)

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_unverified_certificate_present_client_certificates(self, mocked_load_cert_chain):
Expand Down
67 changes: 47 additions & 20 deletions tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from osbenchmark.metrics import GlobalStatsCalculator
from osbenchmark.workload import Task, Operation, TestProcedure, Workload

AWS_ACCESS_KEY_ID_LENGTH = 12
AWS_SECRET_ACCESS_KEY_LENGTH = 40
AWS_SESSION_TOKEN_LENGTH = 752

class MockClientFactory:
def __init__(self, cfg):
Expand Down Expand Up @@ -239,24 +242,33 @@ def test_config_opts_parsing_aws_creds_with_env(self, client_OsClientfactory):
}
self.config_opts_parsing_aws_creds("environment", override_datastore=override_config)

# verify config parsing is successful when all required parameters are present
config_opts = self.config_opts_parsing_aws_creds("environment")

expected_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
"aws_access_key_id": config_opts["_datastore_aws_access_key_id"],
"aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"],
"service": config_opts["_datastore_aws_service"],
"region": config_opts["_datastore_aws_region"],
"verify_certs": config_opts["_datastore_verify_certs"]
}

client_OsClientfactory.assert_called_with(
hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}],
client_options=expected_client_options
)
# validate client_options when session_token is passed
enable_role_access = [False, True]
for role_based in enable_role_access:
# verify config parsing is successful when all required parameters are present
config_opts = self.config_opts_parsing_aws_creds("environment", role_based=role_based)

expected_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
"aws_access_key_id": config_opts["_datastore_aws_access_key_id"],
"aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"],
"service": config_opts["_datastore_aws_service"],
"region": config_opts["_datastore_aws_region"],
"verify_certs": config_opts["_datastore_verify_certs"]
}

if role_based:
expected_client_options["aws_session_token"] = config_opts["_datastore_aws_session_token"]

client_OsClientfactory.assert_called_with(
hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}],
client_options=expected_client_options
)


def config_opts_parsing(self, password_configuration):
cfg = config.Config()
Expand Down Expand Up @@ -302,7 +314,7 @@ def config_opts_parsing(self, password_configuration):
"_datastore_verify_certs": _datastore_verify_certs
}

def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None):
def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None, role_based=False):
if override_datastore is None:
override_datastore = {}
cfg = config.Config()
Expand All @@ -314,11 +326,16 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
_datastore_password = ""
_datastore_verify_certs = random.choice([True, False])
_datastore_amazon_aws_log_in = configuration_source
_datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(12)])
_datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(40)])
_datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(AWS_ACCESS_KEY_ID_LENGTH)])
_datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) \
for _ in range(AWS_SECRET_ACCESS_KEY_LENGTH)])
_datastore_aws_service = random.choice(['es', 'aoss'])
_datastore_aws_region = random.choice(['us-east-1', 'eu-west-1'])

# optional
_datastore_aws_session_token = "".join([random.choice(string.ascii_letters + string.digits) \
for _ in range(AWS_SESSION_TOKEN_LENGTH)])

cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.host", _datastore_host)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.port", _datastore_port)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.secure", _datastore_secure)
Expand All @@ -334,12 +351,17 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
_datastore_aws_secret_access_key)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.service", _datastore_aws_service)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.region", _datastore_aws_region)
if role_based:
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.aws_session_token", _datastore_aws_session_token)
elif _datastore_amazon_aws_log_in == 'environment':
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", _datastore_aws_access_key_id)
monkeypatch.setenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", _datastore_aws_secret_access_key)
monkeypatch.setenv("OSB_DATASTORE_SERVICE", _datastore_aws_service)
monkeypatch.setenv("OSB_DATASTORE_REGION", _datastore_aws_region)
if role_based:
monkeypatch.setenv("OSB_DATASTORE_AWS_SESSION_TOKEN", _datastore_aws_session_token)


if not _datastore_verify_certs:
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.ssl.verification_mode", "none")
Expand Down Expand Up @@ -375,7 +397,7 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
assert e.message == missing_aws_credentials_message
return

return {
response = {
"_datastore_user": _datastore_user,
"_datastore_host": _datastore_host,
"_datastore_password": _datastore_password,
Expand All @@ -387,6 +409,11 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
"_datastore_aws_region": _datastore_aws_region
}

if role_based:
response["_datastore_aws_session_token"] = _datastore_aws_session_token

return response

def test_raises_sytem_setup_error_on_connection_problems(self):
def raise_connection_error():
raise opensearchpy.exceptions.ConnectionError("unit-test")
Expand Down
Loading