diff --git a/Dockerfile b/Dockerfile index 37c07f7..cd5a2a2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang@sha256:ec67c62f48ddfbca1ccaef18f9b3addccd707e1885fa28702a3954340786fcf6 as dependency +FROM golang:1.18.1 as dependency WORKDIR /work ADD ./go.* ./ RUN go mod download diff --git a/Taskfile.yml b/Taskfile.yml index 970f41c..225fd31 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -137,9 +137,6 @@ tasks: go:integration-tests: cmds: - KUBECTL_CONTEXT=kind-{{.KIND_CLUSTER_NAME}} go test --tags=integration ./... - go:mod-tidy: - cmds: - - go mod tidy example:load: desc: load demo data cmds: diff --git a/cmd/kubernetes-kms-vault/main.go b/cmd/kubernetes-kms-vault/main.go index 0d6c62b..ee7b810 100644 --- a/cmd/kubernetes-kms-vault/main.go +++ b/cmd/kubernetes-kms-vault/main.go @@ -24,7 +24,10 @@ import ( ) const ( - healthPort = 8787 + defaultHealthzTimeout = 20 * time.Second + hostPortFormatBase = 10 + + healthPort = 8787 metricsPort = "8095" ) @@ -34,7 +37,7 @@ var ( configFilePath = flag.String("config-file-path", "./config.yaml", "Path for Vault Provider config file") healthzPort = flag.Int("healthz-port", healthPort, "port for health check") healthzPath = flag.String("healthz-path", "/healthz", "path for health check") - healthzTimeout = flag.Duration("healthz-timeout", 20*time.Second, "RPC timeout for health check") + healthzTimeout = flag.Duration("healthz-timeout", defaultHealthzTimeout, "RPC timeout for health check") metricsBackend = flag.String("metrics-backend", "prometheus", "Backend used for metrics") metricsAddress = flag.String("metrics-addr", metricsPort, "The address the metric endpoint binds to") ) @@ -43,9 +46,11 @@ func main() { klog.InitFlags(nil) flag.Parse() + if *logFormatJSON { klog.SetLogger(json.JSONLogger) } + ctx := withShutdownSignal(context.Background()) // initialize metrics exporter @@ -55,20 +60,24 @@ func main() { klog.Errorln(err) os.Exit(1) } + klog.Fatalln("metrics service has stopped gracefully") }() klog.InfoS("Starting VaultEncryptionServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate) + cfg, err := config.New(*configFilePath) if err != nil { klog.Errorln(err) os.Exit(1) } + proto, addr, err := utils.ParseEndpoint(*listenAddr) if err != nil { klog.Errorln(err) os.Exit(1) } + listener, err := net.Listen(proto, addr) if err != nil { klog.Errorln(err) @@ -82,34 +91,39 @@ func main() { s := grpc.NewServer(opts...) kmsServer, err := server.New(ctx, cfg) pb.RegisterKeyManagementServiceServer(s, kmsServer) + if err != nil { klog.Errorln(fmt.Errorf("failed to listen: %w", err)) os.Exit(1) } + klog.Infof("Listening for connections on address: %v", listener.Addr()) + go func() { - err := s.Serve(listener) - if err != nil { + if err := s.Serve(listener); err != nil { klog.Errorln(err) os.Exit(1) } + klog.Fatalln("GRPC service has stopped gracefully") }() + healthz := &server.HealthZ{ Service: kmsServer, HealthCheckURL: &url.URL{ - Host: net.JoinHostPort("", strconv.FormatUint(uint64(*healthzPort), 10)), + Host: net.JoinHostPort("", strconv.FormatUint(uint64(*healthzPort), hostPortFormatBase)), Path: *healthzPath, }, UnixSocketPath: listener.Addr().String(), RPCTimeout: *healthzTimeout, } + go func() { - err := healthz.Serve() - if err != nil { + if err := healthz.Serve(); err != nil { klog.Errorln(err) os.Exit(1) } + klog.Fatalln("healtz service has stopped gracefully") }() @@ -135,5 +149,6 @@ func withShutdownSignal(ctx context.Context) context.Context { klog.Info("received shutdown signal") cancel() }() + return nctx } diff --git a/go.mod b/go.mod index cbf0d6a..6996505 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ondat/trousseau -go 1.17 +go 1.18 replace github.com/ondat/trousseau => ./ diff --git a/internal/config/config.go b/internal/config/config.go index 92cd2dd..e3545d6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,19 +20,22 @@ type ProviderConfig interface { func New(cfpPath string) (ProviderConfig, error) { klog.V(klogv).Infof("Populating AppConfig from %s", cfpPath) viper.SetConfigType("yaml") + file, err := os.ReadFile(filepath.Clean(cfpPath)) if err != nil { return nil, fmt.Errorf("unable to open config file %s: %w", cfpPath, err) } + err = viper.ReadConfig(bytes.NewBuffer(file)) if err != nil { return nil, fmt.Errorf("unable to read config file %s: %w", cfpPath, err) } + var cfg appConfig - err = viper.Unmarshal(&cfg) - if err != nil { + if err = viper.Unmarshal(&cfg); err != nil { return nil, fmt.Errorf("unable to unmarshal config file %s: %w", cfpPath, err) } + return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8c2c708..4b677ba 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -20,8 +20,11 @@ vault: func TestMain(m *testing.M) { setUp() + retCode := m.Run() + tearDown() + os.Exit(retCode) } @@ -30,22 +33,24 @@ func setUp() { if err != nil { log.Fatal(err) } - defer f.Close() + _, err = f.Write(data) + f.Close() + if err != nil { log.Fatal(err) } } func tearDown() { - err := os.Remove(file) - if err != nil { + if err := os.Remove(file); err != nil { log.Fatal(err) } } func TestParseProvderInConfig(t *testing.T) { r, err := cfg.New(file) + assert.NoError(t, err) assert.Equal(t, "vault", r.GetProvider(), "Provider should return vault") } @@ -54,6 +59,7 @@ func TestParseVaultAddressInConfig(t *testing.T) { r, err := cfg.New(file) vaultCfg := r.GetVaultConfig() + assert.NoError(t, err) assert.Equal(t, "http://localhost:9200", vaultCfg.Address, "Config should return vault address") } diff --git a/internal/encrypt/vault.go b/internal/encrypt/vault.go index dc5802d..b93058a 100644 --- a/internal/encrypt/vault.go +++ b/internal/encrypt/vault.go @@ -16,7 +16,7 @@ import ( const ( defaultTransitPath = "transit" - defaultAuthPath = "auth" + defaultAuthPath = "auth" ) // Handle all communication with Vault server. @@ -30,8 +30,8 @@ type vaultWrapper struct { } // Initialize a client wrapper for vault kms provider. -func newClientWrapper(config *config.VaultConfig) (*vaultWrapper, error) { - client, err := newVaultApiClient(config) +func newClientWrapper(vaultConfig *config.VaultConfig) (*vaultWrapper, error) { + client, err := newVaultAPIClient(vaultConfig) if err != nil { return nil, fmt.Errorf("unable to create vault client: %w", err) } @@ -39,28 +39,29 @@ func newClientWrapper(config *config.VaultConfig) (*vaultWrapper, error) { // Vault transit path is configurable. "path", "/path", "path/" and "/path/" // are the same. transit := defaultTransitPath - if config.TransitPath != "" { - transit = config.TransitPath + if vaultConfig.TransitPath != "" { + transit = vaultConfig.TransitPath } // auth path is configurable. "path", "/path", "path/" and "/path/" are the same. auth := defaultAuthPath - if config.AuthPath != "" { - auth = config.AuthPath + if vaultConfig.AuthPath != "" { + auth = vaultConfig.AuthPath } + wrapper := &vaultWrapper{ client: client, encryptPath: path.Join("v1", transit, "encrypt"), decryptPath: path.Join("v1", transit, "decrypt"), authPath: path.Join(auth), - config: config, + config: vaultConfig, } // Set token for the vaultapi.client. - if config.Token != "" { - client.SetToken(config.Token) + if vaultConfig.Token != "" { + client.SetToken(vaultConfig.Token) } else { - if err := wrapper.getInitialToken(config); err != nil { + if err := wrapper.getInitialToken(vaultConfig); err != nil { return nil, fmt.Errorf("unable to get initial token: %w", err) } } @@ -68,36 +69,38 @@ func newClientWrapper(config *config.VaultConfig) (*vaultWrapper, error) { return wrapper, nil } -func newVaultApiClient(config *config.VaultConfig) (*vaultapi.Client, error) { - vaultConfig := vaultapi.DefaultConfig() - vaultConfig.Address = config.Address +func newVaultAPIClient(vaultConfig *config.VaultConfig) (*vaultapi.Client, error) { + defaultConfig := vaultapi.DefaultConfig() + defaultConfig.Address = vaultConfig.Address tlsConfig := &vaultapi.TLSConfig{ - CACert: config.VaultCACert, - ClientCert: config.ClientCert, - ClientKey: config.ClientKey, - TLSServerName: config.TLSServerName, + CACert: vaultConfig.VaultCACert, + ClientCert: vaultConfig.ClientCert, + ClientKey: vaultConfig.ClientKey, + TLSServerName: vaultConfig.TLSServerName, } - if err := vaultConfig.ConfigureTLS(tlsConfig); err != nil { - return nil, fmt.Errorf("unable to configure TLS for %s: %w", config.TLSServerName, err) + if err := defaultConfig.ConfigureTLS(tlsConfig); err != nil { + return nil, fmt.Errorf("unable to configure TLS for %s: %w", vaultConfig.TLSServerName, err) } - return vaultapi.NewClient(vaultConfig) + return vaultapi.NewClient(defaultConfig) } -func (c *vaultWrapper) getInitialToken(config *config.VaultConfig) error { +func (c *vaultWrapper) getInitialToken(vaultConfig *config.VaultConfig) error { switch { - case config.ClientCert != "" && config.ClientKey != "": - token, err := c.tlsToken(config) + case vaultConfig.ClientCert != "" && vaultConfig.ClientKey != "": + token, err := c.tlsToken() if err != nil { return fmt.Errorf("rotating token through TLS auth backend: %w", err) } + c.client.SetToken(token) - case config.RoleID != "": - token, err := c.appRoleToken(config) + case vaultConfig.RoleID != "": + token, err := c.appRoleToken(vaultConfig) if err != nil { return fmt.Errorf("rotating token through app role backend: %w", err) } + c.client.SetToken(token) default: // configuration has already been validated, flow should not reach here @@ -107,11 +110,12 @@ func (c *vaultWrapper) getInitialToken(config *config.VaultConfig) error { return nil } -func (c *vaultWrapper) tlsToken(config *config.VaultConfig) (string, error) { - path := path.Join("/", c.authPath, "cert", "login") - resp, err := c.client.Logical().Write(path, nil) +func (c *vaultWrapper) tlsToken() (string, error) { + loginPath := path.Join("/", c.authPath, "cert", "login") + + resp, err := c.client.Logical().Write(loginPath, nil) if err != nil { - return "", fmt.Errorf("unable to write TLS via API on %s: %w", path, err) + return "", fmt.Errorf("unable to write TLS via API on %s: %w", loginPath, err) } else if resp.Auth == nil { return "", errors.New("authentication information not found") } @@ -119,15 +123,16 @@ func (c *vaultWrapper) tlsToken(config *config.VaultConfig) (string, error) { return resp.Auth.ClientToken, nil } -func (c *vaultWrapper) appRoleToken(config *config.VaultConfig) (string, error) { +func (c *vaultWrapper) appRoleToken(vaultConfig *config.VaultConfig) (string, error) { data := map[string]interface{}{ - "role_id": config.RoleID, - "secret_id": config.SecretID, + "role_id": vaultConfig.RoleID, + "secret_id": vaultConfig.SecretID, } - path := path.Join("/", c.authPath, "approle", "login") - resp, err := c.client.Logical().Write(path, data) + loginPath := path.Join("/", c.authPath, "approle", "login") + + resp, err := c.client.Logical().Write(loginPath, data) if err != nil { - return "", fmt.Errorf("unable to write app role token via API on %s: %w", path, err) + return "", fmt.Errorf("unable to write app role token via API on %s: %w", loginPath, err) } else if resp.Auth == nil { return "", errors.New("authentication information not found") } @@ -139,6 +144,7 @@ func (c *vaultWrapper) Encrypt(data []byte) ([]byte, error) { if err != nil { return nil, fmt.Errorf("unable to encrypt data: %w", err) } + return []byte(response), nil } func (c *vaultWrapper) Decrypt(data []byte) ([]byte, error) { @@ -146,13 +152,14 @@ func (c *vaultWrapper) Decrypt(data []byte) ([]byte, error) { if err != nil { return nil, fmt.Errorf("unable to decrypt data: %w", err) } + return []byte(response), nil } -func (c *vaultWrapper) request(path string, data interface{}) (*vaultapi.Secret, error) { - req := c.client.NewRequest("POST", "/"+path) +func (c *vaultWrapper) request(requestPath string, data interface{}) (*vaultapi.Secret, error) { + req := c.client.NewRequest("POST", "/"+requestPath) if err := req.SetJSONBody(data); err != nil { - return nil, fmt.Errorf("unable to set request JSON on %s: %w", path, err) + return nil, fmt.Errorf("unable to set request JSON on %s: %w", requestPath, err) } resp, err := c.client.RawRequest(req) @@ -160,15 +167,17 @@ func (c *vaultWrapper) request(path string, data interface{}) (*vaultapi.Secret, if resp.StatusCode == http.StatusForbidden { return nil, newForbiddenError(err) } - return nil, fmt.Errorf("error making POST request on %s: %w", path, err) - } - if resp == nil { - return nil, fmt.Errorf("no response received for POST request on %s: %w", path, err) + + return nil, fmt.Errorf("error making POST request on %s: %w", requestPath, err) + } else if resp == nil { + return nil, fmt.Errorf("no response received for POST request on %s: %w", requestPath, err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response code: %v received for POST request to %v", resp.StatusCode, path) + return nil, fmt.Errorf("unexpected response code: %v received for POST request to %v", resp.StatusCode, requestPath) } + return vaultapi.ParseSecret(resp.Body) } @@ -176,48 +185,57 @@ func (c *vaultWrapper) withRefreshToken(isEncrypt bool, key, data string) (strin // Execute operation first time. var ( result string - err error + err error ) func() { c.rwmutex.RLock() defer c.rwmutex.RUnlock() + if isEncrypt { result, err = c.encryptLocked(key, data) } else { result, err = c.decryptLocked(key, data) } }() + if err == nil || c.config.Token != "" { return result, nil } - _, ok := err.(*forbiddenError) - if !ok { + + if _, ok := err.(*forbiddenError); !ok { return result, fmt.Errorf("error during connection: %w", err) } + c.rwmutex.Lock() defer c.rwmutex.Unlock() + err = c.refreshTokenLocked(c.config) if err != nil { return result, fmt.Errorf("error refresh token request: %w", err) } - klog.V(2).Infof("vault token refreshed") + + klog.Infof("vault token refreshed") + if isEncrypt { result, err = c.encryptLocked(key, data) } else { result, err = c.decryptLocked(key, data) } + if err != nil { err = fmt.Errorf("error during en/de-cryption: %w", err) } + return result, err } -func (c *vaultWrapper) refreshTokenLocked(config *config.VaultConfig) error { - return c.getInitialToken(config) +func (c *vaultWrapper) refreshTokenLocked(vaultConfig *config.VaultConfig) error { + return c.getInitialToken(vaultConfig) } -func (c *vaultWrapper) encryptLocked(key string, data string) (string, error) { +func (c *vaultWrapper) encryptLocked(key, data string) (string, error) { dataReq := map[string]string{"plaintext": data} + resp, err := c.request(path.Join(c.encryptPath, key), dataReq) if err != nil { return "", fmt.Errorf("error during encrypt request: %w", err) @@ -231,8 +249,9 @@ func (c *vaultWrapper) encryptLocked(key string, data string) (string, error) { return result, nil } -func (c *vaultWrapper) decryptLocked(key string, data string) (string, error) { - dataReq := map[string]string{"ciphertext": string(data)} +func (c *vaultWrapper) decryptLocked(_, data string) (string, error) { + dataReq := map[string]string{"ciphertext": data} + resp, err := c.request(path.Join(c.decryptPath, c.config.KeyNames[0]), dataReq) if err != nil { return "", fmt.Errorf("error during decrypt request: %w", err) diff --git a/internal/metrics/prometheus_exporter.go b/internal/metrics/prometheus_exporter.go index c4d746e..19b44ea 100644 --- a/internal/metrics/prometheus_exporter.go +++ b/internal/metrics/prometheus_exporter.go @@ -25,6 +25,7 @@ func servePrometheusExporter(metricsAddress string) error { klog.InfoS("Prometheus metrics server starting", "address", metricsAddress) http.HandleFunc(fmt.Sprintf("/%s", metricsEndpoint), exporter.ServeHTTP) + if err := http.ListenAndServe(fmt.Sprintf(":%s", metricsAddress), nil); err != nil { return fmt.Errorf("failed to register prometheus endpoint: %w", err) } diff --git a/internal/server/grpc.go b/internal/server/grpc.go index 3920be3..0139774 100644 --- a/internal/server/grpc.go +++ b/internal/server/grpc.go @@ -32,6 +32,7 @@ func New(ctx context.Context, cfg config.ProviderConfig) (KeyManagementService, if err != nil { return nil, fmt.Errorf("unable to create encrypt service: %w", err) } + return &keyManagementServiceServer{ kvClient: kvClient, reporter: metrics.NewStatsReporter(), @@ -45,24 +46,30 @@ func (k *keyManagementServiceServer) Decrypt(ctx context.Context, data *v1beta1. defer func() { errors := "" status := metrics.SuccessStatusTypeValue + if err != nil { status = metrics.ErrorStatusTypeValue errors = err.Error() } + k.reporter.ReportRequest(ctx, metrics.DecryptOperationTypeValue, status, time.Since(start).Seconds(), errors) }() klog.V(klogv).Infof("decrypt request started ") + r, err := k.kvClient.Decrypt(data.Cipher) if err != nil { klog.ErrorS(err, "failed to decrypt") return nil, fmt.Errorf("failed to decrypt: %w", err) } + w, err := base64.StdEncoding.DecodeString(string(r)) if err != nil { klog.ErrorS(err, "failed decode encrypted data") return nil, fmt.Errorf("failed decode encrypted data: %w", err) } - klog.V(2).Infof("decrypt request complete") + + klog.Infof("decrypt request complete") + return &v1beta1.DecryptResponse{Plain: w}, nil } @@ -73,20 +80,26 @@ func (k *keyManagementServiceServer) Encrypt(ctx context.Context, data *v1beta1. defer func() { errors := "" status := metrics.SuccessStatusTypeValue + if err != nil { status = metrics.ErrorStatusTypeValue errors = err.Error() } + k.reporter.ReportRequest(ctx, metrics.EncryptOperationTypeValue, status, time.Since(start).Seconds(), errors) }() klog.V(klogv).Infof("encrypt request started") + plain := base64.StdEncoding.EncodeToString(data.Plain) + response, err := k.kvClient.Encrypt([]byte(plain)) if err != nil { klog.ErrorS(err, "failed to encrypt") return nil, fmt.Errorf("failed to encrypt: %w", err) } - klog.V(2).Infof("encrypt request complete") + + klog.Infof("encrypt request complete") + return &v1beta1.EncryptResponse{Cipher: response}, nil } diff --git a/internal/server/health.go b/internal/server/health.go index cef3239..de65301 100644 --- a/internal/server/health.go +++ b/internal/server/health.go @@ -31,6 +31,7 @@ type HealthZ struct { func (h *HealthZ) Serve() error { serveMux := http.NewServeMux() serveMux.HandleFunc(h.HealthCheckURL.EscapedPath(), h.ServeHTTP) + if err := http.ListenAndServe(h.HealthCheckURL.Host, serveMux); err != nil && err != http.ErrServerClosed { return fmt.Errorf("failed to start health check server: %w", err) } @@ -40,6 +41,7 @@ func (h *HealthZ) Serve() error { func (h *HealthZ) ServeHTTP(w http.ResponseWriter, r *http.Request) { klog.V(klogv).Infof("Started health check") + ctx, cancel := context.WithTimeout(context.Background(), h.RPCTimeout) defer cancel() @@ -51,8 +53,8 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer conn.Close() kmsClient := pb.NewKeyManagementServiceClient(conn) - err = h.checkRPC(ctx, kmsClient) - if err != nil { + + if err = h.checkRPC(ctx, kmsClient); err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } @@ -62,21 +64,23 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + dec, err := h.Service.Decrypt(ctx, &pb.DecryptRequest{Cipher: enc.Cipher}) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return - } - if string(dec.Plain) != healthCheckPlainText { + } else if string(dec.Plain) != healthCheckPlainText { http.Error(w, "plain text mismatch after decryption", http.StatusInternalServerError) return } + w.WriteHeader(http.StatusOK) - _, err = w.Write([]byte("ok")) - if err != nil { + + if _, err := w.Write([]byte("ok")); err != nil { http.Error(w, err.Error(), http.StatusNotFound) return } + klog.V(klogv).Infof("Completed health check") } @@ -87,9 +91,11 @@ func (h *HealthZ) checkRPC(ctx context.Context, client pb.KeyManagementServiceCl if err != nil { return fmt.Errorf("unable to get version: %w", err) } + if v.Version != version.APIVersion || v.RuntimeName != version.Runtime || v.RuntimeVersion != version.BuildVersion { return errors.New("failed to get correct version response") } + return nil } diff --git a/internal/utils/grpc.go b/internal/utils/grpc.go index 88c7280..80ce232 100644 --- a/internal/utils/grpc.go +++ b/internal/utils/grpc.go @@ -15,24 +15,31 @@ const ( ) // ParseEndpoint returns unix socket's protocol and address -func ParseEndpoint(ep string) (string, string, error) { +func ParseEndpoint(ep string) (proto, address string, err error) { + err = fmt.Errorf("invalid endpoint: %s", ep) + if strings.HasPrefix(strings.ToLower(ep), "unix://") { s := strings.SplitN(ep, "://", splitin) if s[1] != "" { - return s[0], s[1], nil + proto = s[0] + address = s[1] + err = nil } } - return "", "", fmt.Errorf("invalid endpoint: %s", ep) + + return } // UnaryServerInterceptor provides metrics around Unary RPCs. func UnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { klog.V(klogv).Infof("GRPC call: %s", info.FullMethod) + resp, err := handler(ctx, req) if err != nil { klog.ErrorS(err, "GRPC request error") err = fmt.Errorf("GRPC request error: %w", err) } + return resp, err } diff --git a/internal/utils/vault.go b/internal/utils/vault.go index 1a48793..ee5ec11 100644 --- a/internal/utils/vault.go +++ b/internal/utils/vault.go @@ -11,52 +11,58 @@ type AppRoleCredentials struct { RoleID string } -func CreateVaultTransitKey(cli *api.Client, prefix, name string, params map[string]interface{}, configParams map[string]interface{}) error { +func CreateVaultTransitKey(cli *api.Client, prefix, name string, params, configParams map[string]interface{}) error { path := fmt.Sprintf("%s/keys/%s", prefix, name) - _, err := cli.Logical().Write(path, params) - if err != nil { + + if _, err := cli.Logical().Write(path, params); err != nil { return fmt.Errorf("unable to create params for %s: %w", path, err) } + if configParams != nil { path := fmt.Sprintf("transit/keys/%s/config", name) - _, err := cli.Logical().Write(path, configParams) - if err != nil { + + if _, err := cli.Logical().Write(path, configParams); err != nil { return fmt.Errorf("unable to create config params for %s: %w", path, err) } } + return nil } -func RotateVaultTransitKey(cli *api.Client, prefix, name string, params map[string]interface{}, configParams map[string]interface{}) error { +func RotateVaultTransitKey(cli *api.Client, prefix, name string, params, configParams map[string]interface{}) error { path := fmt.Sprintf("%s/keys/%s/rotate", prefix, name) - _, err := cli.Logical().Write(path, params) - if err != nil { + + if _, err := cli.Logical().Write(path, params); err != nil { return fmt.Errorf("unable to rotate params for %s: %w", path, err) } + return nil } func CreateVaultAppRole(cli *api.Client, prefix, name string, params map[string]interface{}) (*AppRoleCredentials, error) { path := fmt.Sprintf("auth/%s/role/%s", prefix, name) - _, err := cli.Logical().Write(path, params) - if err != nil { + + if _, err := cli.Logical().Write(path, params); err != nil { return nil, fmt.Errorf("unable to create role for %s: %w", path, err) } + roleSecret, err := cli.Logical().Read(path + "/role-id") if err != nil { return nil, fmt.Errorf("unable to read role for %s: %w", path, err) } - SecretIDSecret, err := cli.Logical().Write(path+"/secret-id", nil) + + secretIDSecret, err := cli.Logical().Write(path+"/secret-id", nil) if err != nil { return nil, fmt.Errorf("unable to read secret for %s: %w", path, err) } + return &AppRoleCredentials{ RoleID: roleSecret.Data["role_id"].(string), - SecretID: SecretIDSecret.Data["secret_id"].(string), + SecretID: secretIDSecret.Data["secret_id"].(string), }, nil } -func CreateVaultPolicy(api *api.Client, policyName string, keyName string) error { +func CreateVaultPolicy(client *api.Client, policyName, keyName string) error { policy := fmt.Sprintf(` path "transit/encrypt/%s" { capabilities = ["update"] @@ -66,20 +72,24 @@ func CreateVaultPolicy(api *api.Client, policyName string, keyName string) error } `, keyName, keyName) path := fmt.Sprintf("sys/policy/%s", policyName) - _, err := api.Logical().Write(path, map[string]interface{}{ + + _, err := client.Logical().Write(path, map[string]interface{}{ "policy": policy, }) if err != nil { return fmt.Errorf("unable to create policy for %s: %w", path, err) } + return nil } func CreateVaultToken(cli *api.Client, name string, params map[string]interface{}) (string, error) { path := "/auth/token/create" + r, err := cli.Logical().Write(path, params) if err != nil { return "", fmt.Errorf("unable to create vault token for %s: %w", path, err) } + return r.Auth.ClientToken, nil } diff --git a/internal/version/version.go b/internal/version/version.go index eb740f3..5106f9e 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -29,12 +29,13 @@ func PrintVersion() (err error) { GitCommit: GitCommit, } - var res []byte - if res, err = json.Marshal(pv); err != nil { + res, err := json.Marshal(pv) + if err != nil { return } fmt.Printf(string(res) + "\n") + return }