Skip to content

Commit

Permalink
Merge pull request #2 from openstandia/max-session-and-default-iam-role
Browse files Browse the repository at this point in the history
Add options for max session and default IAM role
  • Loading branch information
wadahiro authored Jun 10, 2019
2 parents d784075 + 9bb80b3 commit 8977e3a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NAME := aws-cli-oidc
VERSION := v0.3.0
VERSION := v0.4.0
REVISION := $(shell git rev-parse --short HEAD)

SRCS := $(shell find . -type f -name '*.go')
Expand Down
36 changes: 22 additions & 14 deletions cmd/aws_saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@ import (
"github.com/versent/saml2aws"
)

func GetCredentialsWithSAML(samlResponse string) (*AWSCredentials, error) {
role, err := selectAwsRole(samlResponse)
func GetCredentialsWithSAML(samlResponse string, durationSeconds int64, defaultIAMRoleArn string) (*AWSCredentials, error) {
role, err := selectAwsRole(samlResponse, defaultIAMRoleArn)
if err != nil {
return nil, errors.Wrap(err, "Failed to assume role, please check you are permitted to assume the given role for the AWS service")
}

Writeln("Selected role: %s", role.RoleARN)
Writeln("Max Session Duration: %d seconds", durationSeconds)

return loginToStsUsingRole(role, samlResponse)
return loginToStsUsingRole(role, samlResponse, durationSeconds)
}

func selectAwsRole(samlResponse string) (*saml2aws.AWSRole, error) {
func selectAwsRole(samlResponse, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) {
roles, err := saml2aws.ExtractAwsRoles([]byte(samlResponse))
if err != nil {
return nil, errors.Wrap(err, "Failed to extract aws roles")
return nil, errors.Wrap(err, "Failed to extract aws roles from SAML Assertion")
}

if len(roles) == 0 {
Expand All @@ -41,10 +42,10 @@ func selectAwsRole(samlResponse string) (*saml2aws.AWSRole, error) {
return nil, errors.Wrap(err, "Failed to parse aws roles")
}

return resolveRole(awsRoles, samlResponse)
return resolveRole(awsRoles, samlResponse, defaultIAMRoleArn)
}

func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws.AWSRole, error) {
func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) {
var role = new(saml2aws.AWSRole)

if len(awsRoles) == 1 {
Expand All @@ -57,7 +58,7 @@ func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws.

for {
var err error
role, err = promptForAWSRoleSelection(awsRoles)
role, err = promptForAWSRoleSelection(awsRoles, defaultIAMRoleArn)
if err == nil {
break
}
Expand All @@ -67,14 +68,21 @@ func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws.
return role, nil
}

func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole) (*saml2aws.AWSRole, error) {
func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) {
roles := map[string]*saml2aws.AWSRole{}
var roleOptions []string

for _, role := range awsRoles {
name := fmt.Sprintf("%s", role.RoleARN)
roles[name] = role
roleOptions = append(roleOptions, name)
if defaultIAMRoleArn == role.RoleARN {
Writeln("Selected default role: %s", defaultIAMRoleArn)
return role, nil
}
roles[role.RoleARN] = role
roleOptions = append(roleOptions, role.RoleARN)
}

if defaultIAMRoleArn != "" {
Writeln("Warning: You don't have the default role: %s", defaultIAMRoleArn)
}

sort.Strings(roleOptions)
Expand Down Expand Up @@ -104,7 +112,7 @@ func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole) (*saml2aws.AWSRole,
return roles[roleOptions[i-1]], nil
}

func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string) (*AWSCredentials, error) {
func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string, durationSeconds int64) (*AWSCredentials, error) {
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "Failed to create session")
Expand All @@ -123,7 +131,7 @@ func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string) (*AWSCrede
PrincipalArn: aws.String(role.PrincipalARN), // Required
RoleArn: aws.String(role.RoleARN), // Required
SAMLAssertion: aws.String(b), // Required
DurationSeconds: aws.Int64(int64(900)),
DurationSeconds: aws.Int64(durationSeconds),
}

Writeln("Requesting AWS credentials using SAML assertion")
Expand Down
9 changes: 8 additions & 1 deletion cmd/get_cred.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -84,6 +85,12 @@ func getCred(cmd *cobra.Command, args []string) {
Traceln("ID token: %s", tokenResponse.IDToken)

awsFedType := client.config.GetString(AWS_FEDERATION_TYPE)
maxSessionDurationSecondsString := client.config.GetString(MAX_SESSION_DURATION_SECONDS)
maxSessionDurationSeconds, err := strconv.ParseInt(maxSessionDurationSecondsString, 10, 64)
if err != nil {
maxSessionDurationSeconds = 3600
}
defaultIAMRoleArn := client.config.GetString(DEFAULT_IAM_ROLE_ARN)

var awsCreds *AWSCredentials
if awsFedType == AWS_FEDERATION_TYPE_OIDC {
Expand All @@ -105,7 +112,7 @@ func getCred(cmd *cobra.Command, args []string) {
Exit(err)
}

awsCreds, err = GetCredentialsWithSAML(samlResponse)
awsCreds, err = GetCredentialsWithSAML(samlResponse, maxSessionDurationSeconds, defaultIAMRoleArn)
if err != nil {
Writeln("Failed to get aws credentials with SAML2")
Exit(err)
Expand Down
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const FAILURE_REDIRECT_URL = "failure_redirect_url"
const CLIENT_ID = "client_id"
const CLIENT_SECRET = "client_secret"
const AWS_FEDERATION_TYPE = "aws_federation_type"
const MAX_SESSION_DURATION_SECONDS = "max_session_duration_seconds"
const DEFAULT_IAM_ROLE_ARN = "default_iam_role_arn"

// OIDC config
const AWS_FEDERATION_ROLE = "aws_federation_role"
Expand Down
33 changes: 33 additions & 0 deletions cmd/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cmd
import (
"fmt"
"os"
"strconv"
"strings"

input "github.com/natsukagami/go-input"
"github.com/pkg/errors"
Expand Down Expand Up @@ -64,6 +66,35 @@ func runSetup() {
return nil
},
})
maxSessionDurationSeconds, _ := ui.Ask("The max session duration, in seconds, of the role session [900-43200] (Default: 3600):", &input.Options{
Default: "3600",
Required: true,
Loop: true,
ValidateFunc: func(s string) error {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil || i < 900 || i > 43200 {
return errors.New(fmt.Sprintf("Input must be 900-43200"))
}
return nil
},
})
defaultIAMRoleArn, _ := ui.Ask("The default IAM Role ARN when you have multiple roles, as arn:aws:iam::<account-id>:role/<role-name> (Default: none):", &input.Options{
Default: "",
Required: false,
Loop: true,
ValidateFunc: func(s string) error {
if s == "" {
return nil
}
arn := strings.Split(s, ":")
if len(arn) == 6 {
if arn[0] == "arn" && arn[1] == "aws" && arn[2] == "iam" && arn[3] == "" && strings.HasPrefix(arn[5], "role/") {
return nil
}
}
return errors.New(fmt.Sprintf("Input must be IAM Role ARN"))
},
})

config := map[string]string{}

Expand All @@ -74,6 +105,8 @@ func runSetup() {
config[CLIENT_ID] = clientID
config[CLIENT_SECRET] = clientSecret
config[AWS_FEDERATION_TYPE] = answerFedType
config[MAX_SESSION_DURATION_SECONDS] = maxSessionDurationSeconds
config[DEFAULT_IAM_ROLE_ARN] = defaultIAMRoleArn

if answerFedType == AWS_FEDERATION_TYPE_OIDC {
oidcSetup(config)
Expand Down

0 comments on commit 8977e3a

Please sign in to comment.