From 0d83955ea569706e5296cd3e2f54efb7f1206d0b Mon Sep 17 00:00:00 2001 From: Hangjie Mo Date: Tue, 24 Dec 2024 14:44:58 +0800 Subject: [PATCH] Support delimiter for mysql-tester (#135) --- src/main.go | 107 ++++++++++++++++++++++++++++++++-------------- src/query.go | 83 +++++++++++++++++------------------ src/query_test.go | 77 +++++++++++++++++++++++++++++---- 3 files changed, 185 insertions(+), 82 deletions(-) diff --git a/src/main.go b/src/main.go index 3da256a..cec17c6 100644 --- a/src/main.go +++ b/src/main.go @@ -72,6 +72,7 @@ const ( type query struct { firstWord string Query string + delimiter string Line int tp int } @@ -148,6 +149,9 @@ type tester struct { // replace output result through --replace_regex /\.dll/.so/ replaceRegex []*ReplaceRegex + + // the delimter for TiDB, default value is ";" + delimiter string } func newTester(name string) *tester { @@ -161,6 +165,7 @@ func newTester(name string) *tester { t.enableWarning = false t.enableConcurrent = false t.enableInfo = false + t.delimiter = ";" return t } @@ -462,10 +467,7 @@ func (t *tester) Run() error { t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])}) } case Q_CONNECT: - q.Query = strings.TrimSpace(q.Query) - if q.Query[len(q.Query)-1] == ';' { - q.Query = q.Query[:len(q.Query)-1] - } + q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter) q.Query = q.Query[1 : len(q.Query)-1] args := strings.Split(q.Query, ",") for i := range args { @@ -476,16 +478,10 @@ func (t *tester) Run() error { } t.addConnection(args[0], args[1], args[2], args[3], args[4]) case Q_CONNECTION: - q.Query = strings.TrimSpace(q.Query) - if q.Query[len(q.Query)-1] == ';' { - q.Query = q.Query[:len(q.Query)-1] - } + q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter) t.switchConnection(q.Query) case Q_DISCONNECT: - q.Query = strings.TrimSpace(q.Query) - if q.Query[len(q.Query)-1] == ';' { - q.Query = q.Query[:len(q.Query)-1] - } + q.Query = strings.TrimSuffix(strings.TrimSpace(q.Query), q.delimiter) t.disconnect(q.Query) case Q_LET: q.Query = strings.TrimSpace(q.Query) @@ -622,7 +618,7 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure return } - err := tt.stmtExecute(query.Query) + err := tt.stmtExecute(query) if err != nil && len(t.expectedErrs) > 0 { for _, tStr := range t.expectedErrs { if strings.Contains(err.Error(), tStr) { @@ -650,43 +646,90 @@ func (t *tester) loadQueries() ([]query, error) { seps := bytes.Split(data, []byte("\n")) queries := make([]query, 0, len(seps)) - newStmt := true + buffer := "" for i, v := range seps { v := bytes.TrimSpace(v) s := string(v) // we will skip # comment here if strings.HasPrefix(s, "#") { - newStmt = true + if len(buffer) != 0 { + return nil, errors.Errorf("Has remained message(%s) before COMMENTS", buffer) + } continue } else if strings.HasPrefix(s, "--") { - queries = append(queries, query{Query: s, Line: i + 1}) - newStmt = true + if len(buffer) != 0 { + return nil, errors.Errorf("Has remained message(%s) before COMMANDS", buffer) + } + q, err := ParseQuery(query{Query: s, Line: i + 1, delimiter: t.delimiter}) + if err != nil { + return nil, err + } + if q == nil { + continue + } + if q.tp == Q_DELIMITER { + tokens := strings.Split(strings.TrimSpace(q.Query), " ") + if len(tokens) == 0 { + return nil, errors.Errorf("DELIMITER must be followed by a 'delimiter' character or string") + } + t.delimiter = tokens[0] + } else { + queries = append(queries, *q) + } + continue + } else if strings.HasPrefix(strings.ToLower(strings.TrimSpace(s)), "delimiter ") { + if len(buffer) != 0 { + return nil, errors.Errorf("Has remained message(%s) before DELIMITER COMMAND", buffer) + } + tokens := strings.Split(strings.TrimSpace(s), " ") + if len(tokens) <= 1 { + return nil, errors.Errorf("DELIMITER must be followed by a 'delimiter' character or string") + } + t.delimiter = tokens[1] continue } else if len(s) == 0 { continue } - if newStmt { - queries = append(queries, query{Query: s, Line: i + 1}) - } else { - lastQuery := queries[len(queries)-1] - lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line} - queries[len(queries)-1] = lastQuery + if len(buffer) != 0 { + buffer += "\n" } + buffer += s + for { + idx := strings.LastIndex(buffer, t.delimiter) + if idx == -1 { + break + } - // if the line has a ; in the end, we will treat new line as the new statement. - newStmt = strings.HasSuffix(s, ";") + queryStr := buffer[:idx+len(t.delimiter)] + buffer = buffer[idx+len(t.delimiter):] + q, err := ParseQuery(query{Query: strings.TrimSpace(queryStr), Line: i + 1, delimiter: t.delimiter}) + if err != nil { + return nil, err + } + if q == nil { + continue + } + queries = append(queries, *q) + } + // If has remained comments, ignore them. + if len(buffer) != 0 && strings.HasPrefix(strings.TrimSpace(buffer), "#") { + buffer = "" + } } - - return ParseQueries(queries...) + if len(buffer) != 0 { + return nil, errors.Errorf("Has remained text(%s) in file", buffer) + } + return queries, nil } -func (t *tester) stmtExecute(query string) (err error) { +func (t *tester) stmtExecute(query query) (err error) { if t.enableQueryLog { - t.buf.WriteString(query) + t.buf.WriteString(query.Query) t.buf.WriteString("\n") } - return t.executeStmt(query) + + return t.executeStmt(strings.TrimSuffix(query.Query, query.delimiter)) } // checkExpectedError check if error was expected @@ -784,7 +827,7 @@ func (t *tester) execute(query query) error { } offset := t.buf.Len() - err := t.stmtExecute(query.Query) + err := t.stmtExecute(query) err = t.checkExpectedError(query, err) if err != nil { @@ -967,7 +1010,7 @@ func (t *tester) executeStmt(query string) error { } if t.enableWarning { - raw, err := t.curr.conn.QueryContext(context.Background(), "show warnings;") + raw, err := t.curr.conn.QueryContext(context.Background(), "show warnings") if err != nil { return errors.Trace(err) } diff --git a/src/query.go b/src/query.go index 6a128d8..fdf0dab 100644 --- a/src/query.go +++ b/src/query.go @@ -126,56 +126,57 @@ const ( Q_EMPTY_LINE ) -// ParseQueries parses an array of string into an array of query object. +// ParseQuery parses an array of string into an array of query object. // Note: a query statement may reside in several lines. -func ParseQueries(qs ...query) ([]query, error) { - queries := make([]query, 0, len(qs)) - for _, rs := range qs { - realS := rs.Query - s := rs.Query - q := query{} - q.tp = Q_UNKNOWN - q.Line = rs.Line - // a valid query's length should be at least 3. - if len(s) < 3 { - continue - } - // we will skip #comment and line with zero characters here - if s[0] == '#' { - q.tp = Q_COMMENT - } else if s[0:2] == "--" { - q.tp = Q_COMMENT_WITH_COMMAND - if s[2] == ' ' { - s = s[3:] - } else { - s = s[2:] - } - } else if s[0] == '\n' { - q.tp = Q_EMPTY_LINE +func ParseQuery(rs query) (*query, error) { + realS := rs.Query + s := rs.Query + q := query{delimiter: rs.delimiter, Line: rs.Line} + q.tp = Q_UNKNOWN + // a valid query's length should be at least 3. + if len(s) < 3 { + return nil, nil + } + // we will skip #comment and line with zero characters here + if s[0] == '#' { + q.tp = Q_COMMENT + } else if s[0:2] == "--" { + q.tp = Q_COMMENT_WITH_COMMAND + if s[2] == ' ' { + s = s[3:] + } else { + s = s[2:] } + } else if s[0] == '\n' { + q.tp = Q_EMPTY_LINE + } - if q.tp != Q_COMMENT { - // Calculate first word length(the command), terminated - // by 'space' , '(' or 'delimiter' - var i int - for i = 0; i < len(s) && s[i] != '(' && s[i] != ' ' && s[i] != ';' && s[i] != '\n'; i++ { + if q.tp != Q_COMMENT { + // Calculate first word length(the command), terminated + // by 'space' , '(' and delimiter + var i int + for i = 0; i < len(s); i++ { + if s[i] == '(' || s[i] == ' ' || s[i] == '\n' { + break } - if i > 0 { - q.firstWord = s[:i] + if i+len(rs.delimiter) <= len(s) && s[i:i+len(rs.delimiter)] == rs.delimiter { + break } - s = s[i:] + } + if i > 0 { + q.firstWord = s[:i] + } + s = s[i:] - q.Query = s - if q.tp == Q_UNKNOWN || q.tp == Q_COMMENT_WITH_COMMAND { - if err := q.getQueryType(realS); err != nil { - return nil, err - } + q.Query = s + if q.tp == Q_UNKNOWN || q.tp == Q_COMMENT_WITH_COMMAND { + if err := q.getQueryType(realS); err != nil { + return nil, err } } - - queries = append(queries, q) } - return queries, nil + + return &q, nil } // for a single query, it has some prefix. Prefix mapps to a query type. diff --git a/src/query_test.go b/src/query_test.go index 825d645..4b25812 100644 --- a/src/query_test.go +++ b/src/query_test.go @@ -15,7 +15,11 @@ package main import ( "fmt" + "os" + "path/filepath" "testing" + + "github.com/stretchr/testify/assert" ) func assertEqual(t *testing.T, a interface{}, b interface{}, message string) { @@ -31,28 +35,83 @@ func assertEqual(t *testing.T, a interface{}, b interface{}, message string) { func TestParseQueryies(t *testing.T) { sql := "select * from t;" - if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil { - assertEqual(t, q[0].tp, Q_QUERY, fmt.Sprintf("Expected: %d, got: %d", Q_QUERY, q[0].tp)) - assertEqual(t, q[0].Query, sql, fmt.Sprintf("Expected: %s, got: %s", sql, q[0].Query)) + if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil { + assertEqual(t, q.tp, Q_QUERY, fmt.Sprintf("Expected: %d, got: %d", Q_QUERY, q.tp)) + assertEqual(t, q.Query, sql, fmt.Sprintf("Expected: %s, got: %s", sql, q.Query)) } else { t.Fatalf("error is not nil. %v", err) } sql = "--sorted_result select * from t;" - if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil { - assertEqual(t, q[0].tp, Q_SORTED_RESULT, "sorted_result") - assertEqual(t, q[0].Query, "select * from t;", fmt.Sprintf("Expected: '%s', got '%s'", "select * from t;", q[0].Query)) + if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil { + assertEqual(t, q.tp, Q_SORTED_RESULT, "sorted_result") + assertEqual(t, q.Query, "select * from t;", fmt.Sprintf("Expected: '%s', got '%s'", "select * from t;", q.Query)) } else { t.Fatalf("error is not nil. %s", err) } // invalid comment command style sql = "--abc select * from t;" - _, err := ParseQueries(query{Query: sql, Line: 1}) + _, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}) assertEqual(t, err, ErrInvalidCommand, fmt.Sprintf("Expected: %v, got %v", ErrInvalidCommand, err)) sql = "--let $foo=`SELECT 1`" - if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil { - assertEqual(t, q[0].tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q[0].tp)) + if q, err := ParseQuery(query{Query: sql, Line: 1, delimiter: ";"}); err == nil { + assertEqual(t, q.tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q.tp)) + } +} + +func TestLoadQueries(t *testing.T) { + dir := t.TempDir() + err := os.Chdir(dir) + assert.NoError(t, err) + + err = os.Mkdir("t", 0755) + assert.NoError(t, err) + + testCases := []struct { + input string + queries []query + }{ + { + input: "delimiter |\n do something; select something; |\n delimiter ; \nselect 1;", + queries: []query{ + {Query: "do something; select something; |", tp: Q_QUERY, delimiter: "|"}, + {Query: "select 1;", tp: Q_QUERY, delimiter: ";"}, + }, + }, + { + input: "delimiter |\ndrop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |\ndrop procedure scope|\ndelimiter ;\n", + queries: []query{ + {Query: "drop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |", tp: Q_QUERY, delimiter: "|"}, + {Query: "drop procedure scope|", tp: Q_QUERY, delimiter: "|"}, + }, + }, + { + input: "--error 1054\nselect 1;", + queries: []query{ + {Query: " 1054", tp: Q_ERROR, delimiter: ";"}, + {Query: "select 1;", tp: Q_QUERY, delimiter: ";"}, + }, + }, + } + + for _, testCase := range testCases { + fileName := filepath.Join("t", "test.test") + f, err := os.Create(fileName) + assert.NoError(t, err) + + f.WriteString(testCase.input) + f.Close() + + test := newTester("test") + queries, err := test.loadQueries() + assert.NoError(t, err) + assert.Len(t, queries, len(testCase.queries)) + for i, query := range testCase.queries { + assert.Equal(t, queries[i].Query, query.Query) + assert.Equal(t, queries[i].tp, query.tp) + assert.Equal(t, queries[i].delimiter, query.delimiter) + } } }