Skip to content

Commit

Permalink
Adds unit tests for AWS profile change (#86)
Browse files Browse the repository at this point in the history
Signed-off-by: Massimo Battestini <[email protected]>
  • Loading branch information
massimob76 committed Oct 26, 2023
1 parent b113bb3 commit e1b7494
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 55 deletions.
15 changes: 8 additions & 7 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) {
return client, nil
}

func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region)
func assumeRoleCredentials(region, roleARN, roleExternalID, profile string, endpoint string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region, endpoint)
if profile != "" {
sessOpts.Profile = profile
}
Expand All @@ -415,7 +415,7 @@ func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *aws
return awscredentials.NewChainCredentials([]awscredentials.Provider{assumeRoleProvider})
}

func awsSessionOptions(region string) awssession.Options {
func awsSessionOptions(region string, endpoint string) awssession.Options {
return awssession.Options{
Config: aws.Config{
Region: aws.String(region),
Expand All @@ -430,13 +430,14 @@ func awsSessionOptions(region string) awssession.Options {
// it fail with Credential error
// https://github.com/aws/aws-sdk-go/issues/2914
HTTPClient: &http.Client{Timeout: 10 * time.Second},
Endpoint: aws.String(endpoint),
},
SharedConfigState: awssession.SharedConfigEnable,
}
}

func awsSession(region string, conf *ProviderConf) *awssession.Session {
sessOpts := awsSessionOptions(region)
func awsSession(region string, conf *ProviderConf, endpoint string) *awssession.Session {
sessOpts := awsSessionOptions(region, endpoint)

// 1. access keys take priority
// 2. next is an assume role configuration
Expand All @@ -450,7 +451,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
if conf.awsAssumeRoleExternalID == "" {
conf.awsAssumeRoleExternalID = ""
}
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile)
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile, endpoint)
} else if conf.awsProfile != "" {
sessOpts.Profile = conf.awsProfile
}
Expand All @@ -473,7 +474,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
}

func awsHttpClient(region string, conf *ProviderConf, headers map[string]string) (*http.Client, error) {
session := awsSession(region, conf)
session := awsSession(region, conf, "")
// Call Get() to ensure concurrency safe retrieval of credentials. Since the
// client is created in many go routines, this synchronizes it.
_, err := session.Config.Credentials.Get()
Expand Down
218 changes: 173 additions & 45 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package provider

import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -82,13 +85,13 @@ func TestAWSCredsManualKey(t *testing.T) {
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

// first, check that if we set aws_profile with aws_access_key_id - the latter takes precedence
testConfig := map[string]interface{}{
"aws_profile": namedProfile,
"aws_access_key": manualAccessKeyID,
"aws_secret_key": "MANUAL_SECRET_KEY",
testConfig := &ProviderConf{
awsAccessKeyId: manualAccessKeyID,
awsSecretAccessKey: "MANUAL_SECRET_KEY",
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != manualAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", manualAccessKeyID, creds.AccessKeyID)
Expand All @@ -106,24 +109,24 @@ func TestAWSCredsNamedProfile(t *testing.T) {
namedProfile := "testing"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{
"aws_profile": namedProfile,
testConfig := &ProviderConf{
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

Expand All @@ -132,17 +135,16 @@ func TestAWSCredsNamedProfile(t *testing.T) {
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)

func TestAWSCredsEnv(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"

os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != envAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", envAccessKeyID, creds.AccessKeyID)
Expand All @@ -152,72 +154,152 @@ func TestAWSCredsEnv(t *testing.T) {
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

// Given:
// 1. AWS profile is specified via environment variables
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)
func TestAWSCredsEnvNamedProfile(t *testing.T) {
namedProfile := "testing"
testRegion := "us-east-1"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_PROFILE", namedProfile)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}
os.Unsetenv("AWS_PROFILE")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

// Given:
// 1. An AWS role ARN is specified
// 2. No additional AWS configuration is provided to the provider
// 1. AWS credentials are specified via environment variables
// 2. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can safely generate a session. Note we cannot get the credentials, because that requires connecting to AWS
// This tests that: we can get the credentials after having assumed the given role from the specified AWS credentials.
func TestAWSCredsAssumeRole(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

testConfig := map[string]interface{}{
"aws_assume_role_arn": "test_arn",
"aws_assume_role_external_id": "secret_id",
os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: envAccessKeyID,
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

testConfigData := schema.TestResourceDataRaw(t, Provider().Schema, testConfig)
server.Start(t)
defer server.Stop()

conf := &ProviderConf{
awsAssumeRoleArn: testConfigData.Get("aws_assume_role_arn").(string),
awsAssumeRoleExternalID: testConfigData.Get("aws_assume_role_external_id").(string),
testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}
s := awsSession(testRegion, conf)
if s == nil {
t.Fatalf("awsSession returned nil")

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

func getCreds(t *testing.T, region string, config map[string]interface{}) credentials.Value {
awsAccessKey := ""
awsSecretKey := ""
awsProfile := ""
if val, ok := config["aws_access_key"]; ok {
awsAccessKey = val.(string)
// Given:
// 1. An AWS profile, role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the specified profile.
func TestAWSCredsAssumeRoleFromProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
namedProfile := "testing"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}
if val, ok := config["aws_secret_key"]; ok {
awsSecretKey = val.(string)

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
awsProfile: namedProfile,
}
if val, ok := config["aws_profile"]; ok {
awsProfile = val.(string)

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

conf := &ProviderConf{
awsAccessKeyId: awsAccessKey,
awsSecretAccessKey: awsSecretKey,
awsProfile: awsProfile,
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

// Given:
// 1. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the default profile.
func TestAWSCredsAssumeRoleFromDefaultProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_DEFAULT_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}
s := awsSession(region, conf)

os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

func getCreds(t *testing.T, region string, config *ProviderConf, endpoint string) credentials.Value {
s := awsSession(region, config, endpoint)
if s == nil {
t.Fatalf("awsSession returned nil")
}
Expand All @@ -227,3 +309,49 @@ func getCreds(t *testing.T, region string, config map[string]interface{}) creden
}
return creds
}

type mockServer struct {
ResponseFixturePath string
ExpectedAccessKeyId string
ExpectedRoleArn string
ExpectedExternalId string
Endpoint string
server *httptest.Server
}

func (s *mockServer) Start(t *testing.T) {
s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

auth := r.Header.Get("Authorization")
if !strings.Contains(auth, s.ExpectedAccessKeyId) {
t.Errorf("Could not find expected access key id %s in authorization header %s", s.ExpectedAccessKeyId, auth)
}

err := r.ParseForm()
if err != nil {
t.Errorf("Error while parsing form: %v", err)
}

if r.PostForm.Get("RoleArn") != s.ExpectedRoleArn {
t.Errorf("expected RoleArn to be equal to %s, but got %s", s.ExpectedRoleArn, r.PostForm.Get("RoleArn"))
}

if r.PostForm.Get("ExternalId") != s.ExpectedExternalId {
t.Errorf("expected ExternalId to be equal to %s, but got %s", s.ExpectedExternalId, r.PostForm.Get("ExternalId"))
}

response, err := os.ReadFile(s.ResponseFixturePath)
if err != nil {
t.Errorf("Error while reading mockResponse %v", err)
}

w.WriteHeader(http.StatusOK)
w.Write(response)
}))

s.Endpoint = s.server.URL
}

func (s *mockServer) Stop() {
s.server.Close()
}
25 changes: 25 additions & 0 deletions provider/test-fixtures/api_assume_role_response.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<SourceIdentity>Alice</SourceIdentity>
<AssumedRoleUser>
<Arn>arn:aws:sts::123456789012:assumed-role/demo/TestAR</Arn>
<AssumedRoleId>ARO123EXAMPLE123:TestAR</AssumedRoleId>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>ASIAIOSFODNN7EXAMPLE</AccessKeyId>
<SecretAccessKey>wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY</SecretAccessKey>
<SessionToken>
AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW
LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd
QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU
9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz
+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==
</SessionToken>
<Expiration>2019-11-09T13:34:41Z</Expiration>
</Credentials>
<PackedPolicySize>6</PackedPolicySize>
</AssumeRoleResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>
Loading

0 comments on commit e1b7494

Please sign in to comment.