Skip to content

Commit

Permalink
handle user variables in load data
Browse files Browse the repository at this point in the history
  • Loading branch information
James Cor committed Sep 11, 2024
1 parent 01d3c5f commit cb21e48
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 52 deletions.
160 changes: 160 additions & 0 deletions enginetest/queries/load_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ var LoadDataScripts = []ScriptTest{
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 set i = '123', j = '456', k = '789'",
"create table lt4(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 set i = '123', i = '321'",
"create table lt5(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt5 set j = concat(j, j)",
},
Assertions: []ScriptTestAssertion{
{
Expand Down Expand Up @@ -291,6 +293,14 @@ var LoadDataScripts = []ScriptTest{
{"321", "mno", "pqr"},
},
},
{
Skip: true, // self references are problematic
Query: "select * from lt5 order by i, j, k",
Expected: []sql.Row{
{"321", "defdef", "ghi"},
{"321", "mnomno", "pqr"},
},
},
},
},
{
Expand Down Expand Up @@ -336,6 +346,156 @@ var LoadDataScripts = []ScriptTest{
},
},
},
{
Name: "LOAD DATA assign to static User Variables",
SetUpScript: []string{
"set @i = '123';",
"set @j = '456';",
"set @k = '789';",
"create table lt(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i, j = @j",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt set i = @i, j = @j, k = @k",
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from lt order by i, j, k",
Expected: []sql.Row{
{"123", "456", "789"},
{"123", "456", "789"},
{"123", "456", "ghi"},
{"123", "456", "pqr"},
{"123", "def", "ghi"},
{"123", "mno", "pqr"},
},
},
},
},
{
Name: "LOAD DATA assign to User Variables",
SetUpScript: []string{
"create table lt1(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt1 (@i, j, k)",
"create table lt2(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt2 (i, @j, k)",
"create table lt3(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, j, @k)",
"create table lt4(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 (@ii, @jj, @kk)",
"create table lt5(i text, j text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt5 (i, j, @trash1)",
"create table lt6(j text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt6 (@trash2, j, @trash2)",
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from lt1 order by i, j, k",
Expected: []sql.Row{
{nil, "def", "ghi"},
{nil, "mno", "pqr"},
},
},
{
Query: "select * from lt2 order by i, j, k",
Expected: []sql.Row{
{"abc", nil, "ghi"},
{"jkl", nil, "pqr"},
},
},
{
Query: "select * from lt3 order by i, j, k",
Expected: []sql.Row{
{"abc", "def", nil},
{"jkl", "mno", nil},
},
},
{
Query: "select @i, @j, @k",
Expected: []sql.Row{
{"jkl", "mno", "pqr"},
},
},
{
Query: "select * from lt4 order by i, j, k",
Expected: []sql.Row{
{nil, nil, nil},
{nil, nil, nil},
},
},
{
Query: "select @ii, @jj, @kk",
Expected: []sql.Row{
{"jkl", "mno", "pqr"},
},
},
{
Query: "select * from lt5 order by i, j",
Expected: []sql.Row{
{"abc", "def"},
{"jkl", "mno"},
},
},
{
Query: "select @trash1",
Expected: []sql.Row{
{"pqr"},
},
},
{
Query: "select * from lt6 order by j",
Expected: []sql.Row{
{"def"},
{"mno"},
},
},
{
Query: "select @trash2",
Expected: []sql.Row{
{"pqr"},
},
},
},
},
{
Name: "LOAD DATA with user vars and set expressions",
SetUpScript: []string{
"create table lt1(i text, j text, k text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt1 (k, @j, i) set j = @j",
"create table lt2(i text, j text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt2 (i, j, @k) set j = concat(@k, @k)",
"create table lt3(i text, j text);",
"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, @j, @k) set j = concat(@j, @k)",

// TODO: fix these
//"create table lt3(i text, j text);",
//"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt3 (i, j, @k) set j = concat(j, @k)",
//"create table lt4(i text, j text);",
//"LOAD DATA INFILE './testdata/test9.txt' INTO TABLE lt4 (@i, @j) set i = @j, j = @i",
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from lt1 order by i, j, k",
Expected: []sql.Row{
{"ghi", "def", "abc"},
{"pqr", "mno", "jkl"},
},
},
{
Query: "select * from lt2 order by i, j",
Expected: []sql.Row{
{"abc", "ghighi"},
{"jkl", "pqrpqr"},
},
},
{
Query: "select * from lt3 order by i, j",
Expected: []sql.Row{
{"abc", "defghi"},
{"jkl", "mnopqr"},
},
},
},
},
{
Name: "LOAD DATA with set columns errors",
SetUpScript: []string{
Expand Down
18 changes: 10 additions & 8 deletions sql/plan/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type LoadData struct {
DestSch sql.Schema
ColumnNames []string
SetExprs []sql.Expression
UserSetFields []sql.Expression
ResponsePacketSent bool
IgnoreNum int64
IsIgnore bool
Expand Down Expand Up @@ -107,17 +108,18 @@ func (*LoadData) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio
return sql.Collation_binary, 7
}

func NewLoadData(local bool, file string, destSch sql.Schema, cols []string, ignoreNum int64, ignoreOrReplace string) *LoadData {
func NewLoadData(local bool, file string, destSch sql.Schema, cols []string, userSetFields []sql.Expression, ignoreNum int64, ignoreOrReplace string) *LoadData {
isReplace := ignoreOrReplace == sqlparser.ReplaceStr
isIgnore := ignoreOrReplace == sqlparser.IgnoreStr || (local && !isReplace)
return &LoadData{
Local: local,
File: file,
DestSch: destSch,
ColumnNames: cols,
IgnoreNum: ignoreNum,
IsIgnore: isIgnore,
IsReplace: isReplace,
Local: local,
File: file,
DestSch: destSch,
ColumnNames: cols,
UserSetFields: userSetFields,
IgnoreNum: ignoreNum,
IsIgnore: isIgnore,
IsReplace: isReplace,

FieldsTerminatedBy: defaultFieldsTerminatedBy,
FieldsEnclosedBy: defaultFieldsEnclosedBy,
Expand Down
30 changes: 28 additions & 2 deletions sql/planbuilder/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ package planbuilder

import (
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql/types"
"strings"

ast "github.com/dolthub/vitess/go/vt/sqlparser"

Expand Down Expand Up @@ -66,7 +67,31 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) {
sch = b.resolveSchemaDefaults(destScope, rt.Schema())
}

ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, columnsToStrings(d.Columns), ignoreNumVal, d.IgnoreOrReplace)
// TODO: look through d.Columns and separate out the UserVars
// TODO: handle weird edge case where column names have @ in them
// TODO: @@variables are syntax error
colsOrVars := columnsToStrings(d.Columns)
colNames := make([]string, 0, len(d.Columns))
userSetFields := make([]sql.Expression, max(len(sch), len(d.Columns)))
for i, name := range colsOrVars {
varName, varScope, _, err := ast.VarScope(name)
if err != nil {
b.handleErr(err)
}
switch varScope {
case ast.SetScope_None:
colNames = append(colNames, name)
userSetFields[i] = nil
case ast.SetScope_User:
userVar := expression.NewUserVar(varName)
getField := expression.NewGetField(i, types.Text, name, true)
userSetFields[i] = expression.NewSetField(userVar, getField)
default:
b.handleErr(sql.ErrSyntaxError.New(fmt.Errorf("syntax error near '%s'", name)))
}
}

ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, colNames, userSetFields, ignoreNumVal, d.IgnoreOrReplace)
if d.Charset != "" {
// TODO: deal with charset; ignore for now
ld.Charset = d.Charset
Expand Down Expand Up @@ -130,6 +155,7 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) {
}
if !exists {
ld.ColumnNames = append(ld.ColumnNames, colName)
ld.UserSetFields = append(ld.UserSetFields, nil)
}
}
}
Expand Down
26 changes: 15 additions & 11 deletions sql/rowexec/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func (b *BaseBuilder) buildDropTrigger(ctx *sql.Context, n *plan.DropTrigger, ro
func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql.Row) (sql.RowIter, error) {
var reader io.ReadCloser
var err error

if n.Local {
_, localInfile, ok := sql.SystemVariables.GetGlobal("local_infile")
if !ok {
Expand Down Expand Up @@ -95,8 +94,6 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql.
}

scanner := bufio.NewScanner(reader)

// Set the split function for lines.
scanner.Split(n.SplitLines)

// Skip through the lines that need to be ignored.
Expand All @@ -112,26 +109,33 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql.

sch := n.Schema()
source := sch[0].Source // Schema will always have at least one column
columnNames := n.ColumnNames
if len(columnNames) == 0 {
columnNames = make([]string, len(sch))
colNames := n.ColumnNames
if len(colNames) == 0 {
colNames = make([]string, len(sch))
for i, col := range sch {
columnNames[i] = col.Name
colNames[i] = col.Name
}
}

fieldToColumnMap := make([]int, len(sch))
for fieldIndex, columnName := range columnNames {
fieldToColumnMap[fieldIndex] = sch.IndexOf(columnName, source)
// TODO: account for offsets from user variables?
fieldToColMap := make([]int, len(n.UserSetFields))
for fieldIdx, colIdx := 0, 0; fieldIdx < len(n.UserSetFields) && colIdx < len(colNames); fieldIdx++ {
if n.UserSetFields[fieldIdx] != nil {
fieldToColMap[fieldIdx] = -1
continue
}
fieldToColMap[fieldIdx] = sch.IndexOf(colNames[colIdx], source)
colIdx++
}

return &loadDataIter{
destSch: n.DestSch,
reader: reader,
scanner: scanner,
columnCount: len(n.ColumnNames), // Needs to be the original column count
fieldToColumnMap: fieldToColumnMap,
fieldToColumnMap: fieldToColMap,
setExprs: n.SetExprs,
userSetFields: n.UserSetFields,

fieldsTerminatedBy: n.FieldsTerminatedBy,
fieldsEnclosedBy: n.FieldsEnclosedBy,
Expand Down
Loading

0 comments on commit cb21e48

Please sign in to comment.