diff --git a/README.md b/README.md index ffc3b21..b3669df 100644 --- a/README.md +++ b/README.md @@ -210,9 +210,9 @@ The provider supports the following databases: you should use the following code: ```go stmts, err := gormschema.New("mysql") - .WithJoinTables( - gormschema.SetupJoinTableParams{Model: &models.Person{}, Field: "Addresses", JoinTable: &models.PersonAddress{}}, - ) + .WithJoinTables( func(db *gorm.DB) { + db.SetupJoinTable(&customjointable.Person{}, "Addresses", &customjointable.PersonAddress{}) + }) .Load(&models.Person{}, &models.Address{}) ``` diff --git a/gormschema/gorm.go b/gormschema/gorm.go index fa7c3c6..bcd9934 100644 --- a/gormschema/gorm.go +++ b/gormschema/gorm.go @@ -27,9 +27,9 @@ func New(dialect string, opts ...Option) *Loader { type ( // Loader is a Loader for gorm schema. Loader struct { - dialect string - config *gorm.Config - joinTableCallbacks []func(*gorm.DB) + dialect string + config *gorm.Config + beforeAutoMigrate func(*gorm.DB) } // Option configures the Loader. Option func(*Loader) @@ -86,8 +86,8 @@ func (l *Loader) Load(models ...any) (string, error) { db.Config.DisableForeignKeyConstraintWhenMigrating = true } - for _, setupJoinTable := range l.joinTableCallbacks { - setupJoinTable(db) + if l.beforeAutoMigrate != nil { + l.beforeAutoMigrate(db) } if err = db.AutoMigrate(models...); err != nil { @@ -167,18 +167,7 @@ func (m *migrator) CreateConstraints(models []any) error { return nil } -type SetupJoinTableParams struct { - Model any - Field string - JoinTable any -} - -func (l *Loader) WithJoinTables(params ...SetupJoinTableParams) *Loader { - for _, p := range params { - l.joinTableCallbacks = append(l.joinTableCallbacks, func(db *gorm.DB) { - db.SetupJoinTable(p.Model, p.Field, p.JoinTable) - }) - } - +func (l *Loader) WithJoinTables(cb func(*gorm.DB)) *Loader { + l.beforeAutoMigrate = cb return l }