Skip to content

Commit

Permalink
fix: insert on duplicate update to add list argument in the bind vari…
Browse files Browse the repository at this point in the history
…ables map (#15961)

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
vitess-bot[bot] committed May 17, 2024
1 parent 31c5a7d commit e77759d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
21 changes: 21 additions & 0 deletions go/test/endtoend/vtgate/queries/dml/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ func TestSimpleInsertSelect(t *testing.T) {
utils.AssertMatches(t, mcmp.VtConn, `select num from num_vdx_tbl order by num`, `[[INT64(2)] [INT64(4)] [INT64(40)] [INT64(42)] [INT64(80)] [INT64(84)]]`)
}

// TestInsertOnDup test the insert on duplicate key update feature with argument and list argument.
func TestInsertOnDup(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into order_tbl(oid, region_id, cust_no) values (1,2,3),(3,4,5)")

for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (2,2,3),(4,4,5) on duplicate key update cust_no = if(values(cust_no) in (1, 2, 3), region_id, values(cust_no))`)
mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`)
mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (7,2,2) on duplicate key update cust_no = 10 + values(cust_no)`)
mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`)
})
}
}

func TestFailureInsertSelect(t *testing.T) {
if clusterInstance.HasPartialKeyspaces {
t.Skip("don't run on partial keyspaces")
Expand Down
9 changes: 9 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ func TestNormalize(t *testing.T) {
"bv2": sqltypes.Int64BindVariable(2),
"bv3": sqltypes.Int64BindVariable(3),
},
}, {
// list in on duplicate key update
in: "insert into t(a, b) values (1, 2) on duplicate key update b = if(values(b) in (1, 2), b, values(b))",
outstmt: "insert into t(a, b) values (:bv1 /* INT64 */, :bv2 /* INT64 */) on duplicate key update b = if(values(b) in ::bv3, b, values(b))",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(1),
"bv2": sqltypes.Int64BindVariable(2),
"bv3": sqltypes.TestBindVariable([]any{1, 2}),
},
}}
parser := NewTestParser()
for _, tc := range testcases {
Expand Down
19 changes: 13 additions & 6 deletions go/vt/vtgate/engine/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,20 @@ func (ins *Insert) getInsertShardedQueries(
index, _ := strconv.ParseInt(string(indexValue.Value), 0, 64)
if keyspaceIDs[index] != nil {
walkFunc := func(node sqlparser.SQLNode) (kontinue bool, err error) {
if arg, ok := node.(*sqlparser.Argument); ok {
bv, exists := bindVars[arg.Name]
if !exists {
return false, vterrors.VT03026(arg.Name)
}
shardBindVars[arg.Name] = bv
var arg string
switch argType := node.(type) {
case *sqlparser.Argument:
arg = argType.Name
case sqlparser.ListArg:
arg = string(argType)
default:
return true, nil
}
bv, exists := bindVars[arg]
if !exists {
return false, vterrors.VT03026(arg)
}
shardBindVars[arg] = bv
return true, nil
}
mids = append(mids, sqlparser.String(ins.Mid[index]))
Expand Down
20 changes: 16 additions & 4 deletions go/vt/vtgate/engine/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,22 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) {
{&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}},
},
sqlparser.OnDup{
&sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}},
},
&sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix1"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}},
&sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix2"), Expr: &sqlparser.FuncExpr{
Name: sqlparser.NewIdentifierCI("if"),
Exprs: sqlparser.Exprs{

Check failure on line 362 in go/vt/vtgate/engine/insert_test.go

View workflow job for this annotation

GitHub Actions / Unit Test (Race)

cannot use sqlparser.Exprs{…} (value of type sqlparser.Exprs) as sqlparser.SelectExprs value in struct literal

Check failure on line 362 in go/vt/vtgate/engine/insert_test.go

View workflow job for this annotation

GitHub Actions / Code Coverage

cannot use sqlparser.Exprs{…} (value of type sqlparser.Exprs) as sqlparser.SelectExprs value in struct literal

Check failure on line 362 in go/vt/vtgate/engine/insert_test.go

View workflow job for this annotation

GitHub Actions / Unit Test (mysql80)

cannot use sqlparser.Exprs{…} (value of type sqlparser.Exprs) as sqlparser.SelectExprs value in struct literal

Check failure on line 362 in go/vt/vtgate/engine/insert_test.go

View workflow job for this annotation

GitHub Actions / Static Code Checks Etc

cannot use sqlparser.Exprs{…} (value of type sqlparser.Exprs) as sqlparser.SelectExprs value in struct literal (typecheck)

Check failure on line 362 in go/vt/vtgate/engine/insert_test.go

View workflow job for this annotation

GitHub Actions / Code Coverage

cannot use sqlparser.Exprs{…} (value of type sqlparser.Exprs) as sqlparser.SelectExprs value in struct literal
sqlparser.NewComparisonExpr(sqlparser.InOp, &sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, sqlparser.ListArg("_id_1"), nil),
sqlparser.NewColName("col"),
&sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")},
},
}}},
)
vc := newDMLTestVCursor("-20", "20-")
vc.shardForKsid = []string{"20-", "-20", "20-"}

_, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
_, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{
"_id_1": sqltypes.TestBindVariable([]int{1, 2}),
}, false)
if err != nil {
t.Fatal(err)
}
Expand All @@ -371,7 +380,10 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) {
`ResolveDestinations sharded [value:"0"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`,
// Row 2 will go to -20, rows 1 & 3 will go to 20-
`ExecuteMultiShard ` +
`sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update suffix = :_id_0 /* INT64 */ {_id_0: type:INT64 value:"1"} ` +
`sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update ` +
`suffix1 = :_id_0 /* INT64 */, suffix2 = if(values(col) in ::_id_1, col, values(col)) ` +
`{_id_0: type:INT64 value:"1" ` +
`_id_1: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ` +
`true true`,
})

Expand Down

0 comments on commit e77759d

Please sign in to comment.