Skip to content

Commit

Permalink
aws/all: Add a "hostname_immutable" URL parameter for AWS v2 URLs wit…
Browse files Browse the repository at this point in the history
…h endpoints
  • Loading branch information
caarlos0 authored Aug 21, 2024
1 parent 54419b7 commit d2adbc5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
31 changes: 22 additions & 9 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,24 @@ func NewDefaultV2Config(ctx context.Context) (awsv2.Config, error) {
// - region: The AWS region for requests; sets WithRegion.
// - profile: The shared config profile to use; sets SharedConfigProfile.
// - endpoint: The AWS service endpoint to send HTTP request.
// - hostname_immutable: Make the hostname immutable, only works if endpoint is also set.
func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, error) {
var endpoint string
var hostnameImmutable bool
var opts []func(*awsv2cfg.LoadOptions) error
for param, values := range q {
value := values[0]
switch param {
case "hostname_immutable":
var err error
hostnameImmutable, err = strconv.ParseBool(value)
if err != nil {
return awsv2.Config{}, fmt.Errorf("invalid value for hostname_immutable: %w", err)
}
case "region":
opts = append(opts, awsv2cfg.WithRegion(value))
case "endpoint":
customResolver := awsv2.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (awsv2.Endpoint, error) {
return awsv2.Endpoint{
PartitionID: "aws",
URL: value,
SigningRegion: region,
}, nil
})
opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver))
endpoint = value
case "profile":
opts = append(opts, awsv2cfg.WithSharedConfigProfile(value))
case "awssdk":
Expand All @@ -202,5 +203,17 @@ func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, err
return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param)
}
}
if endpoint != "" {
customResolver := awsv2.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (awsv2.Endpoint, error) {
return awsv2.Endpoint{
PartitionID: "aws",
URL: endpoint,
SigningRegion: region,
HostnameImmutable: hostnameImmutable,
}, nil
})
opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver))
}
return awsv2cfg.LoadDefaultConfig(ctx, opts...)
}
37 changes: 33 additions & 4 deletions aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package aws_test
import (
"context"
"net/url"
"reflect"
"testing"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/google/go-cmp/cmp"
gcaws "gocloud.dev/aws"
Expand Down Expand Up @@ -147,12 +149,16 @@ func TestUseV2(t *testing.T) {
}

func TestV2ConfigFromURLParams(t *testing.T) {
const service = "s3"
const region = "us-east-1"
const partitionID = "aws"
ctx := context.Background()
tests := []struct {
name string
query url.Values
wantRegion string
wantErr bool
name string
query url.Values
wantRegion string
wantErr bool
wantEndpoint *awsv2.Endpoint
}{
{
name: "No overrides",
Expand All @@ -168,6 +174,16 @@ func TestV2ConfigFromURLParams(t *testing.T) {
query: url.Values{"region": {"my_region"}},
wantRegion: "my_region",
},
{
name: "Endpoint and hostname immutable",
query: url.Values{"endpoint": {"foo"}, "hostname_immutable": {"true"}},
wantEndpoint: &awsv2.Endpoint{
PartitionID: partitionID,
SigningRegion: region,
URL: "foo",
HostnameImmutable: true,
},
},
// Can't test "profile", since AWS validates that the profile exists.
}

Expand All @@ -184,6 +200,19 @@ func TestV2ConfigFromURLParams(t *testing.T) {
if test.wantRegion != "" && got.Region != test.wantRegion {
t.Errorf("got region %q, want %q", got.Region, test.wantRegion)
}

if test.wantEndpoint != nil {
if got.EndpointResolverWithOptions == nil {
t.Fatalf("expected an EndpointResolverWithOptions, got nil")
}
gotE, err := got.EndpointResolverWithOptions.ResolveEndpoint(service, region)
if err != nil {
return
}
if !reflect.DeepEqual(gotE, *test.wantEndpoint) {
t.Errorf("got endpoint %+v, want %+v", gotE, *test.wantEndpoint)
}
}
})
}
}

0 comments on commit d2adbc5

Please sign in to comment.