diff --git a/callbacks/delete.go b/callbacks/delete.go index 84f446a3f..bd04854ef 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -126,7 +126,19 @@ func Delete(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) - db.Statement.AddClauseIfNotExists(clause.Delete{}) + + deleteClause := clause.Delete{} + + HandleJoins( + db, + func(db *gorm.DB) { + deleteClause.Table = db.Statement.Table + }, + func(db *gorm.DB, tableAliasName string, idx int, relation *schema.Relationship) { + }, + ) + + db.Statement.AddClauseIfNotExists(deleteClause) if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) diff --git a/callbacks/join.go b/callbacks/join.go new file mode 100644 index 000000000..92897ed32 --- /dev/null +++ b/callbacks/join.go @@ -0,0 +1,154 @@ +package callbacks + +import ( + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" + "strings" +) + +func HandleJoins(db *gorm.DB, prejoinCallback func(db *gorm.DB), perFieldNameCallback func(db *gorm.DB, tableAliasName string, idx int, relation *schema.Relationship)) { + // inline joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v + } + + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { + prejoinCallback(db) + + specifiedRelationsName := make(map[string]interface{}) + for idx, join := range db.Statement.Joins { + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } + } + } + + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } + + perFieldNameCallback(db, tableAliasName, idx, relation) + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.AddClause(fromClause) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + +} diff --git a/callbacks/query.go b/callbacks/query.go index 9b2b17ea9..aa4774452 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -2,13 +2,11 @@ package callbacks import ( "fmt" - "reflect" - "strings" - "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" + "reflect" ) func Query(db *gorm.DB) { @@ -96,166 +94,34 @@ func BuildQuerySQL(db *gorm.DB) { } } - // inline joins - fromClause := clause.From{} - if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - fromClause = v - } - - if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { - if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) - for idx, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} - } - } - - specifiedRelationsName := make(map[string]interface{}) - for _, join := range db.Statement.Joins { - if db.Statement.Schema != nil { - var isRelations bool // is relations or raw sql - var relations []*schema.Relationship - relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] - if ok { - isRelations = true - relations = append(relations, relation) - } else { - // handle nested join like "Manager.Company" - nestedJoinNames := strings.Split(join.Name, ".") - if len(nestedJoinNames) > 1 { - isNestedJoin := true - gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) - currentRelations := db.Statement.Schema.Relationships.Relations - for _, relname := range nestedJoinNames { - // incomplete match, only treated as raw sql - if relation, ok = currentRelations[relname]; ok { - gussNestedRelations = append(gussNestedRelations, relation) - currentRelations = relation.FieldSchema.Relationships.Relations - } else { - isNestedJoin = false - break - } - } - - if isNestedJoin { - isRelations = true - relations = gussNestedRelations - } - } + HandleJoins( + db, + func(db *gorm.DB) { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } + } + }, + func(db *gorm.DB, tableAliasName string, idx int, relation *schema.Relationship) { + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: db.Statement.Joins[idx].Selects, Omits: db.Statement.Joins[idx].Omits, + } - if isRelations { - genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) - } - - columnStmt := gorm.Statement{ - Table: tableAliasName, DB: db, Schema: relation.FieldSchema, - Selects: join.Selects, Omits: join.Omits, - } - - selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) - for _, s := range relation.FieldSchema.DBNames { - if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: utils.NestedRelationName(tableAliasName, s), - }) - } - } - - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } - } - } - } - - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } - - if join.On != nil { - onStmt.AddClause(join.On) - } - - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) - - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) - } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) - } - } - } - } - - return clause.Join{ - Type: joinType, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - } - } - - parentTableName := clause.CurrentTable - for _, rel := range relations { - // joins table alias like "Manager, Company, Manager__Company" - nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) - if _, ok := specifiedRelationsName[nestedAlias]; !ok { - fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) - specifiedRelationsName[nestedAlias] = nil - } - - if parentTableName != clause.CurrentTable { - parentTableName = utils.NestedRelationName(parentTableName, rel.Name) - } else { - parentTableName = rel.Name - } - } - } else { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), }) } - } else { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) } - } - - db.Statement.AddClause(fromClause) - } else { - db.Statement.AddClauseIfNotExists(clause.From{}) - } + }, + ) db.Statement.AddClauseIfNotExists(clauseSelect) diff --git a/chainable_api.go b/chainable_api.go index 8953413d5..8a6aea343 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: -// var users []User -// db.Unscoped().Find(&users) -// // Retrieves all users, including deleted ones. +// +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/clause/delete.go b/clause/delete.go index fc462cd7f..3091d9da5 100644 --- a/clause/delete.go +++ b/clause/delete.go @@ -2,6 +2,7 @@ package clause type Delete struct { Modifier string + Table string } func (d Delete) Name() string { @@ -15,6 +16,10 @@ func (d Delete) Build(builder Builder) { builder.WriteByte(' ') builder.WriteString(d.Modifier) } + if d.Table != "" { + builder.WriteByte(' ') + builder.WriteQuoted(d.Table) + } } func (d Delete) MergeClause(clause *Clause) { diff --git a/go.sum b/go.sum index e3e29009d..72406eb61 100644 --- a/go.sum +++ b/go.sum @@ -4,3 +4,5 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg= +gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/tests/go.mod b/tests/go.mod index 350d17946..4b51b0841 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -11,7 +11,7 @@ require ( gorm.io/driver/postgres v1.5.9 gorm.io/driver/sqlite v1.5.6 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.10 + gorm.io/gorm v1.25.11 ) require (