-
-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(expression): support case-when expression
- Loading branch information
1 parent
deceebf
commit 400d51c
Showing
2 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package clause | ||
|
||
type ExprCaseCondition struct { | ||
When string | ||
Then string | ||
Vars []any | ||
} | ||
|
||
type ExprCaseElse struct { | ||
Then string | ||
Vars []any | ||
} | ||
|
||
type ExprCase struct { | ||
Cases []*ExprCaseCondition | ||
Else *ExprCaseElse | ||
} | ||
|
||
func (expr ExprCase) Name() string { | ||
return "CASE" | ||
} | ||
|
||
func (expr ExprCase) Build(builder Builder) { | ||
var vars []any | ||
for idx, condition := range expr.Cases { | ||
if idx > 0 { | ||
_ = builder.WriteByte(' ') | ||
} | ||
_, _ = builder.WriteString("WHEN ") | ||
_, _ = builder.WriteString(condition.When) | ||
_, _ = builder.WriteString(" THEN ") | ||
_, _ = builder.WriteString(condition.Then) | ||
if len(condition.Vars) > 0 { | ||
vars = append(vars, condition.Vars...) | ||
} | ||
} | ||
|
||
if expr.Else != nil { | ||
elseExpr := expr.Else | ||
_, _ = builder.WriteString(" ELSE ") | ||
_, _ = builder.WriteString(elseExpr.Then) | ||
if len(elseExpr.Vars) > 0 { | ||
vars = append(vars, elseExpr.Vars...) | ||
} | ||
} | ||
_, _ = builder.WriteString(" END") | ||
|
||
clauseExpr := Expr{SQL: "", Vars: vars} | ||
clauseExpr.Build(builder) | ||
} | ||
|
||
func (expr ExprCase) MergeClause(clause *Clause) { | ||
clause.Expression = expr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package clause_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"gorm.io/gorm" | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
func Test_ExprCase(t *testing.T) { | ||
type exampleUser struct { | ||
ID string | ||
Name string | ||
} | ||
|
||
inputUsers := []*exampleUser{ | ||
{ | ||
ID: "user-001", | ||
Name: "user-name-001", | ||
}, | ||
{ | ||
ID: "user-002", | ||
Name: "user-name-002", | ||
}, | ||
} | ||
|
||
userIDs := make([]string, len(inputUsers)) | ||
userNameCases := make([]*clause.ExprCaseCondition, len(inputUsers)) | ||
for idx, user := range inputUsers { | ||
userIDs[idx] = user.ID | ||
userNameCases[idx] = &clause.ExprCaseCondition{ | ||
When: "user_id=?", | ||
Then: "?", | ||
Vars: []any{ | ||
user.ID, | ||
user.Name, | ||
}, | ||
} | ||
} | ||
|
||
sqlQuery := db.ToSQL(func(db *gorm.DB) *gorm.DB { | ||
return db. | ||
Table("users"). | ||
Where("user_id IN (?)", userIDs). | ||
UpdateColumns(map[string]any{ | ||
"user_name": clause.ExprCase{ | ||
Cases: userNameCases, | ||
Else: &clause.ExprCaseElse{ | ||
Then: "user_name", | ||
Vars: nil, | ||
}, | ||
}, | ||
}) | ||
}) | ||
|
||
expectedSQLQuery := "UPDATE `users` SET `user_name`=CASE WHEN user_id=\"user-001\" THEN \"user-name-001\" WHEN user_id=\"user-002\" THEN \"user-name-002\" ELSE user_name END WHERE user_id IN (\"user-001\",\"user-002\")" | ||
if sqlQuery != expectedSQLQuery { | ||
t.Errorf("SQLQuery is mismatch actual: %v expected:%v\n", sqlQuery, expectedSQLQuery) | ||
} | ||
} |