From 464ed1acb4b0b952d518ec481c60fa41d6d0b98c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 25 Apr 2024 10:05:03 +0200 Subject: [PATCH] feat: handle INSERT with RowAlias Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 2 +- go/vt/sqlparser/ast_copy_on_rewrite.go | 6 +- go/vt/sqlparser/ast_equals.go | 4 +- go/vt/sqlparser/ast_rewrite.go | 8 +- go/vt/sqlparser/ast_visit.go | 4 +- go/vt/sqlparser/cached_size.go | 4 +- .../planbuilder/testdata/dml_cases.json | 104 ++++++++++++++++++ go/vt/vtgate/semantics/analyzer_dml_test.go | 17 +++ go/vt/vtgate/semantics/analyzer_test.go | 10 ++ go/vt/vtgate/semantics/table_collector.go | 92 ++++++++++++++++ 11 files changed, 238 insertions(+), 15 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index cfef9923530..c503fb1fa8e 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -341,8 +341,8 @@ type ( Partitions Partitions Columns Columns Rows InsertRows - OnDup OnDup RowAlias *RowAlias + OnDup OnDup } // Ignore represents whether ignore was specified or not diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 252dc3c72a8..7c9f3757eac 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1634,8 +1634,8 @@ func CloneRefOfInsert(n *Insert) *Insert { out.Partitions = ClonePartitions(n.Partitions) out.Columns = CloneColumns(n.Columns) out.Rows = CloneInsertRows(n.Rows) - out.OnDup = CloneOnDup(n.OnDup) out.RowAlias = CloneRefOfRowAlias(n.RowAlias) + out.OnDup = CloneOnDup(n.OnDup) return &out } diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 899e3370acc..daaf781aae0 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -2816,17 +2816,17 @@ func (c *cow) copyOnRewriteRefOfInsert(n *Insert, parent SQLNode) (out SQLNode, _Partitions, changedPartitions := c.copyOnRewritePartitions(n.Partitions, n) _Columns, changedColumns := c.copyOnRewriteColumns(n.Columns, n) _Rows, changedRows := c.copyOnRewriteInsertRows(n.Rows, n) - _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n) _RowAlias, changedRowAlias := c.copyOnRewriteRefOfRowAlias(n.RowAlias, n) - if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedOnDup || changedRowAlias { + _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n) + if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedRowAlias || changedOnDup { res := *n res.Comments, _ = _Comments.(*ParsedComments) res.Table, _ = _Table.(*AliasedTableExpr) res.Partitions, _ = _Partitions.(Partitions) res.Columns, _ = _Columns.(Columns) res.Rows, _ = _Rows.(InsertRows) - res.OnDup, _ = _OnDup.(OnDup) res.RowAlias, _ = _RowAlias.(*RowAlias) + res.OnDup, _ = _OnDup.(OnDup) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index c4066218859..9875fddffb5 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2896,8 +2896,8 @@ func (cmp *Comparator) RefOfInsert(a, b *Insert) bool { cmp.Partitions(a.Partitions, b.Partitions) && cmp.Columns(a.Columns, b.Columns) && cmp.InsertRows(a.Rows, b.Rows) && - cmp.OnDup(a.OnDup, b.OnDup) && - cmp.RefOfRowAlias(a.RowAlias, b.RowAlias) + cmp.RefOfRowAlias(a.RowAlias, b.RowAlias) && + cmp.OnDup(a.OnDup, b.OnDup) } // RefOfInsertExpr does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 56cb5ffd251..e7ac74dc06b 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -3905,13 +3905,13 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer }) { return false } - if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) + if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) { + parent.(*Insert).RowAlias = newNode.(*RowAlias) }) { return false } - if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) { - parent.(*Insert).RowAlias = newNode.(*RowAlias) + if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index f89eb23be6b..f4829192cce 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -2003,10 +2003,10 @@ func VisitRefOfInsert(in *Insert, f Visit) error { if err := VisitInsertRows(in.Rows, f); err != nil { return err } - if err := VisitOnDup(in.OnDup, f); err != nil { + if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil { return err } - if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil { + if err := VisitOnDup(in.OnDup, f); err != nil { return err } return nil diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 361888727b2..9998fb86f66 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -1854,6 +1854,8 @@ func (cached *Insert) CachedSize(alloc bool) int64 { if cc, ok := cached.Rows.(cachedObject); ok { size += cc.CachedSize(true) } + // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias + size += cached.RowAlias.CachedSize(true) // field OnDup vitess.io/vitess/go/vt/sqlparser.OnDup { size += hack.RuntimeAllocSize(int64(cap(cached.OnDup)) * int64(8)) @@ -1861,8 +1863,6 @@ func (cached *Insert) CachedSize(alloc bool) int64 { size += elem.CachedSize(true) } } - // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias - size += cached.RowAlias.CachedSize(true) return size } func (cached *InsertExpr) CachedSize(alloc bool) int64 { diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index abd467fbfd0..3c1f202bd8d 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -6752,5 +6752,109 @@ "user.user" ] } + }, + { + "comment": "RowAlias in INSERT", + "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias with explicit columns in INSERT", + "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias in INSERT (no column list)", + "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias with explicit columns in INSERT (no column list)", + "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } } ] diff --git a/go/vt/vtgate/semantics/analyzer_dml_test.go b/go/vt/vtgate/semantics/analyzer_dml_test.go index c792b2301a0..a5885dda6e8 100644 --- a/go/vt/vtgate/semantics/analyzer_dml_test.go +++ b/go/vt/vtgate/semantics/analyzer_dml_test.go @@ -87,3 +87,20 @@ func TestUpdBindingExpr(t *testing.T) { func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr { return in.Exprs[idx] } + +func TestInsertBindingColName(t *testing.T) { + queries := []string{ + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update texcol = new.uid + new.name", + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", + "insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", + } + for _, query := range queries { + t.Run(query, func(t *testing.T) { + stmt, semTable := parseAndAnalyzeStrict(t, query, "d") + ins, _ := stmt.(*sqlparser.Insert) + ue := ins.OnDup[0] + ts := semTable.RecursiveDeps(ue.Expr) + assert.Equal(t, SingleTableSet(0), ts) + }) + } +} diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index deb84538740..31975475a61 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -1290,6 +1290,16 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, * return parse, semTable } +func parseAndAnalyzeStrict(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) { + t.Helper() + parse, err := sqlparser.NewTestParser().Parse(query) + require.NoError(t, err) + + semTable, err := AnalyzeStrict(parse, dbName, fakeSchemaInfo()) + require.NoError(t, err) + return parse, semTable +} + func TestSingleUnshardedKeyspace(t *testing.T) { tests := []struct { query string diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index e17a75044ba..1bc79a02566 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -19,6 +19,10 @@ package semantics import ( "fmt" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -64,6 +68,7 @@ func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { etc.withTables[cte.ID] = nil } } + } func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { @@ -104,12 +109,99 @@ func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sq etc.Tables = append(etc.Tables, tableInfo) } +func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error { + origTableInfo := tc.Tables[0] + + var colNames []string + var types []evalengine.Type + switch { + case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0: + // we have explicit column list on the row alias and the insert statement + if len(rowAlias.Columns) != len(ins.Columns) { + panic("column count mismatch") + } + origCols := origTableInfo.getColumns() + for1: + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + col := ins.Columns[idx] + for _, origCol := range origCols { + if col.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for1 + } + } + return vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col) + } + case len(rowAlias.Columns) > 0: + if !origTableInfo.authoritative() { + return vterrors.VT09015() + } + // TODO: we need to handle invisible columns here :sigh: + if len(rowAlias.Columns) != len(origTableInfo.getColumns()) { + panic("column count mismatch") + } + origCols := origTableInfo.getColumns() + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + types = append(types, origCols[idx].Type) + } + case len(ins.Columns) > 0: + origCols := origTableInfo.getColumns() + for2: + for _, column := range ins.Columns { + colNames = append(colNames, column.String()) + for _, origCol := range origCols { + if column.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for2 + } + } + types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + } + default: + if !origTableInfo.authoritative() { + return vterrors.VT09015() + } + for _, column := range origTableInfo.getColumns() { + colNames = append(colNames, column.Name) + types = append(types, column.Type) + } + } + deps := make([]TableSet, len(colNames)) + for i := range colNames { + deps[i] = SingleTableSet(0) + } + + derivedTable := &DerivedTable{ + tableName: rowAlias.TableName.String(), + ASTNode: &sqlparser.AliasedTableExpr{ + Expr: sqlparser.NewTableName(rowAlias.TableName.String()), + }, + columnNames: colNames, + tables: SingleTableSet(0), + recursive: deps, + isAuthoritative: true, + types: types, + } + + tc.Tables = append(tc.Tables, derivedTable) + current := tc.scoper.currentScope() + return current.addTable(derivedTable) +} + func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.AliasedTableExpr: return tc.visitAliasedTableExpr(node) case *sqlparser.Union: return tc.visitUnion(node) + case *sqlparser.RowAlias: + ins, ok := cursor.Parent().(*sqlparser.Insert) + if !ok { + return vterrors.VT13001("RowAlias is expected to hang off an Insert statement") + } + return tc.visitRowAlias(ins, node) default: return nil }