Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parse methods with central log instance #7141

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (db *DB) Association(column string) *Association {
association := &Association{DB: db}
table := db.Statement.Table

if err := db.Statement.Parse(db.Statement.Model); err == nil {
if err := db.Statement.ParseWithLogger(db.Statement.Model, db.Logger); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]

Expand Down
2 changes: 1 addition & 1 deletion callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (p *processor) Execute(db *DB) *DB {

// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if err := stmt.ParseWithLogger(stmt.Model, p.db.Logger); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
Expand Down
2 changes: 1 addition & 1 deletion callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.
return true
})

if err := tx.Statement.Parse(dest); err != nil {
if err := tx.Statement.ParseWithLogger(dest, db.Logger); err != nil {
tx.AddError(err)
return tx
}
Expand Down
2 changes: 1 addition & 1 deletion callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func BuildQuerySQL(db *gorm.DB) {
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
if err := stmt.ParseWithLogger(db.Statement.Dest, db.Logger); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))

for idx, dbName := range stmt.Schema.DBNames {
Expand Down
2 changes: 1 addition & 1 deletion callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
if err := updatingStmt.ParseWithLogger(stmt.Dest, stmt.Logger); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
Expand Down
8 changes: 4 additions & 4 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
}
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
if err := tx.Statement.ParseWithLogger(value, db.Logger); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
Expand Down Expand Up @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if tx.Statement.ParseWithLogger(tx.Statement.Model, db.Logger) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
Expand Down Expand Up @@ -554,7 +554,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if tx.Statement.ParseWithLogger(tx.Statement.Model, db.Logger) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
Expand All @@ -574,7 +574,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {

func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
if err := tx.Statement.ParseWithLogger(dest, db.Logger); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest
Expand Down
4 changes: 2 additions & 2 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,13 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
modelSchema, joinSchema *schema.Schema
)

err := stmt.Parse(model)
err := stmt.ParseWithLogger(model, db.Logger)
if err != nil {
return err
}
modelSchema = stmt.Schema

err = stmt.Parse(joinTable)
err = stmt.ParseWithLogger(joinTable, db.Logger)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions migrator/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error

if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
} else if err := stmt.ParseWithSpecialTableNameWithLogger(value, stmt.Table, m.DB.Logger); err != nil {
return err
}

Expand Down Expand Up @@ -348,7 +348,7 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
oldTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil {
if err := stmt.ParseWithLogger(oldName, m.DB.Logger); err == nil {
oldTable = m.CurrentTable(stmt)
} else {
return err
Expand All @@ -359,7 +359,7 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
newTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil {
if err := stmt.ParseWithLogger(newName, m.DB.Logger); err == nil {
newTable = m.CurrentTable(stmt)
} else {
return err
Expand Down Expand Up @@ -918,7 +918,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
}
beDependedOn := map[*schema.Schema]bool{}
// support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
if err := dep.ParseWithSpecialTableNameWithLogger(value, m.DB.Statement.Table, m.Config.DB.Logger); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
Expand Down
14 changes: 11 additions & 3 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,19 @@ type TablerWithNamer interface {

// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
return ParseWithSpecialTableNameWithLogger(dest, cacheStore, namer, "", logger.Default)
}

func ParseWithLogger(dest interface{}, cacheStore *sync.Map, namer Namer, log logger.Interface) (*Schema, error) {
return ParseWithSpecialTableNameWithLogger(dest, cacheStore, namer, "", log)
}

// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
return ParseWithSpecialTableNameWithLogger(dest, cacheStore, namer, specialTableName, logger.Default)
}

func ParseWithSpecialTableNameWithLogger(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string, log logger.Interface) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
Expand Down Expand Up @@ -316,7 +324,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
log.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
Expand All @@ -331,7 +339,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam

defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
log.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
Expand Down
12 changes: 10 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,19 @@ func (stmt *Statement) Build(clauses ...string) {
}

func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithSpecialTableName(value, "")
return stmt.ParseWithSpecialTableNameWithLogger(value, "", logger.Default)
}

func (stmt *Statement) ParseWithLogger(value interface{}, log logger.Interface) (err error) {
return stmt.ParseWithSpecialTableNameWithLogger(value, "", log)
}

func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
return stmt.ParseWithSpecialTableNameWithLogger(value, specialTableName, logger.Default)
}

func (stmt *Statement) ParseWithSpecialTableNameWithLogger(value interface{}, specialTableName string, log logger.Interface) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableNameWithLogger(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName, log); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
Expand Down
Loading