diff --git a/internal/datafs/fsys.go b/internal/datafs/fsys.go
index f7894b32c..03abdf6d6 100644
--- a/internal/datafs/fsys.go
+++ b/internal/datafs/fsys.go
@@ -101,6 +101,8 @@ func FSysForPath(ctx context.Context, path string) (fs.FS, error) {
fsys = vaultauth.WithAuthMethod(compositeVaultAuthMethod(fileFsys), fsys)
}
+ fsys = fsimpl.WithContextFS(ctx, fsys)
+
return fsys, nil
}
diff --git a/internal/datafs/vaultauth.go b/internal/datafs/vaultauth.go
index 3ad733a0f..83b9daa57 100644
--- a/internal/datafs/vaultauth.go
+++ b/internal/datafs/vaultauth.go
@@ -20,15 +20,12 @@ func compositeVaultAuthMethod(envFsys fs.FS) api.AuthMethod {
return vaultauth.CompositeAuthMethod(
vaultauth.EnvAuthMethod(),
envEC2AuthAdapter(envFsys),
+ envIAMAuthAdapter(envFsys),
)
}
-// func CompositeVaultAuthMethod() api.AuthMethod {
-// return compositeVaultAuthMethod(WrapWdFS(osfs.NewFS()))
-// }
-
// envEC2AuthAdapter builds an AWS EC2 authentication method from environment
-// variables, for use only with [CompositeVaultAuthMethod]
+// variables, for use only with [compositeVaultAuthMethod]
func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")
@@ -61,8 +58,34 @@ func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod {
return &ec2AuthNonceWriter{AWSAuth: awsauth, nonce: nonce, output: output}
}
+// envIAMAuthAdapter builds an AWS IAM authentication method from environment
+// variables, for use only with [compositeVaultAuthMethod]
+func envIAMAuthAdapter(envFS fs.FS) api.AuthMethod {
+ mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws")
+ role := GetenvFsys(envFS, "VAULT_AUTH_AWS_ROLE")
+
+ // temporary workaround while we wait to deprecate AWS_META_ENDPOINT
+ if endpoint := os.Getenv("AWS_META_ENDPOINT"); endpoint != "" {
+ deprecated.WarnDeprecated(context.Background(), "Use AWS_EC2_METADATA_SERVICE_ENDPOINT instead of AWS_META_ENDPOINT")
+ if os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") == "" {
+ os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", endpoint)
+ }
+ }
+
+ awsauth, err := aws.NewAWSAuth(
+ aws.WithIAMAuth(),
+ aws.WithMountPath(mountPath),
+ aws.WithRole(role),
+ )
+ if err != nil {
+ return nil
+ }
+
+ return awsauth
+}
+
// ec2AuthNonceWriter - wraps an AWSAuth, and writes the nonce to the nonce
-// output file
+// output file - only for ec2 auth
type ec2AuthNonceWriter struct {
*aws.AWSAuth
nonce string
diff --git a/internal/tests/integration/datasources_vault_ec2_test.go b/internal/tests/integration/datasources_vault_ec2_test.go
index c00a8d43e..4b57020d0 100644
--- a/internal/tests/integration/datasources_vault_ec2_test.go
+++ b/internal/tests/integration/datasources_vault_ec2_test.go
@@ -4,71 +4,14 @@
package integration
import (
- "encoding/pem"
- "io"
- "net/http"
- "net/http/httptest"
"testing"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "gotest.tools/v3/fs"
)
-func setupDatasourcesVaultEc2Test(t *testing.T) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
- t.Helper()
-
- priv, der, _ := certificateGenerate()
- cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
-
- mux := http.NewServeMux()
- mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
- mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
- mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- var b []byte
- if r.Body != nil {
- var err error
- b, err = io.ReadAll(r.Body)
- if !assert.NoError(t, err) {
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- defer r.Body.Close()
- }
- t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)
-
- w.Write([]byte("testtoken"))
- }))
- mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- t.Logf("IMDS request: %s %s", r.Method, r.URL)
- w.Write([]byte("i-00000000"))
- }))
- mux.HandleFunc("/sts/", stsHandler)
- mux.HandleFunc("/ec2/", ec2Handler)
- mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- t.Logf("unhandled request: %s %s", r.Method, r.URL)
- w.WriteHeader(http.StatusNotFound)
- }))
-
- srv := httptest.NewServer(mux)
- t.Cleanup(srv.Close)
-
- tmpDir, v := startVault(t)
-
- err := v.vc.Sys().PutPolicy("writepol", `path "*" {
- policy = "write"
-}`)
- require.NoError(t, err)
- err = v.vc.Sys().PutPolicy("readpol", `path "*" {
- policy = "read"
-}`)
- require.NoError(t, err)
-
- return tmpDir, v, srv, cert
-}
-
func TestDatasources_VaultEc2(t *testing.T) {
- tmpDir, v, srv, cert := setupDatasourcesVaultEc2Test(t)
+ accountID, user := "1", "Test"
+ tmpDir, v, srv, cert := setupDatasourcesVaultAWSTest(t, accountID, user)
v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
defer v.vc.Logical().Delete("secret/foo")
diff --git a/internal/tests/integration/datasources_vault_iam_test.go b/internal/tests/integration/datasources_vault_iam_test.go
new file mode 100644
index 000000000..0cee8d3d6
--- /dev/null
+++ b/internal/tests/integration/datasources_vault_iam_test.go
@@ -0,0 +1,56 @@
+//go:build !windows
+// +build !windows
+
+package integration
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestDatasources_VaultIAM(t *testing.T) {
+ accountID := "000000000000"
+ user := "foo"
+
+ tmpDir, v, srv, _ := setupDatasourcesVaultAWSTest(t, accountID, user)
+
+ v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"})
+ defer v.vc.Logical().Delete("secret/foo")
+
+ err := v.vc.Sys().EnableAuth("aws", "aws", "")
+ require.NoError(t, err)
+ defer v.vc.Sys().DisableAuth("aws")
+
+ endpoint := srv.URL
+
+ accessKeyID := "secret"
+ secretAccessKey := "access"
+
+ _, err = v.vc.Logical().Write("auth/aws/config/client", map[string]interface{}{
+ "access_key": accessKeyID,
+ "secret_key": secretAccessKey,
+ "endpoint": endpoint,
+ "iam_endpoint": endpoint + "/iam",
+ "sts_endpoint": endpoint + "/sts",
+ "sts_region": "us-east-1",
+ })
+ require.NoError(t, err)
+
+ _, err = v.vc.Logical().Write("auth/aws/role/foo", map[string]interface{}{
+ "auth_type": "iam",
+ "bound_iam_principal_arn": "arn:aws:iam::" + accountID + ":*",
+ "policies": "readpol",
+ "max_ttl": "5m",
+ })
+ require.NoError(t, err)
+
+ o, e, err := cmd(t, "-d", "vault=vault:///secret/",
+ "-i", `{{(ds "vault" "foo").value}}`).
+ withEnv("HOME", tmpDir.Join("home")).
+ withEnv("VAULT_ADDR", "http://"+v.addr).
+ withEnv("AWS_ACCESS_KEY_ID", accessKeyID).
+ withEnv("AWS_SECRET_ACCESS_KEY", secretAccessKey).
+ run()
+ assertSuccess(t, o, e, err, "bar")
+}
diff --git a/internal/tests/integration/datasources_vault_test.go b/internal/tests/integration/datasources_vault_test.go
index 5fd0124c0..74a2512bb 100644
--- a/internal/tests/integration/datasources_vault_test.go
+++ b/internal/tests/integration/datasources_vault_test.go
@@ -69,7 +69,7 @@ func startVault(t *testing.T) (*fs.Dir, *vaultClient) {
"-dev",
"-dev-root-token-id="+vaultRootToken,
"-dev-kv-v1", // default to v1, so we can test v1 and v2
- "-log-level=err",
+ "-log-level=info",
"-dev-listen-address="+vaultAddr,
"-config="+tmpDir.Join("config.json"),
)
diff --git a/internal/tests/integration/test_ec2_utils.go b/internal/tests/integration/test_ec2_utils_test.go
similarity index 66%
rename from internal/tests/integration/test_ec2_utils.go
rename to internal/tests/integration/test_ec2_utils_test.go
index f7ad1eaae..29b92350d 100644
--- a/internal/tests/integration/test_ec2_utils.go
+++ b/internal/tests/integration/test_ec2_utils_test.go
@@ -1,3 +1,6 @@
+//go:build !windows
+// +build !windows
+
package integration
import (
@@ -7,12 +10,21 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
+ "fmt"
+ "io"
"log"
"math/big"
"net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
"time"
"github.com/fullsailor/pkcs7"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gotest.tools/v3/fs"
)
const instanceDocument = `{
@@ -106,21 +118,34 @@ func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter
}
}
-func stsHandler(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/xml")
- _, err := w.Write([]byte(`
-
- arn:aws:iam::1:user/Test
- AKIAI44QH8DHBEXAMPLE
- 1
-
-
- 01234567-89ab-cdef-0123-456789abcdef
-
-`))
- if err != nil {
- w.WriteHeader(500)
- }
+func stsHandler(t *testing.T, accountID, user string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ body, _ := io.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ form, _ := url.ParseQuery(string(body))
+
+ // action must be GetCallerIdentity
+ assert.Equal(t, "GetCallerIdentity", form.Get("Action"))
+
+ w.Header().Set("Content-Type", "text/xml")
+ _, err := w.Write([]byte(fmt.Sprintf(`
+
+
+ arn:aws:iam::%[1]s:user/%[2]s
+ AKIAI44QH8DHBEXAMPLE
+ %[1]s
+
+
+ 01234567-89ab-cdef-0123-456789abcdef
+
+ `, accountID, user)))
+ if err != nil {
+ t.Errorf("failed to write response: %s", err)
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ assert.NoError(t, err)
+ })
}
func ec2Handler(w http.ResponseWriter, _ *http.Request) {
@@ -246,6 +271,100 @@ func ec2Handler(w http.ResponseWriter, _ *http.Request) {
`))
if err != nil {
- w.WriteHeader(500)
+ w.WriteHeader(http.StatusInternalServerError)
}
}
+
+func iamGetUserHandler(t *testing.T, accountID string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ body, _ := io.ReadAll(r.Body)
+ form, _ := url.ParseQuery(string(body))
+
+ // action must be GetUser
+ assert.Equal(t, "GetUser", form.Get("Action"))
+
+ w.Header().Set("Content-Type", "text/xml")
+ _, err := w.Write([]byte(fmt.Sprintf(`
+
+
+
+ /
+ %[1]s
+ m3o9qmhhl9dnjlh2fflg
+ arn:aws:iam::%[2]s:user/%[1]s
+ 2024-07-21T17:21:27.259000Z
+
+
+
+ 3d0e2445-64ea-4bfb-9244-30d810773f9e
+
+ `, form.Get("UserName"), accountID)))
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ assert.NoError(t, err)
+ })
+}
+
+func setupDatasourcesVaultAWSTest(t *testing.T, accountID, user string) (*fs.Dir, *vaultClient, *httptest.Server, []byte) {
+ t.Helper()
+
+ priv, der, _ := certificateGenerate()
+ cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der))
+ mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler)
+ mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var b []byte
+ if r.Body != nil {
+ var err error
+ b, err = io.ReadAll(r.Body)
+ if !assert.NoError(t, err) {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ defer r.Body.Close()
+ }
+ t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b)
+
+ w.Write([]byte("testtoken"))
+ }))
+ mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Logf("IMDS request: %s %s", r.Method, r.URL)
+ w.Write([]byte("i-00000000"))
+ }))
+ mux.Handle("/sts/", stsHandler(t, accountID, user))
+ mux.Handle("/iam/", iamGetUserHandler(t, accountID))
+ mux.HandleFunc("/ec2/", ec2Handler)
+ mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Logf("unhandled request: %s %s", r.Method, r.URL)
+ w.WriteHeader(http.StatusNotFound)
+ }))
+
+ // Vault sends requests to "/sts///" for some reason, and the ServeMux
+ // responds by redirecting to "/sts/" which Vault rejects. So we need to
+ // handle the extra slashes in a middleware first.
+ stripSlashes := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ for strings.HasSuffix(r.URL.Path, "//") {
+ r.URL.Path = r.URL.Path[:len(r.URL.Path)-1]
+ }
+ mux.ServeHTTP(w, r)
+ })
+
+ srv := httptest.NewServer(stripSlashes)
+ t.Cleanup(srv.Close)
+
+ tmpDir, v := startVault(t)
+
+ err := v.vc.Sys().PutPolicy("writepol", `path "*" {
+ policy = "write"
+}`)
+ require.NoError(t, err)
+ err = v.vc.Sys().PutPolicy("readpol", `path "*" {
+ policy = "read"
+}`)
+ require.NoError(t, err)
+
+ return tmpDir, v, srv, cert
+}