diff --git a/velero-plugin-for-aws/config.go b/velero-plugin-for-aws/config.go index 4a3aed0..b368d91 100644 --- a/velero-plugin-for-aws/config.go +++ b/velero-plugin-for-aws/config.go @@ -9,44 +9,40 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/pkg/errors" + "github.com/sirupsen/logrus" "net/http" "os" ) -func newAWSConfig(region, profile, credentialsFile string, insecureSkipTLSVerify bool, caCert string) (aws.Config, error) { - empty := aws.Config{} - client := awshttp.NewBuildableClient().WithTransportOptions(func(tr *http.Transport) { - if len(caCert) > 0 { - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM([]byte(caCert)) - if tr.TLSClientConfig == nil { - tr.TLSClientConfig = &tls.Config{ - RootCAs: caCertPool, - } - } else { - tr.TLSClientConfig.RootCAs = caCertPool - } - } - tr.TLSClientConfig.InsecureSkipVerify = insecureSkipTLSVerify - }) - opts := []func(*config.LoadOptions) error{ - config.WithRegion(region), - config.WithSharedConfigProfile(profile), - config.WithHTTPClient(client), +type configBuilder struct { + log logrus.FieldLogger + opts []func(*config.LoadOptions) error + credsFlag bool +} + +func newConfigBuilder(logger logrus.FieldLogger) *configBuilder { + return &configBuilder{ + log: logger, } +} + +func (cb *configBuilder) WithRegion(region string) *configBuilder { + cb.opts = append(cb.opts, config.WithRegion(region)) + return cb +} +func (cb *configBuilder) WithProfile(profile string) *configBuilder { + cb.opts = append(cb.opts, config.WithSharedConfigProfile(profile)) + return cb +} + +func (cb *configBuilder) WithCredentialsFile(credentialsFile string) *configBuilder { if credentialsFile == "" && os.Getenv("AWS_SHARED_CREDENTIALS_FILE") != "" { credentialsFile = os.Getenv("AWS_SHARED_CREDENTIALS_FILE") } if credentialsFile != "" { - if _, err := os.Stat(credentialsFile); err != nil { - if os.IsNotExist(err) { - return empty, errors.Wrapf(err, "provided credentialsFile does not exist") - } - return empty, errors.Wrapf(err, "could not get credentialsFile info") - } - opts = append(opts, config.WithSharedCredentialsFiles([]string{credentialsFile}), + cb.opts = append(cb.opts, config.WithSharedCredentialsFiles([]string{credentialsFile}), // To support the existing use case where config file is passed // as credentials of a BSL config.WithSharedConfigFiles([]string{credentialsFile})) @@ -54,17 +50,42 @@ func newAWSConfig(region, profile, credentialsFile string, insecureSkipTLSVerify os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "") os.Setenv("AWS_ROLE_SESSION_NAME", "") os.Setenv("AWS_ROLE_ARN", "") + cb.credsFlag = true } + return cb +} + +func (cb *configBuilder) WithTLSSettings(insecureSkipTLSVerify bool, caCert string) *configBuilder { + cb.opts = append(cb.opts, config.WithHTTPClient(awshttp.NewBuildableClient().WithTransportOptions(func(tr *http.Transport) { + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + if len(caCert) > 0 { + var caCertPool *x509.CertPool + caCertPool, err := x509.SystemCertPool() + if err != nil { + cb.log.Warnf("Failed to load system cert pool, using empty cert pool, err: %v", err) + caCertPool = x509.NewCertPool() + } + caCertPool.AppendCertsFromPEM([]byte(caCert)) + tr.TLSClientConfig.RootCAs = caCertPool + } + tr.TLSClientConfig.InsecureSkipVerify = insecureSkipTLSVerify + }))) + return cb +} - awsConfig, err := config.LoadDefaultConfig(context.Background(), opts...) +func (cb *configBuilder) Build() (aws.Config, error) { + conf, err := config.LoadDefaultConfig(context.Background(), cb.opts...) if err != nil { - return empty, errors.Wrapf(err, "could not load config") + return aws.Config{}, err } - if _, err := awsConfig.Credentials.Retrieve(context.Background()); err != nil { - return empty, errors.WithStack(err) + if cb.credsFlag { + if _, err := conf.Credentials.Retrieve(context.Background()); err != nil { + return aws.Config{}, errors.WithStack(err) + } } - - return awsConfig, nil + return conf, nil } func newS3Client(cfg aws.Config, url string, forcePathStyle bool) (*s3.Client, error) { diff --git a/velero-plugin-for-aws/helpers.go b/velero-plugin-for-aws/helpers.go index ae840a9..5616dc4 100644 --- a/velero-plugin-for-aws/helpers.go +++ b/velero-plugin-for-aws/helpers.go @@ -17,36 +17,12 @@ limitations under the License. package main import ( - "context" "net/url" "strings" - "github.com/aws/aws-sdk-go-v2/config" - s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/pkg/errors" ) -// GetBucketRegion returns the AWS region that a bucket is in, or an error -// if the region cannot be determined. -func GetBucketRegion(bucket string, usePathStyle bool) (string, error) { - cfg, err := config.LoadDefaultConfig(context.Background()) - if err != nil { - return "", errors.WithStack(err) - } - client := s3.NewFromConfig(cfg) - region, err := s3manager.GetBucketRegion(context.Background(), client, bucket, func(o *s3.Options) { - o.UsePathStyle = usePathStyle - }) - if err != nil { - return "", err - } - if region == "" { - return "", errors.New("unable to determine bucket's region") - } - return region, nil -} - // IsValidS3URLScheme returns true if the scheme is http:// or https:// // and the url parses correctly, otherwise, return false func IsValidS3URLScheme(s3URL string) bool { diff --git a/velero-plugin-for-aws/object_store.go b/velero-plugin-for-aws/object_store.go index 66c56db..11bf9ea 100644 --- a/velero-plugin-for-aws/object_store.go +++ b/velero-plugin-for-aws/object_store.go @@ -18,6 +18,7 @@ package main import ( "context" + "fmt" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" @@ -130,24 +131,37 @@ func (o *ObjectStore) Init(config map[string]string) error { } } + if insecureSkipTLSVerifyVal != "" { + if insecureSkipTLSVerify, err = strconv.ParseBool(insecureSkipTLSVerifyVal); err != nil { + return errors.Wrapf(err, "could not parse %s (expected bool)", insecureSkipTLSVerifyKey) + } + } + // AWS (not an alternate S3-compatible API) and region not // explicitly specified: determine the bucket's region if s3URL == "" && region == "" { - var err error - region, err = GetBucketRegion(bucket, s3ForcePathStyle) + cfg, err := newConfigBuilder(o.log).WithTLSSettings(insecureSkipTLSVerify, caCert).Build() + if err != nil { + return errors.WithStack(err) + } + client, err := newS3Client(cfg, s3URL, s3ForcePathStyle) if err != nil { - o.log.Errorf("Failed to get bucket region, bucket: %s, error: %v", bucket, err) + return errors.WithStack(err) + } + region, err = manager.GetBucketRegion(context.Background(), client, bucket) + if err != nil { + o.log.Errorf("Failed to determine bucket's region bucket: %s, error: %v", bucket, err) return err } - } - - if insecureSkipTLSVerifyVal != "" { - if insecureSkipTLSVerify, err = strconv.ParseBool(insecureSkipTLSVerifyVal); err != nil { - return errors.Wrapf(err, "could not parse %s (expected bool)", insecureSkipTLSVerifyKey) + if region == "" { + return fmt.Errorf("unable to determine bucket's region, bucket: %s", bucket) } } - cfg, err := newAWSConfig(region, credentialProfile, credentialsFile, insecureSkipTLSVerify, caCert) + cfg, err := newConfigBuilder(o.log).WithRegion(region). + WithProfile(credentialProfile). + WithCredentialsFile(credentialsFile). + WithTLSSettings(insecureSkipTLSVerify, caCert).Build() if err != nil { return errors.WithStack(err) } diff --git a/velero-plugin-for-aws/volume_snapshotter.go b/velero-plugin-for-aws/volume_snapshotter.go index c3889bf..118884e 100644 --- a/velero-plugin-for-aws/volume_snapshotter.go +++ b/velero-plugin-for-aws/volume_snapshotter.go @@ -63,19 +63,18 @@ func (b *VolumeSnapshotter) Init(config map[string]string) error { region := config[regionKey] credentialProfile := config[credentialProfileKey] credentialsFile := config[credentialsFileKey] - // enableSharedConfig := config[enableSharedConfigKey] if region == "" { return errors.Errorf("missing %s in aws configuration", regionKey) } - - cfg, err := newAWSConfig(region, credentialProfile, credentialsFile, false, "") + cfg, err := newConfigBuilder(b.log). + WithRegion(region). + WithProfile(credentialProfile). + WithCredentialsFile(credentialsFile).Build() if err != nil { return errors.WithStack(err) } - b.ec2 = ec2.NewFromConfig(cfg) - return nil }