Skip to content

Commit

Permalink
Support delimiter for mysql-tester (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 authored Dec 24, 2024
1 parent 314107b commit 0d83955
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 82 deletions.
107 changes: 75 additions & 32 deletions src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (
type query struct {
firstWord string
Query string
delimiter string
Line int
tp int
}
Expand Down Expand Up @@ -148,6 +149,9 @@ type tester struct {

// replace output result through --replace_regex /\.dll/.so/
replaceRegex []*ReplaceRegex

// the delimter for TiDB, default value is ";"
delimiter string
}

func newTester(name string) *tester {
Expand All @@ -161,6 +165,7 @@ func newTester(name string) *tester {
t.enableWarning = false
t.enableConcurrent = false
t.enableInfo = false
t.delimiter = ";"

return t
}
Expand Down Expand Up @@ -462,10 +467,7 @@ func (t *tester) Run() error {
t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])})
}
case Q_CONNECT:
q.Query = strings.TrimSpace(q.Query)
if q.Query[len(q.Query)-1] == ';' {
q.Query = q.Query[:len(q.Query)-1]
}
q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter)
q.Query = q.Query[1 : len(q.Query)-1]
args := strings.Split(q.Query, ",")
for i := range args {
Expand All @@ -476,16 +478,10 @@ func (t *tester) Run() error {
}
t.addConnection(args[0], args[1], args[2], args[3], args[4])
case Q_CONNECTION:
q.Query = strings.TrimSpace(q.Query)
if q.Query[len(q.Query)-1] == ';' {
q.Query = q.Query[:len(q.Query)-1]
}
q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter)
t.switchConnection(q.Query)
case Q_DISCONNECT:
q.Query = strings.TrimSpace(q.Query)
if q.Query[len(q.Query)-1] == ';' {
q.Query = q.Query[:len(q.Query)-1]
}
q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter)
t.disconnect(q.Query)
case Q_LET:
q.Query = strings.TrimSpace(q.Query)
Expand Down Expand Up @@ -622,7 +618,7 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
return
}

err := tt.stmtExecute(query.Query)
err := tt.stmtExecute(query)
if err != nil && len(t.expectedErrs) > 0 {
for _, tStr := range t.expectedErrs {
if strings.Contains(err.Error(), tStr) {
Expand Down Expand Up @@ -650,43 +646,90 @@ func (t *tester) loadQueries() ([]query, error) {

seps := bytes.Split(data, []byte("\n"))
queries := make([]query, 0, len(seps))
newStmt := true
buffer := ""
for i, v := range seps {
v := bytes.TrimSpace(v)
s := string(v)
// we will skip # comment here
if strings.HasPrefix(s, "#") {
newStmt = true
if len(buffer) != 0 {
return nil, errors.Errorf("Has remained message(%s) before COMMENTS", buffer)
}
continue
} else if strings.HasPrefix(s, "--") {
queries = append(queries, query{Query: s, Line: i + 1})
newStmt = true
if len(buffer) != 0 {
return nil, errors.Errorf("Has remained message(%s) before COMMANDS", buffer)
}
q, err := ParseQuery(query{Query: s, Line: i + 1, delimiter: t.delimiter})
if err != nil {
return nil, err
}
if q == nil {
continue
}
if q.tp == Q_DELIMITER {
tokens := strings.Split(strings.TrimSpace(q.Query), " ")
if len(tokens) == 0 {
return nil, errors.Errorf("DELIMITER must be followed by a 'delimiter' character or string")
}
t.delimiter = tokens[0]
} else {
queries = append(queries, *q)
}
continue
} else if strings.HasPrefix(strings.ToLower(strings.TrimSpace(s)), "delimiter ") {
if len(buffer) != 0 {
return nil, errors.Errorf("Has remained message(%s) before DELIMITER COMMAND", buffer)
}
tokens := strings.Split(strings.TrimSpace(s), " ")
if len(tokens) <= 1 {
return nil, errors.Errorf("DELIMITER must be followed by a 'delimiter' character or string")
}
t.delimiter = tokens[1]
continue
} else if len(s) == 0 {
continue
}

if newStmt {
queries = append(queries, query{Query: s, Line: i + 1})
} else {
lastQuery := queries[len(queries)-1]
lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line}
queries[len(queries)-1] = lastQuery
if len(buffer) != 0 {
buffer += "\n"
}
buffer += s
for {
idx := strings.LastIndex(buffer, t.delimiter)
if idx == -1 {
break
}

// if the line has a ; in the end, we will treat new line as the new statement.
newStmt = strings.HasSuffix(s, ";")
queryStr := buffer[:idx+len(t.delimiter)]
buffer = buffer[idx+len(t.delimiter):]
q, err := ParseQuery(query{Query: strings.TrimSpace(queryStr), Line: i + 1, delimiter: t.delimiter})
if err != nil {
return nil, err
}
if q == nil {
continue
}
queries = append(queries, *q)
}
// If has remained comments, ignore them.
if len(buffer) != 0 && strings.HasPrefix(strings.TrimSpace(buffer), "#") {
buffer = ""
}
}

return ParseQueries(queries...)
if len(buffer) != 0 {
return nil, errors.Errorf("Has remained text(%s) in file", buffer)
}
return queries, nil
}

func (t *tester) stmtExecute(query string) (err error) {
func (t *tester) stmtExecute(query query) (err error) {
if t.enableQueryLog {
t.buf.WriteString(query)
t.buf.WriteString(query.Query)
t.buf.WriteString("\n")
}
return t.executeStmt(query)

return t.executeStmt(strings.TrimSuffix(query.Query, query.delimiter))
}

// checkExpectedError check if error was expected
Expand Down Expand Up @@ -784,7 +827,7 @@ func (t *tester) execute(query query) error {
}

offset := t.buf.Len()
err := t.stmtExecute(query.Query)
err := t.stmtExecute(query)

err = t.checkExpectedError(query, err)
if err != nil {
Expand Down Expand Up @@ -967,7 +1010,7 @@ func (t *tester) executeStmt(query string) error {
}

if t.enableWarning {
raw, err := t.curr.conn.QueryContext(context.Background(), "show warnings;")
raw, err := t.curr.conn.QueryContext(context.Background(), "show warnings")
if err != nil {
return errors.Trace(err)
}
Expand Down
83 changes: 42 additions & 41 deletions src/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,56 +126,57 @@ const (
Q_EMPTY_LINE
)

// ParseQueries parses an array of string into an array of query object.
// ParseQuery parses an array of string into an array of query object.
// Note: a query statement may reside in several lines.
func ParseQueries(qs ...query) ([]query, error) {
queries := make([]query, 0, len(qs))
for _, rs := range qs {
realS := rs.Query
s := rs.Query
q := query{}
q.tp = Q_UNKNOWN
q.Line = rs.Line
// a valid query's length should be at least 3.
if len(s) < 3 {
continue
}
// we will skip #comment and line with zero characters here
if s[0] == '#' {
q.tp = Q_COMMENT
} else if s[0:2] == "--" {
q.tp = Q_COMMENT_WITH_COMMAND
if s[2] == ' ' {
s = s[3:]
} else {
s = s[2:]
}
} else if s[0] == '\n' {
q.tp = Q_EMPTY_LINE
func ParseQuery(rs query) (*query, error) {
realS := rs.Query
s := rs.Query
q := query{delimiter: rs.delimiter, Line: rs.Line}
q.tp = Q_UNKNOWN
// a valid query's length should be at least 3.
if len(s) < 3 {
return nil, nil
}
// we will skip #comment and line with zero characters here
if s[0] == '#' {
q.tp = Q_COMMENT
} else if s[0:2] == "--" {
q.tp = Q_COMMENT_WITH_COMMAND
if s[2] == ' ' {
s = s[3:]
} else {
s = s[2:]
}
} else if s[0] == '\n' {
q.tp = Q_EMPTY_LINE
}

if q.tp != Q_COMMENT {
// Calculate first word length(the command), terminated
// by 'space' , '(' or 'delimiter'
var i int
for i = 0; i < len(s) && s[i] != '(' && s[i] != ' ' && s[i] != ';' && s[i] != '\n'; i++ {
if q.tp != Q_COMMENT {
// Calculate first word length(the command), terminated
// by 'space' , '(' and delimiter
var i int
for i = 0; i < len(s); i++ {
if s[i] == '(' || s[i] == ' ' || s[i] == '\n' {
break
}
if i > 0 {
q.firstWord = s[:i]
if i+len(rs.delimiter) <= len(s) && s[i:i+len(rs.delimiter)] == rs.delimiter {
break
}
s = s[i:]
}
if i > 0 {
q.firstWord = s[:i]
}
s = s[i:]

q.Query = s
if q.tp == Q_UNKNOWN || q.tp == Q_COMMENT_WITH_COMMAND {
if err := q.getQueryType(realS); err != nil {
return nil, err
}
q.Query = s
if q.tp == Q_UNKNOWN || q.tp == Q_COMMENT_WITH_COMMAND {
if err := q.getQueryType(realS); err != nil {
return nil, err
}
}

queries = append(queries, q)
}
return queries, nil

return &q, nil
}

// for a single query, it has some prefix. Prefix mapps to a query type.
Expand Down
77 changes: 68 additions & 9 deletions src/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ package main

import (
"fmt"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func assertEqual(t *testing.T, a interface{}, b interface{}, message string) {
Expand All @@ -31,28 +35,83 @@ func assertEqual(t *testing.T, a interface{}, b interface{}, message string) {
func TestParseQueryies(t *testing.T) {
sql := "select * from t;"

if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil {
assertEqual(t, q[0].tp, Q_QUERY, fmt.Sprintf("Expected: %d, got: %d", Q_QUERY, q[0].tp))
assertEqual(t, q[0].Query, sql, fmt.Sprintf("Expected: %s, got: %s", sql, q[0].Query))
if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_QUERY, fmt.Sprintf("Expected: %d, got: %d", Q_QUERY, q.tp))
assertEqual(t, q.Query, sql, fmt.Sprintf("Expected: %s, got: %s", sql, q.Query))
} else {
t.Fatalf("error is not nil. %v", err)
}

sql = "--sorted_result select * from t;"
if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil {
assertEqual(t, q[0].tp, Q_SORTED_RESULT, "sorted_result")
assertEqual(t, q[0].Query, "select * from t;", fmt.Sprintf("Expected: '%s', got '%s'", "select * from t;", q[0].Query))
if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_SORTED_RESULT, "sorted_result")
assertEqual(t, q.Query, "select * from t;", fmt.Sprintf("Expected: '%s', got '%s'", "select * from t;", q.Query))
} else {
t.Fatalf("error is not nil. %s", err)
}

// invalid comment command style
sql = "--abc select * from t;"
_, err := ParseQueries(query{Query: sql, Line: 1})
_, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"})
assertEqual(t, err, ErrInvalidCommand, fmt.Sprintf("Expected: %v, got %v", ErrInvalidCommand, err))

sql = "--let $foo=`SELECT 1`"
if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil {
assertEqual(t, q[0].tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q[0].tp))
if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q.tp))
}
}

func TestLoadQueries(t *testing.T) {
dir := t.TempDir()
err := os.Chdir(dir)
assert.NoError(t, err)

err = os.Mkdir("t", 0755)
assert.NoError(t, err)

testCases := []struct {
input string
queries []query
}{
{
input: "delimiter |\n do something; select something; |\n delimiter ; \nselect 1;",
queries: []query{
{Query: "do something; select something; |", tp: Q_QUERY, delimiter: "|"},
{Query: "select 1;", tp: Q_QUERY, delimiter: ";"},
},
},
{
input: "delimiter |\ndrop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |\ndrop procedure scope|\ndelimiter ;\n",
queries: []query{
{Query: "drop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |", tp: Q_QUERY, delimiter: "|"},
{Query: "drop procedure scope|", tp: Q_QUERY, delimiter: "|"},
},
},
{
input: "--error 1054\nselect 1;",
queries: []query{
{Query: " 1054", tp: Q_ERROR, delimiter: ";"},
{Query: "select 1;", tp: Q_QUERY, delimiter: ";"},
},
},
}

for _, testCase := range testCases {
fileName := filepath.Join("t", "test.test")
f, err := os.Create(fileName)
assert.NoError(t, err)

f.WriteString(testCase.input)
f.Close()

test := newTester("test")
queries, err := test.loadQueries()
assert.NoError(t, err)
assert.Len(t, queries, len(testCase.queries))
for i, query := range testCase.queries {
assert.Equal(t, queries[i].Query, query.Query)
assert.Equal(t, queries[i].tp, query.tp)
assert.Equal(t, queries[i].delimiter, query.delimiter)
}
}
}

0 comments on commit 0d83955

Please sign in to comment.