From c40d84c14bfb49c85b7e6faf318498c10f2a2ed5 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Mon, 16 Sep 2024 14:12:22 +1000 Subject: [PATCH] refactor: consolidate encryption logic into base encryption package (#2681) Quite a bit of logic was in encryption/api, which now only contains types. --- Justfile | 10 +- backend/controller/controller.go | 3 +- backend/controller/cronjobs/cronjobs_test.go | 3 +- backend/controller/dal/async_calls_test.go | 3 +- backend/controller/dal/dal_test.go | 5 +- backend/controller/dal/fsm_test.go | 3 +- .../controller/encryption/api/encryption.go | 235 ------------------ backend/controller/encryption/aws.go | 182 ++++++++++++++ backend/controller/encryption/builder.go | 49 ++++ backend/controller/encryption/dal/dal.go | 82 ++---- .../encryption/{api => }/encryption_test.go | 10 +- .../encryption/{api => }/integration_test.go | 7 +- backend/controller/encryption/noop.go | 23 ++ backend/controller/encryption/service.go | 91 ++++++- backend/controller/encryption/service_test.go | 4 +- .../testdata/go/encryption/encryption.go | 0 .../{api => }/testdata/go/encryption/ftl.toml | 0 .../{api => }/testdata/go/encryption/go.mod | 0 .../{api => }/testdata/go/encryption/go.sum | 0 backend/controller/timeline/dal/dal_test.go | 5 +- 20 files changed, 388 insertions(+), 327 deletions(-) create mode 100644 backend/controller/encryption/aws.go create mode 100644 backend/controller/encryption/builder.go rename backend/controller/encryption/{api => }/encryption_test.go (83%) rename backend/controller/encryption/{api => }/integration_test.go (97%) create mode 100644 backend/controller/encryption/noop.go rename backend/controller/encryption/{api => }/testdata/go/encryption/encryption.go (100%) rename backend/controller/encryption/{api => }/testdata/go/encryption/ftl.toml (100%) rename backend/controller/encryption/{api => }/testdata/go/encryption/go.mod (100%) rename backend/controller/encryption/{api => }/testdata/go/encryption/go.sum (100%) diff --git a/Justfile b/Justfile index 1223c01aa..d03d86619 100644 --- a/Justfile +++ b/Justfile @@ -39,8 +39,7 @@ dev *args: watchexec -r {{WATCHEXEC_ARGS}} -- "just build-sqlc && ftl dev {{args}}" # Build everything -build-all: build-protos-unconditionally build-frontend build-generate build-sqlc build-zips lsp-generate build-java - @just build ftl ftl-controller ftl-runner ftl-initdb +build-all: build-protos-unconditionally build-backend build-backend-tests build-frontend build-generate build-sqlc build-zips lsp-generate build-java # Run "go generate" on all packages build-generate: @@ -51,17 +50,22 @@ build-generate: build +tools: build-protos build-zips build-frontend #!/bin/bash shopt -s extglob + set -x if [ "${FTL_DEBUG:-}" = "true" ]; then for tool in $@; do go build -o "{{RELEASE}}/$tool" -tags release -gcflags=all="-N -l" -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done else - for tool in $@; do mk "{{RELEASE}}/$tool" : !(build|integration) -- go build -o "{{RELEASE}}/$tool" -tags release -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done + for tool in $@; do mk "{{RELEASE}}/$tool" : !(build|integration|node_modules|Procfile*|Dockerfile*) -- go build -o "{{RELEASE}}/$tool" -tags release -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done fi # Build all backend binaries build-backend: just build ftl ftl-controller ftl-runner +# Build all backend tests +build-backend-tests: + go test -run ^NONE -tags integration ./... + build-java *args: mvn -f jvm-runtime/ftl-runtime install {{args}} diff --git a/backend/controller/controller.go b/backend/controller/controller.go index eb2b490c6..726d847ac 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -39,7 +39,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/cronjobs" "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/encryption" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/ingress" "github.com/TBD54566975/ftl/backend/controller/leases" leasesdal "github.com/TBD54566975/ftl/backend/controller/leases/dal" @@ -237,7 +236,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, devel bool) (*Service config.ControllerTimeout = time.Second * 5 } - encryption, err := encryption.New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI))) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI))) if err != nil { return nil, fmt.Errorf("failed to create encryption dal: %w", err) } diff --git a/backend/controller/cronjobs/cronjobs_test.go b/backend/controller/cronjobs/cronjobs_test.go index a39386600..0e6b75f53 100644 --- a/backend/controller/cronjobs/cronjobs_test.go +++ b/backend/controller/cronjobs/cronjobs_test.go @@ -15,7 +15,6 @@ import ( "github.com/TBD54566975/ftl/backend/controller/cronjobs/dal" parentdal "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/encryption" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/libdal" "github.com/TBD54566975/ftl/backend/schema" @@ -37,7 +36,7 @@ func TestNewCronJobsForModule(t *testing.T) { dal := dal.New(conn) uri := "fake-kms://CK6YwYkBElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEJy4TIQgfCuwxA3ZZgChp_wYARABGK6YwYkBIAE" - encryption, err := encryption.New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Some(uri))) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri))) assert.NoError(t, err) parentDAL := parentdal.New(ctx, conn, encryption) diff --git a/backend/controller/dal/async_calls_test.go b/backend/controller/dal/async_calls_test.go index 5f3b98833..e69c68bd3 100644 --- a/backend/controller/dal/async_calls_test.go +++ b/backend/controller/dal/async_calls_test.go @@ -7,7 +7,6 @@ import ( "github.com/alecthomas/assert/v2" "github.com/TBD54566975/ftl/backend/controller/encryption" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/libdal" "github.com/TBD54566975/ftl/backend/schema" @@ -18,7 +17,7 @@ import ( func TestNoCallToAcquire(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, api.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(ctx, conn, encryption) diff --git a/backend/controller/dal/dal_test.go b/backend/controller/dal/dal_test.go index 76cf8ba11..dafcbe653 100644 --- a/backend/controller/dal/dal_test.go +++ b/backend/controller/dal/dal_test.go @@ -13,7 +13,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/TBD54566975/ftl/backend/controller/encryption" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/libdal" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" @@ -26,7 +25,7 @@ import ( func TestDAL(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, api.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(ctx, conn, encryption) @@ -223,7 +222,7 @@ func TestDAL(t *testing.T) { func TestCreateArtefactConflict(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, api.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(ctx, conn, encryption) diff --git a/backend/controller/dal/fsm_test.go b/backend/controller/dal/fsm_test.go index 1cd415ab5..4f7fa1a6c 100644 --- a/backend/controller/dal/fsm_test.go +++ b/backend/controller/dal/fsm_test.go @@ -9,7 +9,6 @@ import ( "github.com/alecthomas/types/either" "github.com/TBD54566975/ftl/backend/controller/encryption" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" leasedal "github.com/TBD54566975/ftl/backend/controller/leases/dal" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/libdal" @@ -20,7 +19,7 @@ import ( func TestSendFSMEvent(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, api.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(ctx, conn, encryption) diff --git a/backend/controller/encryption/api/encryption.go b/backend/controller/encryption/api/encryption.go index f0d3dd327..cc7d5aa79 100644 --- a/backend/controller/encryption/api/encryption.go +++ b/backend/controller/encryption/api/encryption.go @@ -1,23 +1,7 @@ package api import ( - "bytes" "context" - "fmt" - "strings" - - "github.com/alecthomas/types/optional" - awsv1kms "github.com/aws/aws-sdk-go/service/kms" - "github.com/tink-crypto/tink-go-awskms/integration/awskms" - "github.com/tink-crypto/tink-go/v2/aead" - "github.com/tink-crypto/tink-go/v2/core/registry" - "github.com/tink-crypto/tink-go/v2/keyderivation" - "github.com/tink-crypto/tink-go/v2/keyset" - "github.com/tink-crypto/tink-go/v2/prf" - "github.com/tink-crypto/tink-go/v2/testing/fakekms" - "github.com/tink-crypto/tink-go/v2/tink" - - "github.com/TBD54566975/ftl/internal/mutex" ) // Encrypted is an interface for values that contain encrypted data. @@ -34,226 +18,7 @@ type KeyStoreProvider interface { EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error) } -// Builder constructs a DataEncryptor when used with a provider. -// Use a chain of With* methods to configure the builder. -type Builder struct { - kmsURI optional.Option[string] -} - -func NewBuilder() Builder { - return Builder{ - kmsURI: optional.None[string](), - } -} - -// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor. -func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder { - b.kmsURI = kmsURI - return b -} - -func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) { - kmsURI, ok := b.kmsURI.Get() - if !ok { - return NewNoOpEncryptor(), nil - } - - key, err := provider.EnsureKey(ctx, func() ([]byte, error) { - return newKey(kmsURI, nil) - }) - if err != nil { - return nil, fmt.Errorf("failed to ensure key from provider: %w", err) - } - - encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key) - if err != nil { - return nil, fmt.Errorf("failed to create KMS encryptor: %w", err) - } - - return encryptor, nil -} - type DataEncryptor interface { Encrypt(cleartext []byte, dest Encrypted) error Decrypt(encrypted Encrypted) ([]byte, error) } - -// NoOpEncryptor does not encrypt and just passes the input as is. -type NoOpEncryptor struct{} - -func NewNoOpEncryptor() NoOpEncryptor { - return NoOpEncryptor{} -} - -var _ DataEncryptor = NoOpEncryptor{} - -func (n NoOpEncryptor) Encrypt(cleartext []byte, dest Encrypted) error { - dest.Set(cleartext) - return nil -} - -func (n NoOpEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) { - return encrypted.Bytes(), nil -} - -// KMSEncryptor encrypts and decrypts using a KMS key via tink. -type KMSEncryptor struct { - root keyset.Handle - kekAEAD tink.AEAD - encryptedKeyset []byte - cachedDerived *mutex.Mutex[map[SubKey]tink.AEAD] -} - -var _ DataEncryptor = &KMSEncryptor{} - -func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) { - var client registry.KMSClient - var err error - - if strings.HasPrefix(strings.ToLower(uri), "fake-kms://") { - client, err = fakekms.NewClient(uri) - if err != nil { - return nil, fmt.Errorf("failed to create fake KMS client: %w", err) - } - - } else { - // tink does not support awsv2 yet - // https://github.com/tink-crypto/tink-go-awskms/issues/2 - var opts []awskms.ClientOption - if kms != nil { - opts = append(opts, awskms.WithKMS(kms)) - } - client, err = awskms.NewClientWithOptions(uri, opts...) - if err != nil { - return nil, fmt.Errorf("failed to create KMS client: %w", err) - } - } - - kekAEAD, err := client.GetAEAD(uri) - if err != nil { - return nil, fmt.Errorf("failed to get aead: %w", err) - } - - return kekAEAD, nil -} - -func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) { - kekAEAD, err := newClientWithAEAD(uri, v1client) - if err != nil { - return nil, fmt.Errorf("failed to create KMS client: %w", err) - } - - // Create a PRF key template using HKDF-SHA256 - prfKeyTemplate := prf.HKDFSHA256PRFKeyTemplate() - - // Create an AES-256-GCM key template - aeadKeyTemplate := aead.AES256GCMKeyTemplate() - - template, err := keyderivation.CreatePRFBasedKeyTemplate(prfKeyTemplate, aeadKeyTemplate) - if err != nil { - return nil, fmt.Errorf("failed to create PRF based key template: %w", err) - } - - handle, err := keyset.NewHandle(template) - if err != nil { - return nil, fmt.Errorf("failed to create keyset handle: %w", err) - } - - // Encrypt the keyset with the KEK AEAD. - buf := new(bytes.Buffer) - writer := keyset.NewBinaryWriter(buf) - err = handle.Write(writer, kekAEAD) - if err != nil { - return nil, fmt.Errorf("failed to encrypt DEK: %w", err) - } - return buf.Bytes(), nil -} - -func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) { - kekAEAD, err := newClientWithAEAD(uri, v1client) - if err != nil { - return nil, fmt.Errorf("failed to create KMS client: %w", err) - } - - reader := keyset.NewBinaryReader(bytes.NewReader(encryptedKeyset)) - handle, err := keyset.Read(reader, kekAEAD) - if err != nil { - return nil, fmt.Errorf("failed to read keyset: %w", err) - } - - return &KMSEncryptor{ - root: *handle, - kekAEAD: kekAEAD, - encryptedKeyset: encryptedKeyset, - cachedDerived: mutex.New(map[SubKey]tink.AEAD{}), - }, nil -} - -func (k *KMSEncryptor) GetEncryptedKeyset() []byte { - return k.encryptedKeyset -} - -func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) { - deriver, err := keyderivation.New(&root) - if err != nil { - return nil, fmt.Errorf("failed to create deriver: %w", err) - } - - derived, err := deriver.DeriveKeyset(salt) - if err != nil { - return nil, fmt.Errorf("failed to derive keyset: %w", err) - } - - return derived, nil -} - -func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) { - cachedDerived := k.cachedDerived.Lock() - defer k.cachedDerived.Unlock() - primitive, ok := cachedDerived[subKey] - if ok { - return primitive, nil - } - - derived, err := deriveKeyset(k.root, []byte(subKey.SubKey())) - if err != nil { - return nil, fmt.Errorf("failed to derive keyset: %w", err) - } - - primitive, err = aead.New(derived) - if err != nil { - return nil, fmt.Errorf("failed to create primitive: %w", err) - } - - cachedDerived[subKey] = primitive - return primitive, nil -} - -func (k *KMSEncryptor) Encrypt(cleartext []byte, dest Encrypted) error { - primitive, err := k.getDerivedPrimitive(dest) - if err != nil { - return fmt.Errorf("%s: failed to get derived primitive: %w", dest.SubKey(), err) - } - - encrypted, err := primitive.Encrypt(cleartext, nil) - if err != nil { - return fmt.Errorf("%s: failed to encrypt: %w", dest.SubKey(), err) - } - - dest.Set(encrypted) - return nil -} - -func (k *KMSEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) { - primitive, err := k.getDerivedPrimitive(encrypted) - if err != nil { - return nil, fmt.Errorf("%s: failed to get derived primitive: %w", encrypted.SubKey(), err) - } - - decrypted, err := primitive.Decrypt(encrypted.Bytes(), nil) - if err != nil { - return nil, fmt.Errorf("%s: failed to decrypt: %w", encrypted.SubKey(), err) - } - - return decrypted, nil -} diff --git a/backend/controller/encryption/aws.go b/backend/controller/encryption/aws.go new file mode 100644 index 000000000..34e01f414 --- /dev/null +++ b/backend/controller/encryption/aws.go @@ -0,0 +1,182 @@ +package encryption + +import ( + "bytes" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/service/kms" + "github.com/tink-crypto/tink-go-awskms/integration/awskms" + "github.com/tink-crypto/tink-go/v2/aead" + "github.com/tink-crypto/tink-go/v2/core/registry" + "github.com/tink-crypto/tink-go/v2/keyderivation" + "github.com/tink-crypto/tink-go/v2/keyset" + "github.com/tink-crypto/tink-go/v2/prf" + "github.com/tink-crypto/tink-go/v2/testing/fakekms" + "github.com/tink-crypto/tink-go/v2/tink" + + "github.com/TBD54566975/ftl/backend/controller/encryption/api" + "github.com/TBD54566975/ftl/internal/mutex" +) + +// KMSEncryptor encrypts and decrypts using a KMS key via tink. +type KMSEncryptor struct { + root keyset.Handle + kekAEAD tink.AEAD + encryptedKeyset []byte + cachedDerived *mutex.Mutex[map[api.SubKey]tink.AEAD] +} + +var _ api.DataEncryptor = &KMSEncryptor{} + +func newClientWithAEAD(uri string, kms *kms.KMS) (tink.AEAD, error) { + var client registry.KMSClient + var err error + + if strings.HasPrefix(strings.ToLower(uri), "fake-kms://") { + client, err = fakekms.NewClient(uri) + if err != nil { + return nil, fmt.Errorf("failed to create fake KMS client: %w", err) + } + + } else { + // tink does not support awsv2 yet + // https://github.com/tink-crypto/tink-go-awskms/issues/2 + var opts []awskms.ClientOption + if kms != nil { + opts = append(opts, awskms.WithKMS(kms)) + } + client, err = awskms.NewClientWithOptions(uri, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create KMS client: %w", err) + } + } + + kekAEAD, err := client.GetAEAD(uri) + if err != nil { + return nil, fmt.Errorf("failed to get aead: %w", err) + } + + return kekAEAD, nil +} + +func newKey(uri string, v1client *kms.KMS) ([]byte, error) { + kekAEAD, err := newClientWithAEAD(uri, v1client) + if err != nil { + return nil, fmt.Errorf("failed to create KMS client: %w", err) + } + + // Create a PRF key template using HKDF-SHA256 + prfKeyTemplate := prf.HKDFSHA256PRFKeyTemplate() + + // Create an AES-256-GCM key template + aeadKeyTemplate := aead.AES256GCMKeyTemplate() + + template, err := keyderivation.CreatePRFBasedKeyTemplate(prfKeyTemplate, aeadKeyTemplate) + if err != nil { + return nil, fmt.Errorf("failed to create PRF based key template: %w", err) + } + + handle, err := keyset.NewHandle(template) + if err != nil { + return nil, fmt.Errorf("failed to create keyset handle: %w", err) + } + + // Encrypt the keyset with the KEK AEAD. + buf := new(bytes.Buffer) + writer := keyset.NewBinaryWriter(buf) + err = handle.Write(writer, kekAEAD) + if err != nil { + return nil, fmt.Errorf("failed to encrypt DEK: %w", err) + } + return buf.Bytes(), nil +} + +func NewKMSEncryptorWithKMS(uri string, v1client *kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) { + kekAEAD, err := newClientWithAEAD(uri, v1client) + if err != nil { + return nil, fmt.Errorf("failed to create KMS client: %w", err) + } + + reader := keyset.NewBinaryReader(bytes.NewReader(encryptedKeyset)) + handle, err := keyset.Read(reader, kekAEAD) + if err != nil { + return nil, fmt.Errorf("failed to read keyset: %w", err) + } + + return &KMSEncryptor{ + root: *handle, + kekAEAD: kekAEAD, + encryptedKeyset: encryptedKeyset, + cachedDerived: mutex.New(map[api.SubKey]tink.AEAD{}), + }, nil +} + +func (k *KMSEncryptor) GetEncryptedKeyset() []byte { + return k.encryptedKeyset +} + +func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) { + deriver, err := keyderivation.New(&root) + if err != nil { + return nil, fmt.Errorf("failed to create deriver: %w", err) + } + + derived, err := deriver.DeriveKeyset(salt) + if err != nil { + return nil, fmt.Errorf("failed to derive keyset: %w", err) + } + + return derived, nil +} + +func (k *KMSEncryptor) getDerivedPrimitive(subKey api.SubKey) (tink.AEAD, error) { + cachedDerived := k.cachedDerived.Lock() + defer k.cachedDerived.Unlock() + primitive, ok := cachedDerived[subKey] + if ok { + return primitive, nil + } + + derived, err := deriveKeyset(k.root, []byte(subKey.SubKey())) + if err != nil { + return nil, fmt.Errorf("failed to derive keyset: %w", err) + } + + primitive, err = aead.New(derived) + if err != nil { + return nil, fmt.Errorf("failed to create primitive: %w", err) + } + + cachedDerived[subKey] = primitive + return primitive, nil +} + +func (k *KMSEncryptor) Encrypt(cleartext []byte, dest api.Encrypted) error { + primitive, err := k.getDerivedPrimitive(dest) + if err != nil { + return fmt.Errorf("%s: failed to get derived primitive: %w", dest.SubKey(), err) + } + + encrypted, err := primitive.Encrypt(cleartext, nil) + if err != nil { + return fmt.Errorf("%s: failed to encrypt: %w", dest.SubKey(), err) + } + + dest.Set(encrypted) + return nil +} + +func (k *KMSEncryptor) Decrypt(encrypted api.Encrypted) ([]byte, error) { + primitive, err := k.getDerivedPrimitive(encrypted) + if err != nil { + return nil, fmt.Errorf("%s: failed to get derived primitive: %w", encrypted.SubKey(), err) + } + + decrypted, err := primitive.Decrypt(encrypted.Bytes(), nil) + if err != nil { + return nil, fmt.Errorf("%s: failed to decrypt: %w", encrypted.SubKey(), err) + } + + return decrypted, nil +} diff --git a/backend/controller/encryption/builder.go b/backend/controller/encryption/builder.go new file mode 100644 index 000000000..131f3a301 --- /dev/null +++ b/backend/controller/encryption/builder.go @@ -0,0 +1,49 @@ +package encryption + +import ( + "context" + "fmt" + + "github.com/alecthomas/types/optional" + + "github.com/TBD54566975/ftl/backend/controller/encryption/api" +) + +// Builder constructs a DataEncryptor when used with a provider. +// Use a chain of With* methods to configure the builder. +type Builder struct { + kmsURI optional.Option[string] +} + +func NewBuilder() Builder { + return Builder{ + kmsURI: optional.None[string](), + } +} + +// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor. +func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder { + b.kmsURI = kmsURI + return b +} + +func (b Builder) Build(ctx context.Context, provider api.KeyStoreProvider) (api.DataEncryptor, error) { + kmsURI, ok := b.kmsURI.Get() + if !ok { + return NewNoOpEncryptor(), nil + } + + key, err := provider.EnsureKey(ctx, func() ([]byte, error) { + return newKey(kmsURI, nil) + }) + if err != nil { + return nil, fmt.Errorf("failed to ensure key from provider: %w", err) + } + + encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key) + if err != nil { + return nil, fmt.Errorf("failed to create KMS encryptor: %w", err) + } + + return encryptor, nil +} diff --git a/backend/controller/encryption/dal/dal.go b/backend/controller/encryption/dal/dal.go index 3717fe3b0..02cdc946a 100644 --- a/backend/controller/encryption/dal/dal.go +++ b/backend/controller/encryption/dal/dal.go @@ -4,8 +4,6 @@ import ( "context" "fmt" - "github.com/alecthomas/types/optional" - "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/encryption/dal/internal/sql" "github.com/TBD54566975/ftl/backend/libdal" @@ -60,85 +58,41 @@ func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) return row.Key, nil } -const verification = "FTL - Towards a 𝝺-calculus for large-scale systems" +// VerificationKeys contains the verification keys for the timeline and async encryption. +type VerificationKeys struct { + VerifyTimeline api.OptionalEncryptedTimelineColumn + VerifyAsync api.OptionalEncryptedAsyncColumn +} -func (d *DAL) VerifyEncryptor(ctx context.Context, encryptor api.DataEncryptor) (err error) { +func (d *DAL) GetVerificationKeys(ctx context.Context) (keys VerificationKeys, err error) { tx, err := d.Begin(ctx) if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return VerificationKeys{}, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.CommitOrRollback(ctx, &err) row, err := tx.db.GetOnlyEncryptionKey(ctx) if err != nil { - if libdal.IsNotFound(err) { - // No encryption key found, probably using noop. - return nil - } - return fmt.Errorf("failed to get encryption row from the db: %w", err) + return VerificationKeys{}, fmt.Errorf("failed to get encryption key from the db: %w", err) } - needsUpdate := false - newTimeline, err := verifySubkey(encryptor, row.VerifyTimeline) - if err != nil { - return fmt.Errorf("failed to verify timeline subkey: %w", err) - } - if newTimeline.Ok() { - needsUpdate = true - row.VerifyTimeline = newTimeline - } + return VerificationKeys{ + VerifyTimeline: row.VerifyTimeline, + VerifyAsync: row.VerifyAsync, + }, nil +} - newAsync, err := verifySubkey(encryptor, row.VerifyAsync) +func (d *DAL) UpdateVerificationKeys(ctx context.Context, keys VerificationKeys) (err error) { + tx, err := d.Begin(ctx) if err != nil { - return fmt.Errorf("failed to verify async subkey: %w", err) - } - if newAsync.Ok() { - needsUpdate = true - row.VerifyAsync = newAsync - } - - if !needsUpdate { - return nil - } - - if !row.VerifyTimeline.Ok() || !row.VerifyAsync.Ok() { - panic("should be unreachable. verifySubkey should have set the subkey") + return fmt.Errorf("failed to begin transaction: %w", err) } + defer tx.CommitOrRollback(ctx, &err) - err = tx.db.UpdateEncryptionVerification(ctx, row.VerifyTimeline, row.VerifyAsync) + err = tx.db.UpdateEncryptionVerification(ctx, keys.VerifyTimeline, keys.VerifyAsync) if err != nil { return fmt.Errorf("failed to update encryption verification: %w", err) } return nil } - -// verifySubkey checks if the subkey is set and if not, sets it to a verification string. -// returns (nil, nil) if verified and not changed -func verifySubkey[SK api.SubKey]( - encryptor api.DataEncryptor, - encrypted optional.Option[api.EncryptedColumn[SK]], -) (optional.Option[api.EncryptedColumn[SK]], error) { - type EC = api.EncryptedColumn[SK] - - verifyField, ok := encrypted.Get() - if !ok { - err := encryptor.Encrypt([]byte(verification), &verifyField) - if err != nil { - return optional.None[EC](), fmt.Errorf("failed to encrypt verification sanity string: %w", err) - } - return optional.Some(verifyField), nil - } - - decrypted, err := encryptor.Decrypt(&verifyField) - if err != nil { - return optional.None[EC](), fmt.Errorf("failed to decrypt verification sanity string: %w", err) - } - - if string(decrypted) != verification { - return optional.None[EC](), fmt.Errorf("decrypted verification string does not match expected value") - } - - // verified, no need to update - return optional.None[EC](), nil -} diff --git a/backend/controller/encryption/api/encryption_test.go b/backend/controller/encryption/encryption_test.go similarity index 83% rename from backend/controller/encryption/api/encryption_test.go rename to backend/controller/encryption/encryption_test.go index 5b20f5751..efab0af50 100644 --- a/backend/controller/encryption/api/encryption_test.go +++ b/backend/controller/encryption/encryption_test.go @@ -1,15 +1,17 @@ -package api +package encryption import ( "testing" "github.com/alecthomas/assert/v2" + + "github.com/TBD54566975/ftl/backend/controller/encryption/api" ) func TestNoOpEncryptor(t *testing.T) { encryptor := NoOpEncryptor{} - var encrypted EncryptedTimelineColumn + var encrypted api.EncryptedTimelineColumn err := encryptor.Encrypt([]byte("hunter2"), &encrypted) assert.NoError(t, err) @@ -29,7 +31,7 @@ func TestKMSEncryptorFakeKMS(t *testing.T) { encryptor, err := NewKMSEncryptorWithKMS(uri, nil, key) assert.NoError(t, err) - var encrypted EncryptedTimelineColumn + var encrypted api.EncryptedTimelineColumn err = encryptor.Encrypt([]byte("hunter2"), &encrypted) assert.NoError(t, err) @@ -37,7 +39,7 @@ func TestKMSEncryptorFakeKMS(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "hunter2", string(decrypted)) - wrongSubKey := EncryptedAsyncColumn(encrypted) + wrongSubKey := api.EncryptedAsyncColumn(encrypted) // Should fail to decrypt with the wrong subkey _, err = encryptor.Decrypt(&wrongSubKey) assert.Error(t, err) diff --git a/backend/controller/encryption/api/integration_test.go b/backend/controller/encryption/integration_test.go similarity index 97% rename from backend/controller/encryption/api/integration_test.go rename to backend/controller/encryption/integration_test.go index 4911ea4da..155668413 100644 --- a/backend/controller/encryption/api/integration_test.go +++ b/backend/controller/encryption/integration_test.go @@ -1,6 +1,6 @@ //go:build integration -package api +package encryption import ( "context" @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/TBD54566975/ftl/backend/controller/encryption/api" pbconsole "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/console" in "github.com/TBD54566975/ftl/internal/integration" "github.com/TBD54566975/ftl/internal/log" @@ -152,7 +153,7 @@ func TestKMSEncryptorLocalstack(t *testing.T) { encryptor, err := NewKMSEncryptorWithKMS(uri, v1client, key) assert.NoError(t, err) - var encrypted EncryptedTimelineColumn + var encrypted api.EncryptedTimelineColumn err = encryptor.Encrypt([]byte("hunter2"), &encrypted) assert.NoError(t, err) @@ -161,7 +162,7 @@ func TestKMSEncryptorLocalstack(t *testing.T) { assert.Equal(t, "hunter2", string(decrypted)) // Should fail to decrypt with the wrong subkey - wrongSubKey := EncryptedAsyncColumn(encrypted) + wrongSubKey := api.EncryptedAsyncColumn(encrypted) _, err = encryptor.Decrypt(&wrongSubKey) assert.Error(t, err) } diff --git a/backend/controller/encryption/noop.go b/backend/controller/encryption/noop.go new file mode 100644 index 000000000..0c6f21d89 --- /dev/null +++ b/backend/controller/encryption/noop.go @@ -0,0 +1,23 @@ +package encryption + +import ( + "github.com/TBD54566975/ftl/backend/controller/encryption/api" +) + +// NoOpEncryptor does not encrypt and just passes the input as is. +type NoOpEncryptor struct{} + +func NewNoOpEncryptor() NoOpEncryptor { + return NoOpEncryptor{} +} + +var _ api.DataEncryptor = NoOpEncryptor{} + +func (n NoOpEncryptor) Encrypt(cleartext []byte, dest api.Encrypted) error { + dest.Set(cleartext) + return nil +} + +func (n NoOpEncryptor) Decrypt(encrypted api.Encrypted) ([]byte, error) { + return encrypted.Bytes(), nil +} diff --git a/backend/controller/encryption/service.go b/backend/controller/encryption/service.go index 96a5c6c70..896e3de39 100644 --- a/backend/controller/encryption/service.go +++ b/backend/controller/encryption/service.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" + "github.com/alecthomas/types/optional" + "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/encryption/dal" "github.com/TBD54566975/ftl/backend/libdal" @@ -14,7 +16,7 @@ type Service struct { encryptor api.DataEncryptor } -func New(ctx context.Context, conn libdal.Connection, encryptionBuilder api.Builder) (*Service, error) { +func New(ctx context.Context, conn libdal.Connection, encryptionBuilder Builder) (*Service, error) { d := dal.New(ctx, conn) encryptor, err := encryptionBuilder.Build(ctx, d) @@ -22,7 +24,7 @@ func New(ctx context.Context, conn libdal.Connection, encryptionBuilder api.Buil return nil, fmt.Errorf("build encryptor: %w", err) } - if err := d.VerifyEncryptor(ctx, encryptor); err != nil { + if err := verifyEncryptor(ctx, d, encryptor); err != nil { return nil, fmt.Errorf("verify encryptor: %w", err) } @@ -70,3 +72,88 @@ func (s *Service) Decrypt(encrypted api.Encrypted) ([]byte, error) { return v, nil } + +const verification = "FTL - Towards a 𝝺-calculus for large-scale systems" + +func verifyEncryptor(ctx context.Context, d *dal.DAL, encryptor api.DataEncryptor) (err error) { + tx, err := d.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.CommitOrRollback(ctx, &err) + + keys, err := tx.GetVerificationKeys(ctx) + if err != nil { + if libdal.IsNotFound(err) { + if _, ok := encryptor.(NoOpEncryptor); ok { + return nil + } + return fmt.Errorf("no encryption key found in the db for encryptor %T: %w", encryptor, err) + } + return fmt.Errorf("failed to get encryption row from the db: %w", err) + } + + needsUpdate := false + newTimeline, err := verifySubkey(encryptor, keys.VerifyTimeline) + if err != nil { + return fmt.Errorf("failed to verify timeline subkey: %w", err) + } + if newTimeline.Ok() { + needsUpdate = true + keys.VerifyTimeline = newTimeline + } + + newAsync, err := verifySubkey(encryptor, keys.VerifyAsync) + if err != nil { + return fmt.Errorf("failed to verify async subkey: %w", err) + } + if newAsync.Ok() { + needsUpdate = true + keys.VerifyAsync = newAsync + } + + if !needsUpdate { + return nil + } + + if !keys.VerifyTimeline.Ok() || !keys.VerifyAsync.Ok() { + panic("should be unreachable. verifySubkey should have set the subkey") + } + + err = tx.UpdateVerificationKeys(ctx, keys) + if err != nil { + return fmt.Errorf("failed to update encryption verification: %w", err) + } + + return nil +} + +// verifySubkey checks if the subkey is set and if not, sets it to a verification string. +// returns (nil, nil) if verified and not changed +func verifySubkey[SK api.SubKey]( + encryptor api.DataEncryptor, + encrypted optional.Option[api.EncryptedColumn[SK]], +) (optional.Option[api.EncryptedColumn[SK]], error) { + type EC = api.EncryptedColumn[SK] + + verifyField, ok := encrypted.Get() + if !ok { + err := encryptor.Encrypt([]byte(verification), &verifyField) + if err != nil { + return optional.None[EC](), fmt.Errorf("failed to encrypt verification sanity string: %w", err) + } + return optional.Some(verifyField), nil + } + + decrypted, err := encryptor.Decrypt(&verifyField) + if err != nil { + return optional.None[EC](), fmt.Errorf("failed to decrypt verification sanity string: %w", err) + } + + if string(decrypted) != verification { + return optional.None[EC](), fmt.Errorf("decrypted verification string does not match expected value") + } + + // verified, no need to update + return optional.None[EC](), nil +} diff --git a/backend/controller/encryption/service_test.go b/backend/controller/encryption/service_test.go index b67ccc268..f5b9751e7 100644 --- a/backend/controller/encryption/service_test.go +++ b/backend/controller/encryption/service_test.go @@ -19,7 +19,7 @@ func TestEncryptionService(t *testing.T) { uri := "fake-kms://CK6YwYkBElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEJy4TIQgfCuwxA3ZZgChp_wYARABGK6YwYkBIAE" t.Run("EncryptDecryptJSON", func(t *testing.T) { - service, err := New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Some(uri))) + service, err := New(ctx, conn, NewBuilder().WithKMSURI(optional.Some(uri))) assert.NoError(t, err) type TestStruct struct { @@ -40,7 +40,7 @@ func TestEncryptionService(t *testing.T) { }) t.Run("EncryptDecryptBinary", func(t *testing.T) { - service, err := New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Some(uri))) + service, err := New(ctx, conn, NewBuilder().WithKMSURI(optional.Some(uri))) assert.NoError(t, err) original := []byte("Hello, World!") diff --git a/backend/controller/encryption/api/testdata/go/encryption/encryption.go b/backend/controller/encryption/testdata/go/encryption/encryption.go similarity index 100% rename from backend/controller/encryption/api/testdata/go/encryption/encryption.go rename to backend/controller/encryption/testdata/go/encryption/encryption.go diff --git a/backend/controller/encryption/api/testdata/go/encryption/ftl.toml b/backend/controller/encryption/testdata/go/encryption/ftl.toml similarity index 100% rename from backend/controller/encryption/api/testdata/go/encryption/ftl.toml rename to backend/controller/encryption/testdata/go/encryption/ftl.toml diff --git a/backend/controller/encryption/api/testdata/go/encryption/go.mod b/backend/controller/encryption/testdata/go/encryption/go.mod similarity index 100% rename from backend/controller/encryption/api/testdata/go/encryption/go.mod rename to backend/controller/encryption/testdata/go/encryption/go.mod diff --git a/backend/controller/encryption/api/testdata/go/encryption/go.sum b/backend/controller/encryption/testdata/go/encryption/go.sum similarity index 100% rename from backend/controller/encryption/api/testdata/go/encryption/go.sum rename to backend/controller/encryption/testdata/go/encryption/go.sum diff --git a/backend/controller/timeline/dal/dal_test.go b/backend/controller/timeline/dal/dal_test.go index 1f9ec7438..80fc2fb01 100644 --- a/backend/controller/timeline/dal/dal_test.go +++ b/backend/controller/timeline/dal/dal_test.go @@ -12,7 +12,6 @@ import ( controllerdal "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/encryption" - ftlencryption "github.com/TBD54566975/ftl/backend/controller/encryption/api" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/schema" "github.com/TBD54566975/ftl/internal/log" @@ -23,7 +22,7 @@ import ( func TestTimelineDAL(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(conn, encryption) @@ -157,7 +156,7 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) { func TestDeleteOldEvents(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - encryption, err := encryption.New(ctx, conn, ftlencryption.NewBuilder()) + encryption, err := encryption.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) dal := New(conn, encryption)