From 464ed1acb4b0b952d518ec481c60fa41d6d0b98c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 25 Apr 2024 10:05:03 +0200 Subject: [PATCH 1/8] feat: handle INSERT with RowAlias Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 2 +- go/vt/sqlparser/ast_copy_on_rewrite.go | 6 +- go/vt/sqlparser/ast_equals.go | 4 +- go/vt/sqlparser/ast_rewrite.go | 8 +- go/vt/sqlparser/ast_visit.go | 4 +- go/vt/sqlparser/cached_size.go | 4 +- .../planbuilder/testdata/dml_cases.json | 104 ++++++++++++++++++ go/vt/vtgate/semantics/analyzer_dml_test.go | 17 +++ go/vt/vtgate/semantics/analyzer_test.go | 10 ++ go/vt/vtgate/semantics/table_collector.go | 92 ++++++++++++++++ 11 files changed, 238 insertions(+), 15 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index cfef9923530..c503fb1fa8e 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -341,8 +341,8 @@ type ( Partitions Partitions Columns Columns Rows InsertRows - OnDup OnDup RowAlias *RowAlias + OnDup OnDup } // Ignore represents whether ignore was specified or not diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 252dc3c72a8..7c9f3757eac 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1634,8 +1634,8 @@ func CloneRefOfInsert(n *Insert) *Insert { out.Partitions = ClonePartitions(n.Partitions) out.Columns = CloneColumns(n.Columns) out.Rows = CloneInsertRows(n.Rows) - out.OnDup = CloneOnDup(n.OnDup) out.RowAlias = CloneRefOfRowAlias(n.RowAlias) + out.OnDup = CloneOnDup(n.OnDup) return &out } diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 899e3370acc..daaf781aae0 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -2816,17 +2816,17 @@ func (c *cow) copyOnRewriteRefOfInsert(n *Insert, parent SQLNode) (out SQLNode, _Partitions, changedPartitions := c.copyOnRewritePartitions(n.Partitions, n) _Columns, changedColumns := c.copyOnRewriteColumns(n.Columns, n) _Rows, changedRows := c.copyOnRewriteInsertRows(n.Rows, n) - _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n) _RowAlias, changedRowAlias := c.copyOnRewriteRefOfRowAlias(n.RowAlias, n) - if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedOnDup || changedRowAlias { + _OnDup, changedOnDup := c.copyOnRewriteOnDup(n.OnDup, n) + if changedComments || changedTable || changedPartitions || changedColumns || changedRows || changedRowAlias || changedOnDup { res := *n res.Comments, _ = _Comments.(*ParsedComments) res.Table, _ = _Table.(*AliasedTableExpr) res.Partitions, _ = _Partitions.(Partitions) res.Columns, _ = _Columns.(Columns) res.Rows, _ = _Rows.(InsertRows) - res.OnDup, _ = _OnDup.(OnDup) res.RowAlias, _ = _RowAlias.(*RowAlias) + res.OnDup, _ = _OnDup.(OnDup) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index c4066218859..9875fddffb5 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2896,8 +2896,8 @@ func (cmp *Comparator) RefOfInsert(a, b *Insert) bool { cmp.Partitions(a.Partitions, b.Partitions) && cmp.Columns(a.Columns, b.Columns) && cmp.InsertRows(a.Rows, b.Rows) && - cmp.OnDup(a.OnDup, b.OnDup) && - cmp.RefOfRowAlias(a.RowAlias, b.RowAlias) + cmp.RefOfRowAlias(a.RowAlias, b.RowAlias) && + cmp.OnDup(a.OnDup, b.OnDup) } // RefOfInsertExpr does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 56cb5ffd251..e7ac74dc06b 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -3905,13 +3905,13 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer }) { return false } - if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) + if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) { + parent.(*Insert).RowAlias = newNode.(*RowAlias) }) { return false } - if !a.rewriteRefOfRowAlias(node, node.RowAlias, func(newNode, parent SQLNode) { - parent.(*Insert).RowAlias = newNode.(*RowAlias) + if !a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index f89eb23be6b..f4829192cce 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -2003,10 +2003,10 @@ func VisitRefOfInsert(in *Insert, f Visit) error { if err := VisitInsertRows(in.Rows, f); err != nil { return err } - if err := VisitOnDup(in.OnDup, f); err != nil { + if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil { return err } - if err := VisitRefOfRowAlias(in.RowAlias, f); err != nil { + if err := VisitOnDup(in.OnDup, f); err != nil { return err } return nil diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 361888727b2..9998fb86f66 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -1854,6 +1854,8 @@ func (cached *Insert) CachedSize(alloc bool) int64 { if cc, ok := cached.Rows.(cachedObject); ok { size += cc.CachedSize(true) } + // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias + size += cached.RowAlias.CachedSize(true) // field OnDup vitess.io/vitess/go/vt/sqlparser.OnDup { size += hack.RuntimeAllocSize(int64(cap(cached.OnDup)) * int64(8)) @@ -1861,8 +1863,6 @@ func (cached *Insert) CachedSize(alloc bool) int64 { size += elem.CachedSize(true) } } - // field RowAlias *vitess.io/vitess/go/vt/sqlparser.RowAlias - size += cached.RowAlias.CachedSize(true) return size } func (cached *InsertExpr) CachedSize(alloc bool) int64 { diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index abd467fbfd0..3c1f202bd8d 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -6752,5 +6752,109 @@ "user.user" ] } + }, + { + "comment": "RowAlias in INSERT", + "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias with explicit columns in INSERT", + "query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias in INSERT (no column list)", + "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } + }, + { + "comment": "RowAlias with explicit columns in INSERT (no column list)", + "query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "plan": { + "QueryType": "INSERT", + "Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c", + "Instructions": { + "OperatorType": "Insert", + "Variant": "Sharded", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "InsertIgnore": true, + "Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c", + "TableName": "authoritative", + "VindexValues": { + "user_index": "1, 4" + } + }, + "TablesUsed": [ + "user.authoritative" + ] + } } ] diff --git a/go/vt/vtgate/semantics/analyzer_dml_test.go b/go/vt/vtgate/semantics/analyzer_dml_test.go index c792b2301a0..a5885dda6e8 100644 --- a/go/vt/vtgate/semantics/analyzer_dml_test.go +++ b/go/vt/vtgate/semantics/analyzer_dml_test.go @@ -87,3 +87,20 @@ func TestUpdBindingExpr(t *testing.T) { func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr { return in.Exprs[idx] } + +func TestInsertBindingColName(t *testing.T) { + queries := []string{ + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update texcol = new.uid + new.name", + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", + "insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", + } + for _, query := range queries { + t.Run(query, func(t *testing.T) { + stmt, semTable := parseAndAnalyzeStrict(t, query, "d") + ins, _ := stmt.(*sqlparser.Insert) + ue := ins.OnDup[0] + ts := semTable.RecursiveDeps(ue.Expr) + assert.Equal(t, SingleTableSet(0), ts) + }) + } +} diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index deb84538740..31975475a61 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -1290,6 +1290,16 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, * return parse, semTable } +func parseAndAnalyzeStrict(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) { + t.Helper() + parse, err := sqlparser.NewTestParser().Parse(query) + require.NoError(t, err) + + semTable, err := AnalyzeStrict(parse, dbName, fakeSchemaInfo()) + require.NoError(t, err) + return parse, semTable +} + func TestSingleUnshardedKeyspace(t *testing.T) { tests := []struct { query string diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index e17a75044ba..1bc79a02566 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -19,6 +19,10 @@ package semantics import ( "fmt" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -64,6 +68,7 @@ func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { etc.withTables[cte.ID] = nil } } + } func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { @@ -104,12 +109,99 @@ func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sq etc.Tables = append(etc.Tables, tableInfo) } +func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error { + origTableInfo := tc.Tables[0] + + var colNames []string + var types []evalengine.Type + switch { + case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0: + // we have explicit column list on the row alias and the insert statement + if len(rowAlias.Columns) != len(ins.Columns) { + panic("column count mismatch") + } + origCols := origTableInfo.getColumns() + for1: + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + col := ins.Columns[idx] + for _, origCol := range origCols { + if col.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for1 + } + } + return vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col) + } + case len(rowAlias.Columns) > 0: + if !origTableInfo.authoritative() { + return vterrors.VT09015() + } + // TODO: we need to handle invisible columns here :sigh: + if len(rowAlias.Columns) != len(origTableInfo.getColumns()) { + panic("column count mismatch") + } + origCols := origTableInfo.getColumns() + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + types = append(types, origCols[idx].Type) + } + case len(ins.Columns) > 0: + origCols := origTableInfo.getColumns() + for2: + for _, column := range ins.Columns { + colNames = append(colNames, column.String()) + for _, origCol := range origCols { + if column.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for2 + } + } + types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + } + default: + if !origTableInfo.authoritative() { + return vterrors.VT09015() + } + for _, column := range origTableInfo.getColumns() { + colNames = append(colNames, column.Name) + types = append(types, column.Type) + } + } + deps := make([]TableSet, len(colNames)) + for i := range colNames { + deps[i] = SingleTableSet(0) + } + + derivedTable := &DerivedTable{ + tableName: rowAlias.TableName.String(), + ASTNode: &sqlparser.AliasedTableExpr{ + Expr: sqlparser.NewTableName(rowAlias.TableName.String()), + }, + columnNames: colNames, + tables: SingleTableSet(0), + recursive: deps, + isAuthoritative: true, + types: types, + } + + tc.Tables = append(tc.Tables, derivedTable) + current := tc.scoper.currentScope() + return current.addTable(derivedTable) +} + func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.AliasedTableExpr: return tc.visitAliasedTableExpr(node) case *sqlparser.Union: return tc.visitUnion(node) + case *sqlparser.RowAlias: + ins, ok := cursor.Parent().(*sqlparser.Insert) + if !ok { + return vterrors.VT13001("RowAlias is expected to hang off an Insert statement") + } + return tc.visitRowAlias(ins, node) default: return nil } From c0bc78ab02cd667b2f1c8edc3f967c997baf902a Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 25 Apr 2024 19:13:19 +0530 Subject: [PATCH 2/8] handle invisible columns for insert derived table column list Signed-off-by: Harshit Gangal --- go/vt/vtgate/semantics/analyzer_dml_test.go | 46 +++++++++++-- go/vt/vtgate/semantics/analyzer_test.go | 26 ++++++++ go/vt/vtgate/semantics/derived_table.go | 8 +-- go/vt/vtgate/semantics/early_rewriter.go | 13 +--- go/vt/vtgate/semantics/real_table.go | 71 ++++++++++----------- go/vt/vtgate/semantics/semantic_state.go | 2 +- go/vt/vtgate/semantics/table_collector.go | 11 ++-- go/vt/vtgate/semantics/vindex_table.go | 4 +- go/vt/vtgate/semantics/vtable.go | 2 +- 9 files changed, 115 insertions(+), 68 deletions(-) diff --git a/go/vt/vtgate/semantics/analyzer_dml_test.go b/go/vt/vtgate/semantics/analyzer_dml_test.go index a5885dda6e8..147594d1e80 100644 --- a/go/vt/vtgate/semantics/analyzer_dml_test.go +++ b/go/vt/vtgate/semantics/analyzer_dml_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/sqlparser" ) @@ -90,17 +91,50 @@ func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr { func TestInsertBindingColName(t *testing.T) { queries := []string{ - "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update texcol = new.uid + new.name", - "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", - "insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y", + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update textcol = new.uid + new.name", + "insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update textcol = x + y", + "insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update textcol = x + y", + "insert into t3(uid, name, invcol) values (1,'foo','bar') as new on duplicate key update textcol = new.invcol", + "insert into t3 values (1,'foo','bar') as new on duplicate key update textcol = new.uid+new.name+new.textcol", + "insert into t3 values (1,'foo','bar') as new on duplicate key update textcol = new.uid+new.name+new.textcol, uid = new.name", } for _, query := range queries { t.Run(query, func(t *testing.T) { stmt, semTable := parseAndAnalyzeStrict(t, query, "d") ins, _ := stmt.(*sqlparser.Insert) - ue := ins.OnDup[0] - ts := semTable.RecursiveDeps(ue.Expr) - assert.Equal(t, SingleTableSet(0), ts) + for _, ue := range ins.OnDup { + // check deps on the column + ts := semTable.RecursiveDeps(ue.Name) + assert.Equal(t, SingleTableSet(0), ts) + // check deps on the expression + ts = semTable.RecursiveDeps(ue.Expr) + assert.Equal(t, SingleTableSet(0), ts) + } + }) + } +} + +func TestInsertBindingColNameErrorCases(t *testing.T) { + tcases := []struct { + query string + expErr string + }{{ + "insert into t2 values (1,'foo','bar') as new on duplicate key update textcol = new.unknowncol", + "column 'new.unknowncol' not found", + }, { + "insert into t3 values (1,'foo','bar', 'baz') as new on duplicate key update textcol = new.invcol", + "column 'new.invcol' not found", + }, { + "insert into t3(uid, name) values (1,'foo') as new(x, y, z) on duplicate key update textcol = x + y", + "column 'new.invcol' not found", + }} + for _, tc := range tcases { + t.Run(tc.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(tc.query) + require.NoError(t, err) + + _, err = AnalyzeStrict(parse, "d", fakeSchemaInfo()) + require.ErrorContains(t, err, tc.expErr) }) } } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 31975475a61..38cbd80aa9d 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -1403,6 +1403,7 @@ func fakeSchemaInfo() *FakeSI { "t": tableT(), "t1": tableT1(), "t2": tableT2(), + "t3": tableT3(), }, } return si @@ -1447,3 +1448,28 @@ func tableT2() *vindexes.Table { Keyspace: ks3, } } + +func tableT3() *vindexes.Table { + return &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t3"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("uid"), + Type: querypb.Type_INT64, + }, { + Name: sqlparser.NewIdentifierCI("name"), + Type: querypb.Type_VARCHAR, + CollationName: "utf8_bin", + }, { + Name: sqlparser.NewIdentifierCI("textcol"), + Type: querypb.Type_VARCHAR, + CollationName: "big5_bin", + }, { + Name: sqlparser.NewIdentifierCI("invcol"), + Type: querypb.Type_VARCHAR, + CollationName: "big5_bin", + Invisible: true, + }}, + ColumnListAuthoritative: true, + Keyspace: ks3, + } +} diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index 0425d78ed93..aabbe9f0b22 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -116,7 +116,7 @@ func (dt *DerivedTable) dependencies(colName string, org originable) (dependenci return createCertain(directDeps, recursiveDeps, qt), nil } - if !dt.hasStar() { + if dt.authoritative() { return ¬hing{}, nil } @@ -154,7 +154,7 @@ func (dt *DerivedTable) GetVindexTable() *vindexes.Table { return nil } -func (dt *DerivedTable) getColumns() []ColumnInfo { +func (dt *DerivedTable) getColumns(bool) []ColumnInfo { cols := make([]ColumnInfo, 0, len(dt.columnNames)) for _, col := range dt.columnNames { cols = append(cols, ColumnInfo{ @@ -164,10 +164,6 @@ func (dt *DerivedTable) getColumns() []ColumnInfo { return cols } -func (dt *DerivedTable) hasStar() bool { - return dt.tables.NotEmpty() -} - // GetTables implements the TableInfo interface func (dt *DerivedTable) getTableSet(_ originable) TableSet { return dt.tables diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 61abd9c3fa7..10466234798 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -1031,7 +1031,7 @@ func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, colum case *sqlparser.AliasedTableExpr: ts := b.tc.tableSetFor(tbl) tblInfo := b.tc.Tables[ts.TableOffset()] - for _, info := range tblInfo.getColumns() { + for _, info := range tblInfo.getColumns(false /* ignoreInvisibleCol */) { if column.EqualString(info.Name) { return []TableInfo{tblInfo}, nil } @@ -1188,10 +1188,7 @@ func (e *expanderState) processColumnsFor(tbl TableInfo) error { outer: // in this first loop we just find columns used in any JOIN USING used on this table - for _, col := range tbl.getColumns() { - if col.Invisible { - continue - } + for _, col := range tbl.getColumns(true /* ignoreInvisibleCol */) { ts, found := usingCols[col.Name] if found { for i, ts := range ts.Constituents() { @@ -1207,11 +1204,7 @@ outer: } // and this time around we are printing any columns not involved in any JOIN USING - for _, col := range tbl.getColumns() { - if col.Invisible { - continue - } - + for _, col := range tbl.getColumns(true /* ignoreInvisibleCol */) { if ts, found := usingCols[col.Name]; found && currTable.IsSolvedBy(ts) { continue } diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index a8c3d699b59..4f1639d0897 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -41,7 +41,7 @@ var _ TableInfo = (*RealTable)(nil) // dependencies implements the TableInfo interface func (r *RealTable) dependencies(colName string, org originable) (dependencies, error) { ts := org.tableSetFor(r.ASTNode) - for _, info := range r.getColumns() { + for _, info := range r.getColumns(false /* ignoreInvisbleCol */) { if strings.EqualFold(info.Name, colName) { return createCertain(ts, ts, info.Type), nil } @@ -69,8 +69,40 @@ func (r *RealTable) IsInfSchema() bool { } // GetColumns implements the TableInfo interface -func (r *RealTable) getColumns() []ColumnInfo { - return vindexTableToColumnInfo(r.Table, r.collationEnv) +func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { + if r.Table == nil { + return nil + } + nameMap := map[string]any{} + cols := make([]ColumnInfo, 0, len(r.Table.Columns)) + for _, col := range r.Table.Columns { + if col.Invisible && ignoreInvisbleCol { + continue + } + cols = append(cols, ColumnInfo{ + Name: col.Name.String(), + Type: col.ToEvalengineType(r.collationEnv), + Invisible: col.Invisible, + }) + nameMap[col.Name.String()] = nil + } + // If table is authoritative, we do not need ColumnVindexes to help in resolving the unqualified columns. + if r.Table.ColumnListAuthoritative { + return cols + } + for _, vindex := range r.Table.ColumnVindexes { + for _, column := range vindex.Columns { + name := column.String() + if _, exists := nameMap[name]; exists { + continue + } + cols = append(cols, ColumnInfo{ + Name: name, + }) + nameMap[name] = nil + } + } + return cols } // GetExpr implements the TableInfo interface @@ -122,36 +154,3 @@ func (r *RealTable) authoritative() bool { func (r *RealTable) matches(name sqlparser.TableName) bool { return (name.Qualifier.IsEmpty() || name.Qualifier.String() == r.dbName) && r.tableName == name.Name.String() } - -func vindexTableToColumnInfo(tbl *vindexes.Table, collationEnv *collations.Environment) []ColumnInfo { - if tbl == nil { - return nil - } - nameMap := map[string]any{} - cols := make([]ColumnInfo, 0, len(tbl.Columns)) - for _, col := range tbl.Columns { - cols = append(cols, ColumnInfo{ - Name: col.Name.String(), - Type: col.ToEvalengineType(collationEnv), - Invisible: col.Invisible, - }) - nameMap[col.Name.String()] = nil - } - // If table is authoritative, we do not need ColumnVindexes to help in resolving the unqualified columns. - if tbl.ColumnListAuthoritative { - return cols - } - for _, vindex := range tbl.ColumnVindexes { - for _, column := range vindex.Columns { - name := column.String() - if _, exists := nameMap[name]; exists { - continue - } - cols = append(cols, ColumnInfo{ - Name: name, - }) - nameMap[name] = nil - } - } - return cols -} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 6c89b2bb999..66b0f99035a 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -58,7 +58,7 @@ type ( canShortCut() shortCut // getColumns returns the known column information for this table - getColumns() []ColumnInfo + getColumns(ignoreInvisibleCol bool) []ColumnInfo dependencies(colName string, org originable) (dependencies, error) getExprFor(s string) (sqlparser.Expr, error) diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 1bc79a02566..945f652bf21 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -120,7 +120,7 @@ func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlpars if len(rowAlias.Columns) != len(ins.Columns) { panic("column count mismatch") } - origCols := origTableInfo.getColumns() + origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) for1: for idx, column := range rowAlias.Columns { colNames = append(colNames, column.String()) @@ -137,17 +137,16 @@ func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlpars if !origTableInfo.authoritative() { return vterrors.VT09015() } - // TODO: we need to handle invisible columns here :sigh: - if len(rowAlias.Columns) != len(origTableInfo.getColumns()) { + origCols := origTableInfo.getColumns(true /* ignoreInvisibleCol */) + if len(rowAlias.Columns) != len(origCols) { panic("column count mismatch") } - origCols := origTableInfo.getColumns() for idx, column := range rowAlias.Columns { colNames = append(colNames, column.String()) types = append(types, origCols[idx].Type) } case len(ins.Columns) > 0: - origCols := origTableInfo.getColumns() + origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) for2: for _, column := range ins.Columns { colNames = append(colNames, column.String()) @@ -163,7 +162,7 @@ func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlpars if !origTableInfo.authoritative() { return vterrors.VT09015() } - for _, column := range origTableInfo.getColumns() { + for _, column := range origTableInfo.getColumns(true /* ignoreInvisibleCol */) { colNames = append(colNames, column.Name) types = append(types, column.Type) } diff --git a/go/vt/vtgate/semantics/vindex_table.go b/go/vt/vtgate/semantics/vindex_table.go index fba8f8ab9a0..b598c93f36a 100644 --- a/go/vt/vtgate/semantics/vindex_table.go +++ b/go/vt/vtgate/semantics/vindex_table.go @@ -76,8 +76,8 @@ func (v *VindexTable) canShortCut() shortCut { } // GetColumns implements the TableInfo interface -func (v *VindexTable) getColumns() []ColumnInfo { - return v.Table.getColumns() +func (v *VindexTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { + return v.Table.getColumns(ignoreInvisbleCol) } // IsInfSchema implements the TableInfo interface diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index 81f81de3813..14519a7e938 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -104,7 +104,7 @@ func (v *vTableInfo) GetVindexTable() *vindexes.Table { return nil } -func (v *vTableInfo) getColumns() []ColumnInfo { +func (v *vTableInfo) getColumns(bool) []ColumnInfo { cols := make([]ColumnInfo, 0, len(v.columnNames)) for _, col := range v.columnNames { cols = append(cols, ColumnInfo{ From f710bc9912147db17186ad8b361665d982404580 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 25 Apr 2024 19:24:30 +0530 Subject: [PATCH 3/8] return error instead of panic for wrong column counts Signed-off-by: Harshit Gangal --- go/mysql/sqlerror/constants.go | 1 + go/mysql/sqlerror/sql_error.go | 1 + go/vt/vterrors/code.go | 2 ++ go/vt/vterrors/state.go | 1 + go/vt/vtgate/semantics/analyzer_dml_test.go | 2 +- go/vt/vtgate/semantics/table_collector.go | 4 ++-- 6 files changed, 8 insertions(+), 3 deletions(-) diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go index fdec64588c1..a247ca15aa4 100644 --- a/go/mysql/sqlerror/constants.go +++ b/go/mysql/sqlerror/constants.go @@ -235,6 +235,7 @@ const ( ERUnknownTimeZone = ErrorCode(1298) ERInvalidCharacterString = ErrorCode(1300) ERQueryInterrupted = ErrorCode(1317) + ERViewWrongList = ErrorCode(1353) ERTruncatedWrongValueForField = ErrorCode(1366) ERIllegalValueForType = ErrorCode(1367) ERDataTooLong = ErrorCode(1406) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index bebd9e41ca7..935fd77a12f 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -216,6 +216,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState}, + vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState}, vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState}, vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError}, vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index ffce4fc553d..d485c930b77 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -58,6 +58,7 @@ var ( VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.") VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces") VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.") + VT03033 = errorWithState("VT03033", vtrpcpb.Code_INVALID_ARGUMENT, ViewWrongList, "In definition of view, derived table or common table expression, SELECT list and column names list have different column counts", "The table column list and derived column list have different column counts.") VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.") VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.") @@ -146,6 +147,7 @@ var ( VT03030, VT03031, VT03032, + VT03033, VT05001, VT05002, VT05003, diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 2b0ada0bc6d..8223405fc92 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -49,6 +49,7 @@ const ( WrongArguments BadNullError InvalidGroupFuncUse + ViewWrongList // failed precondition NoDB diff --git a/go/vt/vtgate/semantics/analyzer_dml_test.go b/go/vt/vtgate/semantics/analyzer_dml_test.go index 147594d1e80..3e50f98f77a 100644 --- a/go/vt/vtgate/semantics/analyzer_dml_test.go +++ b/go/vt/vtgate/semantics/analyzer_dml_test.go @@ -126,7 +126,7 @@ func TestInsertBindingColNameErrorCases(t *testing.T) { "column 'new.invcol' not found", }, { "insert into t3(uid, name) values (1,'foo') as new(x, y, z) on duplicate key update textcol = x + y", - "column 'new.invcol' not found", + "VT03033: In definition of view, derived table or common table expression, SELECT list and column names list have different column counts", }} for _, tc := range tcases { t.Run(tc.query, func(t *testing.T) { diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 945f652bf21..4740b1e5bf6 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -118,7 +118,7 @@ func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlpars case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0: // we have explicit column list on the row alias and the insert statement if len(rowAlias.Columns) != len(ins.Columns) { - panic("column count mismatch") + return vterrors.VT03033() } origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) for1: @@ -139,7 +139,7 @@ func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlpars } origCols := origTableInfo.getColumns(true /* ignoreInvisibleCol */) if len(rowAlias.Columns) != len(origCols) { - panic("column count mismatch") + return vterrors.VT03033() } for idx, column := range rowAlias.Columns { colNames = append(colNames, column.String()) From f9a607cd3a2b71f6b62c2b2077361cb9ec75dc0c Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 25 Apr 2024 21:39:00 +0530 Subject: [PATCH 4/8] notes: added release notes Signed-off-by: Harshit Gangal --- changelog/20.0/20.0.0/summary.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/changelog/20.0/20.0.0/summary.md b/changelog/20.0/20.0.0/summary.md index 9421018bc9c..8b560fe283f 100644 --- a/changelog/20.0/20.0.0/summary.md +++ b/changelog/20.0/20.0.0/summary.md @@ -22,6 +22,7 @@ - [Delete with Subquery Support](#delete-subquery) - [Delete with Multi Target Support](#delete-multi-target) - [User Defined Functions Support](#udf-support) + - [Insert Row Alias Support](#insert-row-alias-support) - **[Query Timeout](#query-timeout)** - **[Flag changes](#flag-changes)** - [`pprof-http` default change](#pprof-http-default) @@ -197,6 +198,16 @@ Without this flag, VTGate will not be aware that there might be aggregating user More details about how to load UDFs is available in [MySQL Docs](https://dev.mysql.com/doc/extending-mysql/8.0/en/adding-loadable-function.html) +#### Insert Row Alias Support + +Support is added to have row alias in Insert statement to be used with `on duplicate key update`. + +Example: +- `insert into user(id, name, email) valies (100, 'Alice', 'alice@mail.com') as new on duplicate key update name = new.name, email = new.email` +- `insert into user(id, name, email) valies (100, 'Alice', 'alice@mail.com') as new(m, n, p) on duplicate key update name = n, email = p` + +More details about how it works is available in [MySQL Docs](https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html) + ### Query Timeout On a query timeout, Vitess closed the connection using the `kill connection` statement. This leads to connection churn which is not desirable in some cases. To avoid this, Vitess now uses the `kill query` statement to cancel the query. From 70e21140c1155daf85f4e97cd8d174fefaf3589a Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 25 Apr 2024 23:23:09 +0530 Subject: [PATCH 5/8] fix: insert sharded query to include alias in the formatted query Signed-off-by: Harshit Gangal --- .../vtgate/queries/dml/insert_test.go | 23 +++++++++++++++++++ go/vt/vtgate/engine/insert.go | 18 ++++++++++++++- .../planbuilder/operator_transformers.go | 5 +++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index ce052b7b2ba..91309339e02 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -462,3 +462,26 @@ func TestMixedCases(t *testing.T) { // final check count on the lookup vindex table. utils.AssertMatches(t, mcmp.VtConn, "select count(*) from lkp_mixed_idx", "[[INT64(12)]]") } + +// TestInsertAlias test the alias feature in insert statement. +func TestInsertAlias(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + // initial record + mcmp.Exec("insert into user_tbl(id, region_id, name) values (1, 1,'foo'),(2, 2,'bar'),(3, 3,'baz'),(4, 4,'buzz')") + + qr := mcmp.Exec("insert into user_tbl(id, region_id, name) values (2, 2, 'foo') as new on duplicate key update name = new.name") + assert.EqualValues(t, 2, qr.RowsAffected) + + // this validates the record. + mcmp.Exec("select id, region_id, name from user_tbl order by id") + + qr = mcmp.Exec("insert into user_tbl(id, region_id, name) values (3, 3, 'foo') as new(m, n, p) on duplicate key update name = p") + assert.EqualValues(t, 2, qr.RowsAffected) + + // this validates the record. + mcmp.Exec("select id, region_id, name from user_tbl order by id") +} diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 332ccc92098..c23c85d132f 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -55,6 +55,9 @@ type Insert struct { // Mid is the row values for the sharded insert plans. Mid sqlparser.Values + + // Alias represents the row alias with columns if specified in the query. + Alias string } // newQueryInsert creates an Insert with a query string. @@ -287,7 +290,7 @@ func (ins *Insert) getInsertShardedQueries( } } } - rewritten := ins.Prefix + strings.Join(mids, ",") + sqlparser.String(ins.Suffix) + rewritten := ins.Prefix + strings.Join(mids, ",") + ins.Alias + sqlparser.String(ins.Suffix) queries[i] = &querypb.BoundQuery{ Sql: rewritten, BindVariables: shardBindVars, @@ -363,6 +366,19 @@ func (ins *Insert) description() PrimitiveDescription { other["VindexValues"] = valuesOffsets } + // This is a check to ensure we send the correct query to the database. + // "ActualQuery" should not be part of the plan output, if it does, it means the query was not rewritten correctly. + if ins.Mid != nil { + var mids []string + for _, n := range ins.Mid { + mids = append(mids, sqlparser.String(n)) + } + shardedQuery := ins.Prefix + strings.Join(mids, ", ") + ins.Alias + sqlparser.String(ins.Suffix) + if shardedQuery != ins.Query { + other["ActualQuery"] = shardedQuery + } + } + return PrimitiveDescription{ OperatorType: "Insert", Keyspace: ins.Keyspace, diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 2a7f37a258f..a324592af8e 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -619,6 +619,9 @@ func buildInsertLogicalPlan( // when unsharded query with autoincrement for that there is no input operator. if eins.Opcode != engine.InsertUnsharded { eins.Prefix, eins.Mid, eins.Suffix = generateInsertShardedQuery(ins.AST) + if ins.AST.RowAlias != nil { + eins.Alias = sqlparser.String(ins.AST.RowAlias) + } } eins.Query = generateQuery(stmt) @@ -660,7 +663,7 @@ func generateInsertShardedQuery(ins *sqlparser.Insert) (prefix string, mids sqlp prefixBuf := sqlparser.NewTrackedBuffer(dmlFormatter) prefixBuf.Myprintf(prefixFormat, ins.Comments, ins.Ignore.ToString(), - ins.Table, ins.Columns) + ins.Table, ins.Columns, ins.RowAlias) prefix = prefixBuf.String() suffix = sqlparser.CopyOnRewrite(ins.OnDup, nil, func(cursor *sqlparser.CopyOnWriteCursor) { From 5b43495b4cfc556d7fedcd309ba803153cf4d4d8 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 26 Apr 2024 10:18:52 +0530 Subject: [PATCH 6/8] test: skip test for vttablet and vtgate not on v20 Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/queries/dml/insert_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index 91309339e02..771eb64ab02 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -466,6 +466,7 @@ func TestMixedCases(t *testing.T) { // TestInsertAlias test the alias feature in insert statement. func TestInsertAlias(t *testing.T) { utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + utils.SkipIfBinaryIsBelowVersion(t, 20, "vttablet") mcmp, closer := start(t) defer closer() From e7782d1605b3381ca1c73ca118f4dde1ad114329 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 26 Apr 2024 10:28:50 +0530 Subject: [PATCH 7/8] sizgen run Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/cached_size.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 22b3a38a990..18e22c00378 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -440,7 +440,7 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(208) + size += int64(224) } // field InsertCommon vitess.io/vitess/go/vt/vtgate/engine.InsertCommon size += cached.InsertCommon.CachedSize(false) @@ -479,6 +479,8 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } } } + // field Alias string + size += hack.RuntimeAllocSize(int64(len(cached.Alias))) return size } func (cached *InsertCommon) CachedSize(alloc bool) int64 { From d1c396bf022a77aec051f55a6beda4554ffae9f0 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 26 Apr 2024 13:09:39 +0530 Subject: [PATCH 8/8] split the rowalias table collector method to smaller methods Signed-off-by: Harshit Gangal --- go/vt/vtgate/semantics/table_collector.go | 215 +++++++++++++--------- 1 file changed, 128 insertions(+), 87 deletions(-) diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 4740b1e5bf6..ae107cc070c 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -109,86 +109,6 @@ func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sq etc.Tables = append(etc.Tables, tableInfo) } -func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error { - origTableInfo := tc.Tables[0] - - var colNames []string - var types []evalengine.Type - switch { - case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0: - // we have explicit column list on the row alias and the insert statement - if len(rowAlias.Columns) != len(ins.Columns) { - return vterrors.VT03033() - } - origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) - for1: - for idx, column := range rowAlias.Columns { - colNames = append(colNames, column.String()) - col := ins.Columns[idx] - for _, origCol := range origCols { - if col.EqualString(origCol.Name) { - types = append(types, origCol.Type) - continue for1 - } - } - return vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col) - } - case len(rowAlias.Columns) > 0: - if !origTableInfo.authoritative() { - return vterrors.VT09015() - } - origCols := origTableInfo.getColumns(true /* ignoreInvisibleCol */) - if len(rowAlias.Columns) != len(origCols) { - return vterrors.VT03033() - } - for idx, column := range rowAlias.Columns { - colNames = append(colNames, column.String()) - types = append(types, origCols[idx].Type) - } - case len(ins.Columns) > 0: - origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) - for2: - for _, column := range ins.Columns { - colNames = append(colNames, column.String()) - for _, origCol := range origCols { - if column.EqualString(origCol.Name) { - types = append(types, origCol.Type) - continue for2 - } - } - types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) - } - default: - if !origTableInfo.authoritative() { - return vterrors.VT09015() - } - for _, column := range origTableInfo.getColumns(true /* ignoreInvisibleCol */) { - colNames = append(colNames, column.Name) - types = append(types, column.Type) - } - } - deps := make([]TableSet, len(colNames)) - for i := range colNames { - deps[i] = SingleTableSet(0) - } - - derivedTable := &DerivedTable{ - tableName: rowAlias.TableName.String(), - ASTNode: &sqlparser.AliasedTableExpr{ - Expr: sqlparser.NewTableName(rowAlias.TableName.String()), - }, - columnNames: colNames, - tables: SingleTableSet(0), - recursive: deps, - isAuthoritative: true, - types: types, - } - - tc.Tables = append(tc.Tables, derivedTable) - current := tc.scoper.currentScope() - return current.addTable(derivedTable) -} - func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.AliasedTableExpr: @@ -206,6 +126,16 @@ func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { } } +func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr) error { + switch t := node.Expr.(type) { + case *sqlparser.DerivedTable: + return tc.handleDerivedTable(node, t) + case sqlparser.TableName: + return tc.handleTableName(node, t) + } + return nil +} + func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { firstSelect := sqlparser.GetFirstSelect(union) expanded, selectExprs := getColumnNames(firstSelect.SelectExprs) @@ -248,15 +178,126 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { return nil } -func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr) error { - switch t := node.Expr.(type) { - case *sqlparser.DerivedTable: - return tc.handleDerivedTable(node, t) +func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error { + origTableInfo := tc.Tables[0] - case sqlparser.TableName: - return tc.handleTableName(node, t) + colNames, types, err := tc.getColumnNamesAndTypes(ins, rowAlias, origTableInfo) + if err != nil { + return err } - return nil + + derivedTable := buildDerivedTable(colNames, rowAlias, types) + tc.Tables = append(tc.Tables, derivedTable) + current := tc.scoper.currentScope() + return current.addTable(derivedTable) +} + +func (tc *tableCollector) getColumnNamesAndTypes(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias, origTableInfo TableInfo) (colNames []string, types []evalengine.Type, err error) { + switch { + case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0: + return tc.handleExplicitColumns(ins, rowAlias, origTableInfo) + case len(rowAlias.Columns) > 0: + return tc.handleRowAliasColumns(origTableInfo, rowAlias) + case len(ins.Columns) > 0: + colNames, types = tc.handleInsertColumns(ins, origTableInfo) + return colNames, types, nil + default: + return tc.handleDefaultColumns(origTableInfo) + } +} + +// handleDefaultColumns have no explicit column list on the insert statement and no column list on the row alias +func (tc *tableCollector) handleDefaultColumns(origTableInfo TableInfo) ([]string, []evalengine.Type, error) { + if !origTableInfo.authoritative() { + return nil, nil, vterrors.VT09015() + } + var colNames []string + var types []evalengine.Type + for _, column := range origTableInfo.getColumns(true /* ignoreInvisibleCol */) { + colNames = append(colNames, column.Name) + types = append(types, column.Type) + } + return colNames, types, nil +} + +// handleInsertColumns have explicit column list on the insert statement and no column list on the row alias +func (tc *tableCollector) handleInsertColumns(ins *sqlparser.Insert, origTableInfo TableInfo) ([]string, []evalengine.Type) { + var colNames []string + var types []evalengine.Type + origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) +for2: + for _, column := range ins.Columns { + colNames = append(colNames, column.String()) + for _, origCol := range origCols { + if column.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for2 + } + } + types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + } + return colNames, types +} + +// handleRowAliasColumns have explicit column list on the row alias and no column list on the insert statement +func (tc *tableCollector) handleRowAliasColumns(origTableInfo TableInfo, rowAlias *sqlparser.RowAlias) ([]string, []evalengine.Type, error) { + if !origTableInfo.authoritative() { + return nil, nil, vterrors.VT09015() + } + origCols := origTableInfo.getColumns(true /* ignoreInvisibleCol */) + if len(rowAlias.Columns) != len(origCols) { + return nil, nil, vterrors.VT03033() + } + var colNames []string + var types []evalengine.Type + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + types = append(types, origCols[idx].Type) + } + return colNames, types, nil +} + +// handleExplicitColumns have explicit column list on the row alias and the insert statement +func (tc *tableCollector) handleExplicitColumns(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias, origTableInfo TableInfo) ([]string, []evalengine.Type, error) { + if len(rowAlias.Columns) != len(ins.Columns) { + return nil, nil, vterrors.VT03033() + } + var colNames []string + var types []evalengine.Type + origCols := origTableInfo.getColumns(false /* ignoreInvisbleCol */) +for1: + for idx, column := range rowAlias.Columns { + colNames = append(colNames, column.String()) + col := ins.Columns[idx] + for _, origCol := range origCols { + if col.EqualString(origCol.Name) { + types = append(types, origCol.Type) + continue for1 + } + } + return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col) + } + return colNames, types, nil +} + +func buildDerivedTable(colNames []string, rowAlias *sqlparser.RowAlias, types []evalengine.Type) *DerivedTable { + deps := make([]TableSet, len(colNames)) + for i := range colNames { + deps[i] = SingleTableSet(0) + } + + derivedTable := &DerivedTable{ + tableName: rowAlias.TableName.String(), + ASTNode: &sqlparser.AliasedTableExpr{ + Expr: sqlparser.NewTableName(rowAlias.TableName.String()), + }, + columnNames: colNames, + tables: SingleTableSet(0), + recursive: deps, + isAuthoritative: true, + types: types, + } + return derivedTable } func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (err error) {