diff --git a/osbenchmark/client.py b/osbenchmark/client.py index 810992073..f779a5999 100644 --- a/osbenchmark/client.py +++ b/osbenchmark/client.py @@ -135,6 +135,9 @@ def __init__(self, hosts, client_options): self.aws_log_in_dict = self.parse_aws_log_in_params() masked_client_options["aws_access_key_id"] = "*****" masked_client_options["aws_secret_access_key"] = "*****" + # session_token is optional and used only for role based access + if self.aws_log_in_dict.get("aws_session_token", None): + masked_client_options["aws_session_token"] = "*****" self.logger.info("Creating OpenSearch client connected to %s with options [%s]", hosts, masked_client_options) # we're using an SSL context now and it is not allowed to have use_ssl present in client options anymore @@ -206,7 +209,7 @@ def __init__(self, hosts, client_options): self.logger.info("HTTP basic authentication: off") if self._is_set(self.client_options, "compressed"): - console.warn("You set the deprecated client option 'compressed‘. Please use 'http_compress' instead.", logger=self.logger) + console.warn("You set the deprecated client option 'compressed'. Please use 'http_compress' instead.", logger=self.logger) self.client_options["http_compress"] = self.client_options.pop("compressed") if self._is_set(self.client_options, "http_compress"): @@ -251,12 +254,16 @@ def parse_aws_log_in_params(self): aws_log_in_dict["aws_secret_access_key"] = os.environ.get("OSB_AWS_SECRET_ACCESS_KEY") aws_log_in_dict["region"] = os.environ.get("OSB_REGION") aws_log_in_dict["service"] = os.environ.get("OSB_SERVICE") + # optional: applicable only for role-based access + aws_log_in_dict["aws_session_token"] = os.environ.get("OSB_AWS_SESSION_TOKEN") # aws log in : option 2) parameters are passed in from command line elif self.client_options["amazon_aws_log_in"] == "client_option": aws_log_in_dict["aws_access_key_id"] = self.client_options.get("aws_access_key_id") aws_log_in_dict["aws_secret_access_key"] = self.client_options.get("aws_secret_access_key") aws_log_in_dict["region"] = self.client_options.get("region") aws_log_in_dict["service"] = self.client_options.get("service") + # optional: applicable only for role-based access + aws_log_in_dict["aws_session_token"] = self.client_options.get("aws_session_token") if (not aws_log_in_dict["aws_access_key_id"] or not aws_log_in_dict["aws_secret_access_key"] or not aws_log_in_dict["service"] or not aws_log_in_dict["region"]): self.logger.error("Invalid amazon aws log in parameters, required input aws_access_key_id, " @@ -282,7 +289,8 @@ def create(self): return opensearchpy.OpenSearch(hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options) credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"], - secret_key=self.aws_log_in_dict["aws_secret_access_key"]) + secret_key=self.aws_log_in_dict["aws_secret_access_key"], + token=self.aws_log_in_dict["aws_session_token"]) aws_auth = opensearchpy.AWSV4SignerAuth(credentials, self.aws_log_in_dict["region"], self.aws_log_in_dict["service"]) return opensearchpy.OpenSearch(hosts=self.hosts, use_ssl=True, verify_certs=True, http_auth=aws_auth, @@ -332,7 +340,8 @@ class BenchmarkAsyncOpenSearch(opensearchpy.AsyncOpenSearch, RequestContextHolde **self.client_options) credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"], - secret_key=self.aws_log_in_dict["aws_secret_access_key"]) + secret_key=self.aws_log_in_dict["aws_secret_access_key"], + token=self.aws_log_in_dict["aws_session_token"]) aws_auth = opensearchpy.AWSV4SignerAsyncAuth(credentials, self.aws_log_in_dict["region"], self.aws_log_in_dict["service"]) return BenchmarkAsyncOpenSearch(hosts=self.hosts, diff --git a/osbenchmark/metrics.py b/osbenchmark/metrics.py index a296db0a0..a316b35ca 100644 --- a/osbenchmark/metrics.py +++ b/osbenchmark/metrics.py @@ -188,6 +188,7 @@ def __init__(self, cfg): default_value=None, mandatory=False) metrics_aws_access_key_id = None metrics_aws_secret_access_key = None + metrics_aws_session_token = None metrics_aws_region = None metrics_aws_service = None @@ -196,6 +197,8 @@ def __init__(self, cfg): default_value=None, mandatory=False) metrics_aws_secret_access_key = self._config.opts("results_publishing", "datastore.aws_secret_access_key", default_value=None, mandatory=False) + metrics_aws_session_token = self._config.opts("results_publishing", "datastore.aws_session_token", + default_value=None, mandatory=False) metrics_aws_region = self._config.opts("results_publishing", "datastore.region", default_value=None, mandatory=False) metrics_aws_service = self._config.opts("results_publishing", "datastore.service", @@ -203,6 +206,7 @@ def __init__(self, cfg): elif metrics_amazon_aws_log_in == 'environment': metrics_aws_access_key_id = os.getenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", default=None) metrics_aws_secret_access_key = os.getenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", default=None) + metrics_aws_session_token = os.getenv("OSB_DATASTORE_AWS_SESSION_TOKEN", default=None) metrics_aws_region = os.getenv("OSB_DATASTORE_REGION", default=None) metrics_aws_service = os.getenv("OSB_DATASTORE_SERVICE", default=None) @@ -254,7 +258,8 @@ def __init__(self, cfg): client_options["basic_auth_user"] = user client_options["basic_auth_password"] = password - #add options for aws user login: pass in aws access key id, aws secret access key, service and region on command + # add options for aws user login: + # pass in aws access key id, aws secret access key, aws session token, service and region on command if metrics_amazon_aws_log_in is not None: client_options["amazon_aws_log_in"] = 'client_option' client_options["aws_access_key_id"] = metrics_aws_access_key_id @@ -262,6 +267,9 @@ def __init__(self, cfg): client_options["service"] = metrics_aws_service client_options["region"] = metrics_aws_region + if metrics_aws_session_token: + client_options["aws_session_token"] = metrics_aws_session_token + factory = client.OsClientFactory(hosts=[{"host": host, "port": port}], client_options=client_options) self._client = factory.create() diff --git a/tests/client_test.py b/tests/client_test.py index 33de0dfd1..12f8a3f49 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -249,7 +249,7 @@ def test_create_https_connection_unverified_certificate(self, mocked_load_cert_c @mock.patch.object(ssl.SSLContext, "load_cert_chain") def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain): hosts = [{"host": "localhost", "port": 9200}] - client_options = { + user_based_client_options = { "use_ssl": True, "timeout": 120, "amazon_aws_log_in": 'client_option', @@ -259,34 +259,50 @@ def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain): "region": "us-east-1", "verify_certs": True } - # make a copy so we can verify later that the factory did not modify it - original_client_options = dict(client_options) + + role_based_client_options = dict(user_based_client_options) + role_based_client_options["aws_session_token"] = "dummy_token" + + client_options_list = [ + user_based_client_options, + role_based_client_options + ] logger = logging.getLogger("osbenchmark.client") - with mock.patch.object(logger, "info") as mocked_info_logger: - f = client.OsClientFactory(hosts, client_options) - mocked_info_logger.assert_has_calls([ - mock.call("SSL support: on"), - mock.call("SSL certificate verification: on"), - mock.call("SSL client authentication: off") - ]) - assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \ - "client certs" + for client_options in client_options_list: + # make a copy so we can verify later that the factory did not modify it + original_client_options = dict(client_options) - self.assertEqual(hosts, f.hosts) - self.assertTrue(f.ssl_context.check_hostname) - self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode) + with mock.patch.object(logger, "info") as mocked_info_logger: + f = client.OsClientFactory(hosts, client_options) - self.assertEqual("https", f.client_options["scheme"]) - self.assertIn("timeout", f.client_options) - self.assertIn("aws_access_key_id", f.client_options) - self.assertIn("aws_secret_access_key", f.client_options) - self.assertIn("amazon_aws_log_in", f.client_options) - self.assertIn("service", f.client_options) - self.assertIn("region", f.client_options) + mocked_info_logger.assert_has_calls([ + mock.call("SSL support: on"), + mock.call("SSL certificate verification: on"), + mock.call("SSL client authentication: off") + ]) + + assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \ + "client certs" + + self.assertEqual(hosts, f.hosts) + self.assertTrue(f.ssl_context.check_hostname) + self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode) + + self.assertEqual("https", f.client_options["scheme"]) + self.assertIn("timeout", f.client_options) + self.assertIn("aws_access_key_id", f.client_options) + self.assertIn("aws_secret_access_key", f.client_options) + self.assertIn("amazon_aws_log_in", f.client_options) + self.assertIn("service", f.client_options) + self.assertIn("region", f.client_options) + + if "aws_session_token" in original_client_options: + self.assertIn("aws_session_token", f.client_options) + + self.assertDictEqual(original_client_options, client_options) - self.assertDictEqual(original_client_options, client_options) @mock.patch.object(ssl.SSLContext, "load_cert_chain") def test_create_https_connection_unverified_certificate_present_client_certificates(self, mocked_load_cert_chain): diff --git a/tests/metrics_test.py b/tests/metrics_test.py index 0fa190775..6ee80e2d6 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -42,6 +42,9 @@ from osbenchmark.metrics import GlobalStatsCalculator from osbenchmark.workload import Task, Operation, TestProcedure, Workload +AWS_ACCESS_KEY_ID_LENGTH = 12 +AWS_SECRET_ACCESS_KEY_LENGTH = 40 +AWS_SESSION_TOKEN_LENGTH = 752 class MockClientFactory: def __init__(self, cfg): @@ -239,24 +242,33 @@ def test_config_opts_parsing_aws_creds_with_env(self, client_OsClientfactory): } self.config_opts_parsing_aws_creds("environment", override_datastore=override_config) - # verify config parsing is successful when all required parameters are present - config_opts = self.config_opts_parsing_aws_creds("environment") - expected_client_options = { - "use_ssl": True, - "timeout": 120, - "amazon_aws_log_in": 'client_option', - "aws_access_key_id": config_opts["_datastore_aws_access_key_id"], - "aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"], - "service": config_opts["_datastore_aws_service"], - "region": config_opts["_datastore_aws_region"], - "verify_certs": config_opts["_datastore_verify_certs"] - } - client_OsClientfactory.assert_called_with( - hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}], - client_options=expected_client_options - ) + # validate client_options when session_token is passed + enable_role_access = [False, True] + for role_based in enable_role_access: + # verify config parsing is successful when all required parameters are present + config_opts = self.config_opts_parsing_aws_creds("environment", role_based=role_based) + + expected_client_options = { + "use_ssl": True, + "timeout": 120, + "amazon_aws_log_in": 'client_option', + "aws_access_key_id": config_opts["_datastore_aws_access_key_id"], + "aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"], + "service": config_opts["_datastore_aws_service"], + "region": config_opts["_datastore_aws_region"], + "verify_certs": config_opts["_datastore_verify_certs"] + } + + if role_based: + expected_client_options["aws_session_token"] = config_opts["_datastore_aws_session_token"] + + client_OsClientfactory.assert_called_with( + hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}], + client_options=expected_client_options + ) + def config_opts_parsing(self, password_configuration): cfg = config.Config() @@ -302,7 +314,7 @@ def config_opts_parsing(self, password_configuration): "_datastore_verify_certs": _datastore_verify_certs } - def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None): + def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None, role_based=False): if override_datastore is None: override_datastore = {} cfg = config.Config() @@ -314,11 +326,14 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore _datastore_password = "" _datastore_verify_certs = random.choice([True, False]) _datastore_amazon_aws_log_in = configuration_source - _datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(12)]) - _datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(40)]) + _datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(AWS_ACCESS_KEY_ID_LENGTH)]) + _datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(AWS_SECRET_ACCESS_KEY_LENGTH)]) _datastore_aws_service = random.choice(['es', 'aoss']) _datastore_aws_region = random.choice(['us-east-1', 'eu-west-1']) + # optional + _datastore_aws_session_token = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(AWS_SESSION_TOKEN_LENGTH)]) + cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.host", _datastore_host) cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.port", _datastore_port) cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.secure", _datastore_secure) @@ -334,12 +349,17 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore _datastore_aws_secret_access_key) cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.service", _datastore_aws_service) cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.region", _datastore_aws_region) + if role_based: + cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.aws_session_token", _datastore_aws_session_token) elif _datastore_amazon_aws_log_in == 'environment': monkeypatch = pytest.MonkeyPatch() monkeypatch.setenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", _datastore_aws_access_key_id) monkeypatch.setenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", _datastore_aws_secret_access_key) monkeypatch.setenv("OSB_DATASTORE_SERVICE", _datastore_aws_service) monkeypatch.setenv("OSB_DATASTORE_REGION", _datastore_aws_region) + if role_based: + monkeypatch.setenv("OSB_DATASTORE_AWS_SESSION_TOKEN", _datastore_aws_session_token) + if not _datastore_verify_certs: cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.ssl.verification_mode", "none") @@ -375,7 +395,7 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore assert e.message == missing_aws_credentials_message return - return { + response = { "_datastore_user": _datastore_user, "_datastore_host": _datastore_host, "_datastore_password": _datastore_password, @@ -387,6 +407,11 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore "_datastore_aws_region": _datastore_aws_region } + if role_based: + response["_datastore_aws_session_token"] = _datastore_aws_session_token + + return response + def test_raises_sytem_setup_error_on_connection_problems(self): def raise_connection_error(): raise opensearchpy.exceptions.ConnectionError("unit-test")