diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index ce052b7b2ba..dfb5961d887 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -54,6 +54,27 @@ func TestSimpleInsertSelect(t *testing.T) { utils.AssertMatches(t, mcmp.VtConn, `select num from num_vdx_tbl order by num`, `[[INT64(2)] [INT64(4)] [INT64(40)] [INT64(42)] [INT64(80)] [INT64(84)]]`) } +// TestInsertOnDup test the insert on duplicate key update feature with argument and list argument. +func TestInsertOnDup(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into order_tbl(oid, region_id, cust_no) values (1,2,3),(3,4,5)") + + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (2,2,3),(4,4,5) on duplicate key update cust_no = if(values(cust_no) in (1, 2, 3), region_id, values(cust_no))`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (7,2,2) on duplicate key update cust_no = 10 + values(cust_no)`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + }) + } +} + func TestFailureInsertSelect(t *testing.T) { if clusterInstance.HasPartialKeyspaces { t.Skip("don't run on partial keyspaces") diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 18f2ad44a7f..de1fdc868ad 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -388,6 +388,15 @@ func TestNormalize(t *testing.T) { "bv2": sqltypes.Int64BindVariable(2), "bv3": sqltypes.Int64BindVariable(3), }, + }, { + // list in on duplicate key update + in: "insert into t(a, b) values (1, 2) on duplicate key update b = if(values(b) in (1, 2), b, values(b))", + outstmt: "insert into t(a, b) values (:bv1 /* INT64 */, :bv2 /* INT64 */) on duplicate key update b = if(values(b) in ::bv3, b, values(b))", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(2), + "bv3": sqltypes.TestBindVariable([]any{1, 2}), + }, }} parser := NewTestParser() for _, tc := range testcases { diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 332ccc92098..be0bb889083 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -265,13 +265,20 @@ func (ins *Insert) getInsertShardedQueries( index, _ := strconv.ParseInt(string(indexValue.Value), 0, 64) if keyspaceIDs[index] != nil { walkFunc := func(node sqlparser.SQLNode) (kontinue bool, err error) { - if arg, ok := node.(*sqlparser.Argument); ok { - bv, exists := bindVars[arg.Name] - if !exists { - return false, vterrors.VT03026(arg.Name) - } - shardBindVars[arg.Name] = bv + var arg string + switch argType := node.(type) { + case *sqlparser.Argument: + arg = argType.Name + case sqlparser.ListArg: + arg = string(argType) + default: + return true, nil } + bv, exists := bindVars[arg] + if !exists { + return false, vterrors.VT03026(arg) + } + shardBindVars[arg] = bv return true, nil } mids = append(mids, sqlparser.String(ins.Mid[index])) diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 762c68a83dc..af6eb4f51b2 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -356,13 +356,22 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { {&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, }, sqlparser.OnDup{ - &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, - }, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix1"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix2"), Expr: &sqlparser.FuncExpr{ + Name: sqlparser.NewIdentifierCI("if"), + Exprs: sqlparser.SelectExprs{ + sqlparser.NewAliasedExpr(sqlparser.NewComparisonExpr(sqlparser.InOp, &sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, sqlparser.ListArg("_id_1"), nil), ""), + sqlparser.NewAliasedExpr(sqlparser.NewColName("col"), ""), + sqlparser.NewAliasedExpr(&sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, ""), + }, + }}}, ) vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} - _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{ + "_id_1": sqltypes.TestBindVariable([]int{1, 2}), + }, false) if err != nil { t.Fatal(err) } @@ -371,7 +380,10 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { `ResolveDestinations sharded [value:"0"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`, // Row 2 will go to -20, rows 1 & 3 will go to 20- `ExecuteMultiShard ` + - `sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update suffix = :_id_0 /* INT64 */ {_id_0: type:INT64 value:"1"} ` + + `sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update ` + + `suffix1 = :_id_0 /* INT64 */, suffix2 = if(values(col) in ::_id_1, col, values(col)) ` + + `{_id_0: type:INT64 value:"1" ` + + `_id_1: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ` + `true true`, })