diff --git a/internal/datafs/fsys.go b/internal/datafs/fsys.go index f7894b32c..03abdf6d6 100644 --- a/internal/datafs/fsys.go +++ b/internal/datafs/fsys.go @@ -101,6 +101,8 @@ func FSysForPath(ctx context.Context, path string) (fs.FS, error) { fsys = vaultauth.WithAuthMethod(compositeVaultAuthMethod(fileFsys), fsys) } + fsys = fsimpl.WithContextFS(ctx, fsys) + return fsys, nil } diff --git a/internal/datafs/vaultauth.go b/internal/datafs/vaultauth.go index 3ad733a0f..83b9daa57 100644 --- a/internal/datafs/vaultauth.go +++ b/internal/datafs/vaultauth.go @@ -20,15 +20,12 @@ func compositeVaultAuthMethod(envFsys fs.FS) api.AuthMethod { return vaultauth.CompositeAuthMethod( vaultauth.EnvAuthMethod(), envEC2AuthAdapter(envFsys), + envIAMAuthAdapter(envFsys), ) } -// func CompositeVaultAuthMethod() api.AuthMethod { -// return compositeVaultAuthMethod(WrapWdFS(osfs.NewFS())) -// } - // envEC2AuthAdapter builds an AWS EC2 authentication method from environment -// variables, for use only with [CompositeVaultAuthMethod] +// variables, for use only with [compositeVaultAuthMethod] func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod { mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws") @@ -61,8 +58,34 @@ func envEC2AuthAdapter(envFS fs.FS) api.AuthMethod { return &ec2AuthNonceWriter{AWSAuth: awsauth, nonce: nonce, output: output} } +// envIAMAuthAdapter builds an AWS IAM authentication method from environment +// variables, for use only with [compositeVaultAuthMethod] +func envIAMAuthAdapter(envFS fs.FS) api.AuthMethod { + mountPath := GetenvFsys(envFS, "VAULT_AUTH_AWS_MOUNT", "aws") + role := GetenvFsys(envFS, "VAULT_AUTH_AWS_ROLE") + + // temporary workaround while we wait to deprecate AWS_META_ENDPOINT + if endpoint := os.Getenv("AWS_META_ENDPOINT"); endpoint != "" { + deprecated.WarnDeprecated(context.Background(), "Use AWS_EC2_METADATA_SERVICE_ENDPOINT instead of AWS_META_ENDPOINT") + if os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT") == "" { + os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", endpoint) + } + } + + awsauth, err := aws.NewAWSAuth( + aws.WithIAMAuth(), + aws.WithMountPath(mountPath), + aws.WithRole(role), + ) + if err != nil { + return nil + } + + return awsauth +} + // ec2AuthNonceWriter - wraps an AWSAuth, and writes the nonce to the nonce -// output file +// output file - only for ec2 auth type ec2AuthNonceWriter struct { *aws.AWSAuth nonce string diff --git a/internal/tests/integration/datasources_vault_ec2_test.go b/internal/tests/integration/datasources_vault_ec2_test.go index c00a8d43e..4b57020d0 100644 --- a/internal/tests/integration/datasources_vault_ec2_test.go +++ b/internal/tests/integration/datasources_vault_ec2_test.go @@ -4,71 +4,14 @@ package integration import ( - "encoding/pem" - "io" - "net/http" - "net/http/httptest" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gotest.tools/v3/fs" ) -func setupDatasourcesVaultEc2Test(t *testing.T) (*fs.Dir, *vaultClient, *httptest.Server, []byte) { - t.Helper() - - priv, der, _ := certificateGenerate() - cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) - - mux := http.NewServeMux() - mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der)) - mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler) - mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var b []byte - if r.Body != nil { - var err error - b, err = io.ReadAll(r.Body) - if !assert.NoError(t, err) { - w.WriteHeader(http.StatusInternalServerError) - return - } - defer r.Body.Close() - } - t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b) - - w.Write([]byte("testtoken")) - })) - mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("IMDS request: %s %s", r.Method, r.URL) - w.Write([]byte("i-00000000")) - })) - mux.HandleFunc("/sts/", stsHandler) - mux.HandleFunc("/ec2/", ec2Handler) - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("unhandled request: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusNotFound) - })) - - srv := httptest.NewServer(mux) - t.Cleanup(srv.Close) - - tmpDir, v := startVault(t) - - err := v.vc.Sys().PutPolicy("writepol", `path "*" { - policy = "write" -}`) - require.NoError(t, err) - err = v.vc.Sys().PutPolicy("readpol", `path "*" { - policy = "read" -}`) - require.NoError(t, err) - - return tmpDir, v, srv, cert -} - func TestDatasources_VaultEc2(t *testing.T) { - tmpDir, v, srv, cert := setupDatasourcesVaultEc2Test(t) + accountID, user := "1", "Test" + tmpDir, v, srv, cert := setupDatasourcesVaultAWSTest(t, accountID, user) v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"}) defer v.vc.Logical().Delete("secret/foo") diff --git a/internal/tests/integration/datasources_vault_iam_test.go b/internal/tests/integration/datasources_vault_iam_test.go new file mode 100644 index 000000000..0cee8d3d6 --- /dev/null +++ b/internal/tests/integration/datasources_vault_iam_test.go @@ -0,0 +1,56 @@ +//go:build !windows +// +build !windows + +package integration + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDatasources_VaultIAM(t *testing.T) { + accountID := "000000000000" + user := "foo" + + tmpDir, v, srv, _ := setupDatasourcesVaultAWSTest(t, accountID, user) + + v.vc.Logical().Write("secret/foo", map[string]interface{}{"value": "bar"}) + defer v.vc.Logical().Delete("secret/foo") + + err := v.vc.Sys().EnableAuth("aws", "aws", "") + require.NoError(t, err) + defer v.vc.Sys().DisableAuth("aws") + + endpoint := srv.URL + + accessKeyID := "secret" + secretAccessKey := "access" + + _, err = v.vc.Logical().Write("auth/aws/config/client", map[string]interface{}{ + "access_key": accessKeyID, + "secret_key": secretAccessKey, + "endpoint": endpoint, + "iam_endpoint": endpoint + "/iam", + "sts_endpoint": endpoint + "/sts", + "sts_region": "us-east-1", + }) + require.NoError(t, err) + + _, err = v.vc.Logical().Write("auth/aws/role/foo", map[string]interface{}{ + "auth_type": "iam", + "bound_iam_principal_arn": "arn:aws:iam::" + accountID + ":*", + "policies": "readpol", + "max_ttl": "5m", + }) + require.NoError(t, err) + + o, e, err := cmd(t, "-d", "vault=vault:///secret/", + "-i", `{{(ds "vault" "foo").value}}`). + withEnv("HOME", tmpDir.Join("home")). + withEnv("VAULT_ADDR", "http://"+v.addr). + withEnv("AWS_ACCESS_KEY_ID", accessKeyID). + withEnv("AWS_SECRET_ACCESS_KEY", secretAccessKey). + run() + assertSuccess(t, o, e, err, "bar") +} diff --git a/internal/tests/integration/datasources_vault_test.go b/internal/tests/integration/datasources_vault_test.go index 5fd0124c0..74a2512bb 100644 --- a/internal/tests/integration/datasources_vault_test.go +++ b/internal/tests/integration/datasources_vault_test.go @@ -69,7 +69,7 @@ func startVault(t *testing.T) (*fs.Dir, *vaultClient) { "-dev", "-dev-root-token-id="+vaultRootToken, "-dev-kv-v1", // default to v1, so we can test v1 and v2 - "-log-level=err", + "-log-level=info", "-dev-listen-address="+vaultAddr, "-config="+tmpDir.Join("config.json"), ) diff --git a/internal/tests/integration/test_ec2_utils.go b/internal/tests/integration/test_ec2_utils_test.go similarity index 66% rename from internal/tests/integration/test_ec2_utils.go rename to internal/tests/integration/test_ec2_utils_test.go index f7ad1eaae..29b92350d 100644 --- a/internal/tests/integration/test_ec2_utils.go +++ b/internal/tests/integration/test_ec2_utils_test.go @@ -1,3 +1,6 @@ +//go:build !windows +// +build !windows + package integration import ( @@ -7,12 +10,21 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" + "io" "log" "math/big" "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" "time" "github.com/fullsailor/pkcs7" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gotest.tools/v3/fs" ) const instanceDocument = `{ @@ -106,21 +118,34 @@ func pkcsHandler(priv *rsa.PrivateKey, derBytes []byte) func(http.ResponseWriter } } -func stsHandler(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/xml") - _, err := w.Write([]byte(` - - arn:aws:iam::1:user/Test - AKIAI44QH8DHBEXAMPLE - 1 - - - 01234567-89ab-cdef-0123-456789abcdef - -`)) - if err != nil { - w.WriteHeader(500) - } +func stsHandler(t *testing.T, accountID, user string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + + form, _ := url.ParseQuery(string(body)) + + // action must be GetCallerIdentity + assert.Equal(t, "GetCallerIdentity", form.Get("Action")) + + w.Header().Set("Content-Type", "text/xml") + _, err := w.Write([]byte(fmt.Sprintf(` + + + arn:aws:iam::%[1]s:user/%[2]s + AKIAI44QH8DHBEXAMPLE + %[1]s + + + 01234567-89ab-cdef-0123-456789abcdef + + `, accountID, user))) + if err != nil { + t.Errorf("failed to write response: %s", err) + w.WriteHeader(http.StatusInternalServerError) + } + assert.NoError(t, err) + }) } func ec2Handler(w http.ResponseWriter, _ *http.Request) { @@ -246,6 +271,100 @@ func ec2Handler(w http.ResponseWriter, _ *http.Request) { `)) if err != nil { - w.WriteHeader(500) + w.WriteHeader(http.StatusInternalServerError) } } + +func iamGetUserHandler(t *testing.T, accountID string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + form, _ := url.ParseQuery(string(body)) + + // action must be GetUser + assert.Equal(t, "GetUser", form.Get("Action")) + + w.Header().Set("Content-Type", "text/xml") + _, err := w.Write([]byte(fmt.Sprintf(` + + + + / + %[1]s + m3o9qmhhl9dnjlh2fflg + arn:aws:iam::%[2]s:user/%[1]s + 2024-07-21T17:21:27.259000Z + + + + 3d0e2445-64ea-4bfb-9244-30d810773f9e + + `, form.Get("UserName"), accountID))) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + assert.NoError(t, err) + }) +} + +func setupDatasourcesVaultAWSTest(t *testing.T, accountID, user string) (*fs.Dir, *vaultClient, *httptest.Server, []byte) { + t.Helper() + + priv, der, _ := certificateGenerate() + cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + + mux := http.NewServeMux() + mux.HandleFunc("/latest/dynamic/instance-identity/pkcs7", pkcsHandler(priv, der)) + mux.HandleFunc("/latest/dynamic/instance-identity/document", instanceDocumentHandler) + mux.HandleFunc("/latest/api/token", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var b []byte + if r.Body != nil { + var err error + b, err = io.ReadAll(r.Body) + if !assert.NoError(t, err) { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer r.Body.Close() + } + t.Logf("IMDS Token request: %s %s: %s", r.Method, r.URL, b) + + w.Write([]byte("testtoken")) + })) + mux.HandleFunc("/latest/meta-data/instance-id", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("IMDS request: %s %s", r.Method, r.URL) + w.Write([]byte("i-00000000")) + })) + mux.Handle("/sts/", stsHandler(t, accountID, user)) + mux.Handle("/iam/", iamGetUserHandler(t, accountID)) + mux.HandleFunc("/ec2/", ec2Handler) + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("unhandled request: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + })) + + // Vault sends requests to "/sts///" for some reason, and the ServeMux + // responds by redirecting to "/sts/" which Vault rejects. So we need to + // handle the extra slashes in a middleware first. + stripSlashes := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for strings.HasSuffix(r.URL.Path, "//") { + r.URL.Path = r.URL.Path[:len(r.URL.Path)-1] + } + mux.ServeHTTP(w, r) + }) + + srv := httptest.NewServer(stripSlashes) + t.Cleanup(srv.Close) + + tmpDir, v := startVault(t) + + err := v.vc.Sys().PutPolicy("writepol", `path "*" { + policy = "write" +}`) + require.NoError(t, err) + err = v.vc.Sys().PutPolicy("readpol", `path "*" { + policy = "read" +}`) + require.NoError(t, err) + + return tmpDir, v, srv, cert +}