diff --git a/changelog/20.0/20.0.0/summary.md b/changelog/20.0/20.0.0/summary.md
index 9421018bc9c..8b560fe283f 100644
--- a/changelog/20.0/20.0.0/summary.md
+++ b/changelog/20.0/20.0.0/summary.md
@@ -22,6 +22,7 @@
- [Delete with Subquery Support](#delete-subquery)
- [Delete with Multi Target Support](#delete-multi-target)
- [User Defined Functions Support](#udf-support)
+ - [Insert Row Alias Support](#insert-row-alias-support)
- **[Query Timeout](#query-timeout)**
- **[Flag changes](#flag-changes)**
- [`pprof-http` default change](#pprof-http-default)
@@ -197,6 +198,16 @@ Without this flag, VTGate will not be aware that there might be aggregating user
More details about how to load UDFs is available in [MySQL Docs](https://dev.mysql.com/doc/extending-mysql/8.0/en/adding-loadable-function.html)
+#### Insert Row Alias Support
+
+Support is added to have row alias in Insert statement to be used with `on duplicate key update`.
+
+Example:
+- `insert into user(id, name, email) valies (100, 'Alice', 'alice@mail.com') as new on duplicate key update name = new.name, email = new.email`
+- `insert into user(id, name, email) valies (100, 'Alice', 'alice@mail.com') as new(m, n, p) on duplicate key update name = n, email = p`
+
+More details about how it works is available in [MySQL Docs](https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html)
+
### Query Timeout
On a query timeout, Vitess closed the connection using the `kill connection` statement. This leads to connection churn
which is not desirable in some cases. To avoid this, Vitess now uses the `kill query` statement to cancel the query.
diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go
index fdec64588c1..a247ca15aa4 100644
--- a/go/mysql/sqlerror/constants.go
+++ b/go/mysql/sqlerror/constants.go
@@ -235,6 +235,7 @@ const (
ERUnknownTimeZone = ErrorCode(1298)
ERInvalidCharacterString = ErrorCode(1300)
ERQueryInterrupted = ErrorCode(1317)
+ ERViewWrongList = ErrorCode(1353)
ERTruncatedWrongValueForField = ErrorCode(1366)
ERIllegalValueForType = ErrorCode(1367)
ERDataTooLong = ErrorCode(1406)
diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go
index bebd9e41ca7..935fd77a12f 100644
--- a/go/mysql/sqlerror/sql_error.go
+++ b/go/mysql/sqlerror/sql_error.go
@@ -216,6 +216,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{
vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns},
vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow},
vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState},
+ vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState},
vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState},
vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError},
vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState},
diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go
index ce052b7b2ba..771eb64ab02 100644
--- a/go/test/endtoend/vtgate/queries/dml/insert_test.go
+++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go
@@ -462,3 +462,27 @@ func TestMixedCases(t *testing.T) {
// final check count on the lookup vindex table.
utils.AssertMatches(t, mcmp.VtConn, "select count(*) from lkp_mixed_idx", "[[INT64(12)]]")
}
+
+// TestInsertAlias test the alias feature in insert statement.
+func TestInsertAlias(t *testing.T) {
+ utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")
+ utils.SkipIfBinaryIsBelowVersion(t, 20, "vttablet")
+
+ mcmp, closer := start(t)
+ defer closer()
+
+ // initial record
+ mcmp.Exec("insert into user_tbl(id, region_id, name) values (1, 1,'foo'),(2, 2,'bar'),(3, 3,'baz'),(4, 4,'buzz')")
+
+ qr := mcmp.Exec("insert into user_tbl(id, region_id, name) values (2, 2, 'foo') as new on duplicate key update name = new.name")
+ assert.EqualValues(t, 2, qr.RowsAffected)
+
+ // this validates the record.
+ mcmp.Exec("select id, region_id, name from user_tbl order by id")
+
+ qr = mcmp.Exec("insert into user_tbl(id, region_id, name) values (3, 3, 'foo') as new(m, n, p) on duplicate key update name = p")
+ assert.EqualValues(t, 2, qr.RowsAffected)
+
+ // this validates the record.
+ mcmp.Exec("select id, region_id, name from user_tbl order by id")
+}
diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go
index cfef9923530..c503fb1fa8e 100644
--- a/go/vt/sqlparser/ast.go
+++ b/go/vt/sqlparser/ast.go
@@ -341,8 +341,8 @@ type (
Partitions Partitions
Columns Columns
Rows InsertRows
- OnDup OnDup
RowAlias *RowAlias
+ OnDup OnDup
}
// Ignore represents whether ignore was specified or not
diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go
index 252dc3c72a8..7c9f3757eac 100644
--- a/go/vt/sqlparser/ast_clone.go
+++ b/go/vt/sqlparser/ast_clone.go
@@ -1634,8 +1634,8 @@ func CloneRefOfInsert(n *Insert) *Insert {
out.Partitions = ClonePartitions(n.Partitions)
out.Columns = CloneColumns(n.Columns)
out.Rows = CloneInsertRows(n.Rows)
- out.OnDup = CloneOnDup(n.OnDup)
out.RowAlias = CloneRefOfRowAlias(n.RowAlias)
+ out.OnDup = CloneOnDup(n.OnDup)
return &out
}
diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go
index 899e3370acc..daaf781aae0 100644
--- a/go/vt/sqlparser/ast_copy_on_rewrite.go
+++ b/go/vt/sqlparser/ast_copy_on_rewrite.go
@@ -2816,17 +2816,17 @@ func (c *cow) copyOnRewriteRefOfInsert(n *Insert, parent SQLNode) (out SQLNode,
_Partitions, changedPartitions := c.copyOnRewritePartitions(n.Partitions, n)
_Columns, changedColumns := c.copyOnRewriteColumns(n.Columns, n)
_Rows, changedRows := c.copyOnRewriteInsertRows(n.Rows, n)
- _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n)
_RowAlias, changedRowAlias := c.copyOnRewriteRefOfRowAlias(n.RowAlias, n)
- if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedOnDup || changedRowAlias {
+ _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n)
+ if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedRowAlias || changedOnDup {
res := *n
res.Comments, _ = _Comments.(*ParsedComments)
res.Table, _ = _Table.(*AliasedTableExpr)
res.Partitions, _ = _Partitions.(Partitions)
res.Columns, _ = _Columns.(Columns)
res.Rows, _ = _Rows.(InsertRows)
- res.OnDup, _ = _OnDup.(OnDup)
res.RowAlias, _ = _RowAlias.(*RowAlias)
+ res.OnDup, _ = _OnDup.(OnDup)
out = &res
if c.cloned != nil {
c.cloned(n, out)
diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go
index c4066218859..9875fddffb5 100644
--- a/go/vt/sqlparser/ast_equals.go
+++ b/go/vt/sqlparser/ast_equals.go
@@ -2896,8 +2896,8 @@ func (cmp *Comparator) RefOfInsert(a, b *Insert) bool {
cmp.Partitions(a.Partitions, b.Partitions) &&
cmp.Columns(a.Columns, b.Columns) &&
cmp.InsertRows(a.Rows, b.Rows) &&
- cmp.OnDup(a.OnDup, b.OnDup) &&
- cmp.RefOfRowAlias(a.RowAlias, b.RowAlias)
+ cmp.RefOfRowAlias(a.RowAlias, b.RowAlias) &&
+ cmp.OnDup(a.OnDup, b.OnDup)
}
// RefOfInsertExpr does deep equals between the two objects.
diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go
index 56cb5ffd251..e7ac74dc06b 100644
--- a/go/vt/sqlparser/ast_rewrite.go
+++ b/go/vt/sqlparser/ast_rewrite.go
@@ -3905,13 +3905,13 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer
}) {
return false
}
- if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) {
- parent.(*Insert).OnDup = newNode.(OnDup)
+ if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) {
+ parent.(*Insert).RowAlias = newNode.(*RowAlias)
}) {
return false
}
- if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) {
- parent.(*Insert).RowAlias = newNode.(*RowAlias)
+ if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) {
+ parent.(*Insert).OnDup = newNode.(OnDup)
}) {
return false
}
diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go
index f89eb23be6b..f4829192cce 100644
--- a/go/vt/sqlparser/ast_visit.go
+++ b/go/vt/sqlparser/ast_visit.go
@@ -2003,10 +2003,10 @@ func VisitRefOfInsert(in *Insert, f Visit) error {
if err := VisitInsertRows(in.Rows, f); err != nil {
return err
}
- if err := VisitOnDup(in.OnDup, f); err != nil {
+ if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil {
return err
}
- if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil {
+ if err := VisitOnDup(in.OnDup, f); err != nil {
return err
}
return nil
diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go
index 361888727b2..9998fb86f66 100644
--- a/go/vt/sqlparser/cached_size.go
+++ b/go/vt/sqlparser/cached_size.go
@@ -1854,6 +1854,8 @@ func (cached *Insert) CachedSize(alloc bool) int64 {
if cc, ok := cached.Rows.(cachedObject); ok {
size += cc.CachedSize(true)
}
+ // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias
+ size += cached.RowAlias.CachedSize(true)
// field OnDup vitess.io/vitess/go/vt/sqlparser.OnDup
{
size += hack.RuntimeAllocSize(int64(cap(cached.OnDup)) * int64(8))
@@ -1861,8 +1863,6 @@ func (cached *Insert) CachedSize(alloc bool) int64 {
size += elem.CachedSize(true)
}
}
- // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias
- size += cached.RowAlias.CachedSize(true)
return size
}
func (cached *InsertExpr) CachedSize(alloc bool) int64 {
diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go
index ffce4fc553d..d485c930b77 100644
--- a/go/vt/vterrors/code.go
+++ b/go/vt/vterrors/code.go
@@ -58,6 +58,7 @@ var (
VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces")
VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
+ VT03033 = errorWithState("VT03033", vtrpcpb.Code_INVALID_ARGUMENT, ViewWrongList, "In definition of view, derived table or common table expression, SELECT list and column names list have different column counts", "The table column list and derived column list have different column counts.")
VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.")
VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.")
@@ -146,6 +147,7 @@ var (
VT03030,
VT03031,
VT03032,
+ VT03033,
VT05001,
VT05002,
VT05003,
diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go
index 2b0ada0bc6d..8223405fc92 100644
--- a/go/vt/vterrors/state.go
+++ b/go/vt/vterrors/state.go
@@ -49,6 +49,7 @@ const (
WrongArguments
BadNullError
InvalidGroupFuncUse
+ ViewWrongList
// failed precondition
NoDB
diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go
index 22b3a38a990..18e22c00378 100644
--- a/go/vt/vtgate/engine/cached_size.go
+++ b/go/vt/vtgate/engine/cached_size.go
@@ -440,7 +440,7 @@ func (cached *Insert) CachedSize(alloc bool) int64 {
}
size := int64(0)
if alloc {
- size += int64(208)
+ size += int64(224)
}
// field InsertCommon vitess.io/vitess/go/vt/vtgate/engine.InsertCommon
size += cached.InsertCommon.CachedSize(false)
@@ -479,6 +479,8 @@ func (cached *Insert) CachedSize(alloc bool) int64 {
}
}
}
+ // field Alias string
+ size += hack.RuntimeAllocSize(int64(len(cached.Alias)))
return size
}
func (cached *InsertCommon) CachedSize(alloc bool) int64 {
diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go
index 332ccc92098..c23c85d132f 100644
--- a/go/vt/vtgate/engine/insert.go
+++ b/go/vt/vtgate/engine/insert.go
@@ -55,6 +55,9 @@ type Insert struct {
// Mid is the row values for the sharded insert plans.
Mid sqlparser.Values
+
+ // Alias represents the row alias with columns if specified in the query.
+ Alias string
}
// newQueryInsert creates an Insert with a query string.
@@ -287,7 +290,7 @@ func (ins *Insert) getInsertShardedQueries(
}
}
}
- rewritten := ins.Prefix + strings.Join(mids, ",") + sqlparser.String(ins.Suffix)
+ rewritten := ins.Prefix + strings.Join(mids, ",") + ins.Alias + sqlparser.String(ins.Suffix)
queries[i] = &querypb.BoundQuery{
Sql: rewritten,
BindVariables: shardBindVars,
@@ -363,6 +366,19 @@ func (ins *Insert) description() PrimitiveDescription {
other["VindexValues"] = valuesOffsets
}
+ // This is a check to ensure we send the correct query to the database.
+ // "ActualQuery" should not be part of the plan output, if it does, it means the query was not rewritten correctly.
+ if ins.Mid != nil {
+ var mids []string
+ for _, n := range ins.Mid {
+ mids = append(mids, sqlparser.String(n))
+ }
+ shardedQuery := ins.Prefix + strings.Join(mids, ", ") + ins.Alias + sqlparser.String(ins.Suffix)
+ if shardedQuery != ins.Query {
+ other["ActualQuery"] = shardedQuery
+ }
+ }
+
return PrimitiveDescription{
OperatorType: "Insert",
Keyspace: ins.Keyspace,
diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go
index 2a7f37a258f..a324592af8e 100644
--- a/go/vt/vtgate/planbuilder/operator_transformers.go
+++ b/go/vt/vtgate/planbuilder/operator_transformers.go
@@ -619,6 +619,9 @@ func buildInsertLogicalPlan(
// when unsharded query with autoincrement for that there is no input operator.
if eins.Opcode != engine.InsertUnsharded {
eins.Prefix, eins.Mid, eins.Suffix = generateInsertShardedQuery(ins.AST)
+ if ins.AST.RowAlias != nil {
+ eins.Alias = sqlparser.String(ins.AST.RowAlias)
+ }
}
eins.Query = generateQuery(stmt)
@@ -660,7 +663,7 @@ func generateInsertShardedQuery(ins *sqlparser.Insert) (prefix string, mids sqlp
prefixBuf := sqlparser.NewTrackedBuffer(dmlFormatter)
prefixBuf.Myprintf(prefixFormat,
ins.Comments, ins.Ignore.ToString(),
- ins.Table, ins.Columns)
+ ins.Table, ins.Columns, ins.RowAlias)
prefix = prefixBuf.String()
suffix = sqlparser.CopyOnRewrite(ins.OnDup, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json
index abd467fbfd0..3c1f202bd8d 100644
--- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json
+++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json
@@ -6752,5 +6752,109 @@
"user.user"
]
}
+ },
+ {
+ "comment": "RowAlias in INSERT",
+ "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
+ "plan": {
+ "QueryType": "INSERT",
+ "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
+ "Instructions": {
+ "OperatorType": "Insert",
+ "Variant": "Sharded",
+ "Keyspace": {
+ "Name": "user",
+ "Sharded": true
+ },
+ "TargetTabletType": "PRIMARY",
+ "InsertIgnore": true,
+ "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1",
+ "TableName": "authoritative",
+ "VindexValues": {
+ "user_index": "1, 4"
+ }
+ },
+ "TablesUsed": [
+ "user.authoritative"
+ ]
+ }
+ },
+ {
+ "comment": "RowAlias with explicit columns in INSERT",
+ "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
+ "plan": {
+ "QueryType": "INSERT",
+ "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
+ "Instructions": {
+ "OperatorType": "Insert",
+ "Variant": "Sharded",
+ "Keyspace": {
+ "Name": "user",
+ "Sharded": true
+ },
+ "TargetTabletType": "PRIMARY",
+ "InsertIgnore": true,
+ "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c",
+ "TableName": "authoritative",
+ "VindexValues": {
+ "user_index": "1, 4"
+ }
+ },
+ "TablesUsed": [
+ "user.authoritative"
+ ]
+ }
+ },
+ {
+ "comment": "RowAlias in INSERT (no column list)",
+ "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
+ "plan": {
+ "QueryType": "INSERT",
+ "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
+ "Instructions": {
+ "OperatorType": "Insert",
+ "Variant": "Sharded",
+ "Keyspace": {
+ "Name": "user",
+ "Sharded": true
+ },
+ "TargetTabletType": "PRIMARY",
+ "InsertIgnore": true,
+ "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1",
+ "TableName": "authoritative",
+ "VindexValues": {
+ "user_index": "1, 4"
+ }
+ },
+ "TablesUsed": [
+ "user.authoritative"
+ ]
+ }
+ },
+ {
+ "comment": "RowAlias with explicit columns in INSERT (no column list)",
+ "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
+ "plan": {
+ "QueryType": "INSERT",
+ "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
+ "Instructions": {
+ "OperatorType": "Insert",
+ "Variant": "Sharded",
+ "Keyspace": {
+ "Name": "user",
+ "Sharded": true
+ },
+ "TargetTabletType": "PRIMARY",
+ "InsertIgnore": true,
+ "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c",
+ "TableName": "authoritative",
+ "VindexValues": {
+ "user_index": "1, 4"
+ }
+ },
+ "TablesUsed": [
+ "user.authoritative"
+ ]
+ }
}
]
diff --git a/go/vt/vtgate/semantics/analyzer_dml_test.go b/go/vt/vtgate/semantics/analyzer_dml_test.go
index c792b2301a0..3e50f98f77a 100644
--- a/go/vt/vtgate/semantics/analyzer_dml_test.go
+++ b/go/vt/vtgate/semantics/analyzer_dml_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"vitess.io/vitess/go/vt/sqlparser"
)
@@ -87,3 +88,53 @@ func TestUpdBindingExpr(t *testing.T) {
func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr {
return in.Exprs[idx]
}
+
+func TestInsertBindingColName(t *testing.T) {
+ queries := []string{
+ "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update textcol = new.uid + new.name",
+ "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update textcol = x + y",
+ "insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update textcol = x + y",
+ "insert into t3(uid, name, invcol) values (1,'foo','bar') as new on duplicate key update textcol = new.invcol",
+ "insert into t3 values (1,'foo','bar') as new on duplicate key update textcol = new.uid+new.name+new.textcol",
+ "insert into t3 values (1,'foo','bar') as new on duplicate key update textcol = new.uid+new.name+new.textcol, uid = new.name",
+ }
+ for _, query := range queries {
+ t.Run(query, func(t *testing.T) {
+ stmt, semTable := parseAndAnalyzeStrict(t, query, "d")
+ ins, _ := stmt.(*sqlparser.Insert)
+ for _, ue := range ins.OnDup {
+ // check deps on the column
+ ts := semTable.RecursiveDeps(ue.Name)
+ assert.Equal(t, SingleTableSet(0), ts)
+ // check deps on the expression
+ ts = semTable.RecursiveDeps(ue.Expr)
+ assert.Equal(t, SingleTableSet(0), ts)
+ }
+ })
+ }
+}
+
+func TestInsertBindingColNameErrorCases(t *testing.T) {
+ tcases := []struct {
+ query string
+ expErr string
+ }{{
+ "insert into t2 values (1,'foo','bar') as new on duplicate key update textcol = new.unknowncol",
+ "column 'new.unknowncol' not found",
+ }, {
+ "insert into t3 values (1,'foo','bar', 'baz') as new on duplicate key update textcol = new.invcol",
+ "column 'new.invcol' not found",
+ }, {
+ "insert into t3(uid, name) values (1,'foo') as new(x, y, z) on duplicate key update textcol = x + y",
+ "VT03033: In definition of view, derived table or common table expression, SELECT list and column names list have different column counts",
+ }}
+ for _, tc := range tcases {
+ t.Run(tc.query, func(t *testing.T) {
+ parse, err := sqlparser.NewTestParser().Parse(tc.query)
+ require.NoError(t, err)
+
+ _, err = AnalyzeStrict(parse, "d", fakeSchemaInfo())
+ require.ErrorContains(t, err, tc.expErr)
+ })
+ }
+}
diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go
index deb84538740..38cbd80aa9d 100644
--- a/go/vt/vtgate/semantics/analyzer_test.go
+++ b/go/vt/vtgate/semantics/analyzer_test.go
@@ -1290,6 +1290,16 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *
return parse, semTable
}
+func parseAndAnalyzeStrict(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) {
+ t.Helper()
+ parse, err := sqlparser.NewTestParser().Parse(query)
+ require.NoError(t, err)
+
+ semTable, err := AnalyzeStrict(parse, dbName, fakeSchemaInfo())
+ require.NoError(t, err)
+ return parse, semTable
+}
+
func TestSingleUnshardedKeyspace(t *testing.T) {
tests := []struct {
query string
@@ -1393,6 +1403,7 @@ func fakeSchemaInfo() *FakeSI {
"t": tableT(),
"t1": tableT1(),
"t2": tableT2(),
+ "t3": tableT3(),
},
}
return si
@@ -1437,3 +1448,28 @@ func tableT2() *vindexes.Table {
Keyspace: ks3,
}
}
+
+func tableT3() *vindexes.Table {
+ return &vindexes.Table{
+ Name: sqlparser.NewIdentifierCS("t3"),
+ Columns: []vindexes.Column{{
+ Name: sqlparser.NewIdentifierCI("uid"),
+ Type: querypb.Type_INT64,
+ }, {
+ Name: sqlparser.NewIdentifierCI("name"),
+ Type: querypb.Type_VARCHAR,
+ CollationName: "utf8_bin",
+ }, {
+ Name: sqlparser.NewIdentifierCI("textcol"),
+ Type: querypb.Type_VARCHAR,
+ CollationName: "big5_bin",
+ }, {
+ Name: sqlparser.NewIdentifierCI("invcol"),
+ Type: querypb.Type_VARCHAR,
+ CollationName: "big5_bin",
+ Invisible: true,
+ }},
+ ColumnListAuthoritative: true,
+ Keyspace: ks3,
+ }
+}
diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go
index 0425d78ed93..aabbe9f0b22 100644
--- a/go/vt/vtgate/semantics/derived_table.go
+++ b/go/vt/vtgate/semantics/derived_table.go
@@ -116,7 +116,7 @@ func (dt *DerivedTable) dependencies(colName string, org originable) (dependenci
return createCertain(directDeps, recursiveDeps, qt), nil
}
- if !dt.hasStar() {
+ if dt.authoritative() {
return ¬hing{}, nil
}
@@ -154,7 +154,7 @@ func (dt *DerivedTable) GetVindexTable() *vindexes.Table {
return nil
}
-func (dt *DerivedTable) getColumns() []ColumnInfo {
+func (dt *DerivedTable) getColumns(bool) []ColumnInfo {
cols := make([]ColumnInfo, 0, len(dt.columnNames))
for _, col := range dt.columnNames {
cols = append(cols, ColumnInfo{
@@ -164,10 +164,6 @@ func (dt *DerivedTable) getColumns() []ColumnInfo {
return cols
}
-func (dt *DerivedTable) hasStar() bool {
- return dt.tables.NotEmpty()
-}
-
// GetTables implements the TableInfo interface
func (dt *DerivedTable) getTableSet(_ originable) TableSet {
return dt.tables
diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go
index 61abd9c3fa7..10466234798 100644
--- a/go/vt/vtgate/semantics/early_rewriter.go
+++ b/go/vt/vtgate/semantics/early_rewriter.go
@@ -1031,7 +1031,7 @@ func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, colum
case *sqlparser.AliasedTableExpr:
ts := b.tc.tableSetFor(tbl)
tblInfo := b.tc.Tables[ts.TableOffset()]
- for _, info := range tblInfo.getColumns() {
+ for _, info := range tblInfo.getColumns(false /* ignoreInvisibleCol */) {
if column.EqualString(info.Name) {
return []TableInfo{tblInfo}, nil
}
@@ -1188,10 +1188,7 @@ func (e *expanderState) processColumnsFor(tbl TableInfo) error {
outer:
// in this first loop we just find columns used in any JOIN USING used on this table
- for _, col := range tbl.getColumns() {
- if col.Invisible {
- continue
- }
+ for _, col := range tbl.getColumns(true /* ignoreInvisibleCol */) {
ts, found := usingCols[col.Name]
if found {
for i, ts := range ts.Constituents() {
@@ -1207,11 +1204,7 @@ outer:
}
// and this time around we are printing any columns not involved in any JOIN USING
- for _, col := range tbl.getColumns() {
- if col.Invisible {
- continue
- }
-
+ for _, col := range tbl.getColumns(true /* ignoreInvisibleCol */) {
if ts, found := usingCols[col.Name]; found && currTable.IsSolvedBy(ts) {
continue
}
diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go
index a8c3d699b59..4f1639d0897 100644
--- a/go/vt/vtgate/semantics/real_table.go
+++ b/go/vt/vtgate/semantics/real_table.go
@@ -41,7 +41,7 @@ var _ TableInfo = (*RealTable)(nil)
// dependencies implements the TableInfo interface
func (r *RealTable) dependencies(colName string, org originable) (dependencies, error) {
ts := org.tableSetFor(r.ASTNode)
- for _, info := range r.getColumns() {
+ for _, info := range r.getColumns(false /* ignoreInvisbleCol */) {
if strings.EqualFold(info.Name, colName) {
return createCertain(ts, ts, info.Type), nil
}
@@ -69,8 +69,40 @@ func (r *RealTable) IsInfSchema() bool {
}
// GetColumns implements the TableInfo interface
-func (r *RealTable) getColumns() []ColumnInfo {
- return vindexTableToColumnInfo(r.Table, r.collationEnv)
+func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo {
+ if r.Table == nil {
+ return nil
+ }
+ nameMap := map[string]any{}
+ cols := make([]ColumnInfo, 0, len(r.Table.Columns))
+ for _, col := range r.Table.Columns {
+ if col.Invisible && ignoreInvisbleCol {
+ continue
+ }
+ cols = append(cols, ColumnInfo{
+ Name: col.Name.String(),
+ Type: col.ToEvalengineType(r.collationEnv),
+ Invisible: col.Invisible,
+ })
+ nameMap[col.Name.String()] = nil
+ }
+ // If table is authoritative, we do not need ColumnVindexes to help in resolving the unqualified columns.
+ if r.Table.ColumnListAuthoritative {
+ return cols
+ }
+ for _, vindex := range r.Table.ColumnVindexes {
+ for _, column := range vindex.Columns {
+ name := column.String()
+ if _, exists := nameMap[name]; exists {
+ continue
+ }
+ cols = append(cols, ColumnInfo{
+ Name: name,
+ })
+ nameMap[name] = nil
+ }
+ }
+ return cols
}
// GetExpr implements the TableInfo interface
@@ -122,36 +154,3 @@ func (r *RealTable) authoritative() bool {
func (r *RealTable) matches(name sqlparser.TableName) bool {
return (name.Qualifier.IsEmpty() || name.Qualifier.String() == r.dbName) && r.tableName == name.Name.String()
}
-
-func vindexTableToColumnInfo(tbl *vindexes.Table, collationEnv *collations.Environment) []ColumnInfo {
- if tbl == nil {
- return nil
- }
- nameMap := map[string]any{}
- cols := make([]ColumnInfo, 0, len(tbl.Columns))
- for _, col := range tbl.Columns {
- cols = append(cols, ColumnInfo{
- Name: col.Name.String(),
- Type: col.ToEvalengineType(collationEnv),
- Invisible: col.Invisible,
- })
- nameMap[col.Name.String()] = nil
- }
- // If table is authoritative, we do not need ColumnVindexes to help in resolving the unqualified columns.
- if tbl.ColumnListAuthoritative {
- return cols
- }
- for _, vindex := range tbl.ColumnVindexes {
- for _, column := range vindex.Columns {
- name := column.String()
- if _, exists := nameMap[name]; exists {
- continue
- }
- cols = append(cols, ColumnInfo{
- Name: name,
- })
- nameMap[name] = nil
- }
- }
- return cols
-}
diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go
index 6c89b2bb999..66b0f99035a 100644
--- a/go/vt/vtgate/semantics/semantic_state.go
+++ b/go/vt/vtgate/semantics/semantic_state.go
@@ -58,7 +58,7 @@ type (
canShortCut() shortCut
// getColumns returns the known column information for this table
- getColumns() []ColumnInfo
+ getColumns(ignoreInvisibleCol bool) []ColumnInfo
dependencies(colName string, org originable) (dependencies, error)
getExprFor(s string) (sqlparser.Expr, error)
diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go
index e17a75044ba..ae107cc070c 100644
--- a/go/vt/vtgate/semantics/table_collector.go
+++ b/go/vt/vtgate/semantics/table_collector.go
@@ -19,6 +19,10 @@ package semantics
import (
"fmt"
+ "vitess.io/vitess/go/mysql/collations"
+ "vitess.io/vitess/go/sqltypes"
+ vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
+
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
@@ -64,6 +68,7 @@ func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) {
etc.withTables[cte.ID] = nil
}
}
+
}
func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) {
@@ -110,11 +115,27 @@ func (tc *tableCollector) up(cursor *sqlparser.Cursor) error {
return tc.visitAliasedTableExpr(node)
case *sqlparser.Union:
return tc.visitUnion(node)
+ case *sqlparser.RowAlias:
+ ins, ok := cursor.Parent().(*sqlparser.Insert)
+ if !ok {
+ return vterrors.VT13001("RowAlias is expected to hang off an Insert statement")
+ }
+ return tc.visitRowAlias(ins, node)
default:
return nil
}
}
+func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr) error {
+ switch t := node.Expr.(type) {
+ case *sqlparser.DerivedTable:
+ return tc.handleDerivedTable(node, t)
+ case sqlparser.TableName:
+ return tc.handleTableName(node, t)
+ }
+ return nil
+}
+
func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
firstSelect := sqlparser.GetFirstSelect(union)
expanded, selectExprs := getColumnNames(firstSelect.SelectExprs)
@@ -157,15 +178,126 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
return nil
}
-func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr) error {
- switch t := node.Expr.(type) {
- case *sqlparser.DerivedTable:
- return tc.handleDerivedTable(node, t)
+func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error {
+ origTableInfo := tc.Tables[0]
- case sqlparser.TableName:
- return tc.handleTableName(node, t)
+ colNames, types, err := tc.getColumnNamesAndTypes(ins, rowAlias, origTableInfo)
+ if err != nil {
+ return err
}
- return nil
+
+ derivedTable := buildDerivedTable(colNames, rowAlias, types)
+ tc.Tables = append(tc.Tables, derivedTable)
+ current := tc.scoper.currentScope()
+ return current.addTable(derivedTable)
+}
+
+func (tc *tableCollector) getColumnNamesAndTypes(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias, origTableInfo TableInfo) (colNames []string, types []evalengine.Type, err error) {
+ switch {
+ case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0:
+ return tc.handleExplicitColumns(ins, rowAlias, origTableInfo)
+ case len(rowAlias.Columns) > 0:
+ return tc.handleRowAliasColumns(origTableInfo, rowAlias)
+ case len(ins.Columns) > 0:
+ colNames, types = tc.handleInsertColumns(ins, origTableInfo)
+ return colNames, types, nil
+ default:
+ return tc.handleDefaultColumns(origTableInfo)
+ }
+}
+
+// handleDefaultColumns have no explicit column list on the insert statement and no column list on the row alias
+func (tc *tableCollector) handleDefaultColumns(origTableInfo TableInfo) ([]string, []evalengine.Type, error) {
+ if !origTableInfo.authoritative() {
+ return nil, nil, vterrors.VT09015()
+ }
+ var colNames []string
+ var types []evalengine.Type
+ for _, column := range origTableInfo.getColumns(true /* ignoreInvisibleCol */) {
+ colNames = append(colNames, column.Name)
+ types = append(types, column.Type)
+ }
+ return colNames, types, nil
+}
+
+// handleInsertColumns have explicit column list on the insert statement and no column list on the row alias
+func (tc *tableCollector) handleInsertColumns(ins *sqlparser.Insert, origTableInfo TableInfo) ([]string, []evalengine.Type) {
+ var colNames []string
+ var types []evalengine.Type
+ origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */)
+for2:
+ for _, column := range ins.Columns {
+ colNames = append(colNames, column.String())
+ for _, origCol := range origCols {
+ if column.EqualString(origCol.Name) {
+ types = append(types, origCol.Type)
+ continue for2
+ }
+ }
+ types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown))
+ }
+ return colNames, types
+}
+
+// handleRowAliasColumns have explicit column list on the row alias and no column list on the insert statement
+func (tc *tableCollector) handleRowAliasColumns(origTableInfo TableInfo, rowAlias *sqlparser.RowAlias) ([]string, []evalengine.Type, error) {
+ if !origTableInfo.authoritative() {
+ return nil, nil, vterrors.VT09015()
+ }
+ origCols := origTableInfo.getColumns(true /* ignoreInvisibleCol */)
+ if len(rowAlias.Columns) != len(origCols) {
+ return nil, nil, vterrors.VT03033()
+ }
+ var colNames []string
+ var types []evalengine.Type
+ for idx, column := range rowAlias.Columns {
+ colNames = append(colNames, column.String())
+ types = append(types, origCols[idx].Type)
+ }
+ return colNames, types, nil
+}
+
+// handleExplicitColumns have explicit column list on the row alias and the insert statement
+func (tc *tableCollector) handleExplicitColumns(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias, origTableInfo TableInfo) ([]string, []evalengine.Type, error) {
+ if len(rowAlias.Columns) != len(ins.Columns) {
+ return nil, nil, vterrors.VT03033()
+ }
+ var colNames []string
+ var types []evalengine.Type
+ origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */)
+for1:
+ for idx, column := range rowAlias.Columns {
+ colNames = append(colNames, column.String())
+ col := ins.Columns[idx]
+ for _, origCol := range origCols {
+ if col.EqualString(origCol.Name) {
+ types = append(types, origCol.Type)
+ continue for1
+ }
+ }
+ return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col)
+ }
+ return colNames, types, nil
+}
+
+func buildDerivedTable(colNames []string, rowAlias *sqlparser.RowAlias, types []evalengine.Type) *DerivedTable {
+ deps := make([]TableSet, len(colNames))
+ for i := range colNames {
+ deps[i] = SingleTableSet(0)
+ }
+
+ derivedTable := &DerivedTable{
+ tableName: rowAlias.TableName.String(),
+ ASTNode: &sqlparser.AliasedTableExpr{
+ Expr: sqlparser.NewTableName(rowAlias.TableName.String()),
+ },
+ columnNames: colNames,
+ tables: SingleTableSet(0),
+ recursive: deps,
+ isAuthoritative: true,
+ types: types,
+ }
+ return derivedTable
}
func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (err error) {
diff --git a/go/vt/vtgate/semantics/vindex_table.go b/go/vt/vtgate/semantics/vindex_table.go
index fba8f8ab9a0..b598c93f36a 100644
--- a/go/vt/vtgate/semantics/vindex_table.go
+++ b/go/vt/vtgate/semantics/vindex_table.go
@@ -76,8 +76,8 @@ func (v *VindexTable) canShortCut() shortCut {
}
// GetColumns implements the TableInfo interface
-func (v *VindexTable) getColumns() []ColumnInfo {
- return v.Table.getColumns()
+func (v *VindexTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo {
+ return v.Table.getColumns(ignoreInvisbleCol)
}
// IsInfSchema implements the TableInfo interface
diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go
index 81f81de3813..14519a7e938 100644
--- a/go/vt/vtgate/semantics/vtable.go
+++ b/go/vt/vtgate/semantics/vtable.go
@@ -104,7 +104,7 @@ func (v *vTableInfo) GetVindexTable() *vindexes.Table {
return nil
}
-func (v *vTableInfo) getColumns() []ColumnInfo {
+func (v *vTableInfo) getColumns(bool) []ColumnInfo {
cols := make([]ColumnInfo, 0, len(v.columnNames))
for _, col := range v.columnNames {
cols = append(cols, ColumnInfo{