Skip to content

Commit

Permalink
feat: handle INSERT with RowAlias
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Apr 25, 2024
1 parent 7ca2b81 commit 464ed1a
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 15 deletions.
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ type (
Partitions Partitions
Columns Columns
Rows InsertRows
OnDup OnDup
RowAlias *RowAlias
OnDup OnDup
}

// Ignore represents whether ignore was specified or not
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_clone.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions go/vt/sqlparser/ast_copy_on_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions go/vt/sqlparser/ast_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/ast_visit.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 104 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/dml_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -6752,5 +6752,109 @@
"user.user"
]
}
},
{
"comment": "RowAlias in INSERT",
"query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
"plan": {
"QueryType": "INSERT",
"Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
"Instructions": {
"OperatorType": "Insert",
"Variant": "Sharded",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "PRIMARY",
"InsertIgnore": true,
"Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1",
"TableName": "authoritative",
"VindexValues": {
"user_index": "1, 4"
}
},
"TablesUsed": [
"user.authoritative"
]
}
},
{
"comment": "RowAlias with explicit columns in INSERT",
"query": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
"plan": {
"QueryType": "INSERT",
"Original": "INSERT INTO authoritative (user_id,col1,col2) VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
"Instructions": {
"OperatorType": "Insert",
"Variant": "Sharded",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "PRIMARY",
"InsertIgnore": true,
"Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c",
"TableName": "authoritative",
"VindexValues": {
"user_index": "1, 4"
}
},
"TablesUsed": [
"user.authoritative"
]
}
},
{
"comment": "RowAlias in INSERT (no column list)",
"query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
"plan": {
"QueryType": "INSERT",
"Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new ON DUPLICATE KEY UPDATE col2 = new.user_id+new.col1",
"Instructions": {
"OperatorType": "Insert",
"Variant": "Sharded",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "PRIMARY",
"InsertIgnore": true,
"Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new on duplicate key update col2 = new.user_id + new.col1",
"TableName": "authoritative",
"VindexValues": {
"user_index": "1, 4"
}
},
"TablesUsed": [
"user.authoritative"
]
}
},
{
"comment": "RowAlias with explicit columns in INSERT (no column list)",
"query": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
"plan": {
"QueryType": "INSERT",
"Original": "INSERT INTO authoritative VALUES (1,'2',3),(4,'5',6) AS new(a,b,c) ON DUPLICATE KEY UPDATE col1 = a+c",
"Instructions": {
"OperatorType": "Insert",
"Variant": "Sharded",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "PRIMARY",
"InsertIgnore": true,
"Query": "insert into authoritative(user_id, col1, col2) values (:_user_id_0, '2', 3), (:_user_id_1, '5', 6) as new (a, b, c) on duplicate key update col1 = a + c",
"TableName": "authoritative",
"VindexValues": {
"user_index": "1, 4"
}
},
"TablesUsed": [
"user.authoritative"
]
}
}
]
17 changes: 17 additions & 0 deletions go/vt/vtgate/semantics/analyzer_dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,20 @@ func TestUpdBindingExpr(t *testing.T) {
func extractFromUpdateSet(in *sqlparser.Update, idx int) *sqlparser.UpdateExpr {
return in.Exprs[idx]
}

func TestInsertBindingColName(t *testing.T) {
queries := []string{
"insert into t2 (uid, name, textcol) values (1,'foo','bar') as new on duplicate key update texcol = new.uid + new.name",
"insert into t2 (uid, name, textcol) values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y",
"insert into t2 values (1,'foo','bar') as new(x, y, z) on duplicate key update texcol = x + y",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
stmt, semTable := parseAndAnalyzeStrict(t, query, "d")
ins, _ := stmt.(*sqlparser.Insert)
ue := ins.OnDup[0]
ts := semTable.RecursiveDeps(ue.Expr)
assert.Equal(t, SingleTableSet(0), ts)
})
}
}
10 changes: 10 additions & 0 deletions go/vt/vtgate/semantics/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,16 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *
return parse, semTable
}

func parseAndAnalyzeStrict(t *testing.T, query, dbName string) (sqlparser.Statement, *SemTable) {
t.Helper()
parse, err := sqlparser.NewTestParser().Parse(query)
require.NoError(t, err)

semTable, err := AnalyzeStrict(parse, dbName, fakeSchemaInfo())
require.NoError(t, err)
return parse, semTable
}

func TestSingleUnshardedKeyspace(t *testing.T) {
tests := []struct {
query string
Expand Down
92 changes: 92 additions & 0 deletions go/vt/vtgate/semantics/table_collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package semantics
import (
"fmt"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -64,6 +68,7 @@ func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) {
etc.withTables[cte.ID] = nil
}
}

}

func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) {
Expand Down Expand Up @@ -104,12 +109,99 @@ func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sq
etc.Tables = append(etc.Tables, tableInfo)
}

func (tc *tableCollector) visitRowAlias(ins *sqlparser.Insert, rowAlias *sqlparser.RowAlias) error {
origTableInfo := tc.Tables[0]

var colNames []string
var types []evalengine.Type
switch {
case len(rowAlias.Columns) > 0 && len(ins.Columns) > 0:
// we have explicit column list on the row alias and the insert statement
if len(rowAlias.Columns) != len(ins.Columns) {
panic("column count mismatch")
}
origCols := origTableInfo.getColumns()
for1:
for idx, column := range rowAlias.Columns {
colNames = append(colNames, column.String())
col := ins.Columns[idx]
for _, origCol := range origCols {
if col.EqualString(origCol.Name) {
types = append(types, origCol.Type)
continue for1
}
}
return vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.BadFieldError, "Unknown column '%s' in 'field list'", col)
}
case len(rowAlias.Columns) > 0:
if !origTableInfo.authoritative() {
return vterrors.VT09015()
}
// TODO: we need to handle invisible columns here :sigh:
if len(rowAlias.Columns) != len(origTableInfo.getColumns()) {
panic("column count mismatch")
}
origCols := origTableInfo.getColumns()
for idx, column := range rowAlias.Columns {
colNames = append(colNames, column.String())
types = append(types, origCols[idx].Type)
}
case len(ins.Columns) > 0:
origCols := origTableInfo.getColumns()
for2:
for _, column := range ins.Columns {
colNames = append(colNames, column.String())
for _, origCol := range origCols {
if column.EqualString(origCol.Name) {
types = append(types, origCol.Type)
continue for2
}
}
types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown))
}
default:
if !origTableInfo.authoritative() {
return vterrors.VT09015()
}
for _, column := range origTableInfo.getColumns() {
colNames = append(colNames, column.Name)
types = append(types, column.Type)
}
}
deps := make([]TableSet, len(colNames))
for i := range colNames {
deps[i] = SingleTableSet(0)
}

derivedTable := &DerivedTable{
tableName: rowAlias.TableName.String(),
ASTNode: &sqlparser.AliasedTableExpr{
Expr: sqlparser.NewTableName(rowAlias.TableName.String()),
},
columnNames: colNames,
tables: SingleTableSet(0),
recursive: deps,
isAuthoritative: true,
types: types,
}

tc.Tables = append(tc.Tables, derivedTable)
current := tc.scoper.currentScope()
return current.addTable(derivedTable)
}

func (tc *tableCollector) up(cursor *sqlparser.Cursor) error {
switch node := cursor.Node().(type) {
case *sqlparser.AliasedTableExpr:
return tc.visitAliasedTableExpr(node)
case *sqlparser.Union:
return tc.visitUnion(node)
case *sqlparser.RowAlias:
ins, ok := cursor.Parent().(*sqlparser.Insert)
if !ok {
return vterrors.VT13001("RowAlias is expected to hang off an Insert statement")
}
return tc.visitRowAlias(ins, node)
default:
return nil
}
Expand Down

0 comments on commit 464ed1a

Please sign in to comment.