diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 8456633742..a4cbcf4415 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -249,8 +249,8 @@ const volumeAttachmentStuck = "VolumeAttachmentStuck" // Indicates that a node has volumes stuck in attaching state and hence it is not fit for scheduling more pods const nodeWithImpairedVolumes = "NodeWithImpairedVolumes" -const sourceKey = "x-amz-source-arn" -const accountKey = "x-amz-source-account" +const headerSourceArn = "x-amz-source-arn" +const headerSourceAccount = "x-amz-source-account" const ( // volumeAttachmentConsecutiveErrorLimit is the number of consecutive errors we will ignore when waiting for a volume to attach/detach @@ -617,8 +617,10 @@ type CloudConfig struct { // RoleARN is the IAM role to assume when interaction with AWS APIs. RoleARN string - // SourceARN is value which is passed while assuming role using RoleARN and used for - // restricting the access. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + // SourceARN is value which is passed while assuming role specified by RoleARN. When a service + // assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global + // condition context keys in your role trust policy to limit access to the role to only requests that are generated + // by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html SourceARN string // KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources @@ -1266,23 +1268,10 @@ func init() { var creds *credentials.Credentials if cfg.Global.RoleARN != "" { - klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN) - stsClient := sts.New(sess) - sourceAcct, err := GetSourceAcct(cfg.Global.RoleARN) + stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN) if err != nil { - return nil, err - } - reqHeaders := map[string]string{ - accountKey: sourceAcct, - } - if cfg.Global.SourceARN != "" { - reqHeaders[sourceKey] = cfg.Global.SourceARN + return nil, fmt.Errorf("unable to create sts client, %v", err) } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - klog.Infof("configuring STS client with extra headers") - creds = credentials.NewChainCredentials( []credentials.Provider{ &credentials.EnvProvider{}, @@ -1298,6 +1287,26 @@ func init() { }) } +func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) { + klog.Infof("Using AWS assumed role %v", roleARN) + stsClient := sts.New(sess) + sourceAcct, err := GetSourceAccount(roleARN) + if err != nil { + return nil, err + } + reqHeaders := map[string]string{ + headerSourceAccount: sourceAcct, + } + if sourceARN != "" { + reqHeaders[headerSourceArn] = sourceARN + } + stsClient.Handlers.Sign.PushFront(func(s *request.Request) { + s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) + }) + klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders) + return stsClient, nil +} + // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) { var cfg CloudConfig diff --git a/pkg/providers/v1/aws_utils.go b/pkg/providers/v1/aws_utils.go index 0b2b83b7ad..621731ed1c 100644 --- a/pkg/providers/v1/aws_utils.go +++ b/pkg/providers/v1/aws_utils.go @@ -47,8 +47,8 @@ func stringSetFromPointers(in []*string) sets.String { return out } -// GetSourceAcct constructs source acct and return them for use -func GetSourceAcct(roleARN string) (string, error) { +// GetSourceAccount constructs source acct and return them for use +func GetSourceAccount(roleARN string) (string, error) { // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) // arn:partition:service:region:account-id:resource-type/resource-id // IAM format, region is always blank diff --git a/pkg/providers/v1/aws_utils_test.go b/pkg/providers/v1/aws_utils_test.go index 3a3925b076..5dfe9460d6 100644 --- a/pkg/providers/v1/aws_utils_test.go +++ b/pkg/providers/v1/aws_utils_test.go @@ -47,13 +47,13 @@ func TestGetSourceAcctAndArn(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetSourceAcct(tt.args.roleARN) + got, err := GetSourceAccount(tt.args.roleARN) if (err != nil) != tt.wantErr { - t.Errorf("GetSourceAcct() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { - t.Errorf("GetSourceAcct() got = %v, want %v", got, tt.want) + t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) } }) }