Skip to content

Commit

Permalink
gormschema: supports trigger (#50)
Browse files Browse the repository at this point in the history
* refactor: re-order code blocks: public first

* feat: allow defining Triggers() for a model

* fix: use Triggers return type as an array of Options

* tests: update test cases for trigger

* docs: update usage for trigger

* refactor: change creating trigger API

* chore: re-hash atlas.sum

* fix: remove unicode character

* fix: correct sqlite migration files

* chore: add comments

* docs: change trigger usage
  • Loading branch information
luantranminh authored Jul 11, 2024
1 parent cc90e50 commit 66b75de
Show file tree
Hide file tree
Showing 18 changed files with 487 additions and 80 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ env "gorm" {
> Note: Views are available for logged-in users, run `atlas login` if you haven't already. To learn more about logged-in features for Atlas, visit [Feature Availability](https://atlasgo.io/features#database-features).
To define a Go struct as a database `VIEW`, implement the `ViewDef` method as follow:

```go
// User is a regular gorm.Model stored in the "users" table.
type User struct {
Expand All @@ -158,7 +159,9 @@ func (WorkingAgedUsers) ViewDef(dialect string) []gormschema.ViewOption {
}
}
```

In order to pass a plain `CREATE VIEW` statement, use the `CreateStmt` as follows:

```go
type BotlTracker struct {
ID uint
Expand All @@ -176,19 +179,48 @@ func (BotlTracker) ViewDef(dialect string) []gormschema.ViewOption {
}
}
```

To include both VIEWs and TABLEs in the migration generation, pass all models to the `Load` function:

```go
stmts, err := gormschema.New("mysql").Load(
&models.User{}, // Table-based model.
&models.WorkingAgedUsers{}, // View-based model.
)
```

The view-based model works just like a regular models in GORM queries. However, make sure the view name is identical to the struct name, and in case they are differ, configure the name using the `TableName` method:

```go
func (WorkingAgedUsers) TableName() string {
return "working_aged_users_custom_name" // View name is different than pluralized struct name.
}
```

#### Trigger

> Note: Trigger feature is only available for logged-in users, run `atlas login` if you haven't already. To learn more about logged-in features for Atlas, visit [Feature Availability](https://atlasgo.io/features#database-features).
To attach triggers to a table, use the `Triggers` method as follows:

```go
type Pet struct {
gorm.Model
Name string
}

func (Pet) Triggers(dialect string) []gormschema.Trigger {
var stmt string
switch dialect {
case "mysql":
stmt = "CREATE TRIGGER pet_insert BEFORE INSERT ON pets FOR EACH ROW SET NEW.name = UPPER(NEW.name)"
}
return []gormschema.Trigger{
gormschema.NewTrigger(gormschema.CreateStmt(stmt)),
}
}
```

### Additional Configuration

To supply custom `gorm.Config{}` object to the provider use the [Go Program Mode](#as-go-file) with
Expand Down
183 changes: 116 additions & 67 deletions gormschema/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ import (
gormig "gorm.io/gorm/migrator"
)

// New returns a new Loader.
func New(dialect string, opts ...Option) *Loader {
l := &Loader{dialect: dialect, config: &gorm.Config{}}
for _, opt := range opts {
opt(l)
}
return l
}

type (
// Loader is a Loader for gorm schema.
Loader struct {
Expand All @@ -35,6 +26,33 @@ type (
}
// Option configures the Loader.
Option func(*Loader)
// ViewOption implemented by VIEW's related options
ViewOption interface {
isViewOption()
apply(*schemaBuilder)
}
// TriggerOption implemented by TRIGGER's related options
TriggerOption interface {
isTriggerOption()
apply(*schemaBuilder)
}
// Trigger defines a trigger.
Trigger struct {
opts []TriggerOption
}
// ViewDefiner defines a view.
ViewDefiner interface {
ViewDef(dialect string) []ViewOption
}
// schemaOption configures the schemaBuilder.
schemaOption func(*schemaBuilder)
schemaBuilder struct {
db *gorm.DB
createStmt string
// viewName is only used for the BuildStmt option.
// BuildStmt returns only a subquery; viewName helps to create a full CREATE VIEW statement.
viewName string
}
)

// WithConfig sets the gorm config.
Expand All @@ -44,6 +62,60 @@ func WithConfig(cfg *gorm.Config) Option {
}
}

// WithJoinTable sets up a join table for the given model and field.
// Deprecated: put the join tables alongside the models in the Load call.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
l.beforeAutoMigrate = append(l.beforeAutoMigrate, func(db *gorm.DB) error {
return db.SetupJoinTable(model, field, jointable)
})
}
}

// New returns a new Loader.
func New(dialect string, opts ...Option) *Loader {
l := &Loader{dialect: dialect, config: &gorm.Config{}}
for _, opt := range opts {
opt(l)
}
return l
}

// NewTrigger receives a list of TriggerOption to build a Trigger.
func NewTrigger(opts ...TriggerOption) Trigger {
return Trigger{opts: opts}
}

func (s schemaOption) apply(b *schemaBuilder) {
s(b)
}

func (schemaOption) isViewOption() {}
func (schemaOption) isTriggerOption() {}

// CreateStmt accepts raw SQL to create a view or trigger
func CreateStmt(stmt string) interface {
ViewOption
TriggerOption
} {
return schemaOption(func(b *schemaBuilder) {
b.createStmt = stmt
})
}

// BuildStmt accepts a function with gorm query builder to create a CREATE VIEW statement.
// With this option, the view's name will be the same as the model's table name
func BuildStmt(fn func(db *gorm.DB) *gorm.DB) ViewOption {
return schemaOption(func(b *schemaBuilder) {
vd := b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return fn(tx).
Unscoped(). // Skip gorm deleted_at filtering.
Find(nil) // Execute the query and convert it to SQL.
})
b.createStmt = fmt.Sprintf("CREATE VIEW %s AS %s", b.viewName, vd)
})
}

// Load loads the models and returns the DDL statements representing the schema.
func (l *Loader) Load(models ...any) (string, error) {
var (
Expand Down Expand Up @@ -125,6 +197,9 @@ func (l *Loader) Load(models ...any) (string, error) {
if err = cm.CreateViews(views); err != nil {
return "", err
}
if err = cm.CreateTriggers(models); err != nil {
return "", err
}
if !l.config.DisableForeignKeyConstraintWhenMigrating && l.dialect != "sqlite" {
if err = cm.CreateConstraints(tables); err != nil {
return "", err
Expand Down Expand Up @@ -242,75 +317,20 @@ func (m *migrator) CreateViews(views []ViewDefiner) error {
}); ok {
viewName = namer.TableName()
}
viewBuilder := &viewBuilder{
schemaBuilder := &schemaBuilder{
db: m.DB,
viewName: viewName,
}
for _, opt := range view.ViewDef(m.Dialector.Name()) {
opt(viewBuilder)
opt.apply(schemaBuilder)
}
if err := m.DB.Exec(viewBuilder.createStmt).Error; err != nil {
if err := m.DB.Exec(schemaBuilder.createStmt).Error; err != nil {
return err
}
}
return nil
}

// WithJoinTable sets up a join table for the given model and field.
// Deprecated: put the join tables alongside the models in the Load call.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
l.beforeAutoMigrate = append(l.beforeAutoMigrate, func(db *gorm.DB) error {
return db.SetupJoinTable(model, field, jointable)
})
}
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}

type (
// ViewOption configures a viewBuilder.
ViewOption func(*viewBuilder)
// ViewDefiner defines a view.
ViewDefiner interface {
ViewDef(dialect string) []ViewOption
}
viewBuilder struct {
db *gorm.DB
createStmt string
// viewName is only used for the BuildStmt option.
// BuildStmt returns only a subquery; viewName helps to create a full CREATE VIEW statement.
viewName string
}
)

// CreateStmt accepts raw SQL to create a CREATE VIEW statement.
func CreateStmt(stmt string) ViewOption {
return func(b *viewBuilder) {
b.createStmt = b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Exec(stmt)
})
}
}

// BuildStmt accepts a function with gorm query builder to create a CREATE VIEW statement.
// With this option, the view's name will be the same as the model's table name
func BuildStmt(fn func(db *gorm.DB) *gorm.DB) ViewOption {
return func(b *viewBuilder) {
vd := b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return fn(tx).
Unscoped(). // Skip gorm deleted_at filtering.
Find(nil) // Execute the query and convert it to SQL.
})
b.createStmt = fmt.Sprintf("CREATE VIEW %s AS %s", b.viewName, vd)
}
}

// orderModels places join tables at the end of the list of models (if any),
// which helps GORM resolve m2m relationships correctly.
func (m *migrator) orderModels(models ...any) ([]any, error) {
Expand Down Expand Up @@ -348,3 +368,32 @@ func (m *migrator) orderModels(models ...any) ([]any, error) {
}
return append(otherTables, joinTables...), nil
}

// CreateTriggers creates the triggers for the given models.
func (m *migrator) CreateTriggers(models []any) error {
for _, model := range models {
if md, ok := model.(interface {
Triggers(string) []Trigger
}); ok {
for _, trigger := range md.Triggers(m.Dialector.Name()) {
schemaBuilder := &schemaBuilder{
db: m.DB,
}
for _, opt := range trigger.opts {
opt.apply(schemaBuilder)
if err := m.DB.Exec(schemaBuilder.createStmt).Error; err != nil {
return err
}
}
}
}
}
return nil
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
41 changes: 36 additions & 5 deletions gormschema/gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ import (
func TestSQLiteConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlite")
sql, err := l.Load(models.WorkingAgedUsers{}, models.Pet{}, ckmodels.Event{}, ckmodels.Location{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
models.Pet{},
models.UserPetHistory{},
ckmodels.Event{},
ckmodels.Location{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_default")
resetSession()
l = gormschema.New("sqlite", gormschema.WithConfig(&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}))
sql, err = l.Load(models.Pet{}, models.User{})
sql, err = l.Load(models.UserPetHistory{}, models.Pet{}, models.User{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_no_fk")
resetSession()
Expand All @@ -32,7 +39,15 @@ func TestSQLiteConfig(t *testing.T) {
func TestPostgreSQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("postgres")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/postgresql_default")
resetSession()
Expand All @@ -48,7 +63,15 @@ func TestPostgreSQLConfig(t *testing.T) {
func TestMySQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("mysql")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_default")
resetSession()
Expand Down Expand Up @@ -80,7 +103,15 @@ func TestMySQLConfig(t *testing.T) {
func TestSQLServerConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlserver")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlserver_default")
resetSession()
Expand Down
Loading

0 comments on commit 66b75de

Please sign in to comment.