Skip to content

Commit

Permalink
fix: remove the return value of TurnOffAutoMigrate (#170)
Browse files Browse the repository at this point in the history
Signed-off-by: tangyang9464 <[email protected]>
  • Loading branch information
tangyang9464 authored Jul 3, 2022
1 parent 1bd95e0 commit 3d2fa84
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func main() {
New an adapter will use ``AutoMigrate`` by default for create table, if you want to turn it off, please use API ``TurnOffAutoMigrate(db *gorm.DB) *gorm.DB``. See example:
```go
db, err := gorm.Open(mysql.Open("root:@tcp(127.0.0.1:3306)/casbin"), &gorm.Config{})
db = TurnOffAutoMigrate(db)
TurnOffAutoMigrate(db)
// a,_ := NewAdapterByDB(...)
// a,_ := NewAdapterByDBUseTableName(...)
a,_ := NewAdapterByDBWithCustomTable(...)
Expand Down
37 changes: 12 additions & 25 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package gormadapter

import (
"context"
"errors"
"fmt"
"runtime"
Expand All @@ -37,9 +36,8 @@ const (
defaultTableName = "casbin_rule"
)

type customTableKey struct{}

const disableMigrateKey = "disableMigrateKey"
const customTableKey = "customTableKey"

type CasbinRule struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Expand Down Expand Up @@ -252,31 +250,19 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
return NewAdapterByDBUseTableName(db, "", defaultTableName)
}

func TurnOffAutoMigrate(db *gorm.DB) *gorm.DB {
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, disableMigrateKey, false)

return db.WithContext(ctx)
func TurnOffAutoMigrate(db *gorm.DB) {
*db = *db.Set(disableMigrateKey, false)
}

func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, customTableKey{}, t)
*db = *db.Set(customTableKey, t)

curTableName := defaultTableName
if len(tableName) > 0 {
curTableName = tableName[0]
}

return NewAdapterByDBUseTableName(db.WithContext(ctx), "", curTableName)
return NewAdapterByDBUseTableName(db, "", curTableName)
}

func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
Expand Down Expand Up @@ -380,14 +366,14 @@ func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
}

func (a *Adapter) createTable() error {
disableMigrate := a.db.Statement.Context.Value(disableMigrateKey)
if disableMigrate != nil {
disableMigrate, ok := a.db.Get(disableMigrateKey)
if ok && disableMigrate != nil {
return nil
}

t := a.db.Statement.Context.Value(customTableKey{})
t, ok := a.db.Get(customTableKey)

if t != nil {
if ok && t != nil {
return a.db.AutoMigrate(t)
}

Expand All @@ -408,8 +394,9 @@ func (a *Adapter) createTable() error {
}

func (a *Adapter) dropTable() error {
t := a.db.Statement.Context.Value(customTableKey{})
if t == nil {
t, ok := a.db.Get(customTableKey)

if !ok || t == nil {
return a.db.Migrator().DropTable(a.getTableInstance())
}

Expand Down
2 changes: 1 addition & 1 deletion adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func initAdapterWithoutAutoMigrate(t *testing.T, db *gorm.DB) *Adapter {
}
}

db = TurnOffAutoMigrate(db)
TurnOffAutoMigrate(db)

type CustomCasbinRule struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Expand Down

0 comments on commit 3d2fa84

Please sign in to comment.