From d2adbc5c0014d2b7de556e837a508ba999b08f51 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Wed, 21 Aug 2024 02:07:04 -0300 Subject: [PATCH] aws/all: Add a "hostname_immutable" URL parameter for AWS v2 URLs with endpoints --- aws/aws.go | 31 ++++++++++++++++++++++--------- aws/aws_test.go | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 21d4ec5afa..0f2d9c61ff 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -177,23 +177,24 @@ func NewDefaultV2Config(ctx context.Context) (awsv2.Config, error) { // - region: The AWS region for requests; sets WithRegion. // - profile: The shared config profile to use; sets SharedConfigProfile. // - endpoint: The AWS service endpoint to send HTTP request. +// - hostname_immutable: Make the hostname immutable, only works if endpoint is also set. func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, error) { + var endpoint string + var hostnameImmutable bool var opts []func(*awsv2cfg.LoadOptions) error for param, values := range q { value := values[0] switch param { + case "hostname_immutable": + var err error + hostnameImmutable, err = strconv.ParseBool(value) + if err != nil { + return awsv2.Config{}, fmt.Errorf("invalid value for hostname_immutable: %w", err) + } case "region": opts = append(opts, awsv2cfg.WithRegion(value)) case "endpoint": - customResolver := awsv2.EndpointResolverWithOptionsFunc( - func(service, region string, options ...interface{}) (awsv2.Endpoint, error) { - return awsv2.Endpoint{ - PartitionID: "aws", - URL: value, - SigningRegion: region, - }, nil - }) - opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver)) + endpoint = value case "profile": opts = append(opts, awsv2cfg.WithSharedConfigProfile(value)) case "awssdk": @@ -202,5 +203,17 @@ func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, err return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param) } } + if endpoint != "" { + customResolver := awsv2.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (awsv2.Endpoint, error) { + return awsv2.Endpoint{ + PartitionID: "aws", + URL: endpoint, + SigningRegion: region, + HostnameImmutable: hostnameImmutable, + }, nil + }) + opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver)) + } return awsv2cfg.LoadDefaultConfig(ctx, opts...) } diff --git a/aws/aws_test.go b/aws/aws_test.go index 5d04dae75d..29dbee1ae5 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -17,8 +17,10 @@ package aws_test import ( "context" "net/url" + "reflect" "testing" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws" "github.com/google/go-cmp/cmp" gcaws "gocloud.dev/aws" @@ -147,12 +149,16 @@ func TestUseV2(t *testing.T) { } func TestV2ConfigFromURLParams(t *testing.T) { + const service = "s3" + const region = "us-east-1" + const partitionID = "aws" ctx := context.Background() tests := []struct { - name string - query url.Values - wantRegion string - wantErr bool + name string + query url.Values + wantRegion string + wantErr bool + wantEndpoint *awsv2.Endpoint }{ { name: "No overrides", @@ -168,6 +174,16 @@ func TestV2ConfigFromURLParams(t *testing.T) { query: url.Values{"region": {"my_region"}}, wantRegion: "my_region", }, + { + name: "Endpoint and hostname immutable", + query: url.Values{"endpoint": {"foo"}, "hostname_immutable": {"true"}}, + wantEndpoint: &awsv2.Endpoint{ + PartitionID: partitionID, + SigningRegion: region, + URL: "foo", + HostnameImmutable: true, + }, + }, // Can't test "profile", since AWS validates that the profile exists. } @@ -184,6 +200,19 @@ func TestV2ConfigFromURLParams(t *testing.T) { if test.wantRegion != "" && got.Region != test.wantRegion { t.Errorf("got region %q, want %q", got.Region, test.wantRegion) } + + if test.wantEndpoint != nil { + if got.EndpointResolverWithOptions == nil { + t.Fatalf("expected an EndpointResolverWithOptions, got nil") + } + gotE, err := got.EndpointResolverWithOptions.ResolveEndpoint(service, region) + if err != nil { + return + } + if !reflect.DeepEqual(gotE, *test.wantEndpoint) { + t.Errorf("got endpoint %+v, want %+v", gotE, *test.wantEndpoint) + } + } }) } }