Skip to content

Commit

Permalink
refactor: consolidate encryption logic into base encryption package (#…
Browse files Browse the repository at this point in the history
…2681)

Quite a bit of logic was in encryption/api, which now only contains
types.
  • Loading branch information
alecthomas authored Sep 16, 2024
1 parent c6f4754 commit c40d84c
Show file tree
Hide file tree
Showing 20 changed files with 388 additions and 327 deletions.
10 changes: 7 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ dev *args:
watchexec -r {{WATCHEXEC_ARGS}} -- "just build-sqlc && ftl dev {{args}}"

# Build everything
build-all: build-protos-unconditionally build-frontend build-generate build-sqlc build-zips lsp-generate build-java
@just build ftl ftl-controller ftl-runner ftl-initdb
build-all: build-protos-unconditionally build-backend build-backend-tests build-frontend build-generate build-sqlc build-zips lsp-generate build-java

# Run "go generate" on all packages
build-generate:
Expand All @@ -51,17 +50,22 @@ build-generate:
build +tools: build-protos build-zips build-frontend
#!/bin/bash
shopt -s extglob
set -x

if [ "${FTL_DEBUG:-}" = "true" ]; then
for tool in $@; do go build -o "{{RELEASE}}/$tool" -tags release -gcflags=all="-N -l" -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done
else
for tool in $@; do mk "{{RELEASE}}/$tool" : !(build|integration) -- go build -o "{{RELEASE}}/$tool" -tags release -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done
for tool in $@; do mk "{{RELEASE}}/$tool" : !(build|integration|node_modules|Procfile*|Dockerfile*) -- go build -o "{{RELEASE}}/$tool" -tags release -ldflags "-X github.com/TBD54566975/ftl.Version={{VERSION}} -X github.com/TBD54566975/ftl.Timestamp={{TIMESTAMP}}" "./cmd/$tool"; done
fi

# Build all backend binaries
build-backend:
just build ftl ftl-controller ftl-runner

# Build all backend tests
build-backend-tests:
go test -run ^NONE -tags integration ./...

build-java *args:
mvn -f jvm-runtime/ftl-runtime install {{args}}

Expand Down
3 changes: 1 addition & 2 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (
"github.com/TBD54566975/ftl/backend/controller/cronjobs"
"github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/encryption/api"
"github.com/TBD54566975/ftl/backend/controller/ingress"
"github.com/TBD54566975/ftl/backend/controller/leases"
leasesdal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
Expand Down Expand Up @@ -237,7 +236,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, devel bool) (*Service
config.ControllerTimeout = time.Second * 5
}

encryption, err := encryption.New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create encryption dal: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdal "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/encryption/api"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
Expand All @@ -37,7 +36,7 @@ func TestNewCronJobsForModule(t *testing.T) {
dal := dal.New(conn)

uri := "fake-kms://CK6YwYkBElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEJy4TIQgfCuwxA3ZZgChp_wYARABGK6YwYkBIAE"
encryption, err := encryption.New(ctx, conn, api.NewBuilder().WithKMSURI(optional.Some(uri)))
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Some(uri)))
assert.NoError(t, err)

parentDAL := parentdal.New(ctx, conn, encryption)
Expand Down
3 changes: 1 addition & 2 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/encryption/api"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
Expand All @@ -18,7 +17,7 @@ import (
func TestNoCallToAcquire(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
encryption, err := encryption.New(ctx, conn, api.NewBuilder())
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)
Expand Down
5 changes: 2 additions & 3 deletions backend/controller/dal/dal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"golang.org/x/sync/errgroup"

"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/encryption/api"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
Expand All @@ -26,7 +25,7 @@ import (
func TestDAL(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
encryption, err := encryption.New(ctx, conn, api.NewBuilder())
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)
Expand Down Expand Up @@ -223,7 +222,7 @@ func TestDAL(t *testing.T) {
func TestCreateArtefactConflict(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
encryption, err := encryption.New(ctx, conn, api.NewBuilder())
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)
Expand Down
3 changes: 1 addition & 2 deletions backend/controller/dal/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/alecthomas/types/either"

"github.com/TBD54566975/ftl/backend/controller/encryption"
"github.com/TBD54566975/ftl/backend/controller/encryption/api"
leasedal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/backend/libdal"
Expand All @@ -20,7 +19,7 @@ import (
func TestSendFSMEvent(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
encryption, err := encryption.New(ctx, conn, api.NewBuilder())
encryption, err := encryption.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

dal := New(ctx, conn, encryption)
Expand Down
235 changes: 0 additions & 235 deletions backend/controller/encryption/api/encryption.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
package api

import (
"bytes"
"context"
"fmt"
"strings"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
"github.com/tink-crypto/tink-go-awskms/integration/awskms"
"github.com/tink-crypto/tink-go/v2/aead"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/keyderivation"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/prf"
"github.com/tink-crypto/tink-go/v2/testing/fakekms"
"github.com/tink-crypto/tink-go/v2/tink"

"github.com/TBD54566975/ftl/internal/mutex"
)

// Encrypted is an interface for values that contain encrypted data.
Expand All @@ -34,226 +18,7 @@ type KeyStoreProvider interface {
EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error)
}

// Builder constructs a DataEncryptor when used with a provider.
// Use a chain of With* methods to configure the builder.
type Builder struct {
kmsURI optional.Option[string]
}

func NewBuilder() Builder {
return Builder{
kmsURI: optional.None[string](),
}
}

// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor.
func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder {
b.kmsURI = kmsURI
return b
}

func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) {
kmsURI, ok := b.kmsURI.Get()
if !ok {
return NewNoOpEncryptor(), nil
}

key, err := provider.EnsureKey(ctx, func() ([]byte, error) {
return newKey(kmsURI, nil)
})
if err != nil {
return nil, fmt.Errorf("failed to ensure key from provider: %w", err)
}

encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key)
if err != nil {
return nil, fmt.Errorf("failed to create KMS encryptor: %w", err)
}

return encryptor, nil
}

type DataEncryptor interface {
Encrypt(cleartext []byte, dest Encrypted) error
Decrypt(encrypted Encrypted) ([]byte, error)
}

// NoOpEncryptor does not encrypt and just passes the input as is.
type NoOpEncryptor struct{}

func NewNoOpEncryptor() NoOpEncryptor {
return NoOpEncryptor{}
}

var _ DataEncryptor = NoOpEncryptor{}

func (n NoOpEncryptor) Encrypt(cleartext []byte, dest Encrypted) error {
dest.Set(cleartext)
return nil
}

func (n NoOpEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) {
return encrypted.Bytes(), nil
}

// KMSEncryptor encrypts and decrypts using a KMS key via tink.
type KMSEncryptor struct {
root keyset.Handle
kekAEAD tink.AEAD
encryptedKeyset []byte
cachedDerived *mutex.Mutex[map[SubKey]tink.AEAD]
}

var _ DataEncryptor = &KMSEncryptor{}

func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) {
var client registry.KMSClient
var err error

if strings.HasPrefix(strings.ToLower(uri), "fake-kms://") {
client, err = fakekms.NewClient(uri)
if err != nil {
return nil, fmt.Errorf("failed to create fake KMS client: %w", err)
}

} else {
// tink does not support awsv2 yet
// https://github.com/tink-crypto/tink-go-awskms/issues/2
var opts []awskms.ClientOption
if kms != nil {
opts = append(opts, awskms.WithKMS(kms))
}
client, err = awskms.NewClientWithOptions(uri, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
}
}

kekAEAD, err := client.GetAEAD(uri)
if err != nil {
return nil, fmt.Errorf("failed to get aead: %w", err)
}

return kekAEAD, nil
}

func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) {
kekAEAD, err := newClientWithAEAD(uri, v1client)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
}

// Create a PRF key template using HKDF-SHA256
prfKeyTemplate := prf.HKDFSHA256PRFKeyTemplate()

// Create an AES-256-GCM key template
aeadKeyTemplate := aead.AES256GCMKeyTemplate()

template, err := keyderivation.CreatePRFBasedKeyTemplate(prfKeyTemplate, aeadKeyTemplate)
if err != nil {
return nil, fmt.Errorf("failed to create PRF based key template: %w", err)
}

handle, err := keyset.NewHandle(template)
if err != nil {
return nil, fmt.Errorf("failed to create keyset handle: %w", err)
}

// Encrypt the keyset with the KEK AEAD.
buf := new(bytes.Buffer)
writer := keyset.NewBinaryWriter(buf)
err = handle.Write(writer, kekAEAD)
if err != nil {
return nil, fmt.Errorf("failed to encrypt DEK: %w", err)
}
return buf.Bytes(), nil
}

func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) {
kekAEAD, err := newClientWithAEAD(uri, v1client)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
}

reader := keyset.NewBinaryReader(bytes.NewReader(encryptedKeyset))
handle, err := keyset.Read(reader, kekAEAD)
if err != nil {
return nil, fmt.Errorf("failed to read keyset: %w", err)
}

return &KMSEncryptor{
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: mutex.New(map[SubKey]tink.AEAD{}),
}, nil
}

func (k *KMSEncryptor) GetEncryptedKeyset() []byte {
return k.encryptedKeyset
}

func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) {
deriver, err := keyderivation.New(&root)
if err != nil {
return nil, fmt.Errorf("failed to create deriver: %w", err)
}

derived, err := deriver.DeriveKeyset(salt)
if err != nil {
return nil, fmt.Errorf("failed to derive keyset: %w", err)
}

return derived, nil
}

func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
cachedDerived := k.cachedDerived.Lock()
defer k.cachedDerived.Unlock()
primitive, ok := cachedDerived[subKey]
if ok {
return primitive, nil
}

derived, err := deriveKeyset(k.root, []byte(subKey.SubKey()))
if err != nil {
return nil, fmt.Errorf("failed to derive keyset: %w", err)
}

primitive, err = aead.New(derived)
if err != nil {
return nil, fmt.Errorf("failed to create primitive: %w", err)
}

cachedDerived[subKey] = primitive
return primitive, nil
}

func (k *KMSEncryptor) Encrypt(cleartext []byte, dest Encrypted) error {
primitive, err := k.getDerivedPrimitive(dest)
if err != nil {
return fmt.Errorf("%s: failed to get derived primitive: %w", dest.SubKey(), err)
}

encrypted, err := primitive.Encrypt(cleartext, nil)
if err != nil {
return fmt.Errorf("%s: failed to encrypt: %w", dest.SubKey(), err)
}

dest.Set(encrypted)
return nil
}

func (k *KMSEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) {
primitive, err := k.getDerivedPrimitive(encrypted)
if err != nil {
return nil, fmt.Errorf("%s: failed to get derived primitive: %w", encrypted.SubKey(), err)
}

decrypted, err := primitive.Decrypt(encrypted.Bytes(), nil)
if err != nil {
return nil, fmt.Errorf("%s: failed to decrypt: %w", encrypted.SubKey(), err)
}

return decrypted, nil
}
Loading

0 comments on commit c40d84c

Please sign in to comment.