Skip to content

Commit

Permalink
Merge pull request crossplane-contrib#1003 from ulucinar/optimize-sts…
Browse files Browse the repository at this point in the history
…-calls

refactor external client to make single getAWSConfig call per connect
  • Loading branch information
erhancagirici authored Dec 8, 2023
2 parents ee4e51b + 5672e7c commit b8b5ae3
Showing 1 changed file with 21 additions and 44 deletions.
65 changes: 21 additions & 44 deletions internal/clients/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ const (
keyRoleArn = "role_arn"
keySessionName = "session_name"
keyWebIdentityTokenFile = "web_identity_token_file"
keyAssumeRole = "assume_role"
keyTags = "tags"
keyTransitiveTagKeys = "transitive_tag_keys"
keyExternalID = "external_id"
keySkipCredsValidation = "skip_credentials_validation"
keyS3UsePathStyle = "s3_use_path_style"
keySkipMetadataApiCheck = "skip_metadata_api_check"
Expand Down Expand Up @@ -69,9 +65,19 @@ func SelectTerraformSetup(log logging.Logger, config *SetupConfig) terraform.Set
},
Scheduler: config.DefaultScheduler,
}
awsCfg, err := getAWSConfig(ctx, c, mg)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "cannot get aws config")
} else if awsCfg == nil {
return terraform.Setup{}, errors.Wrap(err, "obtained aws config cannot be nil")
}
creds, err := awsCfg.Credentials.Retrieve(ctx)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "failed to retrieve aws credentials from aws config")
}
account := "000000000"
if !pc.Spec.SkipCredsValidation {
account, err = getAccountId(ctx, c, mg)
account, err = getAccountId(ctx, awsCfg, creds)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "cannot get account id")
}
Expand All @@ -81,8 +87,8 @@ func SelectTerraformSetup(log logging.Logger, config *SetupConfig) terraform.Set
keyAccountId: account,
}

if len(pc.Spec.AssumeRoleChain) > 1 || pc.Spec.Endpoint != nil {
err = DefaultTerraformSetupBuilder(ctx, c, mg, pc, &ps)
if len(pc.Spec.AssumeRoleChain) > 0 || pc.Spec.Endpoint != nil {
err = DefaultTerraformSetupBuilder(ctx, pc, &ps, awsCfg, creds)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "cannot build terraform configuration")
}
Expand All @@ -92,7 +98,7 @@ func SelectTerraformSetup(log logging.Logger, config *SetupConfig) terraform.Set
ps.Scheduler = terraform.NewWorkspaceProviderScheduler(log, terraform.WithNativeProviderPath(*config.NativeProviderPath), terraform.WithNativeProviderName("registry.terraform.io/"+*config.NativeProviderSource))
}
} else {
err = pushDownTerraformSetupBuilder(ctx, c, mg, pc, &ps)
err = pushDownTerraformSetupBuilder(ctx, c, pc, &ps, awsCfg)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "cannot build terraform configuration")
}
Expand All @@ -106,16 +112,12 @@ func SelectTerraformSetup(log logging.Logger, config *SetupConfig) terraform.Set
}
}

func pushDownTerraformSetupBuilder(ctx context.Context, c client.Client, mg resource.Managed, pc *v1beta1.ProviderConfig, ps *terraform.Setup) error { //nolint:gocyclo
if len(pc.Spec.AssumeRoleChain) > 1 || pc.Spec.Endpoint != nil {
func pushDownTerraformSetupBuilder(ctx context.Context, c client.Client, pc *v1beta1.ProviderConfig, ps *terraform.Setup, cfg *aws.Config) error { //nolint:gocyclo
if len(pc.Spec.AssumeRoleChain) > 0 || pc.Spec.Endpoint != nil {
return errors.New("shared scheduler cannot be used because the length of assume role chain array " +
"is more than 1 or endpoint configuration is not nil")
"is more than 0 or endpoint configuration is not nil")
}

cfg, err := getAWSConfig(ctx, c, mg)
if err != nil {
return errors.Wrap(err, "cannot get AWS config")
}
ps.Configuration = map[string]any{
keyRegion: cfg.Region,
}
Expand Down Expand Up @@ -163,27 +165,10 @@ func pushDownTerraformSetupBuilder(ctx context.Context, c client.Client, mg reso
keySessionToken: creds.SessionToken,
}
}
if len(pc.Spec.AssumeRoleChain) != 0 {
ps.Configuration[keyAssumeRole] = map[string]any{
keyRoleArn: pc.Spec.AssumeRoleChain[0].RoleARN,
keyTags: pc.Spec.AssumeRoleChain[0].Tags,
keyTransitiveTagKeys: pc.Spec.AssumeRoleChain[0].TransitiveTagKeys,
keyExternalID: pc.Spec.AssumeRoleChain[0].ExternalID,
}
}
return nil
}

func DefaultTerraformSetupBuilder(ctx context.Context, c client.Client, mg resource.Managed, pc *v1beta1.ProviderConfig, ps *terraform.Setup) error {
cfg, err := getAWSConfig(ctx, c, mg)
if err != nil {
return errors.Wrap(err, "cannot get AWS config")
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return errors.Wrap(err, "failed to retrieve aws credentials from aws config")
}

func DefaultTerraformSetupBuilder(_ context.Context, pc *v1beta1.ProviderConfig, ps *terraform.Setup, cfg *aws.Config, creds aws.Credentials) error {
ps.Configuration = map[string]any{
keyRegion: cfg.Region,
keyAccessKeyID: creds.AccessKeyID,
Expand All @@ -199,7 +184,7 @@ func DefaultTerraformSetupBuilder(ctx context.Context, c client.Client, mg resou
if pc.Spec.Endpoint != nil {
if pc.Spec.Endpoint.URL.Static != nil {
if len(pc.Spec.Endpoint.Services) > 0 && *pc.Spec.Endpoint.URL.Static == "" {
return errors.Wrap(err, "endpoint is wrong")
return errors.New("endpoint.url.static cannot be empty")
} else {
endpoints := make(map[string]string)
for _, service := range pc.Spec.Endpoint.Services {
Expand All @@ -209,18 +194,10 @@ func DefaultTerraformSetupBuilder(ctx context.Context, c client.Client, mg resou
}
}
}
return err
return nil
}

func getAccountId(ctx context.Context, c client.Client, mg resource.Managed) (string, error) {
cfg, err := getAWSConfig(ctx, c, mg)
if err != nil {
return "", errors.Wrap(err, "cannot get AWS config")
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to retrieve aws credentials from aws config")
}
func getAccountId(ctx context.Context, cfg *aws.Config, creds aws.Credentials) (string, error) {
identity, err := GlobalCallerIdentityCache.GetCallerIdentity(ctx, *cfg, creds)
if err != nil {
return "", errors.Wrap(err, "cannot get the caller identity")
Expand Down

0 comments on commit b8b5ae3

Please sign in to comment.