Skip to content

Commit

Permalink
added opensearch serverless support (#92)
Browse files Browse the repository at this point in the history
* added opensearch serverless support

Signed-off-by: vasyaxparfenov <[email protected]>

* added unit tests

Signed-off-by: vasyaxparfenov <[email protected]>

* applied suggestions

Signed-off-by: vasyaxparfenov <[email protected]>

---------

Signed-off-by: vasyaxparfenov <[email protected]>
Co-authored-by: vasiliyparfenov <[email protected]>
  • Loading branch information
vasyaxparfenov and vasiliyparfenov authored Nov 8, 2023
1 parent 97c7627 commit 2245e1d
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
44 changes: 44 additions & 0 deletions provider/awsv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package provider

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"hash"
"io"
"net/http"
)

type awsV4SignerWrapper struct {
internal http.RoundTripper
}

func hashPayload(hasher hash.Hash, payload []byte) string {
hasher.Write(payload)
hashBytes := hasher.Sum(nil)
hash := hex.EncodeToString(hashBytes)
return hash
}

func (client *awsV4SignerWrapper) RoundTrip(request *http.Request) (*http.Response, error) {
hasher := sha256.New()
var hash string
if request.Body == nil {
hash = hashPayload(hasher, []byte(""))
} else {
payload, error := io.ReadAll(request.Body)
request.Body = io.NopCloser(bytes.NewReader(payload))
if error != nil {
return nil, error
}

hash = hashPayload(hasher, payload)
}
request.Header.Set("X-Amz-Content-Sha256", hash)

return client.internal.RoundTrip(request)
}

func Wrap(internal http.RoundTripper) http.RoundTripper {
return &awsV4SignerWrapper{internal: internal}
}
57 changes: 57 additions & 0 deletions provider/awsv4_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package provider

import (
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"strings"
"testing"
)

type MockRoundTripper struct {
}

func (roundTripper *MockRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return new(http.Response), nil
}

func TestRoundTripWithEmptyBody(t *testing.T) {
sut := Wrap(new(MockRoundTripper))

request := new(http.Request)
request.Header = make(http.Header)

_, err := sut.RoundTrip(request)

if err != nil {
t.Fatal(err)
}

if header, contains := request.Header["X-Amz-Content-Sha256"]; !contains || len(header) != 1 || header[0] != "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" {
t.Fatal("Request with empty body doesn't contain X-Amz-Content-Sha256 header with empty string hash value.")
}
}

func TestRoundTripWithBody(t *testing.T) {
sut := Wrap(new(MockRoundTripper))
request := new(http.Request)
request.Header = make(http.Header)

body := "body"
request.Body = io.NopCloser(strings.NewReader(body))

hasher := sha256.New()
hasher.Write([]byte(body))
hashBytes := hasher.Sum(nil)
expectedHash := hex.EncodeToString(hashBytes)

_, err := sut.RoundTrip(request)
if err != nil {
t.Fatal(err)
}

if header, contains := request.Header["X-Amz-Content-Sha256"]; !contains || len(header) != 1 || header[0] != expectedHash {
t.Fatal("Request with body doesn't contain X-Amz-Content-Sha256 header with correct hash value.")
}
}
21 changes: 21 additions & 0 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const (
)

var awsUrlRegexp = regexp.MustCompile(`([a-z0-9-]+).es.amazonaws.com$`)
var awsOpensearchServerlessUrlRegexp = regexp.MustCompile(`([a-z0-9-]+).aoss.amazonaws.com$`)
var minimalOpensearchServerlessVersion = "2.0.0"

type ProviderConf struct {
rawUrl string
Expand Down Expand Up @@ -313,6 +315,25 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) {
return nil, err
}
opts = append(opts, elastic7.SetHttpClient(client), elastic7.SetSniff(false))
} else if m := awsOpensearchServerlessUrlRegexp.FindStringSubmatch(conf.parsedUrl.Hostname()); (m != nil || (conf.awsSig4Service == "aoss" && conf.awsRegion != "")) && conf.signAWSRequests {
var region string
if m != nil {
region = m[1]
} else {
region = conf.awsRegion
}
log.Printf("[INFO] Using AWS: %+v", region)
conf.awsSig4Service = "aoss"
client, err := awsHttpClient(region, conf, map[string]string{})
if err != nil {
return nil, err
}
client.Transport = Wrap(client.Transport)
opts = append(opts, elastic7.SetHttpClient(client), elastic7.SetSniff(false))
conf.flavor = OpenSearch
if conf.osVersion == "" {
conf.osVersion = minimalOpensearchServerlessVersion
}
} else if awsRegion := conf.awsRegion; conf.awsRegion != "" && conf.signAWSRequests {
log.Printf("[INFO] Using AWS: %+v", awsRegion)
client, err := awsHttpClient(awsRegion, conf, map[string]string{})
Expand Down

0 comments on commit 2245e1d

Please sign in to comment.