From f85c04b1895fee21827b85ba31951f925f642d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Budnik?= Date: Thu, 15 Jul 2021 14:54:49 +0200 Subject: [PATCH] s3 baseLocation now supports optional prefixes for nested buckets --- .gitignore | 1 + README.md | 4 +-- config/config_test.go | 3 ++- loader/disk_loader.go | 3 +-- loader/s3_loader.go | 41 +++++++++++++++++----------- loader/s3_loader_test.go | 52 ++++++++++++++++++++++++++++++++++++ test/migrator-test-envs.yaml | 2 +- 7 files changed, 85 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 435b506..7677bbb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ test/migrator.yaml coverage-*.txt coverage.txt debug.test +.vscode diff --git a/README.md b/README.md index 9123e6d..30105c1 100644 --- a/README.md +++ b/README.md @@ -548,11 +548,11 @@ baseLocation: /project/migrations ### AWS S3 -If `baseLocation` starts with `s3://` prefix, AWS S3 implementation is used. In such case the `baseLocation` property is treated as a bucket name: +If `baseLocation` starts with `s3://` prefix, AWS S3 implementation is used. In such case the `baseLocation` property is treated as a bucket name followed by optional prefix: ``` # S3 bucket -baseLocation: s3://your-bucket-migrator +baseLocation: s3://your-bucket-migrator/application-x/prod ``` migrator uses official AWS SDK for Go and uses a well known [default credential provider chain](https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html). Please setup your env variables accordingly. diff --git a/config/config_test.go b/config/config_test.go index d67b7fc..66ba45a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -26,9 +26,10 @@ func TestFromFile(t *testing.T) { } func TestWithEnvFromFile(t *testing.T) { + os.Setenv("COMMIT_SHA", "62fd74506651982fe317721d7e07145f8c2fa166") config, err := FromFile("../test/migrator-test-envs.yaml") assert.Nil(t, err) - assert.Equal(t, os.Getenv("TERM"), config.BaseLocation) + assert.Equal(t, "s3://bucket-name/application-x/"+os.Getenv("TERM")+"/"+os.Getenv("COMMIT_SHA"), config.BaseLocation) assert.Equal(t, os.Getenv("PATH"), config.TenantSelectSQL) assert.Equal(t, os.Getenv("GOPATH"), config.TenantInsertSQL) assert.Equal(t, os.Getenv("PWD"), config.Driver) diff --git a/loader/disk_loader.go b/loader/disk_loader.go index ebc8478..1c42f78 100644 --- a/loader/disk_loader.go +++ b/loader/disk_loader.go @@ -3,12 +3,11 @@ package loader import ( "crypto/sha256" "encoding/hex" + "fmt" "io/ioutil" "path/filepath" "strings" - "fmt" - "github.com/lukaszbudnik/migrator/types" ) diff --git a/loader/s3_loader.go b/loader/s3_loader.go index 3bc5c0a..44ff685 100644 --- a/loader/s3_loader.go +++ b/loader/s3_loader.go @@ -32,37 +32,50 @@ func (s3l *s3Loader) GetSourceMigrations() []types.Migration { func (s3l *s3Loader) doGetSourceMigrations(client s3iface.S3API) []types.Migration { migrations := []types.Migration{} - singleMigrationsObjects := s3l.getObjectList(client, s3l.config.SingleMigrations) - tenantMigrationsObjects := s3l.getObjectList(client, s3l.config.TenantMigrations) - singleScriptsObjects := s3l.getObjectList(client, s3l.config.SingleScripts) - tenantScriptsObjects := s3l.getObjectList(client, s3l.config.TenantScripts) + bucketWithPrefixes := strings.Split(strings.Replace(strings.TrimRight(s3l.config.BaseLocation, "/"), "s3://", "", 1), "/") + + bucket := bucketWithPrefixes[0] + optionalPrefixes := "" + if len(bucketWithPrefixes) > 1 { + optionalPrefixes = strings.Join(bucketWithPrefixes[1:], "/") + } + + singleMigrationsObjects := s3l.getObjectList(client, bucket, optionalPrefixes, s3l.config.SingleMigrations) + tenantMigrationsObjects := s3l.getObjectList(client, bucket, optionalPrefixes, s3l.config.TenantMigrations) + singleScriptsObjects := s3l.getObjectList(client, bucket, optionalPrefixes, s3l.config.SingleScripts) + tenantScriptsObjects := s3l.getObjectList(client, bucket, optionalPrefixes, s3l.config.TenantScripts) migrationsMap := make(map[string][]types.Migration) - s3l.getObjects(client, migrationsMap, singleMigrationsObjects, types.MigrationTypeSingleMigration) - s3l.getObjects(client, migrationsMap, tenantMigrationsObjects, types.MigrationTypeTenantMigration) + s3l.getObjects(client, bucket, migrationsMap, singleMigrationsObjects, types.MigrationTypeSingleMigration) + s3l.getObjects(client, bucket, migrationsMap, tenantMigrationsObjects, types.MigrationTypeTenantMigration) s3l.sortMigrations(migrationsMap, &migrations) migrationsMap = make(map[string][]types.Migration) - s3l.getObjects(client, migrationsMap, singleScriptsObjects, types.MigrationTypeSingleScript) + s3l.getObjects(client, bucket, migrationsMap, singleScriptsObjects, types.MigrationTypeSingleScript) s3l.sortMigrations(migrationsMap, &migrations) migrationsMap = make(map[string][]types.Migration) - s3l.getObjects(client, migrationsMap, tenantScriptsObjects, types.MigrationTypeTenantScript) + s3l.getObjects(client, bucket, migrationsMap, tenantScriptsObjects, types.MigrationTypeTenantScript) s3l.sortMigrations(migrationsMap, &migrations) return migrations } -func (s3l *s3Loader) getObjectList(client s3iface.S3API, prefixes []string) []*string { +func (s3l *s3Loader) getObjectList(client s3iface.S3API, bucket, optionalPrefixes string, prefixes []string) []*string { objects := []*string{} - bucket := strings.Replace(s3l.config.BaseLocation, "s3://", "", 1) - for _, prefix := range prefixes { + var fullPrefix string + if optionalPrefixes != "" { + fullPrefix = optionalPrefixes + "/" + prefix + } else { + fullPrefix = prefix + } + input := &s3.ListObjectsV2Input{ Bucket: aws.String(bucket), - Prefix: aws.String(prefix), + Prefix: aws.String(fullPrefix), MaxKeys: aws.Int64(1000), } @@ -84,9 +97,7 @@ func (s3l *s3Loader) getObjectList(client s3iface.S3API, prefixes []string) []*s return objects } -func (s3l *s3Loader) getObjects(client s3iface.S3API, migrationsMap map[string][]types.Migration, objects []*string, migrationType types.MigrationType) { - bucket := strings.Replace(s3l.config.BaseLocation, "s3://", "", 1) - +func (s3l *s3Loader) getObjects(client s3iface.S3API, bucket string, migrationsMap map[string][]types.Migration, objects []*string, migrationType types.MigrationType) { objectInput := &s3.GetObjectInput{Bucket: aws.String(bucket)} for _, o := range objects { objectInput.Key = o diff --git a/loader/s3_loader_test.go b/loader/s3_loader_test.go index 518dd7c..b57fce5 100644 --- a/loader/s3_loader_test.go +++ b/loader/s3_loader_test.go @@ -45,6 +45,28 @@ func (m *mockS3Client) ListObjectsV2Pages(input *s3.ListObjectsV2Input, callback file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "cleanup.sql"))} file3 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "run-reports.sql"))} contents = []*s3.Object{file1, file2, file3} + case "application-x/prod/migrations/config": + file1 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "201602160001.sql"))} + file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "201602160002.sql"))} + contents = []*s3.Object{file1, file2} + case "application-x/prod/migrations/ref": + file1 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "202001100003.sql"))} + file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "202001100005.sql"))} + contents = []*s3.Object{file1, file2} + case "application-x/prod/migrations/tenants": + file1 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "201602160002.sql"))} + file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "202001100004.sql"))} + file3 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "202001100007.sql"))} + contents = []*s3.Object{file1, file2, file3} + case "application-x/prod/migrations/config-scripts": + file1 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "recreate-triggers.sql"))} + file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "cleanup.sql"))} + contents = []*s3.Object{file1, file2} + case "application-x/prod/migrations/tenants-scripts": + file1 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "recreate-triggers.sql"))} + file2 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "cleanup.sql"))} + file3 := &s3.Object{Key: aws.String(fmt.Sprintf("%v/%v", *input.Prefix, "run-reports.sql"))} + contents = []*s3.Object{file1, file2, file3} } callback(&s3.ListObjectsV2Output{ @@ -88,3 +110,33 @@ func TestS3GetSourceMigrations(t *testing.T) { assert.Contains(t, migrations[11].File, "migrations/tenants-scripts/run-reports.sql") } + +func TestS3GetSourceMigrationsBucketWithPrefix(t *testing.T) { + mock := &mockS3Client{} + + config := &config.Config{ + BaseLocation: "s3://your-bucket-migrator/application-x/prod/", + SingleMigrations: []string{"migrations/config", "migrations/ref"}, + TenantMigrations: []string{"migrations/tenants"}, + SingleScripts: []string{"migrations/config-scripts"}, + TenantScripts: []string{"migrations/tenants-scripts"}, + } + + loader := &s3Loader{baseLoader{context.TODO(), config}} + migrations := loader.doGetSourceMigrations(mock) + + assert.Len(t, migrations, 12) + + assert.Contains(t, migrations[0].File, "application-x/prod/migrations/config/201602160001.sql") + assert.Contains(t, migrations[1].File, "application-x/prod/migrations/config/201602160002.sql") + assert.Contains(t, migrations[2].File, "application-x/prod/migrations/tenants/201602160002.sql") + assert.Contains(t, migrations[3].File, "application-x/prod/migrations/ref/202001100003.sql") + assert.Contains(t, migrations[4].File, "application-x/prod/migrations/tenants/202001100004.sql") + assert.Contains(t, migrations[5].File, "application-x/prod/migrations/ref/202001100005.sql") + assert.Contains(t, migrations[6].File, "application-x/prod/migrations/tenants/202001100007.sql") + assert.Contains(t, migrations[7].File, "application-x/prod/migrations/config-scripts/cleanup.sql") + assert.Contains(t, migrations[8].File, "application-x/prod/migrations/config-scripts/recreate-triggers.sql") + assert.Contains(t, migrations[9].File, "application-x/prod/migrations/tenants-scripts/cleanup.sql") + assert.Contains(t, migrations[10].File, "application-x/prod/migrations/tenants-scripts/recreate-triggers.sql") + assert.Contains(t, migrations[11].File, "application-x/prod/migrations/tenants-scripts/run-reports.sql") +} diff --git a/test/migrator-test-envs.yaml b/test/migrator-test-envs.yaml index 98d8242..1707df8 100644 --- a/test/migrator-test-envs.yaml +++ b/test/migrator-test-envs.yaml @@ -1,5 +1,5 @@ # migrator configuration -baseLocation: ${TERM} +baseLocation: s3://bucket-name/application-x/${TERM}/${COMMIT_SHA} driver: ${PWD} dataSource: "lets_assume_password=${HOME}&and_something_else=${USER}¶m=value" # override only if you have own specific way of determining tenants