From 0607e55a4ab095289802e8a8777032044ae8f4e6 Mon Sep 17 00:00:00 2001 From: Matt Toohey Date: Wed, 14 Aug 2024 12:22:15 +1000 Subject: [PATCH] Revert "feat(KMS): derive keys for logs and async (#2338)" This reverts commit 2a3edbcce2b00cdbb99be53da4175bec6028b8fb. --- backend/controller/controller.go | 50 +++- .../cronjobs/cronjobs_integration_test.go | 8 +- backend/controller/cronjobs/cronjobs_test.go | 2 +- backend/controller/cronjobs/sql/models.go | 10 +- backend/controller/dal/async_calls.go | 19 +- backend/controller/dal/async_calls_test.go | 4 +- backend/controller/dal/dal.go | 44 +-- backend/controller/dal/dal_test.go | 7 +- backend/controller/dal/encryption.go | 106 -------- backend/controller/dal/events.go | 9 +- backend/controller/dal/fsm.go | 3 +- .../controller/dal/fsm_integration_test.go | 2 +- backend/controller/dal/fsm_test.go | 3 +- backend/controller/dal/lease_test.go | 4 +- backend/controller/dal/pubsub.go | 4 +- backend/controller/sql/migrate/migrate.go | 2 +- backend/controller/sql/models.go | 10 +- backend/controller/sql/querier.go | 4 +- backend/controller/sql/queries.sql | 11 +- backend/controller/sql/queries.sql.go | 31 +-- .../20240812011321_derive_encryption.sql | 19 -- cmd/ftl-controller/main.go | 10 +- cmd/ftl/cmd_box_run.go | 6 +- cmd/ftl/cmd_serve.go | 8 +- common/configuration/sql/models.go | 10 +- integration/harness.go | 7 +- internal/encryption/encryption.go | 253 +++++++++++++++--- internal/encryption/encryption_test.go | 65 ++++- internal/encryption/integration_test.go | 5 +- 29 files changed, 416 insertions(+), 300 deletions(-) delete mode 100644 backend/controller/dal/encryption.go delete mode 100644 backend/controller/sql/schema/20240812011321_derive_encryption.sql diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 3b79041c1..0ef15118d 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -55,6 +55,7 @@ import ( cf "github.com/TBD54566975/ftl/common/configuration" frontend "github.com/TBD54566975/ftl/frontend" "github.com/TBD54566975/ftl/internal/cors" + "github.com/TBD54566975/ftl/internal/encryption" ftlhttp "github.com/TBD54566975/ftl/internal/http" "github.com/TBD54566975/ftl/internal/log" ftlmaps "github.com/TBD54566975/ftl/internal/maps" @@ -84,6 +85,42 @@ func (c *CommonConfig) Validate() error { return nil } +// EncryptionKeys for the controller config. +// Deprecated: Will remove this at some stage. +type EncryptionKeys struct { + Logs string `name:"log-key" help:"Key for sensitive log data in internal FTL tables." env:"FTL_LOG_ENCRYPTION_KEY"` + Async string `name:"async-key" help:"Key for sensitive async call data in internal FTL tables." env:"FTL_ASYNC_ENCRYPTION_KEY"` +} + +func (e EncryptionKeys) Encryptors(required bool) (*dal.Encryptors, error) { + encryptors := dal.Encryptors{} + if e.Logs != "" { + enc, err := encryption.NewForKeyOrURI(e.Logs) + if err != nil { + return nil, fmt.Errorf("could not create log encryptor: %w", err) + } + encryptors.Logs = enc + } else if required { + return nil, fmt.Errorf("FTL_LOG_ENCRYPTION_KEY is required") + } else { + encryptors.Logs = encryption.NoOpEncryptor{} + } + + if e.Async != "" { + enc, err := encryption.NewForKeyOrURI(e.Async) + if err != nil { + return nil, fmt.Errorf("could not create async calls encryptor: %w", err) + } + encryptors.Async = enc + } else if required { + return nil, fmt.Errorf("FTL_ASYNC_ENCRYPTION_KEY is required") + } else { + encryptors.Async = encryption.NoOpEncryptor{} + } + + return &encryptors, nil +} + type Config struct { Bind *url.URL `help:"Socket to bind to." default:"http://127.0.0.1:8892" env:"FTL_CONTROLLER_BIND"` IngressBind *url.URL `help:"Socket to bind to for ingress." default:"http://127.0.0.1:8891" env:"FTL_CONTROLLER_INGRESS_BIND"` @@ -98,7 +135,8 @@ type Config struct { ModuleUpdateFrequency time.Duration `help:"Frequency to send module updates." default:"30s"` EventLogRetention *time.Duration `help:"Delete call logs after this time period. 0 to disable" env:"FTL_EVENT_LOG_RETENTION" default:"24h"` ArtefactChunkSize int `help:"Size of each chunk streamed to the client." default:"1048576"` - KMSURI *string `help:"URI for KMS key e.g. with fake-kms:// or aws-kms://arn:aws:kms:ap-southeast-2:12345:key/0000-1111" env:"FTL_KMS_URI"` + KMSURI *url.URL `help:"URI for KMS key e.g. aws-kms://arn:aws:kms:ap-southeast-2:12345:key/0000-1111" env:"FTL_KMS_URI"` + EncryptionKeys CommonConfig } @@ -112,7 +150,7 @@ func (c *Config) SetDefaults() { } // Start the Controller. Blocks until the context is cancelled. -func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScaling, conn *sql.DB) error { +func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScaling, conn *sql.DB, encryptors *dal.Encryptors) error { config.SetDefaults() logger := log.FromContext(ctx) @@ -133,7 +171,7 @@ func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScali logger.Infof("Web console available at: %s", config.Bind) } - svc, err := New(ctx, conn, config, runnerScaling) + svc, err := New(ctx, conn, config, runnerScaling, encryptors) if err != nil { return err } @@ -215,7 +253,7 @@ type Service struct { asyncCallsLock sync.Mutex } -func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling.RunnerScaling) (*Service, error) { +func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling.RunnerScaling, encryptors *dal.Encryptors) (*Service, error) { key := config.Key if config.Key.IsZero() { key = model.NewControllerKey(config.Bind.Hostname(), config.Bind.Port()) @@ -229,7 +267,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling config.ControllerTimeout = time.Second * 5 } - db, err := dal.New(ctx, conn, optional.Ptr[string](config.KMSURI)) + db, err := dal.New(ctx, conn, encryptors) if err != nil { return nil, fmt.Errorf("failed to create DAL: %w", err) } @@ -1454,7 +1492,7 @@ func (s *Service) catchAsyncCall(ctx context.Context, logger *log.Logger, call * originalResult := either.RightOf[[]byte](originalError) request := map[string]any{ - "request": json.RawMessage(call.Request), + "request": call.Request, "error": originalError, } body, err := json.Marshal(request) diff --git a/backend/controller/cronjobs/cronjobs_integration_test.go b/backend/controller/cronjobs/cronjobs_integration_test.go index 62d9984c5..61c250e34 100644 --- a/backend/controller/cronjobs/cronjobs_integration_test.go +++ b/backend/controller/cronjobs/cronjobs_integration_test.go @@ -9,15 +9,13 @@ import ( "testing" "time" - "github.com/alecthomas/assert/v2" - "github.com/alecthomas/types/optional" - "github.com/benbjohnson/clock" - db "github.com/TBD54566975/ftl/backend/controller/cronjobs/dal" parentdb "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" in "github.com/TBD54566975/ftl/integration" "github.com/TBD54566975/ftl/internal/log" + "github.com/alecthomas/assert/v2" + "github.com/benbjohnson/clock" ) func TestServiceWithRealDal(t *testing.T) { @@ -28,7 +26,7 @@ func TestServiceWithRealDal(t *testing.T) { conn := sqltest.OpenForTesting(ctx, t) dal := db.New(conn) - parentDAL, err := parentdb.New(ctx, conn, optional.None[string]()) + parentDAL, err := parentdb.New(ctx, conn, parentdb.NoOpEncryptors()) assert.NoError(t, err) // Using a real clock because real db queries use db clock diff --git a/backend/controller/cronjobs/cronjobs_test.go b/backend/controller/cronjobs/cronjobs_test.go index 715476c93..bba46f5a4 100644 --- a/backend/controller/cronjobs/cronjobs_test.go +++ b/backend/controller/cronjobs/cronjobs_test.go @@ -37,7 +37,7 @@ func TestServiceWithMockDal(t *testing.T) { attemptCountMap: map[string]int{}, } conn := sqltest.OpenForTesting(ctx, t) - parentDAL, err := db.New(ctx, conn, optional.None[string]()) + parentDAL, err := db.New(ctx, conn, db.NoOpEncryptors()) assert.NoError(t, err) testServiceWithDal(ctx, t, mockDal, parentDAL, clk) diff --git a/backend/controller/cronjobs/sql/models.go b/backend/controller/cronjobs/sql/models.go index 731679ea0..5b5c1c76e 100644 --- a/backend/controller/cronjobs/sql/models.go +++ b/backend/controller/cronjobs/sql/models.go @@ -378,8 +378,8 @@ type AsyncCall struct { State AsyncCallState Origin string ScheduledAt time.Time - Request []byte - Response []byte + Request json.RawMessage + Response pqtype.NullRawMessage Error optional.Option[string] RemainingAttempts int32 Backoff sqltypes.Duration @@ -429,12 +429,6 @@ type DeploymentArtefact struct { Path string } -type EncryptionKey struct { - ID int64 - Key []byte - CreatedAt time.Time -} - type Event struct { ID int64 TimeStamp time.Time diff --git a/backend/controller/dal/async_calls.go b/backend/controller/dal/async_calls.go index e4ab2f468..696bb9fdf 100644 --- a/backend/controller/dal/async_calls.go +++ b/backend/controller/dal/async_calls.go @@ -2,6 +2,7 @@ package dal import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -14,7 +15,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" - "github.com/TBD54566975/ftl/internal/encryption" ) type asyncOriginParseRoot struct { @@ -77,7 +77,7 @@ type AsyncCall struct { Origin AsyncOrigin Verb schema.RefKey CatchVerb optional.Option[schema.RefKey] - Request []byte + Request json.RawMessage ScheduledAt time.Time QueueDepth int64 ParentRequestKey optional.Option[string] @@ -115,7 +115,8 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error) return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err) } - decryptedRequest, err := d.decrypt(encryption.AsyncSubKey, row.Request) + var decryptedRequest json.RawMessage + err = d.encryptors.Async.DecryptJSON(row.Request, &decryptedRequest) if err != nil { return nil, fmt.Errorf("failed to decrypt async call request: %w", err) } @@ -158,11 +159,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context, didScheduleAnotherCall = false switch result := result.(type) { case either.Left[[]byte, string]: // Successful response. - encryptedResult, err := d.encrypt(encryption.AsyncSubKey, result.Get()) - if err != nil { - return false, fmt.Errorf("failed to encrypt async call result: %w", err) - } - _, err = tx.db.SucceedAsyncCall(ctx, encryptedResult, call.ID) + _, err = tx.db.SucceedAsyncCall(ctx, result.Get(), call.ID) if err != nil { return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck } @@ -227,14 +224,10 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) { if err != nil { return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err) } - request, err := d.decrypt(encryption.AsyncSubKey, row.Request) - if err != nil { - return nil, fmt.Errorf("failed to decrypt async call request: %w", err) - } return &AsyncCall{ ID: row.ID, Verb: row.Verb, Origin: origin, - Request: request, + Request: row.Request, }, nil } diff --git a/backend/controller/dal/async_calls_test.go b/backend/controller/dal/async_calls_test.go index 965a3d6a2..7aab05128 100644 --- a/backend/controller/dal/async_calls_test.go +++ b/backend/controller/dal/async_calls_test.go @@ -4,8 +4,6 @@ import ( "context" "testing" - "github.com/alecthomas/types/optional" - "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/internal/log" @@ -15,7 +13,7 @@ import ( func TestNoCallToAcquire(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) _, err = dal.AcquireAsyncCall(ctx) diff --git a/backend/controller/dal/dal.go b/backend/controller/dal/dal.go index 9e1f7b4ac..55229c1b9 100644 --- a/backend/controller/dal/dal.go +++ b/backend/controller/dal/dal.go @@ -210,30 +210,35 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err return reservation.Commit(ctx) } -func New(ctx context.Context, conn *stdsql.DB, kmsURL optional.Option[string]) (*DAL, error) { - d := &DAL{ +func New(ctx context.Context, conn *stdsql.DB, encryptors *Encryptors) (*DAL, error) { + return &DAL{ db: sql.NewDB(conn), DeploymentChanges: pubsub.New[DeploymentNotification](), - kmsURL: kmsURL, - } - - if err := d.setupEncryptor(ctx); err != nil { - return nil, fmt.Errorf("failed to setup encryptor: %w", err) - } - - return d, nil + encryptors: encryptors, + }, nil } type DAL struct { - db sql.DBI - - kmsURL optional.Option[string] - encryptor encryption.DataEncryptor + db sql.DBI + encryptors *Encryptors // DeploymentChanges is a Topic that receives changes to the deployments table. DeploymentChanges *pubsub.Topic[DeploymentNotification] } +type Encryptors struct { + Logs encryption.Encryptable + Async encryption.Encryptable +} + +// NoOpEncryptors do not encrypt potentially sensitive data. +func NoOpEncryptors() *Encryptors { + return &Encryptors{ + Logs: encryption.NoOpEncryptor{}, + Async: encryption.NoOpEncryptor{}, + } +} + // Tx is DAL within a transaction. type Tx struct { *DAL @@ -280,8 +285,7 @@ func (d *DAL) Begin(ctx context.Context) (*Tx, error) { return &Tx{&DAL{ db: tx, DeploymentChanges: d.DeploymentChanges, - kmsURL: d.kmsURL, - encryptor: d.encryptor, + encryptors: d.encryptors, }}, nil } @@ -709,7 +713,7 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey return dalerrs.TranslatePGError(err) } } - payload, err := d.encryptJSON(encryption.LogsSubKey, map[string]interface{}{ + payload, err := d.encryptors.Logs.EncryptJSON(map[string]any{ "prev_min_replicas": deployment.MinReplicas, "min_replicas": minReplicas, }) @@ -782,7 +786,7 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl } } - payload, err := d.encryptJSON(encryption.LogsSubKey, map[string]any{ + payload, err := d.encryptors.Logs.EncryptJSON(map[string]any{ "min_replicas": int32(minReplicas), "replaced": replacedDeploymentKey, }) @@ -1057,7 +1061,7 @@ func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error { "error": log.Error, "stack": log.Stack, } - encryptedPayload, err := d.encryptJSON(encryption.LogsSubKey, payload) + encryptedPayload, err := d.encryptors.Logs.EncryptJSON(payload) if err != nil { return fmt.Errorf("failed to encrypt log payload: %w", err) } @@ -1137,7 +1141,7 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { if pr, ok := call.ParentRequestKey.Get(); ok { parentRequestKey = optional.Some(pr.String()) } - payload, err := d.encryptJSON(encryption.LogsSubKey, map[string]any{ + payload, err := d.encryptors.Logs.EncryptJSON(map[string]any{ "duration_ms": call.Duration.Milliseconds(), "request": call.Request, "response": call.Response, diff --git a/backend/controller/dal/dal_test.go b/backend/controller/dal/dal_test.go index ee3618c44..0f3be9a3a 100644 --- a/backend/controller/dal/dal_test.go +++ b/backend/controller/dal/dal_test.go @@ -25,7 +25,7 @@ import ( func TestDAL(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) assert.NotZero(t, dal) var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100) @@ -235,7 +235,7 @@ func TestDAL(t *testing.T) { DeploymentKey: deploymentKey, RequestKey: optional.Some(requestKey), Request: []byte("{}"), - Response: []byte(`{"time":"now"}`), + Response: []byte(`{"time": "now"}`), DestVerb: schema.Ref{Module: "time", Name: "time"}, } t.Run("InsertCallEvent", func(t *testing.T) { @@ -396,7 +396,6 @@ func normaliseEvents(events []Event) []Event { f.Set(reflect.Zero(f.Type())) events[i] = event } - return events } @@ -408,7 +407,7 @@ func assertEventsEqual(t *testing.T, expected, actual []Event) { func TestDeleteOldEvents(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100) diff --git a/backend/controller/dal/encryption.go b/backend/controller/dal/encryption.go deleted file mode 100644 index 3bedc9476..000000000 --- a/backend/controller/dal/encryption.go +++ /dev/null @@ -1,106 +0,0 @@ -package dal - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/TBD54566975/ftl/backend/dal" - "github.com/TBD54566975/ftl/internal/encryption" - "github.com/TBD54566975/ftl/internal/log" -) - -func (d *DAL) encrypt(subKey encryption.SubKey, cleartext []byte) ([]byte, error) { - if d.encryptor == nil { - return nil, fmt.Errorf("encryptor not set") - } - - v, err := d.encryptor.Encrypt(subKey, cleartext) - if err != nil { - return nil, fmt.Errorf("failed to encrypt binary with subkey %s: %w", subKey, err) - } - - return v, nil -} - -func (d *DAL) decrypt(subKey encryption.SubKey, encrypted []byte) ([]byte, error) { - if d.encryptor == nil { - return nil, fmt.Errorf("encryptor not set") - } - - v, err := d.encryptor.Decrypt(subKey, encrypted) - if err != nil { - return nil, fmt.Errorf("failed to decrypt binary with subkey %s: %w", subKey, err) - } - - return v, nil -} - -func (d *DAL) encryptJSON(subKey encryption.SubKey, v any) ([]byte, error) { - serialized, err := json.Marshal(v) - if err != nil { - return nil, fmt.Errorf("failed to marshal JSON: %w", err) - } - - return d.encrypt(subKey, serialized) -} - -func (d *DAL) decryptJSON(subKey encryption.SubKey, encrypted []byte, v any) error { //nolint:unparam - decrypted, err := d.decrypt(subKey, encrypted) - if err != nil { - return fmt.Errorf("failed to decrypt json with subkey %s: %w", subKey, err) - } - - if err = json.Unmarshal(decrypted, v); err != nil { - return fmt.Errorf("failed to unmarshal JSON: %w", err) - } - - return nil -} - -// setupEncryptor sets up the encryptor for the DAL. -// It will either create a key or load the existing one. -// If the KMS URL is not set, it will use a NoOpEncryptor which does not encrypt anything. -func (d *DAL) setupEncryptor(ctx context.Context) (err error) { - logger := log.FromContext(ctx) - tx, err := d.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer tx.CommitOrRollback(ctx, &err) - - url, ok := d.kmsURL.Get() - if !ok { - logger.Infof("KMS URL not set, encryption not enabled") - d.encryptor = encryption.NewNoOpEncryptor() - return nil - } - - encryptedKey, err := tx.db.GetOnlyEncryptionKey(ctx) - if err != nil { - if dal.IsNotFound(err) { - logger.Infof("No encryption key found, generating a new one") - encryptor, err := encryption.NewKMSEncryptorGenerateKey(url, nil) - if err != nil { - return fmt.Errorf("failed to create encryptor for generation: %w", err) - } - d.encryptor = encryptor - - if err = tx.db.CreateOnlyEncryptionKey(ctx, encryptor.GetEncryptedKeyset()); err != nil { - return fmt.Errorf("failed to create only encryption key: %w", err) - } - - return nil - } - return fmt.Errorf("failed to get only encryption key: %w", err) - } - - logger.Debugf("Encryption key found, using it") - encryptor, err := encryption.NewKMSEncryptorWithKMS(url, nil, encryptedKey) - if err != nil { - return fmt.Errorf("failed to create encryptor with encrypted key: %w", err) - } - d.encryptor = encryptor - - return nil -} diff --git a/backend/controller/dal/events.go b/backend/controller/dal/events.go index 75babd7fe..8889d1c43 100644 --- a/backend/controller/dal/events.go +++ b/backend/controller/dal/events.go @@ -13,7 +13,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" - "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" ) @@ -349,7 +348,7 @@ func (d *DAL) transformRowsToEvents(deploymentKeys map[int64]model.DeploymentKey switch row.Type { case sql.EventTypeLog: var jsonPayload eventLogJSON - if err := d.decryptJSON(encryption.LogsSubKey, row.Payload, &jsonPayload); err != nil { + if err := d.encryptors.Logs.DecryptJSON(row.Payload, &jsonPayload); err != nil { return nil, fmt.Errorf("failed to decrypt log event: %w", err) } @@ -371,7 +370,7 @@ func (d *DAL) transformRowsToEvents(deploymentKeys map[int64]model.DeploymentKey case sql.EventTypeCall: var jsonPayload eventCallJSON - if err := d.decryptJSON(encryption.LogsSubKey, row.Payload, &jsonPayload); err != nil { + if err := d.encryptors.Logs.DecryptJSON(row.Payload, &jsonPayload); err != nil { return nil, fmt.Errorf("failed to decrypt call event: %w", err) } var sourceVerb optional.Option[schema.Ref] @@ -396,7 +395,7 @@ func (d *DAL) transformRowsToEvents(deploymentKeys map[int64]model.DeploymentKey case sql.EventTypeDeploymentCreated: var jsonPayload eventDeploymentCreatedJSON - if err := d.decryptJSON(encryption.LogsSubKey, row.Payload, &jsonPayload); err != nil { + if err := d.encryptors.Logs.DecryptJSON(row.Payload, &jsonPayload); err != nil { return nil, fmt.Errorf("failed to decrypt call event: %w", err) } out = append(out, &DeploymentCreatedEvent{ @@ -411,7 +410,7 @@ func (d *DAL) transformRowsToEvents(deploymentKeys map[int64]model.DeploymentKey case sql.EventTypeDeploymentUpdated: var jsonPayload eventDeploymentUpdatedJSON - if err := d.decryptJSON(encryption.LogsSubKey, row.Payload, &jsonPayload); err != nil { + if err := d.encryptors.Logs.DecryptJSON(row.Payload, &jsonPayload); err != nil { return nil, fmt.Errorf("failed to decrypt call event: %w", err) } out = append(out, &DeploymentUpdatedEvent{ diff --git a/backend/controller/dal/fsm.go b/backend/controller/dal/fsm.go index 1311c76ed..b463b39a5 100644 --- a/backend/controller/dal/fsm.go +++ b/backend/controller/dal/fsm.go @@ -15,7 +15,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" - "github.com/TBD54566975/ftl/internal/encryption" ) // StartFSMTransition sends an event to an executing instance of an FSM. @@ -32,7 +31,7 @@ import ( // // Note: no validation of the FSM is performed. func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, executionKey string, destinationState schema.RefKey, request json.RawMessage, retryParams schema.RetryParams) (err error) { - encryptedRequest, err := d.encryptJSON(encryption.AsyncSubKey, request) + encryptedRequest, err := d.encryptors.Async.EncryptJSON(request) if err != nil { return fmt.Errorf("failed to encrypt FSM request: %w", err) } diff --git a/backend/controller/dal/fsm_integration_test.go b/backend/controller/dal/fsm_integration_test.go index 6ae5a00e8..a783e79f5 100644 --- a/backend/controller/dal/fsm_integration_test.go +++ b/backend/controller/dal/fsm_integration_test.go @@ -98,7 +98,7 @@ func TestFSMRetry(t *testing.T) { in.Call("fsmretry", "startTransitionToThree", in.Obj{"id": "2"}, func(t testing.TB, response any) {}), in.Call("fsmretry", "startTransitionToTwo", in.Obj{"id": "3", "failCatch": true}, func(t testing.TB, response any) {}), - in.Sleep(7*time.Second), //6s is longest run of retries + in.Sleep(9*time.Second), //6s is longest run of retries // First two FSMs instances should have failed // Third one will not as it is still catching diff --git a/backend/controller/dal/fsm_test.go b/backend/controller/dal/fsm_test.go index 7f1eabe03..684768be9 100644 --- a/backend/controller/dal/fsm_test.go +++ b/backend/controller/dal/fsm_test.go @@ -2,7 +2,6 @@ package dal import ( "context" - "github.com/alecthomas/types/optional" "testing" "time" @@ -18,7 +17,7 @@ import ( func TestSendFSMEvent(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) _, err = dal.AcquireAsyncCall(ctx) diff --git a/backend/controller/dal/lease_test.go b/backend/controller/dal/lease_test.go index 9e2370d72..0c6531cec 100644 --- a/backend/controller/dal/lease_test.go +++ b/backend/controller/dal/lease_test.go @@ -36,7 +36,7 @@ func TestLease(t *testing.T) { } ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) // TTL is too short, expect an error @@ -71,7 +71,7 @@ func TestExpireLeases(t *testing.T) { } ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, NoOpEncryptors()) assert.NoError(t, err) leasei, _, err := dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]()) diff --git a/backend/controller/dal/pubsub.go b/backend/controller/dal/pubsub.go index aed607747..667248262 100644 --- a/backend/controller/dal/pubsub.go +++ b/backend/controller/dal/pubsub.go @@ -2,6 +2,7 @@ package dal import ( "context" + "encoding/json" "fmt" "strings" "time" @@ -13,7 +14,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" - "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/rpc" @@ -21,7 +21,7 @@ import ( ) func (d *DAL) PublishEventForTopic(ctx context.Context, module, topic, caller string, payload []byte) error { - encryptedPayload, err := d.encrypt(encryption.AsyncSubKey, payload) + encryptedPayload, err := d.encryptors.Async.EncryptJSON(json.RawMessage(payload)) if err != nil { return fmt.Errorf("failed to encrypt payload: %w", err) } diff --git a/backend/controller/sql/migrate/migrate.go b/backend/controller/sql/migrate/migrate.go index 57f8522e9..d2e44d95b 100644 --- a/backend/controller/sql/migrate/migrate.go +++ b/backend/controller/sql/migrate/migrate.go @@ -38,7 +38,7 @@ func Migration(version, name string, migration MigrationFunc) Option { } } -// LogLevel sets the logging level of the migrator. +// LogLevel sets the loggging level of the migrator. func LogLevel(level log.Level) Option { return func(opts *migrateOptions) { opts.logLevel = level diff --git a/backend/controller/sql/models.go b/backend/controller/sql/models.go index 731679ea0..5b5c1c76e 100644 --- a/backend/controller/sql/models.go +++ b/backend/controller/sql/models.go @@ -378,8 +378,8 @@ type AsyncCall struct { State AsyncCallState Origin string ScheduledAt time.Time - Request []byte - Response []byte + Request json.RawMessage + Response pqtype.NullRawMessage Error optional.Option[string] RemainingAttempts int32 Backoff sqltypes.Duration @@ -429,12 +429,6 @@ type DeploymentArtefact struct { Path string } -type EncryptionKey struct { - ID int64 - Key []byte - CreatedAt time.Time -} - type Event struct { ID int64 TimeStamp time.Time diff --git a/backend/controller/sql/querier.go b/backend/controller/sql/querier.go index b1d5c2046..42fc740ff 100644 --- a/backend/controller/sql/querier.go +++ b/backend/controller/sql/querier.go @@ -32,7 +32,6 @@ type Querier interface { CreateCronJob(ctx context.Context, arg CreateCronJobParams) error CreateDeployment(ctx context.Context, moduleName string, schema []byte, key model.DeploymentKey) error CreateIngressRoute(ctx context.Context, arg CreateIngressRouteParams) error - CreateOnlyEncryptionKey(ctx context.Context, key []byte) error CreateRequest(ctx context.Context, origin Origin, key model.RequestKey, sourceAddr string) error DeleteOldEvents(ctx context.Context, timeout sqltypes.Duration, type_ EventType) (int64, error) DeleteSubscribers(ctx context.Context, deployment model.DeploymentKey) ([]model.SubscriberKey, error) @@ -72,7 +71,6 @@ type Querier interface { GetLeaseInfo(ctx context.Context, key leases.Key) (GetLeaseInfoRow, error) GetModulesByID(ctx context.Context, ids []int64) ([]Module, error) GetNextEventForSubscription(ctx context.Context, consumptionDelay sqltypes.Duration, topic model.TopicKey, cursor optional.Option[model.TopicEventKey]) (GetNextEventForSubscriptionRow, error) - GetOnlyEncryptionKey(ctx context.Context) ([]byte, error) GetProcessList(ctx context.Context) ([]GetProcessListRow, error) GetRandomSubscriber(ctx context.Context, key model.SubscriptionKey) (GetRandomSubscriberRow, error) // Retrieve routing information for a runner. @@ -113,7 +111,7 @@ type Querier interface { // // "key" is the unique identifier for the FSM execution. StartFSMTransition(ctx context.Context, arg StartFSMTransitionParams) (FsmInstance, error) - SucceedAsyncCall(ctx context.Context, response []byte, iD int64) (bool, error) + SucceedAsyncCall(ctx context.Context, response json.RawMessage, iD int64) (bool, error) SucceedFSMInstance(ctx context.Context, fsm schema.RefKey, key string) (bool, error) UpsertController(ctx context.Context, key model.ControllerKey, endpoint string) (int64, error) UpsertModule(ctx context.Context, language string, name string) (int64, error) diff --git a/backend/controller/sql/queries.sql b/backend/controller/sql/queries.sql index cd98b31a8..10854ad72 100644 --- a/backend/controller/sql/queries.sql +++ b/backend/controller/sql/queries.sql @@ -542,7 +542,7 @@ RETURNING UPDATE async_calls SET state = 'success'::async_call_state, - response = @response, + response = @response::JSONB, error = null WHERE id = @id RETURNING true; @@ -878,12 +878,3 @@ WHERE id = $1::BIGINT; SELECT * FROM topic_events WHERE id = $1::BIGINT; - --- name: GetOnlyEncryptionKey :one -SELECT key -FROM encryption_keys -WHERE id = 1; - --- name: CreateOnlyEncryptionKey :exec -INSERT INTO encryption_keys (id, key) -VALUES (1, $1); diff --git a/backend/controller/sql/queries.sql.go b/backend/controller/sql/queries.sql.go index fb7f4f10b..77ed4654a 100644 --- a/backend/controller/sql/queries.sql.go +++ b/backend/controller/sql/queries.sql.go @@ -67,7 +67,7 @@ type AcquireAsyncCallRow struct { Origin string Verb schema.RefKey CatchVerb optional.Option[schema.RefKey] - Request []byte + Request json.RawMessage ScheduledAt time.Time RemainingAttempts int32 Error optional.Option[string] @@ -216,7 +216,7 @@ RETURNING id type CreateAsyncCallParams struct { Verb schema.RefKey Origin string - Request []byte + Request json.RawMessage RemainingAttempts int32 Backoff sqltypes.Duration MaxBackoff sqltypes.Duration @@ -311,16 +311,6 @@ func (q *Queries) CreateIngressRoute(ctx context.Context, arg CreateIngressRoute return err } -const createOnlyEncryptionKey = `-- name: CreateOnlyEncryptionKey :exec -INSERT INTO encryption_keys (id, key) -VALUES (1, $1) -` - -func (q *Queries) CreateOnlyEncryptionKey(ctx context.Context, key []byte) error { - _, err := q.db.ExecContext(ctx, createOnlyEncryptionKey, key) - return err -} - const createRequest = `-- name: CreateRequest :exec INSERT INTO requests (origin, "key", source_addr) VALUES ($1, $2, $3) @@ -1462,19 +1452,6 @@ func (q *Queries) GetNextEventForSubscription(ctx context.Context, consumptionDe return i, err } -const getOnlyEncryptionKey = `-- name: GetOnlyEncryptionKey :one -SELECT key -FROM encryption_keys -WHERE id = 1 -` - -func (q *Queries) GetOnlyEncryptionKey(ctx context.Context) ([]byte, error) { - row := q.db.QueryRowContext(ctx, getOnlyEncryptionKey) - var key []byte - err := row.Scan(&key) - return key, err -} - const getProcessList = `-- name: GetProcessList :many SELECT d.min_replicas, d.key AS deployment_key, @@ -2573,13 +2550,13 @@ const succeedAsyncCall = `-- name: SucceedAsyncCall :one UPDATE async_calls SET state = 'success'::async_call_state, - response = $1, + response = $1::JSONB, error = null WHERE id = $2 RETURNING true ` -func (q *Queries) SucceedAsyncCall(ctx context.Context, response []byte, iD int64) (bool, error) { +func (q *Queries) SucceedAsyncCall(ctx context.Context, response json.RawMessage, iD int64) (bool, error) { row := q.db.QueryRowContext(ctx, succeedAsyncCall, response, iD) var column_1 bool err := row.Scan(&column_1) diff --git a/backend/controller/sql/schema/20240812011321_derive_encryption.sql b/backend/controller/sql/schema/20240812011321_derive_encryption.sql deleted file mode 100644 index bdb8b2396..000000000 --- a/backend/controller/sql/schema/20240812011321_derive_encryption.sql +++ /dev/null @@ -1,19 +0,0 @@ --- migrate:up - -CREATE TABLE encryption_keys ( - id BIGINT NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, - key bytea NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc') -); - -ALTER TABLE events - ALTER COLUMN payload TYPE bytea - USING payload::text::bytea; - -ALTER TABLE async_calls - ALTER COLUMN request TYPE bytea - USING request::text::bytea, - ALTER COLUMN response TYPE bytea - USING response::text::bytea; - --- migrate:down diff --git a/cmd/ftl-controller/main.go b/cmd/ftl-controller/main.go index 83c2e88a4..20c4fa5f5 100644 --- a/cmd/ftl-controller/main.go +++ b/cmd/ftl-controller/main.go @@ -9,7 +9,6 @@ import ( "github.com/XSAM/otelsql" "github.com/alecthomas/kong" - "github.com/alecthomas/types/optional" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" @@ -45,9 +44,8 @@ func main() { ) cli.ControllerConfig.SetDefaults() - if cli.ControllerConfig.KMSURI == nil { - kctx.Fatalf("KMSURI is required") - } + encryptors, err := cli.ControllerConfig.EncryptionKeys.Encryptors(true) + kctx.FatalIfErrorf(err, "failed to create encryptors") ctx := log.ContextWithLogger(context.Background(), log.Configure(os.Stderr, cli.LogConfig)) err = observability.Init(ctx, false, "", "ftl-controller", ftl.Version, cli.ObservabilityConfig) @@ -58,7 +56,7 @@ func main() { kctx.FatalIfErrorf(err) err = otelsql.RegisterDBStatsMetrics(conn, otelsql.WithAttributes(semconv.DBSystemPostgreSQL)) kctx.FatalIfErrorf(err) - dal, err := dal.New(ctx, conn, optional.Some[string](*cli.ControllerConfig.KMSURI)) + dal, err := dal.New(ctx, conn, encryptors) kctx.FatalIfErrorf(err) configDal, err := cfdal.New(ctx, conn) @@ -79,6 +77,6 @@ func main() { kctx.FatalIfErrorf(err) ctx = cf.ContextWithSecrets(ctx, sm) - err = controller.Start(ctx, cli.ControllerConfig, scaling.NewK8sScaling(), conn) + err = controller.Start(ctx, cli.ControllerConfig, scaling.NewK8sScaling(), conn, encryptors) kctx.FatalIfErrorf(err) } diff --git a/cmd/ftl/cmd_box_run.go b/cmd/ftl/cmd_box_run.go index 1fe671759..541dd679a 100644 --- a/cmd/ftl/cmd_box_run.go +++ b/cmd/ftl/cmd_box_run.go @@ -67,10 +67,14 @@ func (b *boxRunCmd) Run(ctx context.Context, projConfig projectconfig.Config) er if err != nil { return fmt.Errorf("failed to register DB metrics: %w", err) } + encryptors, err := config.EncryptionKeys.Encryptors(false) + if err != nil { + return fmt.Errorf("failed to create encryptors: %w", err) + } wg := errgroup.Group{} wg.Go(func() error { - return controller.Start(ctx, config, runnerScaling, conn) + return controller.Start(ctx, config, runnerScaling, conn, encryptors) }) // Wait for the controller to come up. diff --git a/cmd/ftl/cmd_serve.go b/cmd/ftl/cmd_serve.go index 985181b91..5956677ce 100644 --- a/cmd/ftl/cmd_serve.go +++ b/cmd/ftl/cmd_serve.go @@ -148,7 +148,7 @@ func (s *serveCmd) run(ctx context.Context, projConfig projectconfig.Config, ini } controllerCtx = cf.ContextWithSecrets(controllerCtx, sm) - // Bring up the DB connection for the controller. + // Bring up the DB connection and DAL. conn, err := otelsql.Open("pgx", config.DSN) if err != nil { return fmt.Errorf("failed to bring up DB connection: %w", err) @@ -157,9 +157,13 @@ func (s *serveCmd) run(ctx context.Context, projConfig projectconfig.Config, ini if err != nil { return fmt.Errorf("failed to register DB metrics: %w", err) } + encryptors, err := config.EncryptionKeys.Encryptors(false) + if err != nil { + return fmt.Errorf("failed to create encryptors: %w", err) + } wg.Go(func() error { - if err := controller.Start(controllerCtx, config, runnerScaling, conn); err != nil { + if err := controller.Start(controllerCtx, config, runnerScaling, conn, encryptors); err != nil { logger.Errorf(err, "controller%d failed: %v", i, err) return fmt.Errorf("controller%d failed: %w", i, err) } diff --git a/common/configuration/sql/models.go b/common/configuration/sql/models.go index 731679ea0..5b5c1c76e 100644 --- a/common/configuration/sql/models.go +++ b/common/configuration/sql/models.go @@ -378,8 +378,8 @@ type AsyncCall struct { State AsyncCallState Origin string ScheduledAt time.Time - Request []byte - Response []byte + Request json.RawMessage + Response pqtype.NullRawMessage Error optional.Option[string] RemainingAttempts int32 Backoff sqltypes.Duration @@ -429,12 +429,6 @@ type DeploymentArtefact struct { Path string } -type EncryptionKey struct { - ID int64 - Key []byte - CreatedAt time.Time -} - type Event struct { ID int64 TimeStamp time.Time diff --git a/integration/harness.go b/integration/harness.go index 5a9281ec5..aae96f7f1 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -77,8 +77,11 @@ func RunWithoutController(t *testing.T, ftlConfigPath string, actions ...Action) } func RunWithEncryption(t *testing.T, ftlConfigPath string, actions ...Action) { - uri := "fake-kms://CKbvh_ILElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEE6tD2yE5AWYOirhmkY-r3sYARABGKbvh_ILIAE" - t.Setenv("FTL_KMS_URI", uri) + logKey := `{"primaryKeyId":1467957621,"key":[{"keyData":{"typeUrl":"type.googleapis.com/google.crypto.tink.AesCtrHmacStreamingKey","value":"Eg4IgIBAECAYAyIECAMQIBog7t16YRvohzTJBKt0D4WcqFpoeWH0C20Hr09v+AxbOOE=","keyMaterialType":"SYMMETRIC"},"status":"ENABLED","keyId":1467957621,"outputPrefixType":"RAW"}]}` + asyncKey := `{"primaryKeyId":2710864232,"key":[{"keyData":{"typeUrl":"type.googleapis.com/google.crypto.tink.AesCtrHmacStreamingKey","value":"Eg4IgIBAECAYAyIECAMQIBogTFCSLcJGRRazu74LrehNGL82J0sicjnjG5uNZcDyjGE=","keyMaterialType":"SYMMETRIC"},"status":"ENABLED","keyId":2710864232,"outputPrefixType":"RAW"}]}` + + t.Setenv("FTL_LOG_ENCRYPTION_KEY", logKey) + t.Setenv("FTL_ASYNC_ENCRYPTION_KEY", asyncKey) run(t, ftlConfigPath, true, false, actions...) } diff --git a/internal/encryption/encryption.go b/internal/encryption/encryption.go index b6816e598..59b1366e0 100644 --- a/internal/encryption/encryption.go +++ b/internal/encryption/encryption.go @@ -2,20 +2,177 @@ package encryption import ( "bytes" + "encoding/json" "fmt" + "io" "strings" awsv1kms "github.com/aws/aws-sdk-go/service/kms" "github.com/tink-crypto/tink-go-awskms/integration/awskms" "github.com/tink-crypto/tink-go/v2/aead" "github.com/tink-crypto/tink-go/v2/core/registry" + "github.com/tink-crypto/tink-go/v2/insecurecleartextkeyset" "github.com/tink-crypto/tink-go/v2/keyderivation" "github.com/tink-crypto/tink-go/v2/keyset" "github.com/tink-crypto/tink-go/v2/prf" + "github.com/tink-crypto/tink-go/v2/streamingaead" "github.com/tink-crypto/tink-go/v2/testing/fakekms" "github.com/tink-crypto/tink-go/v2/tink" ) +// Encryptable is an interface for encrypting and decrypting JSON payloads. +// Deprecated: This is will be changed or removed very soon. +type Encryptable interface { + EncryptJSON(input any) (json.RawMessage, error) + DecryptJSON(input json.RawMessage, output any) error +} + +// NewForKeyOrURI creates a new encryptor using the provided key or URI. +// Deprecated: This is will be changed or removed very soon. +func NewForKeyOrURI(keyOrURI string) (Encryptable, error) { + if len(keyOrURI) == 0 { + return NoOpEncryptor{}, nil + } + + // If keyOrUri is a JSON string, it is a clear text key set. + if strings.TrimSpace(keyOrURI)[0] == '{' { + return NewClearTextEncryptor(keyOrURI) + // Otherwise should be a URI for KMS. + // aws-kms://arn:aws:kms:[region]:[account-id]:key/[key-id] + } else if strings.HasPrefix(keyOrURI, "aws-kms://") { + panic("not implemented") + } + return nil, fmt.Errorf("unsupported key or uri: %s", keyOrURI) +} + +// NoOpEncryptor does not encrypt or decrypt and just passes the input as is. +// Deprecated: This is will be changed or removed very soon. +type NoOpEncryptor struct { +} + +func (n NoOpEncryptor) EncryptJSON(input any) (json.RawMessage, error) { + msg, err := json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + + return msg, nil +} + +func (n NoOpEncryptor) DecryptJSON(input json.RawMessage, output any) error { + err := json.Unmarshal(input, output) + if err != nil { + return fmt.Errorf("failed to unmarshal input: %w", err) + } + + return nil +} + +func NewClearTextEncryptor(key string) (Encryptable, error) { + keySetHandle, err := insecurecleartextkeyset.Read( + keyset.NewJSONReader(bytes.NewBufferString(key))) + if err != nil { + return nil, fmt.Errorf("failed to read clear text keyset: %w", err) + } + + encryptor, err := NewDeprecatedEncryptor(*keySetHandle) + if err != nil { + return nil, fmt.Errorf("failed to create clear text encryptor: %w", err) + } + + return encryptor, nil +} + +// NewDeprecatedEncryptor encrypts and decrypts JSON payloads using the provided key set. +// The key set must contain a primary key that is a streaming AEAD primitive. +func NewDeprecatedEncryptor(keySet keyset.Handle) (Encryptable, error) { + primitive, err := streamingaead.New(&keySet) + if err != nil { + return nil, fmt.Errorf("failed to create primitive during encryption: %w", err) + } + + return Encryptor{keySetHandle: keySet, primitive: primitive}, nil +} + +// Encryptor uses streaming with JSON payloads. +// Deprecated: This is will be changed or removed very soon. +type Encryptor struct { + keySetHandle keyset.Handle + primitive tink.StreamingAEAD +} + +// EncryptedPayload is a JSON payload that contains the encrypted data to put into the database. +// Deprecated: This is will be changed or removed very soon. +type EncryptedPayload struct { + Encrypted []byte `json:"encrypted"` +} + +func (e Encryptor) EncryptJSON(input any) (json.RawMessage, error) { + msg, err := json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + + encrypted, err := encryptBytesForStreaming(e.primitive, msg) + if err != nil { + return nil, fmt.Errorf("failed to encrypt data: %w", err) + } + + out, err := json.Marshal(EncryptedPayload{Encrypted: encrypted}) + if err != nil { + return nil, fmt.Errorf("failed to marshal encrypted data: %w", err) + } + return out, nil +} + +func (e Encryptor) DecryptJSON(input json.RawMessage, output any) error { + var payload EncryptedPayload + if err := json.Unmarshal(input, &payload); err != nil { + return fmt.Errorf("failed to unmarshal encrypted payload: %w", err) + } + + decryptedBuffer, err := decryptBytesForStreaming(e.primitive, payload.Encrypted) + if err != nil { + return fmt.Errorf("failed to decrypt data: %w", err) + } + + if err := json.Unmarshal(decryptedBuffer, output); err != nil { + return fmt.Errorf("failed to unmarshal decrypted data: %w", err) + } + + return nil +} + +func encryptBytesForStreaming(streamingPrimitive tink.StreamingAEAD, clearText []byte) ([]byte, error) { + encryptedBuffer := &bytes.Buffer{} + msgBuffer := bytes.NewBuffer(clearText) + writer, err := streamingPrimitive.NewEncryptingWriter(encryptedBuffer, nil) + if err != nil { + return nil, fmt.Errorf("failed to create encrypting writer: %w", err) + } + if _, err := io.Copy(writer, msgBuffer); err != nil { + return nil, fmt.Errorf("failed to copy encrypted data: %w", err) + } + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close encrypted writer: %w", err) + } + + return encryptedBuffer.Bytes(), nil +} + +func decryptBytesForStreaming(streamingPrimitive tink.StreamingAEAD, encrypted []byte) ([]byte, error) { + encryptedBuffer := bytes.NewReader(encrypted) + decryptedBuffer := &bytes.Buffer{} + reader, err := streamingPrimitive.NewDecryptingReader(encryptedBuffer, nil) + if err != nil { + return nil, fmt.Errorf("failed to create decrypting reader: %w", err) + } + if _, err := io.Copy(decryptedBuffer, reader); err != nil { + return nil, fmt.Errorf("failed to copy decrypted data: %w", err) + } + return decryptedBuffer.Bytes(), nil +} + type SubKey string const ( @@ -23,7 +180,7 @@ const ( AsyncSubKey SubKey = "async" ) -type DataEncryptor interface { +type EncryptorNext interface { Encrypt(subKey SubKey, cleartext []byte) ([]byte, error) Decrypt(subKey SubKey, encrypted []byte) ([]byte, error) } @@ -31,10 +188,6 @@ type DataEncryptor interface { // NoOpEncryptorNext does not encrypt and just passes the input as is. type NoOpEncryptorNext struct{} -func NewNoOpEncryptor() NoOpEncryptorNext { - return NoOpEncryptorNext{} -} - func (n NoOpEncryptorNext) Encrypt(_ SubKey, cleartext []byte) ([]byte, error) { return cleartext, nil } @@ -43,12 +196,44 @@ func (n NoOpEncryptorNext) Decrypt(_ SubKey, encrypted []byte) ([]byte, error) { return encrypted, nil } +type PlaintextEncryptor struct { + root keyset.Handle +} + +func NewPlaintextEncryptor(key string) (*PlaintextEncryptor, error) { + handle, err := insecurecleartextkeyset.Read( + keyset.NewJSONReader(bytes.NewBufferString(key))) + if err != nil { + return nil, fmt.Errorf("failed to read clear text keyset: %w", err) + } + + return &PlaintextEncryptor{root: *handle}, nil +} + +func (p PlaintextEncryptor) Encrypt(subKey SubKey, cleartext []byte) ([]byte, error) { + encrypted, err := derivedEncrypt(p.root, subKey, cleartext) + if err != nil { + return nil, fmt.Errorf("failed to encrypt with derive: %w", err) + } + + return encrypted, nil +} + +func (p PlaintextEncryptor) Decrypt(subKey SubKey, encrypted []byte) ([]byte, error) { + decrypted, err := derivedDecrypt(p.root, subKey, encrypted) + if err != nil { + return nil, fmt.Errorf("failed to decrypt with derive: %w", err) + } + + return decrypted, nil +} + // KMSEncryptor encrypts and decrypts using a KMS key via tink. +// TODO: maybe change to DerivableEncryptor and integrate plaintext and kms encryptor. type KMSEncryptor struct { root keyset.Handle kekAEAD tink.AEAD encryptedKeyset []byte - cachedDerived map[SubKey]tink.AEAD } func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) { @@ -117,7 +302,6 @@ func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncrypt root: *handle, kekAEAD: kekAEAD, encryptedKeyset: encryptedKeyset, - cachedDerived: make(map[SubKey]tink.AEAD), }, nil } @@ -137,7 +321,6 @@ func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset root: *handle, kekAEAD: kekAEAD, encryptedKeyset: encryptedKeyset, - cachedDerived: make(map[SubKey]tink.AEAD), }, nil } @@ -159,12 +342,26 @@ func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) { return derived, nil } -func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) { - if primitive, ok := k.cachedDerived[subKey]; ok { - return primitive, nil +func (k *KMSEncryptor) Encrypt(subKey SubKey, cleartext []byte) ([]byte, error) { + encrypted, err := derivedEncrypt(k.root, subKey, cleartext) + if err != nil { + return nil, fmt.Errorf("failed to encrypt with derive: %w", err) } - derived, err := deriveKeyset(k.root, []byte(subKey)) + return encrypted, nil +} + +func (k *KMSEncryptor) Decrypt(subKey SubKey, encrypted []byte) ([]byte, error) { + decrypted, err := derivedDecrypt(k.root, subKey, encrypted) + if err != nil { + return nil, fmt.Errorf("failed to decrypt with derive: %w", err) + } + + return decrypted, nil +} + +func derivedDecrypt(root keyset.Handle, subKey SubKey, encrypted []byte) ([]byte, error) { + derived, err := deriveKeyset(root, []byte(subKey)) if err != nil { return nil, fmt.Errorf("failed to derive keyset: %w", err) } @@ -174,34 +371,30 @@ func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) { return nil, fmt.Errorf("failed to create primitive: %w", err) } - k.cachedDerived[subKey] = primitive - return primitive, nil -} - -func (k *KMSEncryptor) Encrypt(subKey SubKey, cleartext []byte) ([]byte, error) { - primitive, err := k.getDerivedPrimitive(subKey) + bytes, err := primitive.Decrypt(encrypted, nil) if err != nil { - return nil, fmt.Errorf("failed to get derived primitive: %w", err) + return nil, fmt.Errorf("failed to decrypt: %w", err) } - encrypted, err := primitive.Encrypt(cleartext, nil) + return bytes, nil +} + +func derivedEncrypt(root keyset.Handle, subKey SubKey, cleartext []byte) ([]byte, error) { + // TODO: Deriving might be expensive, consider caching the derived keyset. + derived, err := deriveKeyset(root, []byte(subKey)) if err != nil { - return nil, fmt.Errorf("failed to encrypt: %w", err) + return nil, fmt.Errorf("failed to derive keyset: %w", err) } - return encrypted, nil -} - -func (k *KMSEncryptor) Decrypt(subKey SubKey, encrypted []byte) ([]byte, error) { - primitive, err := k.getDerivedPrimitive(subKey) + primitive, err := aead.New(derived) if err != nil { - return nil, fmt.Errorf("failed to get derived primitive: %w", err) + return nil, fmt.Errorf("failed to create primitive: %w", err) } - decrypted, err := primitive.Decrypt(encrypted, nil) + bytes, err := primitive.Encrypt(cleartext, nil) if err != nil { - return nil, fmt.Errorf("failed to decrypt: %w", err) + return nil, fmt.Errorf("failed to encrypt: %w", err) } - return decrypted, nil + return bytes, nil } diff --git a/internal/encryption/encryption_test.go b/internal/encryption/encryption_test.go index f13e4358d..84542d39a 100644 --- a/internal/encryption/encryption_test.go +++ b/internal/encryption/encryption_test.go @@ -1,11 +1,73 @@ package encryption import ( + "encoding/json" + "fmt" "testing" "github.com/alecthomas/assert/v2" ) +const streamingKey = `{ + "primaryKeyId": 1720777699, + "key": [{ + "keyData": { + "typeUrl": "type.googleapis.com/google.crypto.tink.AesCtrHmacStreamingKey", + "keyMaterialType": "SYMMETRIC", + "value": "Eg0IgCAQIBgDIgQIAxAgGiDtesd/4gCnQdTrh+AXodwpm2b6BFJkp043n+8mqx0YGw==" + }, + "outputPrefixType": "RAW", + "keyId": 1720777699, + "status": "ENABLED" + }] +}` + +func TestDeprecatedNewEncryptor(t *testing.T) { + jsonInput := "\"hello\"" + + encryptor, err := NewForKeyOrURI(streamingKey) + assert.NoError(t, err) + + encrypted, err := encryptor.EncryptJSON(jsonInput) + assert.NoError(t, err) + fmt.Printf("Encrypted: %s\n", encrypted) + + var decrypted json.RawMessage + err = encryptor.DecryptJSON(encrypted, &decrypted) + assert.NoError(t, err) + fmt.Printf("Decrypted: %s\n", decrypted) + + var decryptedString string + err = json.Unmarshal(decrypted, &decryptedString) + assert.NoError(t, err) + fmt.Printf("Decrypted string: %s\n", decryptedString) + + assert.Equal(t, jsonInput, decryptedString) +} + +// tinkey create-keyset --key-template HKDF_SHA256_DERIVES_AES256_GCM +const key = `{"primaryKeyId":2304101620,"key":[{"keyData":{"typeUrl":"type.googleapis.com/google.crypto.tink.PrfBasedDeriverKey","value":"El0KMXR5cGUuZ29vZ2xlYXBpcy5jb20vZ29vZ2xlLmNyeXB0by50aW5rLkhrZGZQcmZLZXkSJhICCAMaIDnEx9gPgeF32LQYjFYNSZe8b9KUl41Xy6to8MqKcSjBGAEaOgo4CjB0eXBlLmdvb2dsZWFwaXMuY29tL2dvb2dsZS5jcnlwdG8udGluay5BZXNHY21LZXkSAhAgGAE=","keyMaterialType":"SYMMETRIC"},"status":"ENABLED","keyId":2304101620,"outputPrefixType":"TINK"}]}` + +func TestPlaintextEncryptor(t *testing.T) { + encryptor, err := NewPlaintextEncryptor(key) + assert.NoError(t, err) + + encrypted, err := encryptor.Encrypt(LogsSubKey, []byte("hunter2")) + assert.NoError(t, err) + fmt.Printf("Encrypted: %s\n", encrypted) + + decrypted, err := encryptor.Decrypt(LogsSubKey, encrypted) + assert.NoError(t, err) + fmt.Printf("Decrypted: %s\n", decrypted) + + assert.Equal(t, "hunter2", string(decrypted)) + + // Should fail to decrypt with the wrong subkey + _, err = encryptor.Decrypt(AsyncSubKey, encrypted) + assert.Error(t, err) + +} + func TestNoOpEncryptor(t *testing.T) { encryptor := NoOpEncryptorNext{} @@ -18,9 +80,8 @@ func TestNoOpEncryptor(t *testing.T) { assert.Equal(t, "hunter2", string(decrypted)) } -// echo -n "fake-kms://" && tinkey create-keyset --key-template AES128_GCM --out-format binary | base64 | tr '+/' '-_' | tr -d '=' func TestKMSEncryptorFakeKMS(t *testing.T) { - uri := "fake-kms://CKbvh_ILElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEE6tD2yE5AWYOirhmkY-r3sYARABGKbvh_ILIAE" + uri := "fake-kms://CM2b3_MDElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEIK75t5L-adlUwVhWvRuWUwYARABGM2b3_MDIAE" encryptor, err := NewKMSEncryptorGenerateKey(uri, nil) assert.NoError(t, err) diff --git a/internal/encryption/integration_test.go b/internal/encryption/integration_test.go index 92cca64ad..bef904c8b 100644 --- a/internal/encryption/integration_test.go +++ b/internal/encryption/integration_test.go @@ -55,12 +55,13 @@ func TestEncryptionForLogs(t *testing.T) { values := in.GetRow(t, ic, "ftl", "SELECT payload FROM events WHERE type = 'call' LIMIT 1", 1) payload, ok := values[0].([]byte) assert.True(t, ok, "could not convert payload to string") + assert.Contains(t, string(payload), "encrypted", "raw request string should not be stored in the table") assert.NotContains(t, string(payload), "Alice", "raw request string should not be stored in the table") }, ) } -func TestEncryptionForPubSub(t *testing.T) { +func TestEncryptionForubSub(t *testing.T) { in.RunWithEncryption(t, "", in.CopyModule("encryption"), in.Deploy("encryption"), @@ -74,6 +75,7 @@ func TestEncryptionForPubSub(t *testing.T) { values := in.GetRow(t, ic, "ftl", "SELECT payload FROM topic_events", 1) payload, ok := values[0].([]byte) assert.True(t, ok, "could not convert payload to string") + assert.Contains(t, string(payload), "encrypted", "raw request string should not be stored in the table") assert.NotContains(t, string(payload), "AliceInWonderland", "raw request string should not be stored in the table") }, validateAsyncCall("consume", "AliceInWonderland"), @@ -101,6 +103,7 @@ func validateAsyncCall(verb string, sensitive string) in.Action { values := in.GetRow(t, ic, "ftl", fmt.Sprintf("SELECT request FROM async_calls WHERE verb = 'encryption.%s' AND state = 'success'", verb), 1) request, ok := values[0].([]byte) assert.True(t, ok, "could not convert payload to string") + assert.Contains(t, string(request), "encrypted", "raw request string should not be stored in the table") assert.NotContains(t, string(request), sensitive, "raw request string should not be stored in the table") } }