diff --git a/enginetest/queries/load_queries.go b/enginetest/queries/load_queries.go index b0e1f3f6ef..5c0d288a79 100644 --- a/enginetest/queries/load_queries.go +++ b/enginetest/queries/load_queries.go @@ -261,6 +261,8 @@ var LoadDataScripts = []ScriptTest{ "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 set i = '123', j = '456', k = '789'", "create table lt4(i text, j text, k text);", "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 set i = '123', i = '321'", + "create table lt5(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt5 set j = concat(j, j)", }, Assertions: []ScriptTestAssertion{ { @@ -291,6 +293,14 @@ var LoadDataScripts = []ScriptTest{ {"321", "mno", "pqr"}, }, }, + { + Skip: true, // self references are problematic + Query: "select * from lt5 order by i, j, k", + Expected: []sql.Row{ + {"321", "defdef", "ghi"}, + {"321", "mnomno", "pqr"}, + }, + }, }, }, { @@ -336,6 +346,156 @@ var LoadDataScripts = []ScriptTest{ }, }, }, + { + Name: "LOAD DATA assign to static User Variables", + SetUpScript: []string{ + "set @i = '123';", + "set @j = '456';", + "set @k = '789';", + "create table lt(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i, j = @j", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i, j = @j, k = @k", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from lt order by i, j, k", + Expected: []sql.Row{ + {"123", "456", "789"}, + {"123", "456", "789"}, + {"123", "456", "ghi"}, + {"123", "456", "pqr"}, + {"123", "def", "ghi"}, + {"123", "mno", "pqr"}, + }, + }, + }, + }, + { + Name: "LOAD DATA assign to User Variables", + SetUpScript: []string{ + "create table lt1(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt1 (@i, j, k)", + "create table lt2(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt2 (i, @j, k)", + "create table lt3(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, j, @k)", + "create table lt4(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 (@ii, @jj, @kk)", + "create table lt5(i text, j text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt5 (i, j, @trash1)", + "create table lt6(j text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt6 (@trash2, j, @trash2)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from lt1 order by i, j, k", + Expected: []sql.Row{ + {nil, "def", "ghi"}, + {nil, "mno", "pqr"}, + }, + }, + { + Query: "select * from lt2 order by i, j, k", + Expected: []sql.Row{ + {"abc", nil, "ghi"}, + {"jkl", nil, "pqr"}, + }, + }, + { + Query: "select * from lt3 order by i, j, k", + Expected: []sql.Row{ + {"abc", "def", nil}, + {"jkl", "mno", nil}, + }, + }, + { + Query: "select @i, @j, @k", + Expected: []sql.Row{ + {"jkl", "mno", "pqr"}, + }, + }, + { + Query: "select * from lt4 order by i, j, k", + Expected: []sql.Row{ + {nil, nil, nil}, + {nil, nil, nil}, + }, + }, + { + Query: "select @ii, @jj, @kk", + Expected: []sql.Row{ + {"jkl", "mno", "pqr"}, + }, + }, + { + Query: "select * from lt5 order by i, j", + Expected: []sql.Row{ + {"abc", "def"}, + {"jkl", "mno"}, + }, + }, + { + Query: "select @trash1", + Expected: []sql.Row{ + {"pqr"}, + }, + }, + { + Query: "select * from lt6 order by j", + Expected: []sql.Row{ + {"def"}, + {"mno"}, + }, + }, + { + Query: "select @trash2", + Expected: []sql.Row{ + {"pqr"}, + }, + }, + }, + }, + { + Name: "LOAD DATA with user vars and set expressions", + SetUpScript: []string{ + "create table lt1(i text, j text, k text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt1 (k, @j, i) set j = @j", + "create table lt2(i text, j text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt2 (i, j, @k) set j = concat(@k, @k)", + "create table lt3(i text, j text);", + "LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, @j, @k) set j = concat(@j, @k)", + + // TODO: fix these + //"create table lt3(i text, j text);", + //"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, j, @k) set j = concat(j, @k)", + //"create table lt4(i text, j text);", + //"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 (@i, @j) set i = @j, j = @i", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from lt1 order by i, j, k", + Expected: []sql.Row{ + {"ghi", "def", "abc"}, + {"pqr", "mno", "jkl"}, + }, + }, + { + Query: "select * from lt2 order by i, j", + Expected: []sql.Row{ + {"abc", "ghighi"}, + {"jkl", "pqrpqr"}, + }, + }, + { + Query: "select * from lt3 order by i, j", + Expected: []sql.Row{ + {"abc", "defghi"}, + {"jkl", "mnopqr"}, + }, + }, + }, + }, { Name: "LOAD DATA with set columns errors", SetUpScript: []string{ diff --git a/sql/plan/load_data.go b/sql/plan/load_data.go index e9c311097c..e38cb14e32 100644 --- a/sql/plan/load_data.go +++ b/sql/plan/load_data.go @@ -28,6 +28,7 @@ type LoadData struct { DestSch sql.Schema ColumnNames []string SetExprs []sql.Expression + UserSetFields []sql.Expression ResponsePacketSent bool IgnoreNum int64 IsIgnore bool @@ -107,17 +108,18 @@ func (*LoadData) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio return sql.Collation_binary, 7 } -func NewLoadData(local bool, file string, destSch sql.Schema, cols []string, ignoreNum int64, ignoreOrReplace string) *LoadData { +func NewLoadData(local bool, file string, destSch sql.Schema, cols []string, userSetFields []sql.Expression, ignoreNum int64, ignoreOrReplace string) *LoadData { isReplace := ignoreOrReplace == sqlparser.ReplaceStr isIgnore := ignoreOrReplace == sqlparser.IgnoreStr || (local && !isReplace) return &LoadData{ - Local: local, - File: file, - DestSch: destSch, - ColumnNames: cols, - IgnoreNum: ignoreNum, - IsIgnore: isIgnore, - IsReplace: isReplace, + Local: local, + File: file, + DestSch: destSch, + ColumnNames: cols, + UserSetFields: userSetFields, + IgnoreNum: ignoreNum, + IsIgnore: isIgnore, + IsReplace: isReplace, FieldsTerminatedBy: defaultFieldsTerminatedBy, FieldsEnclosedBy: defaultFieldsEnclosedBy, diff --git a/sql/planbuilder/load.go b/sql/planbuilder/load.go index 06abe57596..5fafbe80b2 100644 --- a/sql/planbuilder/load.go +++ b/sql/planbuilder/load.go @@ -16,7 +16,8 @@ package planbuilder import ( "fmt" - "strings" + "github.com/dolthub/go-mysql-server/sql/types" +"strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" @@ -66,7 +67,31 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) { sch = b.resolveSchemaDefaults(destScope, rt.Schema()) } - ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, columnsToStrings(d.Columns), ignoreNumVal, d.IgnoreOrReplace) + // TODO: look through d.Columns and separate out the UserVars + // TODO: handle weird edge case where column names have @ in them + // TODO: @@variables are syntax error + colsOrVars := columnsToStrings(d.Columns) + colNames := make([]string, 0, len(d.Columns)) + userSetFields := make([]sql.Expression, max(len(sch), len(d.Columns))) + for i, name := range colsOrVars { + varName, varScope, _, err := ast.VarScope(name) + if err != nil { + b.handleErr(err) + } + switch varScope { + case ast.SetScope_None: + colNames = append(colNames, name) + userSetFields[i] = nil + case ast.SetScope_User: + userVar := expression.NewUserVar(varName) + getField := expression.NewGetField(i, types.Text, name, true) + userSetFields[i] = expression.NewSetField(userVar, getField) + default: + b.handleErr(sql.ErrSyntaxError.New(fmt.Errorf("syntax error near '%s'", name))) + } + } + + ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, colNames, userSetFields, ignoreNumVal, d.IgnoreOrReplace) if d.Charset != "" { // TODO: deal with charset; ignore for now ld.Charset = d.Charset @@ -130,6 +155,7 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) { } if !exists { ld.ColumnNames = append(ld.ColumnNames, colName) + ld.UserSetFields = append(ld.UserSetFields, nil) } } } diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index ecaef27d91..661e182417 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -63,7 +63,6 @@ func (b *BaseBuilder) buildDropTrigger(ctx *sql.Context, n *plan.DropTrigger, ro func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql.Row) (sql.RowIter, error) { var reader io.ReadCloser var err error - if n.Local { _, localInfile, ok := sql.SystemVariables.GetGlobal("local_infile") if !ok { @@ -95,8 +94,6 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql. } scanner := bufio.NewScanner(reader) - - // Set the split function for lines. scanner.Split(n.SplitLines) // Skip through the lines that need to be ignored. @@ -112,17 +109,23 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql. sch := n.Schema() source := sch[0].Source // Schema will always have at least one column - columnNames := n.ColumnNames - if len(columnNames) == 0 { - columnNames = make([]string, len(sch)) + colNames := n.ColumnNames + if len(colNames) == 0 { + colNames = make([]string, len(sch)) for i, col := range sch { - columnNames[i] = col.Name + colNames[i] = col.Name } } - fieldToColumnMap := make([]int, len(sch)) - for fieldIndex, columnName := range columnNames { - fieldToColumnMap[fieldIndex] = sch.IndexOf(columnName, source) + // TODO: account for offsets from user variables? + fieldToColMap := make([]int, len(n.UserSetFields)) + for fieldIdx, colIdx := 0, 0; fieldIdx < len(n.UserSetFields) && colIdx < len(colNames); fieldIdx++ { + if n.UserSetFields[fieldIdx] != nil { + fieldToColMap[fieldIdx] = -1 + continue + } + fieldToColMap[fieldIdx] = sch.IndexOf(colNames[colIdx], source) + colIdx++ } return &loadDataIter{ @@ -130,8 +133,9 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql. reader: reader, scanner: scanner, columnCount: len(n.ColumnNames), // Needs to be the original column count - fieldToColumnMap: fieldToColumnMap, + fieldToColumnMap: fieldToColMap, setExprs: n.SetExprs, + userSetFields: n.UserSetFields, fieldsTerminatedBy: n.FieldsTerminatedBy, fieldsEnclosedBy: n.FieldsEnclosedBy, diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 0912aa64c2..440a641e7f 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -38,12 +38,14 @@ import ( ) type loadDataIter struct { - scanner *bufio.Scanner + scanner *bufio.Scanner + reader io.ReadCloser + destSch sql.Schema - reader io.ReadCloser columnCount int fieldToColumnMap []int setExprs []sql.Expression + userSetFields []sql.Expression fieldsTerminatedBy string fieldsEnclosedBy string @@ -142,7 +144,7 @@ func (l loadDataIter) parseFields(ctx *sql.Context, line string) ([]sql.Expressi } } - //Step 4: Handle the ESCAPED BY parameter. + // Step 4: Handle the ESCAPED BY parameter. if l.fieldsEscapedBy != "" { for i, field := range fields { if field == "\\N" { @@ -157,50 +159,71 @@ func (l loadDataIter) parseFields(ctx *sql.Context, line string) ([]sql.Expressi } } - exprs := make([]sql.Expression, len(l.destSch)) - - limit := len(exprs) - if len(fields) < limit { - limit = len(fields) + fieldRow := make(sql.Row, len(fields)) + for i, field := range fields { + fieldRow[i] = field } - destSch := l.destSch - for i := 0; i < limit; i++ { - if l.setExprs != nil { - setExpr := l.setExprs[l.fieldToColumnMap[i]] - if setExpr != nil { - exprs[i] = setExpr - continue + exprs := make([]sql.Expression, len(l.destSch)) + for fieldIdx, exprIdx := 0, 0; fieldIdx < len(fields); fieldIdx++ { + if l.userSetFields[fieldIdx] != nil { + setField := l.userSetFields[fieldIdx].(*expression.SetField) + userVar := setField.LeftChild.(*expression.UserVar) + err := setUserVar(ctx, userVar, setField.RightChild, fieldRow) + if err != nil { + return nil, err } + continue } - field := fields[i] - destCol := destSch[l.fieldToColumnMap[i]] - // Replace the empty string with defaults - if field == "" { - _, ok := destCol.Type.(sql.StringType) - if !ok { + // don't check for `exprIdx < len(exprs)` in for loop + // because we still need to assign trailing user variables + if exprIdx >= len(exprs) { + continue + } + + field := fields[fieldIdx] + destCol := l.destSch[l.fieldToColumnMap[fieldIdx]] + switch field { + case "": + // Replace the empty string with defaults if exists, otherwise NULL + if _, ok := destCol.Type.(sql.StringType); ok { + exprs[exprIdx] = expression.NewLiteral(field, types.LongText) + } else { if destCol.Default != nil { - exprs[i] = destCol.Default + exprs[exprIdx] = destCol.Default } else { - exprs[i] = expression.NewLiteral(nil, types.Null) + exprs[exprIdx] = expression.NewLiteral(nil, types.Null) } - } else { - exprs[i] = expression.NewLiteral(field, types.LongText) } - } else if field == "NULL" { - exprs[i] = expression.NewLiteral(nil, types.Null) - } else { - exprs[i] = expression.NewLiteral(field, types.LongText) + case "NULL": + exprs[exprIdx] = expression.NewLiteral(nil, types.Null) + default: + exprs[exprIdx] = expression.NewLiteral(field, types.LongText) + } + exprIdx++ + } + + // Apply Set Expressions by replacing the corresponding field expression with the set expression + for fieldIdx, exprIdx := 0, 0; len(l.setExprs) > 0 && fieldIdx < len(l.fieldToColumnMap) && exprIdx < len(exprs); fieldIdx++ { + setIdx := l.fieldToColumnMap[fieldIdx] + if setIdx == -1 { + continue + } + setExpr := l.setExprs[setIdx] + if setExpr != nil { + exprs[exprIdx] = setExpr } + exprIdx++ } + // TODO: watch out for this block // Due to how projections work, if no columns are provided (each row may have a variable number of values), the // projection will not insert default values, so we must do it here. if l.columnCount == 0 { for i, expr := range exprs { - if expr == nil && destSch[i].Default != nil { - f := destSch[i] + if expr == nil && l.destSch[i].Default != nil { + f := l.destSch[i] if !f.Nullable && f.Default == nil && !f.AutoIncrement { return nil, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(f.Name) }