Skip to content

Commit

Permalink
#22 when migration table name defined with table schema migrator can'…
Browse files Browse the repository at this point in the history
…t find that table after creation (#23)

* #22 when migration table name defined with table schema migrator can't find that table after creation

* #22 when migration table name defined with table schema migrator can't find that table after creation
  • Loading branch information
perfectio authored May 16, 2021
1 parent f70abd2 commit b8b1a50
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ services:

postgres:
image: postgres:12-alpine
volumes:
- "./docker/volume/postgres/dump:/docker-entrypoint-initdb.d/"
env_file:
- .env

Expand Down
1 change: 1 addition & 0 deletions docker/volume/postgres/dump/init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE SCHEMA docker;
36 changes: 21 additions & 15 deletions migrator/db/postgresMigration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ import (
"time"
)

const defaultSchema = "public"
const DefaultSchema = "public"

type (
Migration struct {
connection *sql.DB
tableName string
directory string
connection *sql.DB
tableName string
tableSchema string
directory string
}
)

func New(connection *sql.DB, tableName, directory string) *Migration {
func New(connection *sql.DB, tableName, tableSchema, directory string) *Migration {
return &Migration{
connection: connection,
tableName: tableName,
directory: directory,
connection: connection,
tableName: tableName,
tableSchema: tableSchema,
directory: directory,
}
}

Expand Down Expand Up @@ -85,7 +87,7 @@ func (s *Migration) GetMigrationHistory(limit int) (db.MigrationEntityList, erro
FROM %s
ORDER BY apply_time DESC, version DESC
LIMIT $1`,
s.tableName,
s.getTableNameWithSchema(),
)
result db.MigrationEntityList
)
Expand Down Expand Up @@ -126,22 +128,22 @@ func (s *Migration) AddMigrationHistory(version string) error {
q := fmt.Sprintf(`
INSERT INTO %s (version, apply_time)
VALUES ($1, $2)`,
s.tableName,
s.getTableNameWithSchema(),
)
_, err := s.connection.Exec(q, version, now)

return s.internalConvertError(err, q)
}

func (s *Migration) RemoveMigrationHistory(version string) error {
q := fmt.Sprintf(`DELETE FROM %s WHERE (version) = ($1)`, s.tableName)
q := fmt.Sprintf(`DELETE FROM %s WHERE (version) = ($1)`, s.getTableNameWithSchema())
_, err := s.connection.Exec(q, version)

return err
}

func (s *Migration) createMigrationHistoryTable() error {
log.Printf(console.Yellow("Creating migration history table %s..."), s.tableName)
log.Printf(console.Yellow("Creating migration history table %s..."), s.getTableNameWithSchema())

q := fmt.Sprintf(
`
Expand All @@ -150,14 +152,14 @@ func (s *Migration) createMigrationHistoryTable() error {
apply_time integer
)
`,
s.tableName,
s.getTableNameWithSchema(),
)

if _, err := s.connection.Exec(q); err != nil {
return s.internalConvertError(err, q)
}
if err := s.AddMigrationHistory(db.BaseMigration); err != nil {
q2 := fmt.Sprintf(`DROP TABLE %s`, s.tableName)
q2 := fmt.Sprintf(`DROP TABLE %s`, s.getTableNameWithSchema())
_, _ = s.connection.Exec(q2)

return err
Expand All @@ -181,7 +183,7 @@ func (s *Migration) getTableScheme() (exists bool, err error) {
rows *sql.Rows
)

rows, err = s.connection.Query(q, s.tableName, defaultSchema)
rows, err = s.connection.Query(q, s.tableName, s.tableSchema)
if err != nil {
return false, s.internalConvertError(err, q)
}
Expand All @@ -203,3 +205,7 @@ func (s *Migration) getTableScheme() (exists bool, err error) {

return false, nil
}

func (s *Migration) getTableNameWithSchema() string {
return s.tableSchema + "." + s.tableName
}
12 changes: 11 additions & 1 deletion migrator/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,19 @@ func (s *Service) initPostgres() error {
return err
}

var tableName, tableSchema string
if strings.Contains(s.options.TableName, ".") {
parts := strings.Split(s.options.TableName, ".")
tableSchema = parts[0]
tableName = parts[1]
} else {
tableSchema = postgresMigration.DefaultSchema
tableName = s.options.TableName
}

s.db = connection
s.migration = db.NewMigration(
postgresMigration.New(connection, s.options.TableName, s.options.Directory),
postgresMigration.New(connection, tableName, tableSchema, s.options.Directory),
connection,
db.MigrationOptions{
MaxSqlOutputLength: s.options.MaxSqlOutputLength,
Expand Down
66 changes: 60 additions & 6 deletions migrator/migrator_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func TestMigrateService_Postgres_UpDown_Successfully(t *testing.T) {
m, err := createPostgresMigrator()
m, err := createPostgresMigrator("migration")
assert.NoError(t, err)

err = m.Down("all")
Expand All @@ -24,13 +24,64 @@ func TestMigrateService_Postgres_UpDown_Successfully(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 0, c)

err = m.db.QueryRow("select count(*) from migration").Scan(&c)
err = m.db.QueryRow("select count(*) from public.migration").Scan(&c)
assert.NoError(t, err)
assert.Equal(t, 3, c)

err = m.Down("all")
assert.NoError(t, err)
}

func TestMigrateService_Postgres_Redo_Successfully(t *testing.T) {
m, err := createPostgresMigrator()
m, err := createPostgresMigrator("migration")
assert.NoError(t, err)

err = m.Down("all")
assert.NoError(t, err)

err = m.Up("2")
assert.NoError(t, err)

err = m.Redo("1")
assert.NoError(t, err)

var c int
err = m.db.QueryRow("select count(*) from public.migration").Scan(&c)
assert.NoError(t, err)
assert.Equal(t, 3, c)

err = m.Down("all")
assert.NoError(t, err)
}

func TestMigrateService_Postgres_UpDown_WithSchema_Successfully(t *testing.T) {
m, err := createPostgresMigrator("docker.migration")
assert.NoError(t, err)

err = m.Down("all")
assert.NoError(t, err)

err = m.Up("2")
assert.NoError(t, err)

err = m.Up("1")
assert.Error(t, err)

var c int
err = m.db.QueryRow("select count(*) from test").Scan(&c)
assert.NoError(t, err)
assert.Equal(t, 0, c)

err = m.db.QueryRow("select count(*) from docker.migration").Scan(&c)
assert.NoError(t, err)
assert.Equal(t, 3, c)

err = m.Down("all")
assert.NoError(t, err)
}

func TestMigrateService_Postgres_Redo_WithSchema_Successfully(t *testing.T) {
m, err := createPostgresMigrator("docker.migration")
assert.NoError(t, err)

err = m.Down("all")
Expand All @@ -43,16 +94,19 @@ func TestMigrateService_Postgres_Redo_Successfully(t *testing.T) {
assert.NoError(t, err)

var c int
err = m.db.QueryRow("select count(*) from migration").Scan(&c)
err = m.db.QueryRow("select count(*) from docker.migration").Scan(&c)
assert.NoError(t, err)
assert.Equal(t, 3, c)

err = m.Down("all")
assert.NoError(t, err)
}

func createPostgresMigrator() (*Service, error) {
func createPostgresMigrator(migrationTableName string) (*Service, error) {
return New(Options{
DSN: os.Getenv("POSTGRES_DSN"),
Directory: os.Getenv("POSTGRES_MIGRATIONS_PATH"),
TableName: "migration",
TableName: migrationTableName,
Compact: false,
Interactive: false,
})
Expand Down

0 comments on commit b8b1a50

Please sign in to comment.