Skip to content

Commit

Permalink
validate encryption mode change on initalization and download
Browse files Browse the repository at this point in the history
  • Loading branch information
ostempel committed Oct 14, 2024
1 parent 7a6cd79 commit 9533151
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 30 deletions.
2 changes: 1 addition & 1 deletion cmd/internal/backup/providers/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type BackupProvider interface {
ListBackups(ctx context.Context) (BackupVersions, error)
CleanupBackups(ctx context.Context) error
GetNextBackupName(ctx context.Context) string
DownloadBackup(ctx context.Context, version *BackupVersion, outPath string) (string, error)
DownloadBackup(ctx context.Context, version *BackupVersion, outDir string) (string, error)
UploadBackup(ctx context.Context, sourcePath string) error
}

Expand Down
22 changes: 10 additions & 12 deletions cmd/internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/spf13/afero"
)

// Suffix is appended on encryption and removed on decryption from given input
// suffix is appended on encryption and removed on decryption from given input
const suffix = ".aes"

// Encrypter is used to encrypt/decrypt backups
Expand Down Expand Up @@ -96,8 +96,8 @@ func (e *Encrypter) Decrypt(inputPath string) (string, error) {
output := strings.TrimSuffix(inputPath, suffix)
e.log.Debug("decrypt", "input", inputPath, "output", output)

if err := e.validateInput(inputPath); err != nil {
return "", err
if IsEncrypted(inputPath) {
return "", fmt.Errorf("input is not encrypted")
}

infile, err := e.fs.Open(inputPath)
Expand Down Expand Up @@ -177,23 +177,21 @@ func (e *Encrypter) encryptFile(infile, outfile afero.File, block cipher.Block,
break
}
if err != nil {
e.log.Info("Read %d bytes: %v", strconv.Itoa(n), err)
e.log.Info("read %d bytes: %s", strconv.Itoa(n), err)
break
}
}

if _, err := outfile.Write(iv); err != nil {
return err
return fmt.Errorf("could not append iv: %w", err)
}

return nil
}

// validateInput() throws error if input file doesn't have encryption suffix
func (e *Encrypter) validateInput(input string) error {
if filepath.Ext(input) != suffix {
return fmt.Errorf("input is not encrypted")
}
return nil
// IsEncrypted() tests if target file is encrypted
func IsEncrypted(path string) bool {
return filepath.Ext(path) == suffix
}

// readIVAndMessageLength() returns initialization vector and message length for decryption
Expand Down Expand Up @@ -234,7 +232,7 @@ func (e *Encrypter) decryptFile(infile, outfile afero.File, block cipher.Block,
break
}
if err != nil {
e.log.Info("Read %d bytes: %v", strconv.Itoa(n), err)
e.log.Info("read %d bytes: %s", strconv.Itoa(n), err)
break
}
}
Expand Down
35 changes: 22 additions & 13 deletions cmd/internal/initializer/initializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,35 @@ func (i *Initializer) initialize(ctx context.Context) error {
i.currentStatus.Status = v1.StatusResponse_CHECKING
i.currentStatus.Message = "checking database"

needsBackup, err := i.db.Check(ctx)
versions, err := i.bp.ListBackups(ctx)
if err != nil {
return fmt.Errorf("unable to check data of database: %w", err)
return fmt.Errorf("unable retrieve backup versions: %w", err)
}

if !needsBackup {
i.log.Info("database does not need to be restored")
latestBackup := versions.Latest()
if latestBackup == nil {
i.log.Info("there are no backups available, it's a fresh database. allow database to start")
return nil
}

i.log.Info("database potentially needs to be restored, looking for backup")
if i.encrypter == nil {
if encryption.IsEncrypted(latestBackup.Name) {
return fmt.Errorf("latest backup is encrypted, but no encryption/decryption is configured")
}
}

versions, err := i.bp.ListBackups(ctx)
needsBackup, err := i.db.Check(ctx)
if err != nil {
return fmt.Errorf("unable retrieve backup versions: %w", err)
return fmt.Errorf("unable to check data of database: %w", err)
}

latestBackup := versions.Latest()
if latestBackup == nil {
i.log.Info("there are no backups available, it's a fresh database. allow database to start")
if !needsBackup {
i.log.Info("database does not need to be restored")
return nil
}

i.log.Info("database potentially needs to be restored, looking for backup")

err = i.Restore(ctx, latestBackup)
if err != nil {
return fmt.Errorf("unable to restore database: %w", err)
Expand Down Expand Up @@ -212,10 +218,13 @@ func (i *Initializer) Restore(ctx context.Context, version *providers.BackupVers
}

if i.encrypter != nil {
backupFilePath, err = i.encrypter.Decrypt(backupFilePath)
if err != nil {
return fmt.Errorf("unable to decrypt backup: %w", err)
if encryption.IsEncrypted(backupFilePath) {
backupFilePath, err = i.encrypter.Decrypt(backupFilePath)
if err != nil {
return fmt.Errorf("unable to decrypt backup: %w", err)
}
}
i.log.Info("restoring unencrypted backup with configured encryption - skipping decryption...")
}

i.currentStatus.Message = "uncompressing backup"
Expand Down
10 changes: 6 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ var downloadBackupCmd = &cobra.Command{
}

if encrypter != nil {
_, err = encrypter.Decrypt(destination)
if err != nil {
return fmt.Errorf("failed to decrypt: %w", err)
if encryption.IsEncrypted(destination) {
_, err = encrypter.Decrypt(destination)
if err != nil {
return fmt.Errorf("unable to decrypt backup: %w", err)
}
}
logger.Info("downloading unencrypted backup with configured encryption - skipping decryption...")
}

return nil
},
}
Expand Down
2 changes: 2 additions & 0 deletions integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ func restoreFlow(t *testing.T, spec *flowSpec) {
require.NoError(t, err)
require.NotNil(t, backup)

require.True(t, strings.HasSuffix(backup.Name, ".aes"))

Check failure on line 146 in integration/main_test.go

View workflow job for this annotation

GitHub Actions / Integration Test

avoid direct access to proto field backup.Name, use backup.GetName() instead (protogetter)

t.Log("remove sts and delete data volume")

err = c.Delete(ctx, spec.sts(ns.Name))
Expand Down

0 comments on commit 9533151

Please sign in to comment.