diff --git a/pkg/sm/sm.go b/pkg/sm/sm.go index bed8060bf..101b23c4d 100644 --- a/pkg/sm/sm.go +++ b/pkg/sm/sm.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "os/signal" + "time" "github.com/Peripli/service-manager/api" "github.com/Peripli/service-manager/config" @@ -85,7 +86,7 @@ func New(ctx context.Context, cancel context.CancelFunc, env env.Environment) *S } securityStorage := smStorage.Security() - if err := initializeSecureStorage(securityStorage); err != nil { + if err := initializeSecureStorage(ctx, securityStorage); err != nil { panic(fmt.Sprintf("error initialzing secure storage: %v", err)) } @@ -154,7 +155,12 @@ func (smb *ServiceManagerBuilder) Build() *ServiceManager { } } -func initializeSecureStorage(secureStorage storage.Security) error { +func initializeSecureStorage(ctx context.Context, secureStorage storage.Security) error { + ctx, cancelFunc := context.WithTimeout(ctx, 2*time.Second) + defer cancelFunc() + if err := secureStorage.Lock(ctx); err != nil { + return err + } keyFetcher := secureStorage.Fetcher() encryptionKey, err := keyFetcher.GetEncryptionKey() if err != nil { @@ -172,7 +178,7 @@ func initializeSecureStorage(secureStorage storage.Security) error { } logrus.Debug("Successfully generated new encryption key") } - return nil + return secureStorage.Unlock() } func handleInterrupts(ctx context.Context, cancelFunc context.CancelFunc) { diff --git a/storage/interfaces.go b/storage/interfaces.go index a3f826a9c..dda99e5ba 100644 --- a/storage/interfaces.go +++ b/storage/interfaces.go @@ -18,6 +18,7 @@ package storage import ( + "context" "fmt" "github.com/Peripli/service-manager/pkg/types" @@ -107,9 +108,14 @@ type Credentials interface { } // Security interface for encryption key operations -type Security interface{ +type Security interface { + // Lock locks the storage so that only one process can manipulate the encryption key. + // Returns an error if the process has already acquired the lock + Lock(ctx context.Context) error + // Unlock releases the acquired lock. + Unlock() error // Fetcher provides means to obtain the encryption key Fetcher() security.KeyFetcher // Setter provides means to change the encryption key Setter() security.KeySetter -} \ No newline at end of file +} diff --git a/storage/postgres/security.go b/storage/postgres/security.go index 16d1bae72..43a64a8bb 100644 --- a/storage/postgres/security.go +++ b/storage/postgres/security.go @@ -17,7 +17,9 @@ package postgres import ( + "context" "fmt" + "sync" "time" "github.com/Peripli/service-manager/security" @@ -25,9 +27,43 @@ import ( "github.com/sirupsen/logrus" ) +const securityLockIndex = 111 + type securityStorage struct { db *sqlx.DB encryptionKey []byte + isLocked bool + mutex *sync.Mutex +} + +// Lock acquires a database lock so that only one process can manipulate the encryption key. +// Returns an error if the process has already acquired the lock +func (s *securityStorage) Lock(ctx context.Context) error { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.isLocked { + return fmt.Errorf("Lock is already acquired") + } + if _, err := s.db.ExecContext(ctx, "SELECT pg_advisory_lock($1)", securityLockIndex); err != nil { + return err + } + s.isLocked = true + return nil +} + +// Unlock releases the database lock. +func (s *securityStorage) Unlock() error { + s.mutex.Lock() + defer s.mutex.Unlock() + if !s.isLocked { + return nil + } + + if _, err := s.db.Exec("SELECT pg_advisory_unlock($1)", securityLockIndex); err != nil { + return err + } + s.isLocked = false + return nil } // Fetcher returns a KeyFetcher configured to fetch a key from the database diff --git a/storage/postgres/security_test.go b/storage/postgres/security_test.go index 04580b033..e4e87ec40 100644 --- a/storage/postgres/security_test.go +++ b/storage/postgres/security_test.go @@ -17,10 +17,12 @@ package postgres import ( + "context" "crypto/rand" "database/sql" "fmt" - "time" + "sync" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/Peripli/service-manager/security" @@ -159,4 +161,73 @@ var _ = Describe("Security", func() { }) }) }) + + Describe("Locking", func() { + var mockdb *sql.DB + var mock sqlmock.Sqlmock + var storage *securityStorage + envEncryptionKey := make([]byte, 32) + + JustBeforeEach(func() { + storage = &securityStorage{ + db: sqlx.NewDb(mockdb, "sqlmock"), + encryptionKey: envEncryptionKey, + isLocked: false, + mutex: &sync.Mutex{}, + } + }) + BeforeEach(func() { + mockdb, mock, _ = sqlmock.New() + rand.Read(envEncryptionKey) + }) + AfterEach(func() { + mockdb.Close() + }) + + Describe("Lock", func() { + + Context("When lock is already acquired", func() { + It("Should return an error", func() { + storage.isLocked = true + err := storage.Lock(context.TODO()) + Expect(err).ToNot(BeNil()) + }) + }) + + Context("When lock is not yet acquired", func() { + AfterEach(func() { + storage.Unlock() + }) + BeforeEach(func() { + mock.ExpectExec("SELECT").WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + }) + It("Should acquire lock", func() { + err := storage.Lock(context.TODO()) + Expect(err).To(BeNil()) + Expect(storage.isLocked).To(Equal(true)) + }) + }) + }) + + Describe("Unlock", func() { + Context("When lock is not acquired", func() { + It("Should return nil", func() { + storage.isLocked = false + err := storage.Unlock() + Expect(err).To(BeNil()) + }) + }) + Context("When lock is acquired", func() { + BeforeEach(func() { + mock.ExpectExec("SELECT").WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + }) + It("Should release lock", func() { + storage.isLocked = true + err := storage.Unlock() + Expect(err).To(BeNil()) + Expect(storage.isLocked).To(Equal(false)) + }) + }) + }) + }) }) diff --git a/storage/postgres/storage.go b/storage/postgres/storage.go index 8a1ea2179..2621ffd92 100644 --- a/storage/postgres/storage.go +++ b/storage/postgres/storage.go @@ -18,11 +18,10 @@ package postgres import ( + "fmt" "sync" "time" - "fmt" - "github.com/Peripli/service-manager/storage" "github.com/golang-migrate/migrate" migratepg "github.com/golang-migrate/migrate/database/postgres" @@ -39,9 +38,10 @@ func init() { } type postgresStorage struct { - db *sqlx.DB - state *storageState + db *sqlx.DB + state *storageState encryptionKey []byte + mutex *sync.Mutex } func (storage *postgresStorage) checkOpen() { @@ -72,9 +72,9 @@ func (storage *postgresStorage) Credentials() storage.Credentials { return &credentialStorage{storage.db} } -func (storage *postgresStorage) Security() storage.Security{ +func (storage *postgresStorage) Security() storage.Security { storage.checkOpen() - return &securityStorage{storage.db, storage.encryptionKey} + return &securityStorage{storage.db, storage.encryptionKey, false, storage.mutex} } func (storage *postgresStorage) Open(uri string, encryptionKey []byte) error { @@ -94,6 +94,7 @@ func (storage *postgresStorage) Open(uri string, encryptionKey []byte) error { db: storage.db, storageCheckInterval: time.Second * 5, } + storage.mutex = &sync.Mutex{} storage.encryptionKey = encryptionKey logrus.Debug("Updating database schema") if err := updateSchema(storage.db); err != nil {