Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Create psql DB in the CF provisioner #3334

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/devel-provisioner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func main() {
"ftl-provisioner-cloudformation",
provisionerconnect.NewProvisionerPluginServiceClient,
plugin.WithEnvars("FTL_PROVISIONER_CF_DB_SUBNET_GROUP=aurora-postgres-subnet-group"),
plugin.WithEnvars("FTL_PROVISIONER_CF_DB_SECURITY_GROUP=sg-08e06d6f8327024de"),
)
if err != nil {
panic(err)
Expand Down
14 changes: 14 additions & 0 deletions cmd/ftl-provisioner-cloudformation/cloudformation_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/cloudformation/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
"github.com/aws/smithy-go"
goformation "github.com/awslabs/goformation/v7/cloudformation"
"github.com/jpillora/backoff"
Expand Down Expand Up @@ -104,6 +105,19 @@ func createClient(ctx context.Context) (*cloudformation.Client, error) {
), nil
}

func createSecretsClient(ctx context.Context) (*secretsmanager.Client, error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load default aws config: %w", err)
}
return secretsmanager.New(
secretsmanager.Options{
Credentials: cfg.Credentials,
Region: cfg.Region,
},
), nil
}

// CloudformationOutputKey is structured key to be used as an output from a CF stack
type CloudformationOutputKey struct {
ResourceID string `json:"r"`
Expand Down
17 changes: 14 additions & 3 deletions cmd/ftl-provisioner-cloudformation/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"connectrpc.com/connect"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
goformation "github.com/awslabs/goformation/v7/cloudformation"
cf "github.com/awslabs/goformation/v7/cloudformation/cloudformation"
"github.com/awslabs/goformation/v7/cloudformation/rds"
Expand All @@ -26,6 +27,7 @@ import (
const (
PropertyDBReadEndpoint = "db:read_endpoint"
PropertyDBWriteEndpoint = "db:write_endpoint"
PropertyMasterUserARN = "db:maser_user_secret_arn"
)

type Config struct {
Expand All @@ -35,8 +37,9 @@ type Config struct {
}

type CloudformationProvisioner struct {
client *cloudformation.Client
confg *Config
client *cloudformation.Client
secrets *secretsmanager.Client
confg *Config
}

var _ provisionerconnect.ProvisionerPluginServiceHandler = (*CloudformationProvisioner)(nil)
Expand All @@ -46,8 +49,12 @@ func NewCloudformationProvisioner(ctx context.Context, config Config) (context.C
if err != nil {
return nil, nil, fmt.Errorf("failed to create cloudformation client: %w", err)
}
secrets, err := createSecretsClient(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to create secretsmanager client: %w", err)
}

return ctx, &CloudformationProvisioner{client: client, confg: &config}, nil
return ctx, &CloudformationProvisioner{client: client, secrets: secrets, confg: &config}, nil
}

func (c *CloudformationProvisioner) Ping(context.Context, *connect.Request[ftlv1.PingRequest]) (*connect.Response[ftlv1.PingResponse], error) {
Expand Down Expand Up @@ -165,6 +172,10 @@ func (c *CloudformationProvisioner) resourceToCF(cluster, module string, templat
ResourceID: resource.ResourceId,
PropertyName: PropertyDBReadEndpoint,
})
addOutput(template.Outputs, goformation.GetAtt(clusterID, "MasterUserSecret.SecretArn"), &CloudformationOutputKey{
ResourceID: resource.ResourceId,
PropertyName: PropertyMasterUserARN,
})
return nil
}
return errors.New("unsupported resource type")
Expand Down
148 changes: 108 additions & 40 deletions cmd/ftl-provisioner-cloudformation/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ package main

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"strings"

"connectrpc.com/connect"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/cloudformation/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
_ "github.com/lib/pq"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use pgx, let's not have two PG drivers.


"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1beta1/provisioner"
)
Expand All @@ -32,7 +39,7 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
case types.StackStatusCreateFailed:
return failure(&stack)
case types.StackStatusCreateComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusRollbackInProgress:
return failure(&stack)
case types.StackStatusRollbackFailed:
Expand All @@ -44,13 +51,13 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
case types.StackStatusDeleteFailed:
return failure(&stack)
case types.StackStatusDeleteComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusUpdateInProgress:
return running()
case types.StackStatusUpdateCompleteCleanupInProgress:
return running()
case types.StackStatusUpdateComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusUpdateFailed:
return failure(&stack)
case types.StackStatusUpdateRollbackInProgress:
Expand All @@ -60,8 +67,8 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
}
}

func success(stack *types.Stack, resources []*provisioner.Resource) (*connect.Response[provisioner.StatusResponse], error) {
err := updateResources(stack.Outputs, resources)
func (c *CloudformationProvisioner) success(ctx context.Context, stack *types.Stack, resources []*provisioner.Resource) (*connect.Response[provisioner.StatusResponse], error) {
err := c.updateResources(ctx, stack.Outputs, resources)
if err != nil {
return nil, err
}
Expand All @@ -86,49 +93,110 @@ func failure(stack *types.Stack) (*connect.Response[provisioner.StatusResponse],
return nil, connect.NewError(connect.CodeUnknown, errors.New(*stack.StackStatusReason))
}

func updateResources(outputs []types.Output, update []*provisioner.Resource) error {
func outputsByResourceID(outputs []types.Output) (map[string][]types.Output, error) {
m := make(map[string][]types.Output)
for _, output := range outputs {
key, err := decodeOutputKey(output)
if err != nil {
return fmt.Errorf("failed to decode output key: %w", err)
return nil, fmt.Errorf("failed to decode output key: %w", err)
}
for _, resource := range update {
if resource.ResourceId == key.ResourceID {
if postgres, ok := resource.Resource.(*provisioner.Resource_Postgres); ok {
if postgres.Postgres == nil {
postgres.Postgres = &provisioner.PostgresResource{}
}
if postgres.Postgres.Output == nil {
postgres.Postgres.Output = &provisioner.PostgresResource_PostgresResourceOutput{}
}

switch key.PropertyName {
case PropertyDBReadEndpoint:
postgres.Postgres.Output.ReadDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
case PropertyDBWriteEndpoint:
postgres.Postgres.Output.WriteDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
}
} else if mysql, ok := resource.Resource.(*provisioner.Resource_Mysql); ok {
if mysql.Mysql == nil {
mysql.Mysql = &provisioner.MysqlResource{}
}
if mysql.Mysql.Output == nil {
mysql.Mysql.Output = &provisioner.MysqlResource_MysqlResourceOutput{}
}

switch key.PropertyName {
case PropertyDBReadEndpoint:
mysql.Mysql.Output.ReadDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
case PropertyDBWriteEndpoint:
mysql.Mysql.Output.WriteDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 3306)
}
}
m[key.ResourceID] = append(m[key.ResourceID], output)
}
return m, nil
}

func outputsByPropertyName(outputs []types.Output) (map[string]types.Output, error) {
m := make(map[string]types.Output)
for _, output := range outputs {
key, err := decodeOutputKey(output)
if err != nil {
return nil, fmt.Errorf("failed to decode output key: %w", err)
}
m[key.PropertyName] = output
}
return m, nil
}

func (c *CloudformationProvisioner) updateResources(ctx context.Context, outputs []types.Output, update []*provisioner.Resource) error {
byResourceID, err := outputsByResourceID(outputs)
if err != nil {
return fmt.Errorf("failed to group outputs by resource ID: %w", err)
}

for _, resource := range update {
if postgres, ok := resource.Resource.(*provisioner.Resource_Postgres); ok {
if postgres.Postgres == nil {
postgres.Postgres = &provisioner.PostgresResource{}
}
if postgres.Postgres.Output == nil {
postgres.Postgres.Output = &provisioner.PostgresResource_PostgresResourceOutput{}
}

if err := c.updatePostgresOutputs(ctx, postgres.Postgres.Output, resource.ResourceId, byResourceID[resource.ResourceId]); err != nil {
return fmt.Errorf("failed to update postgres outputs: %w", err)
}
} else if _, ok := resource.Resource.(*provisioner.Resource_Mysql); ok {
panic("mysql not implemented")
}
}
return nil
}

func endpointToDSN(endpoint, database string, port int) string {
return fmt.Sprintf("postgres://%s:%d/%s?user=postgres&password=password", endpoint, port, database)
func (c *CloudformationProvisioner) updatePostgresOutputs(ctx context.Context, to *provisioner.PostgresResource_PostgresResourceOutput, resourceID string, outputs []types.Output) error {
byName, err := outputsByPropertyName(outputs)
if err != nil {
return fmt.Errorf("failed to group outputs by property name: %w", err)
}

fmt.Fprintf(os.Stderr, "byName: %v\n", byName)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??


// TODO: Move to provisioner workflow
secretARN := *byName[PropertyMasterUserARN].OutputValue
username, password, err := c.secretARNToUsernamePassword(ctx, secretARN)
if err != nil {
return fmt.Errorf("failed to get username and password from secret ARN: %w", err)
}

to.ReadDsn = endpointToDSN(*byName[PropertyDBReadEndpoint].OutputValue, resourceID, 5432, username, password)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These blind pointer deferences are horrifying.

to.WriteDsn = endpointToDSN(*byName[PropertyDBWriteEndpoint].OutputValue, resourceID, 5432, username, password)
adminEndpoint := endpointToDSN(*byName[PropertyDBReadEndpoint].OutputValue, "postgres", 5432, username, password)

// Connect to postgres without a specific database to create the new one
db, err := sql.Open("postgres", adminEndpoint)
if err != nil {
return fmt.Errorf("failed to connect to postgres: %w", err)
}
defer db.Close()

// Create the database if it doesn't exist
if _, err := db.ExecContext(ctx, "CREATE DATABASE "+resourceID); err != nil {
// Ignore if database already exists
if !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("failed to create database: %w", err)
}
}

return nil
}

func endpointToDSN(endpoint, database string, port int, username, password string) string {
urlEncodedPassword := url.QueryEscape(password)
return fmt.Sprintf("postgres://%s:%d/%s?user=%s&password=%s", endpoint, port, database, username, urlEncodedPassword)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to construct a url.URL here, then stringify it, to ensure all escaping is correct?

}

func (c *CloudformationProvisioner) secretARNToUsernamePassword(ctx context.Context, secretARN string) (string, string, error) {
secret, err := c.secrets.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{
SecretId: &secretARN,
})
if err != nil {
return "", "", fmt.Errorf("failed to get secret value: %w", err)
}
secretString := *secret.SecretString

var secretData map[string]string
if err := json.Unmarshal([]byte(secretString), &secretData); err != nil {
return "", "", fmt.Errorf("failed to unmarshal secret data: %w", err)
}

return secretData["username"], secretData["password"], nil
}
Loading