Skip to content

Commit

Permalink
feat(fs): Support Vault AWS IAM auth (#2264)
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Henderson <[email protected]>
  • Loading branch information
hairyhenderson authored Nov 17, 2024
1 parent b772227 commit 1da9105
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 82 deletions.
2 changes: 2 additions & 0 deletions internal/datafs/fsys.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func FSysForPath(ctx context.Context, path string) (fs.FS, error) {
fsys = vaultauth.WithAuthMethod(compositeVaultAuthMethod(fileFsys), fsys)
}

fsys = fsimpl.WithContextFS(ctx, fsys)

return fsys, nil
}

Expand Down
35 changes: 29 additions & 6 deletions internal/datafs/vaultauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@ func compositeVaultAuthMethod(envFsys fs.FS) api.AuthMethod {
return vaultauth.CompositeAuthMethod(
vaultauth.EnvAuthMethod(),
envEC2AuthAdapter(envFsys),
envIAMAuthAdapter(envFsys),
)
}

// func CompositeVaultAuthMethod() api.AuthMethod {
// return compositeVaultAuthMethod(WrapWdFS(osfs.NewFS()))
// }

// envEC2AuthAdapter builds an AWS EC2 authentication method from environment
// variables, for use only with [CompositeVaultAuthMethod]
// variables, for use only with [compositeVaultAuthMethod]
func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")

Expand Down Expand Up @@ -61,8 +58,34 @@ func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
return &ec2AuthNonceWriter{AWSAuth: awsauth, nonce: nonce, output: output}
}

// envIAMAuthAdapter builds an AWS IAM authentication method from environment
// variables, for use only with [compositeVaultAuthMethod]
func envIAMAuthAdapter(envFS fs.FS) api.AuthMethod {
mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")
role := GetenvFsys(envFS, "VAULT_AUTH_AWS_ROLE")

// temporary workaround while we wait to deprecate AWS_META_ENDPOINT
if endpoint := os.Getenv("AWS_META_ENDPOINT"); endpoint != "" {
deprecated.WarnDeprecated(context.Background(), "Use AWS_EC2_METADATA_SERVICE_ENDPOINT instead of AWS_META_ENDPOINT")
if os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") == "" {
os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", endpoint)
}
}

awsauth, err := aws.NewAWSAuth(
aws.WithIAMAuth(),
aws.WithMountPath(mountPath),
aws.WithRole(role),
)
if err != nil {
return nil
}

return awsauth
}

// ec2AuthNonceWriter - wraps an AWSAuth, and writes the nonce to the nonce
// output file
// output file - only for ec2 auth
type ec2AuthNonceWriter struct {
*aws.AWSAuth
nonce string
Expand Down
61 changes: 2 additions & 59 deletions internal/tests/integration/datasources_vault_ec2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,14 @@
package integration

import (
"encoding/pem"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gotest.tools/v3/fs"
)

func setupDatasourcesVaultEc2Test(t *testing.T) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
t.Helper()

priv, der, _ := certificateGenerate()
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})

mux := http.NewServeMux()
mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
if r.Body != nil {
var err error
b, err = io.ReadAll(r.Body)
if !assert.NoError(t, err) {
w.WriteHeader(http.StatusInternalServerError)
return
}
defer r.Body.Close()
}
t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)

w.Write([]byte("testtoken"))
}))
mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("IMDS request: %s %s", r.Method, r.URL)
w.Write([]byte("i-00000000"))
}))
mux.HandleFunc("/sts/", stsHandler)
mux.HandleFunc("/ec2/", ec2Handler)
mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("unhandled request: %s %s", r.Method, r.URL)
w.WriteHeader(http.StatusNotFound)
}))

srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)

tmpDir, v := startVault(t)

err := v.vc.Sys().PutPolicy("writepol", `path "*" {
policy = "write"
}`)
require.NoError(t, err)
err = v.vc.Sys().PutPolicy("readpol", `path "*" {
policy = "read"
}`)
require.NoError(t, err)

return tmpDir, v, srv, cert
}

func TestDatasources_VaultEc2(t *testing.T) {
tmpDir, v, srv, cert := setupDatasourcesVaultEc2Test(t)
accountID, user := "1", "Test"
tmpDir, v, srv, cert := setupDatasourcesVaultAWSTest(t, accountID, user)

v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
defer v.vc.Logical().Delete("secret/foo")
Expand Down
56 changes: 56 additions & 0 deletions internal/tests/integration/datasources_vault_iam_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//go:build !windows
// +build !windows

package integration

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDatasources_VaultIAM(t *testing.T) {
accountID := "000000000000"
user := "foo"

tmpDir, v, srv, _ := setupDatasourcesVaultAWSTest(t, accountID, user)

v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
defer v.vc.Logical().Delete("secret/foo")

err := v.vc.Sys().EnableAuth("aws", "aws", "")
require.NoError(t, err)
defer v.vc.Sys().DisableAuth("aws")

endpoint := srv.URL

accessKeyID := "secret"
secretAccessKey := "access"

_, err = v.vc.Logical().Write("auth/aws/config/client", map[string]interface{}{
"access_key": accessKeyID,
"secret_key": secretAccessKey,
"endpoint": endpoint,
"iam_endpoint": endpoint + "/iam",
"sts_endpoint": endpoint + "/sts",
"sts_region": "us-east-1",
})
require.NoError(t, err)

_, err = v.vc.Logical().Write("auth/aws/role/foo", map[string]interface{}{
"auth_type": "iam",
"bound_iam_principal_arn": "arn:aws:iam::" + accountID + ":*",
"policies": "readpol",
"max_ttl": "5m",
})
require.NoError(t, err)

o, e, err := cmd(t, "-d", "vault=vault:///secret/",
"-i", `{{(ds "vault" "foo").value}}`).
withEnv("HOME", tmpDir.Join("home")).
withEnv("VAULT_ADDR", "http://"+v.addr).
withEnv("AWS_ACCESS_KEY_ID", accessKeyID).
withEnv("AWS_SECRET_ACCESS_KEY", secretAccessKey).
run()
assertSuccess(t, o, e, err, "bar")
}
2 changes: 1 addition & 1 deletion internal/tests/integration/datasources_vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func startVault(t *testing.T) (*fs.Dir, *vaultClient) {
"-dev",
"-dev-root-token-id="+vaultRootToken,
"-dev-kv-v1", // default to v1, so we can test v1 and v2
"-log-level=err",
"-log-level=info",
"-dev-listen-address="+vaultAddr,
"-config="+tmpDir.Join("config.json"),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//go:build !windows
// +build !windows

package integration

import (
Expand All @@ -7,12 +10,21 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"log"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/fullsailor/pkcs7"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gotest.tools/v3/fs"
)

const instanceDocument = `{
Expand Down Expand Up @@ -106,21 +118,34 @@ func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter
}
}

func stsHandler(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/xml")
_, err := w.Write([]byte(`<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::1:user/Test</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>1</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`))
if err != nil {
w.WriteHeader(500)
}
func stsHandler(t *testing.T, accountID, user string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()

form, _ := url.ParseQuery(string(body))

// action must be GetCallerIdentity
assert.Equal(t, "GetCallerIdentity", form.Get("Action"))

w.Header().Set("Content-Type", "text/xml")
_, err := w.Write([]byte(fmt.Sprintf(`<?xml version='1.0' encoding='utf-8'?>
<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::%[1]s:user/%[2]s</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>%[1]s</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`, accountID, user)))
if err != nil {
t.Errorf("failed to write response: %s", err)
w.WriteHeader(http.StatusInternalServerError)
}
assert.NoError(t, err)
})
}

func ec2Handler(w http.ResponseWriter, _ *http.Request) {
Expand Down Expand Up @@ -246,6 +271,100 @@ func ec2Handler(w http.ResponseWriter, _ *http.Request) {
</reservationSet>
</DescribeInstancesResponse>`))
if err != nil {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
}
}

func iamGetUserHandler(t *testing.T, accountID string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
form, _ := url.ParseQuery(string(body))

// action must be GetUser
assert.Equal(t, "GetUser", form.Get("Action"))

w.Header().Set("Content-Type", "text/xml")
_, err := w.Write([]byte(fmt.Sprintf(`<?xml version='1.0' encoding='utf-8'?>
<GetUserResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<GetUserResult>
<User>
<Path>/</Path>
<UserName>%[1]s</UserName>
<UserId>m3o9qmhhl9dnjlh2fflg</UserId>
<Arn>arn:aws:iam::%[2]s:user/%[1]s</Arn>
<CreateDate>2024-07-21T17:21:27.259000Z</CreateDate>
</User>
</GetUserResult>
<ResponseMetadata>
<RequestId>3d0e2445-64ea-4bfb-9244-30d810773f9e</RequestId>
</ResponseMetadata>
</GetUserResponse>`, form.Get("UserName"), accountID)))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
assert.NoError(t, err)
})
}

func setupDatasourcesVaultAWSTest(t *testing.T, accountID, user string) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
t.Helper()

priv, der, _ := certificateGenerate()
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})

mux := http.NewServeMux()
mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
if r.Body != nil {
var err error
b, err = io.ReadAll(r.Body)
if !assert.NoError(t, err) {
w.WriteHeader(http.StatusInternalServerError)
return
}
defer r.Body.Close()
}
t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)

w.Write([]byte("testtoken"))
}))
mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("IMDS request: %s %s", r.Method, r.URL)
w.Write([]byte("i-00000000"))
}))
mux.Handle("/sts/", stsHandler(t, accountID, user))
mux.Handle("/iam/", iamGetUserHandler(t, accountID))
mux.HandleFunc("/ec2/", ec2Handler)
mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("unhandled request: %s %s", r.Method, r.URL)
w.WriteHeader(http.StatusNotFound)
}))

// Vault sends requests to "/sts///" for some reason, and the ServeMux
// responds by redirecting to "/sts/" which Vault rejects. So we need to
// handle the extra slashes in a middleware first.
stripSlashes := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for strings.HasSuffix(r.URL.Path, "//") {
r.URL.Path = r.URL.Path[:len(r.URL.Path)-1]
}
mux.ServeHTTP(w, r)
})

srv := httptest.NewServer(stripSlashes)
t.Cleanup(srv.Close)

tmpDir, v := startVault(t)

err := v.vc.Sys().PutPolicy("writepol", `path "*" {
policy = "write"
}`)
require.NoError(t, err)
err = v.vc.Sys().PutPolicy("readpol", `path "*" {
policy = "read"
}`)
require.NoError(t, err)

return tmpDir, v, srv, cert
}

0 comments on commit 1da9105

Please sign in to comment.