From 7c61d11d9328e73c6ea3a2ef810daddb0cb3fe38 Mon Sep 17 00:00:00 2001 From: Massimo Battestini Date: Thu, 26 Oct 2023 10:30:28 +0100 Subject: [PATCH] Adds unit tests for AWS profile change (#86) Signed-off-by: Massimo Battestini --- provider/provider.go | 15 +- provider/provider_test.go | 221 ++++++++++++++---- .../api_assume_role_response.xml | 25 ++ provider/test-fixtures/test_aws_config | 3 - provider/test-fixtures/test_aws_credentials | 7 + 5 files changed, 216 insertions(+), 55 deletions(-) create mode 100644 provider/test-fixtures/api_assume_role_response.xml delete mode 100644 provider/test-fixtures/test_aws_config create mode 100644 provider/test-fixtures/test_aws_credentials diff --git a/provider/provider.go b/provider/provider.go index 53288bc..033a427 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -395,8 +395,8 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) { return client, nil } -func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *awscredentials.Credentials { - sessOpts := awsSessionOptions(region) +func assumeRoleCredentials(region, roleARN, roleExternalID, profile string, endpoint string) *awscredentials.Credentials { + sessOpts := awsSessionOptions(region, endpoint) if profile != "" { sessOpts.Profile = profile } @@ -415,7 +415,7 @@ func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *aws return awscredentials.NewChainCredentials([]awscredentials.Provider{assumeRoleProvider}) } -func awsSessionOptions(region string) awssession.Options { +func awsSessionOptions(region string, endpoint string) awssession.Options { return awssession.Options{ Config: aws.Config{ Region: aws.String(region), @@ -430,13 +430,14 @@ func awsSessionOptions(region string) awssession.Options { // it fail with Credential error // https://github.com/aws/aws-sdk-go/issues/2914 HTTPClient: &http.Client{Timeout: 10 * time.Second}, + Endpoint: aws.String(endpoint), }, SharedConfigState: awssession.SharedConfigEnable, } } -func awsSession(region string, conf *ProviderConf) *awssession.Session { - sessOpts := awsSessionOptions(region) +func awsSession(region string, conf *ProviderConf, endpoint string) *awssession.Session { + sessOpts := awsSessionOptions(region, endpoint) // 1. access keys take priority // 2. next is an assume role configuration @@ -450,7 +451,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session { if conf.awsAssumeRoleExternalID == "" { conf.awsAssumeRoleExternalID = "" } - sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile) + sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile, endpoint) } else if conf.awsProfile != "" { sessOpts.Profile = conf.awsProfile } @@ -473,7 +474,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session { } func awsHttpClient(region string, conf *ProviderConf, headers map[string]string) (*http.Client, error) { - session := awsSession(region, conf) + session := awsSession(region, conf, "") // Call Get() to ensure concurrency safe retrieval of credentials. Since the // client is created in many go routines, this synchronizes it. _, err := session.Config.Credentials.Get() diff --git a/provider/provider_test.go b/provider/provider_test.go index dae0016..865a31e 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -2,7 +2,10 @@ package provider import ( "context" + "net/http" + "net/http/httptest" "os" + "strings" "testing" "github.com/aws/aws-sdk-go/aws/credentials" @@ -82,13 +85,13 @@ func TestAWSCredsManualKey(t *testing.T) { os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET") // first, check that if we set aws_profile with aws_access_key_id - the latter takes precedence - testConfig := map[string]interface{}{ - "aws_profile": namedProfile, - "aws_access_key": manualAccessKeyID, - "aws_secret_key": "MANUAL_SECRET_KEY", + testConfig := &ProviderConf{ + awsAccessKeyId: manualAccessKeyID, + awsSecretAccessKey: "MANUAL_SECRET_KEY", + awsProfile: namedProfile, } - creds := getCreds(t, testRegion, testConfig) + creds := getCreds(t, testRegion, testConfig, "") if creds.AccessKeyID != manualAccessKeyID { t.Errorf("access key id should have been %s (we got %s)", manualAccessKeyID, creds.AccessKeyID) @@ -106,16 +109,16 @@ func TestAWSCredsNamedProfile(t *testing.T) { namedProfile := "testing" profileAccessKeyID := "PROFILE_ACCESS_KEY" - os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists os.Setenv("AWS_SDK_LOAD_CONFIG", "1") os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID) os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET") - testConfig := map[string]interface{}{ - "aws_profile": namedProfile, + testConfig := &ProviderConf{ + awsProfile: namedProfile, } - creds := getCreds(t, testRegion, testConfig) + creds := getCreds(t, testRegion, testConfig, "") if creds.AccessKeyID != profileAccessKeyID { t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID) @@ -123,7 +126,7 @@ func TestAWSCredsNamedProfile(t *testing.T) { os.Unsetenv("AWS_ACCESS_KEY_ID") os.Unsetenv("AWS_SECRET_ACCESS_KEY") - os.Unsetenv("AWS_CONFIG_FILE") + os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") os.Unsetenv("AWS_SDK_LOAD_CONFIG") } @@ -132,7 +135,6 @@ func TestAWSCredsNamedProfile(t *testing.T) { // 2. No configuration provided to the provider // // This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain) - func TestAWSCredsEnv(t *testing.T) { envAccessKeyID := "ENV_ACCESS_KEY" testRegion := "us-east-1" @@ -140,9 +142,9 @@ func TestAWSCredsEnv(t *testing.T) { os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID) os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET") - testConfig := map[string]interface{}{} + testConfig := &ProviderConf{} - creds := getCreds(t, testRegion, testConfig) + creds := getCreds(t, testRegion, testConfig, "") if creds.AccessKeyID != envAccessKeyID { t.Errorf("access key id should have been %s (we got %s)", envAccessKeyID, creds.AccessKeyID) @@ -152,6 +154,11 @@ func TestAWSCredsEnv(t *testing.T) { os.Unsetenv("AWS_SECRET_ACCESS_KEY") } +// Given: +// 1. AWS profile is specified via environment variables +// 2. No configuration provided to the provider +// +// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain) func TestAWSCredsEnvNamedProfile(t *testing.T) { namedProfile := "testing" testRegion := "us-east-1" @@ -159,65 +166,140 @@ func TestAWSCredsEnvNamedProfile(t *testing.T) { os.Setenv("AWS_PROFILE", namedProfile) os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists - testConfig := map[string]interface{}{} + testConfig := &ProviderConf{} - creds := getCreds(t, testRegion, testConfig) + creds := getCreds(t, testRegion, testConfig, "") if creds.AccessKeyID != profileAccessKeyID { t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID) } os.Unsetenv("AWS_PROFILE") - os.Unsetenv("AWS_CONFIG_FILE") + os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") os.Unsetenv("AWS_SDK_LOAD_CONFIG") } // Given: -// 1. An AWS role ARN is specified -// 2. No additional AWS configuration is provided to the provider +// 1. AWS credentials are specified via environment variables +// 2. An AWS role ARN and External ID are specified via the provider configuration // -// This tests that: we can safely generate a session. Note we cannot get the credentials, because that requires connecting to AWS +// This tests that: we can get the credentials after having assumed the given role from the specified AWS credentials. func TestAWSCredsAssumeRole(t *testing.T) { + envAccessKeyID := "ENV_ACCESS_KEY" testRegion := "us-east-1" + assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR" + assumeRoleExternalId := "secret_id" + assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE" + + os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID) + os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET") - testConfig := map[string]interface{}{ - "aws_assume_role_arn": "test_arn", - "aws_assume_role_external_id": "secret_id", + server := mockServer{ + ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml", + ExpectedAccessKeyId: envAccessKeyID, + ExpectedRoleArn: assumeRoleArn, + ExpectedExternalId: assumeRoleExternalId, } - testConfigData := schema.TestResourceDataRaw(t, Provider().Schema, testConfig) + server.Start(t) + defer server.Stop() - conf := &ProviderConf{ - awsAssumeRoleArn: testConfigData.Get("aws_assume_role_arn").(string), - awsAssumeRoleExternalID: testConfigData.Get("aws_assume_role_external_id").(string), + testConfig := &ProviderConf{ + awsAssumeRoleArn: assumeRoleArn, + awsAssumeRoleExternalID: assumeRoleExternalId, } - s := awsSession(testRegion, conf) - if s == nil { - t.Fatalf("awsSession returned nil") + + creds := getCreds(t, testRegion, testConfig, server.Endpoint) + + if creds.AccessKeyID != assumeRoleAccessKeyID { + t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID) } + + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") } -func getCreds(t *testing.T, region string, config map[string]interface{}) credentials.Value { - awsAccessKey := "" - awsSecretKey := "" - awsProfile := "" - if val, ok := config["aws_access_key"]; ok { - awsAccessKey = val.(string) +// Given: +// 1. An AWS profile, role ARN and External ID are specified via the provider configuration +// +// This tests that: we can get the credentials after having assumed the given role from the specified profile. +func TestAWSCredsAssumeRoleFromProfile(t *testing.T) { + testRegion := "us-east-1" + assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR" + assumeRoleExternalId := "secret_id" + namedProfile := "testing" + assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE" + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists + + server := mockServer{ + ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml", + ExpectedAccessKeyId: "PROFILE_ACCESS_KEY", // from the test-fixture config file + ExpectedRoleArn: assumeRoleArn, + ExpectedExternalId: assumeRoleExternalId, } - if val, ok := config["aws_secret_key"]; ok { - awsSecretKey = val.(string) + + server.Start(t) + defer server.Stop() + + testConfig := &ProviderConf{ + awsAssumeRoleArn: assumeRoleArn, + awsAssumeRoleExternalID: assumeRoleExternalId, + awsProfile: namedProfile, } - if val, ok := config["aws_profile"]; ok { - awsProfile = val.(string) + + creds := getCreds(t, testRegion, testConfig, server.Endpoint) + + if creds.AccessKeyID != assumeRoleAccessKeyID { + t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID) } - conf := &ProviderConf{ - awsAccessKeyId: awsAccessKey, - awsSecretAccessKey: awsSecretKey, - awsProfile: awsProfile, + os.Unsetenv("AWS_SDK_LOAD_CONFIG") + os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") +} + +// Given: +// 1. An AWS role ARN and External ID are specified via the provider configuration +// +// This tests that: we can get the credentials after having assumed the given role from the default profile. +func TestAWSCredsAssumeRoleFromDefaultProfile(t *testing.T) { + testRegion := "us-east-1" + assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR" + assumeRoleExternalId := "secret_id" + assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE" + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists + + server := mockServer{ + ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml", + ExpectedAccessKeyId: "PROFILE_DEFAULT_ACCESS_KEY", // from the test-fixture config file + ExpectedRoleArn: assumeRoleArn, + ExpectedExternalId: assumeRoleExternalId, + } + + server.Start(t) + defer server.Stop() + + testConfig := &ProviderConf{ + awsAssumeRoleArn: assumeRoleArn, + awsAssumeRoleExternalID: assumeRoleExternalId, } - s := awsSession(region, conf) + + creds := getCreds(t, testRegion, testConfig, server.Endpoint) + + if creds.AccessKeyID != assumeRoleAccessKeyID { + t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID) + } + + os.Unsetenv("AWS_SDK_LOAD_CONFIG") + os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") +} + +func getCreds(t *testing.T, region string, config *ProviderConf, endpoint string) credentials.Value { + s := awsSession(region, config, endpoint) if s == nil { t.Fatalf("awsSession returned nil") } @@ -227,3 +309,52 @@ func getCreds(t *testing.T, region string, config map[string]interface{}) creden } return creds } + +type mockServer struct { + ResponseFixturePath string + ExpectedAccessKeyId string + ExpectedRoleArn string + ExpectedExternalId string + Endpoint string + server *httptest.Server +} + +func (s *mockServer) Start(t *testing.T) { + s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + auth := r.Header.Get("Authorization") + if !strings.Contains(auth, s.ExpectedAccessKeyId) { + t.Errorf("Could not find expected access key id %s in authorization header %s", s.ExpectedAccessKeyId, auth) + } + + err := r.ParseForm() + if err != nil { + t.Errorf("Error while parsing form: %v", err) + } + + if r.PostForm.Get("RoleArn") != s.ExpectedRoleArn { + t.Errorf("expected RoleArn to be equal to %s, but got %s", s.ExpectedRoleArn, r.PostForm.Get("RoleArn")) + } + + if r.PostForm.Get("ExternalId") != s.ExpectedExternalId { + t.Errorf("expected ExternalId to be equal to %s, but got %s", s.ExpectedExternalId, r.PostForm.Get("ExternalId")) + } + + response, err := os.ReadFile(s.ResponseFixturePath) + if err != nil { + t.Errorf("Error while reading mockResponse %v", err) + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(response) + if err != nil { + t.Errorf("Error while writing mock server response %v", err) + } + })) + + s.Endpoint = s.server.URL +} + +func (s *mockServer) Stop() { + s.server.Close() +} diff --git a/provider/test-fixtures/api_assume_role_response.xml b/provider/test-fixtures/api_assume_role_response.xml new file mode 100644 index 0000000..4cd2930 --- /dev/null +++ b/provider/test-fixtures/api_assume_role_response.xml @@ -0,0 +1,25 @@ + + + Alice + + arn:aws:sts::123456789012:assumed-role/demo/TestAR + ARO123EXAMPLE123:TestAR + + + ASIAIOSFODNN7EXAMPLE + wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + + AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW + LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd + QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU + 9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz + +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA== + + 2019-11-09T13:34:41Z + + 6 + + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + + \ No newline at end of file diff --git a/provider/test-fixtures/test_aws_config b/provider/test-fixtures/test_aws_config deleted file mode 100644 index 1020481..0000000 --- a/provider/test-fixtures/test_aws_config +++ /dev/null @@ -1,3 +0,0 @@ -[profile testing] -aws_access_key_id = PROFILE_ACCESS_KEY -aws_secret_access_key = PROFILE_SECRET_KEY diff --git a/provider/test-fixtures/test_aws_credentials b/provider/test-fixtures/test_aws_credentials new file mode 100644 index 0000000..88b33b7 --- /dev/null +++ b/provider/test-fixtures/test_aws_credentials @@ -0,0 +1,7 @@ +[testing] +aws_access_key_id = PROFILE_ACCESS_KEY +aws_secret_access_key = PROFILE_SECRET_KEY + +[default] +aws_access_key_id = PROFILE_DEFAULT_ACCESS_KEY +aws_secret_access_key = PROFILE_DEFAULT_SECRET_KEY