diff --git a/internal/encryption/encryption.go b/internal/encryption/encryption.go index 3f56062d9..aafafc387 100644 --- a/internal/encryption/encryption.go +++ b/internal/encryption/encryption.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "strings" - "sync" "github.com/alecthomas/types/optional" awsv1kms "github.com/aws/aws-sdk-go/service/kms" @@ -17,6 +16,8 @@ import ( "github.com/tink-crypto/tink-go/v2/prf" "github.com/tink-crypto/tink-go/v2/testing/fakekms" "github.com/tink-crypto/tink-go/v2/tink" + + "github.com/TBD54566975/ftl/internal/mutex" ) // Encrypted is an interface for values that contain encrypted data. @@ -100,8 +101,7 @@ type KMSEncryptor struct { root keyset.Handle kekAEAD tink.AEAD encryptedKeyset []byte - cachedDerivedMu sync.RWMutex - cachedDerived map[SubKey]tink.AEAD + cachedDerived *mutex.Mutex[map[SubKey]tink.AEAD] } var _ DataEncryptor = &KMSEncryptor{} @@ -185,7 +185,7 @@ func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset root: *handle, kekAEAD: kekAEAD, encryptedKeyset: encryptedKeyset, - cachedDerived: make(map[SubKey]tink.AEAD), + cachedDerived: mutex.New(map[SubKey]tink.AEAD{}), }, nil } @@ -208,9 +208,9 @@ func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) { } func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) { - k.cachedDerivedMu.RLock() - primitive, ok := k.cachedDerived[subKey] - k.cachedDerivedMu.RUnlock() + cachedDerived := k.cachedDerived.Lock() + defer k.cachedDerived.Unlock() + primitive, ok := cachedDerived[subKey] if ok { return primitive, nil } @@ -225,10 +225,7 @@ func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) { return nil, fmt.Errorf("failed to create primitive: %w", err) } - k.cachedDerivedMu.Lock() - k.cachedDerived[subKey] = primitive - k.cachedDerivedMu.Unlock() - + cachedDerived[subKey] = primitive return primitive, nil } diff --git a/internal/mutex/mutex.go b/internal/mutex/mutex.go new file mode 100644 index 000000000..2e411c0fb --- /dev/null +++ b/internal/mutex/mutex.go @@ -0,0 +1,33 @@ +package mutex + +import "sync" + +// Mutex is a simple mutex that can be used to protect a value. +// +// The zero value is safe to use if the zero value of T is safe to use. +// +// Example: +// +// var m mutex.Mutex[*string] +// s := m.Lock() +// defer m.Unlock() +// *s = "hello" +type Mutex[T any] struct { + m sync.Mutex + v T +} + +func New[T any](v T) *Mutex[T] { + return &Mutex[T]{v: v} +} + +// Lock the Mutex and return its protected value. +func (l *Mutex[T]) Lock() T { + l.m.Lock() + return l.v +} + +// Unlock the Mutex. The value returned by Lock is no longer valid. +func (l *Mutex[T]) Unlock() { + l.m.Unlock() +}