Skip to content

Commit

Permalink
get source arn from the config
Browse files Browse the repository at this point in the history
  • Loading branch information
kmala committed Sep 13, 2023
1 parent 4f33db3 commit 619334a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 51 deletions.
29 changes: 17 additions & 12 deletions pkg/providers/v1/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ type CloudConfig struct {

// RoleARN is the IAM role to assume when interaction with AWS APIs.
RoleARN string
// SourceARN is value which is passed while assuming role using RoleARN and used for
// restricting the access. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
SourceARN string

// KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources
KubernetesClusterTag string
Expand Down Expand Up @@ -1282,19 +1285,21 @@ func init() {
if cfg.Global.RoleARN != "" {
klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN)
stsClient := sts.New(sess)
if cfg.Global.KubernetesClusterID != "" {
sourceAcct, sourceArn, err := GetSourceAcctAndArn(cfg.Global.RoleARN, regionName, cfg.Global.KubernetesClusterID)
if err != nil {
return nil, err
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(map[string]string{
sourceKey: sourceArn,
accountKey: sourceAcct,
}))
})
klog.Infof("configuring STS client with extra headers")
sourceAcct, err := GetSourceAcct(cfg.Global.RoleARN)
if err != nil {
return nil, err
}
reqHeaders := map[string]string{
accountKey: sourceAcct,
}
if cfg.Global.SourceARN != "" {
reqHeaders[sourceKey] = cfg.Global.SourceARN
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
klog.Infof("configuring STS client with extra headers")

creds = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
Expand Down
15 changes: 5 additions & 10 deletions pkg/providers/v1/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package aws

import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -48,24 +47,20 @@ func stringSetFromPointers(in []*string) sets.String {
return out
}

// GetSourceAcctAndArn constructs source acct and arn and return them for use
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
// GetSourceAcct constructs source acct and return them for use
func GetSourceAcct(roleARN string) (string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
}
if region == "" {
return "", "", errors.New("region can't be empty")
return "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", "", err
return "", err
}

sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
return parsedArn.AccountID, sourceArn, nil
return parsedArn.AccountID, nil
}
35 changes: 6 additions & 29 deletions pkg/providers/v1/aws_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,40 @@ import "testing"

func TestGetSourceAcctAndArn(t *testing.T) {
type args struct {
roleARN string
region string
clusterName string
roleARN string
}
tests := []struct {
name string
args args
want string
want1 string
wantErr bool
}{
{
name: "corect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
region: "us-west-2",
clusterName: "test-cluster",
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
},
want: "123456789876",
want1: "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster",
wantErr: false,
},
{
name: "incorect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876",
region: "us-west-2",
clusterName: "test-cluster",
roleARN: "arn:aws:iam::123456789876",
},
want: "",
want1: "",
wantErr: true,
},
{
name: "empty region",
args: args{
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
region: "",
clusterName: "test-cluster",
},
want: "",
want1: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := GetSourceAcctAndArn(tt.args.roleARN, tt.args.region, tt.args.clusterName)
got, err := GetSourceAcct(tt.args.roleARN)
if (err != nil) != tt.wantErr {
t.Errorf("GetSourceAcctAndArn() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("GetSourceAcct() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetSourceAcctAndArn() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetSourceAcctAndArn() got1 = %v, want %v", got1, tt.want1)
t.Errorf("GetSourceAcct() got = %v, want %v", got, tt.want)
}
})
}
Expand Down

0 comments on commit 619334a

Please sign in to comment.