Skip to content

Commit

Permalink
Merge pull request vmware-tanzu#195 from reasonerjt/config-builder
Browse files Browse the repository at this point in the history
Respect the TLS setting in BSL in object store plugin
  • Loading branch information
ywk253100 authored Feb 29, 2024
2 parents c89401a + 2774aff commit 1ecdc05
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 71 deletions.
87 changes: 54 additions & 33 deletions velero-plugin-for-aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,83 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"net/http"
"os"
)

func newAWSConfig(region, profile, credentialsFile string, insecureSkipTLSVerify bool, caCert string) (aws.Config, error) {
empty := aws.Config{}
client := awshttp.NewBuildableClient().WithTransportOptions(func(tr *http.Transport) {
if len(caCert) > 0 {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM([]byte(caCert))
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{
RootCAs: caCertPool,
}
} else {
tr.TLSClientConfig.RootCAs = caCertPool
}
}
tr.TLSClientConfig.InsecureSkipVerify = insecureSkipTLSVerify
})
opts := []func(*config.LoadOptions) error{
config.WithRegion(region),
config.WithSharedConfigProfile(profile),
config.WithHTTPClient(client),
type configBuilder struct {
log logrus.FieldLogger
opts []func(*config.LoadOptions) error
credsFlag bool
}

func newConfigBuilder(logger logrus.FieldLogger) *configBuilder {
return &configBuilder{
log: logger,
}
}

func (cb *configBuilder) WithRegion(region string) *configBuilder {
cb.opts = append(cb.opts, config.WithRegion(region))
return cb
}

func (cb *configBuilder) WithProfile(profile string) *configBuilder {
cb.opts = append(cb.opts, config.WithSharedConfigProfile(profile))
return cb
}

func (cb *configBuilder) WithCredentialsFile(credentialsFile string) *configBuilder {
if credentialsFile == "" && os.Getenv("AWS_SHARED_CREDENTIALS_FILE") != "" {
credentialsFile = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
}

if credentialsFile != "" {
if _, err := os.Stat(credentialsFile); err != nil {
if os.IsNotExist(err) {
return empty, errors.Wrapf(err, "provided credentialsFile does not exist")
}
return empty, errors.Wrapf(err, "could not get credentialsFile info")
}
opts = append(opts, config.WithSharedCredentialsFiles([]string{credentialsFile}),
cb.opts = append(cb.opts, config.WithSharedCredentialsFiles([]string{credentialsFile}),
// To support the existing use case where config file is passed
// as credentials of a BSL
config.WithSharedConfigFiles([]string{credentialsFile}))
// unset the env variables to bypass the role assumption when IRSA is configured
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "")
os.Setenv("AWS_ROLE_SESSION_NAME", "")
os.Setenv("AWS_ROLE_ARN", "")
cb.credsFlag = true
}
return cb
}

func (cb *configBuilder) WithTLSSettings(insecureSkipTLSVerify bool, caCert string) *configBuilder {
cb.opts = append(cb.opts, config.WithHTTPClient(awshttp.NewBuildableClient().WithTransportOptions(func(tr *http.Transport) {
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{}
}
if len(caCert) > 0 {
var caCertPool *x509.CertPool
caCertPool, err := x509.SystemCertPool()
if err != nil {
cb.log.Warnf("Failed to load system cert pool, using empty cert pool, err: %v", err)
caCertPool = x509.NewCertPool()
}
caCertPool.AppendCertsFromPEM([]byte(caCert))
tr.TLSClientConfig.RootCAs = caCertPool
}
tr.TLSClientConfig.InsecureSkipVerify = insecureSkipTLSVerify
})))
return cb
}

awsConfig, err := config.LoadDefaultConfig(context.Background(), opts...)
func (cb *configBuilder) Build() (aws.Config, error) {
conf, err := config.LoadDefaultConfig(context.Background(), cb.opts...)
if err != nil {
return empty, errors.Wrapf(err, "could not load config")
return aws.Config{}, err
}
if _, err := awsConfig.Credentials.Retrieve(context.Background()); err != nil {
return empty, errors.WithStack(err)
if cb.credsFlag {
if _, err := conf.Credentials.Retrieve(context.Background()); err != nil {
return aws.Config{}, errors.WithStack(err)
}
}

return awsConfig, nil
return conf, nil
}

func newS3Client(cfg aws.Config, url string, forcePathStyle bool) (*s3.Client, error) {
Expand Down
24 changes: 0 additions & 24 deletions velero-plugin-for-aws/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,12 @@ limitations under the License.
package main

import (
"context"
"net/url"
"strings"

"github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pkg/errors"
)

// GetBucketRegion returns the AWS region that a bucket is in, or an error
// if the region cannot be determined.
func GetBucketRegion(bucket string, usePathStyle bool) (string, error) {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return "", errors.WithStack(err)
}
client := s3.NewFromConfig(cfg)
region, err := s3manager.GetBucketRegion(context.Background(), client, bucket, func(o *s3.Options) {
o.UsePathStyle = usePathStyle
})
if err != nil {
return "", err
}
if region == "" {
return "", errors.New("unable to determine bucket's region")
}
return region, nil
}

// IsValidS3URLScheme returns true if the scheme is http:// or https://
// and the url parses correctly, otherwise, return false
func IsValidS3URLScheme(s3URL string) bool {
Expand Down
32 changes: 23 additions & 9 deletions velero-plugin-for-aws/object_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
Expand Down Expand Up @@ -130,24 +131,37 @@ func (o *ObjectStore) Init(config map[string]string) error {
}
}

if insecureSkipTLSVerifyVal != "" {
if insecureSkipTLSVerify, err = strconv.ParseBool(insecureSkipTLSVerifyVal); err != nil {
return errors.Wrapf(err, "could not parse %s (expected bool)", insecureSkipTLSVerifyKey)
}
}

// AWS (not an alternate S3-compatible API) and region not
// explicitly specified: determine the bucket's region
if s3URL == "" && region == "" {
var err error
region, err = GetBucketRegion(bucket, s3ForcePathStyle)
cfg, err := newConfigBuilder(o.log).WithTLSSettings(insecureSkipTLSVerify, caCert).Build()
if err != nil {
return errors.WithStack(err)
}
client, err := newS3Client(cfg, s3URL, s3ForcePathStyle)
if err != nil {
o.log.Errorf("Failed to get bucket region, bucket: %s, error: %v", bucket, err)
return errors.WithStack(err)
}
region, err = manager.GetBucketRegion(context.Background(), client, bucket)
if err != nil {
o.log.Errorf("Failed to determine bucket's region bucket: %s, error: %v", bucket, err)
return err
}
}

if insecureSkipTLSVerifyVal != "" {
if insecureSkipTLSVerify, err = strconv.ParseBool(insecureSkipTLSVerifyVal); err != nil {
return errors.Wrapf(err, "could not parse %s (expected bool)", insecureSkipTLSVerifyKey)
if region == "" {
return fmt.Errorf("unable to determine bucket's region, bucket: %s", bucket)
}
}

cfg, err := newAWSConfig(region, credentialProfile, credentialsFile, insecureSkipTLSVerify, caCert)
cfg, err := newConfigBuilder(o.log).WithRegion(region).
WithProfile(credentialProfile).
WithCredentialsFile(credentialsFile).
WithTLSSettings(insecureSkipTLSVerify, caCert).Build()
if err != nil {
return errors.WithStack(err)
}
Expand Down
9 changes: 4 additions & 5 deletions velero-plugin-for-aws/volume_snapshotter.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,18 @@ func (b *VolumeSnapshotter) Init(config map[string]string) error {
region := config[regionKey]
credentialProfile := config[credentialProfileKey]
credentialsFile := config[credentialsFileKey]
// enableSharedConfig := config[enableSharedConfigKey]

if region == "" {
return errors.Errorf("missing %s in aws configuration", regionKey)
}

cfg, err := newAWSConfig(region, credentialProfile, credentialsFile, false, "")
cfg, err := newConfigBuilder(b.log).
WithRegion(region).
WithProfile(credentialProfile).
WithCredentialsFile(credentialsFile).Build()
if err != nil {
return errors.WithStack(err)
}

b.ec2 = ec2.NewFromConfig(cfg)

return nil
}

Expand Down

0 comments on commit 1ecdc05

Please sign in to comment.