Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kmala committed Sep 13, 2023
1 parent 6f120a6 commit ac55dc8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
47 changes: 28 additions & 19 deletions pkg/providers/v1/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ const volumeAttachmentStuck = "VolumeAttachmentStuck"
// Indicates that a node has volumes stuck in attaching state and hence it is not fit for scheduling more pods
const nodeWithImpairedVolumes = "NodeWithImpairedVolumes"

const sourceKey = "x-amz-source-arn"
const accountKey = "x-amz-source-account"
const headerSourceArn = "x-amz-source-arn"
const headerSourceAccount = "x-amz-source-account"

const (
// volumeAttachmentConsecutiveErrorLimit is the number of consecutive errors we will ignore when waiting for a volume to attach/detach
Expand Down Expand Up @@ -617,8 +617,10 @@ type CloudConfig struct {

// RoleARN is the IAM role to assume when interaction with AWS APIs.
RoleARN string
// SourceARN is value which is passed while assuming role using RoleARN and used for
// restricting the access. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
// SourceARN is value which is passed while assuming role specified by RoleARN. When a service
// assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global
// condition context keys in your role trust policy to limit access to the role to only requests that are generated
// by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
SourceARN string

// KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources
Expand Down Expand Up @@ -1266,23 +1268,10 @@ func init() {

var creds *credentials.Credentials
if cfg.Global.RoleARN != "" {
klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN)
stsClient := sts.New(sess)
sourceAcct, err := GetSourceAcct(cfg.Global.RoleARN)
stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN)
if err != nil {
return nil, err
}
reqHeaders := map[string]string{
accountKey: sourceAcct,
}
if cfg.Global.SourceARN != "" {
reqHeaders[sourceKey] = cfg.Global.SourceARN
return nil, fmt.Errorf("unable to create sts client, %v", err)
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
klog.Infof("configuring STS client with extra headers")

creds = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
Expand All @@ -1298,6 +1287,26 @@ func init() {
})
}

func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) {
klog.Infof("Using AWS assumed role %v", roleARN)
stsClient := sts.New(sess)
sourceAcct, err := GetSourceAccount(roleARN)
if err != nil {
return nil, err
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
}
if sourceARN != "" {
reqHeaders[headerSourceArn] = sourceARN
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders)
return stsClient, nil
}

// readAWSCloudConfig reads an instance of AWSCloudConfig from config reader.
func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) {
var cfg CloudConfig
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/v1/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func stringSetFromPointers(in []*string) sets.String {
return out
}

// GetSourceAcct constructs source acct and return them for use
func GetSourceAcct(roleARN string) (string, error) {
// GetSourceAccount constructs source acct and return them for use
func GetSourceAccount(roleARN string) (string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
Expand Down
6 changes: 3 additions & 3 deletions pkg/providers/v1/aws_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ func TestGetSourceAcctAndArn(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GetSourceAcct(tt.args.roleARN)
got, err := GetSourceAccount(tt.args.roleARN)
if (err != nil) != tt.wantErr {
t.Errorf("GetSourceAcct() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetSourceAcct() got = %v, want %v", got, tt.want)
t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want)
}
})
}
Expand Down

0 comments on commit ac55dc8

Please sign in to comment.