diff --git a/provider/awsv4.go b/provider/awsv4.go new file mode 100644 index 0000000..103b33b --- /dev/null +++ b/provider/awsv4.go @@ -0,0 +1,39 @@ +package provider + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "io" + "net/http" +) + +var emptyStringSHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + +type awsV4SignerWrapper struct { + internal http.RoundTripper +} + +func (client *awsV4SignerWrapper) RoundTrip(request *http.Request) (*http.Response, error) { + var hash string + if request.Body == nil { + hash = emptyStringSHA256 + } else { + payload, error := io.ReadAll(request.Body) + request.Body = io.NopCloser(bytes.NewReader(payload)) + if error != nil { + return nil, error + } + hasher := sha256.New() + hasher.Write(payload) + hashBytes := hasher.Sum(nil) + hash = hex.EncodeToString(hashBytes) + } + request.Header.Set("X-Amz-Content-Sha256", hash) + + return client.internal.RoundTrip(request) +} + +func Wrap(internal http.RoundTripper) http.RoundTripper { + return &awsV4SignerWrapper{internal: internal} +} diff --git a/provider/provider.go b/provider/provider.go index 036b06b..36a6439 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -38,6 +38,8 @@ const ( ) var awsUrlRegexp = regexp.MustCompile(`([a-z0-9-]+).es.amazonaws.com$`) +var awsOpensearchServerlessUrlRegexp = regexp.MustCompile(`([a-z0-9-]+).aoss.amazonaws.com$`) +var minimalOpensearchServerlessVersion = "2.0.0" type ProviderConf struct { rawUrl string @@ -305,6 +307,25 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) { return nil, err } opts = append(opts, elastic7.SetHttpClient(client), elastic7.SetSniff(false)) + } else if m := awsOpensearchServerlessUrlRegexp.FindStringSubmatch(conf.parsedUrl.Hostname()); (m != nil || (conf.awsSig4Service == "aoss" && conf.awsRegion != "")) && conf.signAWSRequests { + var region string + if m != nil { + region = m[1] + } else { + region = conf.awsRegion + } + log.Printf("[INFO] Using AWS: %+v", region) + conf.awsSig4Service = "aoss" + client, err := awsHttpClient(region, conf, map[string]string{}) + if err != nil { + return nil, err + } + client.Transport = Wrap(client.Transport) + opts = append(opts, elastic7.SetHttpClient(client), elastic7.SetSniff(false)) + conf.flavor = OpenSearch + if conf.osVersion == "" { + conf.osVersion = minimalOpensearchServerlessVersion + } } else if awsRegion := conf.awsRegion; conf.awsRegion != "" && conf.signAWSRequests { log.Printf("[INFO] Using AWS: %+v", awsRegion) client, err := awsHttpClient(awsRegion, conf, map[string]string{})