From f592150fc1d1f889895e36ced9d2f6ad127477a9 Mon Sep 17 00:00:00 2001 From: Maximilian Hoffman Date: Mon, 26 Aug 2024 11:53:24 -0700 Subject: [PATCH] [commands] faster sqldump query scanner (#8291) * [commands] faster sqldump query scanner * imports seem to work * remove logs * more perf, don't parse twice * [ga-format-pr] Run go/utils/repofmt/format_repo.sh and go/Godeps/update.sh * fix delimiters * delete old code * more comments * simplify a bit * add line numbers back * delimiter line num * picky line nums * more comments * zach comment * more comments --------- Co-authored-by: max-hoffman --- go/cmd/dolt/cli/command.go | 3 + go/cmd/dolt/commands/engine/sqlengine.go | 6 + go/cmd/dolt/commands/filter-branch.go | 2 +- go/cmd/dolt/commands/sql.go | 12 +- go/cmd/dolt/commands/sql_statement_scanner.go | 333 +++++++++++++----- .../commands/sql_statement_scanner_test.go | 15 +- .../dolt/commands/sqlserver/queryist_utils.go | 6 + 7 files changed, 278 insertions(+), 99 deletions(-) diff --git a/go/cmd/dolt/cli/command.go b/go/cmd/dolt/cli/command.go index a641e27371..f8b23fe783 100644 --- a/go/cmd/dolt/cli/command.go +++ b/go/cmd/dolt/cli/command.go @@ -23,6 +23,8 @@ import ( "syscall" "github.com/dolthub/go-mysql-server/sql" + querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/fatih/color" eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" @@ -86,6 +88,7 @@ type SignalCommand interface { // SQL. The Queryist can be obtained from the CliContext passed into the Exec method by calling the QueryEngine method. type Queryist interface { Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) + QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]*querypb.BindVariable, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) } // This type is to store the content of a documented command, elsewhere we can transform this struct into diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index d856166ff7..f9b470671f 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -30,6 +30,8 @@ import ( "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/rowexec" _ "github.com/dolthub/go-mysql-server/sql/variables" + querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/sirupsen/logrus" "github.com/dolthub/dolt/go/cmd/dolt/cli" @@ -310,6 +312,10 @@ func (se *SqlEngine) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowI return se.engine.Query(ctx, query) } +func (se *SqlEngine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]*querypb.BindVariable, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + return se.engine.QueryWithBindings(ctx, query, parsed, bindings, qFlags) +} + // Analyze analyzes a node. func (se *SqlEngine) Analyze(ctx *sql.Context, n sql.Node, qFlags *sql.QueryFlags) (sql.Node, error) { return se.engine.Analyzer.Analyze(ctx, n, nil, qFlags) diff --git a/go/cmd/dolt/commands/filter-branch.go b/go/cmd/dolt/commands/filter-branch.go index f491075602..26dccb6622 100644 --- a/go/cmd/dolt/commands/filter-branch.go +++ b/go/cmd/dolt/commands/filter-branch.go @@ -282,7 +282,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, root doltdb.Root return nil, err } - scanner := NewSqlStatementScanner(strings.NewReader(query)) + scanner := newStreamScanner(strings.NewReader(query)) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 6646f5062a..1e07ac3fc5 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -612,7 +612,7 @@ func saveQuery(ctx *sql.Context, root doltdb.RootValue, query string, name strin // execBatchMode runs all the queries in the input reader func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, continueOnErr bool, format engine.PrintResultFormat) error { - scanner := NewSqlStatementScanner(input) + scanner := newStreamScanner(input) var query string for scanner.Scan() { if fileReadProg != nil { @@ -630,7 +630,7 @@ func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, conti if err == sqlparser.ErrEmpty { continue } else if err != nil { - err = buildBatchSqlErr(scanner.statementStartLine, query, err) + err = buildBatchSqlErr(scanner.state.statementStartLine, query, err) if !continueOnErr { return err } else { @@ -642,7 +642,7 @@ func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, conti ctx.SetQueryTime(time.Now()) sqlSch, rowIter, _, err := processParsedQuery(ctx, query, qryist, sqlStatement) if err != nil { - err = buildBatchSqlErr(scanner.statementStartLine, query, err) + err = buildBatchSqlErr(scanner.state.statementStartLine, query, err) if !continueOnErr { return err } else { @@ -661,7 +661,7 @@ func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, conti } err = engine.PrettyPrintResults(ctx, format, sqlSch, rowIter) if err != nil { - err = buildBatchSqlErr(scanner.statementStartLine, query, err) + err = buildBatchSqlErr(scanner.state.statementStartLine, query, err) if !continueOnErr { return err } else { @@ -673,7 +673,7 @@ func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, conti } if err := scanner.Err(); err != nil { - return buildBatchSqlErr(scanner.statementStartLine, query, err) + return buildBatchSqlErr(scanner.state.statementStartLine, query, err) } return nil @@ -1122,7 +1122,7 @@ func processParsedQuery(ctx *sql.Context, query string, qryist cli.Queryist, sql } return qryist.Query(ctx, query) default: - return qryist.Query(ctx, query) + return qryist.QueryWithBindings(ctx, query, sqlStatement, nil, nil) } } diff --git a/go/cmd/dolt/commands/sql_statement_scanner.go b/go/cmd/dolt/commands/sql_statement_scanner.go index 51919b17d4..c21d239a78 100755 --- a/go/cmd/dolt/commands/sql_statement_scanner.go +++ b/go/cmd/dolt/commands/sql_statement_scanner.go @@ -16,8 +16,9 @@ package commands import ( "bufio" + "bytes" + "fmt" "io" - "regexp" "unicode" ) @@ -29,23 +30,8 @@ type statementScanner struct { Delimiter string } -const maxStatementBufferBytes = 100 * 1024 * 1024 - -func NewSqlStatementScanner(input io.Reader) *statementScanner { - scanner := bufio.NewScanner(input) - const initialCapacity = 512 * 1024 - buf := make([]byte, initialCapacity) - scanner.Buffer(buf, maxStatementBufferBytes) - - s := &statementScanner{ - Scanner: scanner, - lineNum: 1, - Delimiter: ";", - } - scanner.Split(s.scanStatements) - - return s -} +const maxStatementBufferBytes = 100*1024*1024 + 4096 +const pageSize = 2 << 11 const ( sQuote byte = '\'' @@ -54,89 +40,272 @@ const ( backtick = '`' ) -var scannerDelimiterRegex = regexp.MustCompile(`(?i)^\s*DELIMITER\s+(\S+)\s*`) +const delimPrefixLen = 10 + +var delimPrefix = []byte("delimiter ") + +// streamScanner is an iterator that reads bytes from |inp| until either +// (1) we match a DELIMITER statement, (2) we match the |delimiter| token, +// or (3) we EOF the file. After each Scan() call, the valid token will +// span from the buffer beginning to |state.end|. +type streamScanner struct { + inp io.Reader + buf []byte + maxSize int + i int // leading byte + fill int + err error + isEOF bool + delimiter []byte + lineNum int + state *qState +} + +func newStreamScanner(r io.Reader) *streamScanner { + return &streamScanner{inp: r, buf: make([]byte, pageSize), maxSize: maxStatementBufferBytes, delimiter: []byte(";"), state: new(qState)} +} + +type qState struct { + start int + end int // token end, usually i - len(delimiter) + quoteChar byte // the opening quote character of the current quote being parsed, or 0 if the current parse location isn't inside a quoted string + lastChar byte // the last character parsed + ignoreNextChar bool // whether to ignore the next character + numConsecutiveBackslashes int // the number of consecutive backslashes encountered + seenNonWhitespaceChar bool // whether we have encountered a non-whitespace character since we returned the last token + numConsecutiveDelimiterMatches int // the consecutive number of characters that have been matched to the delimiter + statementStartLine int +} -// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. -func (s *statementScanner) scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil +func (s *streamScanner) Scan() bool { + // truncate last query + s.truncate() + s.resetState() + + if s.i >= s.fill { + // initialize buffer + if err := s.read(); err != nil { + s.err = err + return false + } } - var ( - quoteChar byte // the opening quote character of the current quote being parsed, or 0 if the current parse location isn't inside a quoted string - lastChar byte // the last character parsed - ignoreNextChar bool // whether to ignore the next character - numConsecutiveBackslashes int // the number of consecutive backslashes encountered - seenNonWhitespaceChar bool // whether we have encountered a non-whitespace character since we returned the last token - numConsecutiveDelimiterMatches int // the consecutive number of characters that have been matched to the delimiter - ) - - s.startLineNum = s.lineNum - - if idxs := scannerDelimiterRegex.FindIndex(data); len(idxs) == 2 { - s.Delimiter = scannerDelimiterRegex.FindStringSubmatch(string(data))[1] - // Returning a nil token is interpreted as an error condition, so we return an empty token instead - return idxs[1], []byte{}, nil + if s.isEOF || s.i == s.fill { + // no token + return false } - for i := 0; i < len(data); i++ { - if !ignoreNextChar { + // discard leading whitespace + for ; unicode.IsSpace(rune(s.buf[s.i])); s.i++ { + if s.buf[s.i] == '\n' { + s.lineNum++ + } + if s.i >= s.fill { + if err := s.read(); err != nil { + s.err = err + return false + } + } + } + s.truncate() + + s.state.statementStartLine = s.lineNum + 1 + + if err, ok := s.isDelimiterExpr(); err != nil { + s.err = err + return false + } else if ok { + // empty token acks DELIMITER + return true + } + + for { + if err, ok := s.seekDelimiter(); err != nil { + s.err = err + return false + } else if ok { + // delimiter found, scanner holds valid token state + return true + } else if s.isEOF && s.i == s.fill { + // token terminates with file + s.state.end = s.fill + return true + } + // haven't found delimiter yet, keep reading + if err := s.read(); err != nil { + s.err = err + return false + } + } +} + +func (s *streamScanner) truncate() { + // copy size should be 4k or less + s.state.start = s.i + s.state.end = s.i +} + +func (s *streamScanner) resetState() { + s.state = &qState{} +} + +func (s *streamScanner) read() error { + if s.fill >= s.maxSize { + // if script exceeds buffer that's OK, if + // a single query exceeds buffer that's not OK + if s.state.start == 0 { + return fmt.Errorf("exceeded max query size") + } + // discard previous queries, resulting buffer will start + // at the current |start| + s.fill -= s.state.start + s.i -= s.state.start + s.state.end = s.state.start + copy(s.buf[:], s.buf[s.state.start:]) + s.state.start = 0 + return s.read() + } + if s.fill == len(s.buf) { + newBufSize := min(len(s.buf)*2, s.maxSize) + newBuf := make([]byte, newBufSize) + copy(newBuf, s.buf) + s.buf = newBuf + } + n, err := s.inp.Read(s.buf[s.fill:]) + if err == io.EOF { + s.isEOF = true + } else if err != nil { + return err + } + s.fill += n + return nil +} + +func (s *streamScanner) Err() error { + return s.err +} + +func (s *streamScanner) Bytes() []byte { + return s.buf[s.state.start:s.state.end] +} + +// Text returns the most recent token generated by a call to [Scanner.Scan] +// as a newly allocated string holding its bytes. +func (s *streamScanner) Text() string { + return string(s.Bytes()) +} + +func (s *streamScanner) isDelimiterExpr() (error, bool) { + if s.i == 0 && s.fill-s.i < delimPrefixLen { + // need to see first |delimPrefixLen| characters + if err := s.read(); err != nil { + s.err = err + return err, false + } + } + + // valid delimiter state machine check + // "DELIMITER " -> 0+ spaces -> -> 1 space + if s.fill-s.i >= delimPrefixLen && bytes.EqualFold(s.buf[s.i:s.i+delimPrefixLen], delimPrefix) { + delimTokenIdx := s.i + s.i += delimPrefixLen + for ; !s.isEOF && unicode.IsSpace(rune(s.buf[s.i])); s.i++ { + if s.i >= s.fill { + if err := s.read(); err != nil { + s.err = err + return err, false + } + } + } + if s.isEOF { + // invalid delimiter + s.i = delimTokenIdx + return nil, false + } + delimStart := s.i + for ; !s.isEOF && !unicode.IsSpace(rune(s.buf[s.i])); s.i++ { + if s.i >= s.fill { + if err := s.read(); err != nil { + s.err = err + return err, false + } + } + } + delimEnd := s.i + s.delimiter = make([]byte, delimEnd-delimStart) + copy(s.delimiter, s.buf[delimStart:delimEnd]) + + // discard delimiter token, return empty token + s.truncate() + return nil, true + } + return nil, false +} + +func (s *streamScanner) seekDelimiter() (error, bool) { + if s.i >= s.fill { + return nil, false + } + + for ; s.i < s.fill; s.i++ { + i := s.i + if !s.state.ignoreNextChar { // this doesn't handle unicode characters correctly and will break on some things, but it's only used for line // number reporting. - if !seenNonWhitespaceChar && !unicode.IsSpace(rune(data[i])) { - seenNonWhitespaceChar = true - s.statementStartLine = s.lineNum + if !s.state.seenNonWhitespaceChar && !unicode.IsSpace(rune(s.buf[i])) { + s.state.seenNonWhitespaceChar = true } + // check if we've matched the delimiter string - if quoteChar == 0 && data[i] == s.Delimiter[numConsecutiveDelimiterMatches] { - numConsecutiveDelimiterMatches++ - if numConsecutiveDelimiterMatches == len(s.Delimiter) { - s.startLineNum = s.lineNum - _, _, _ = s.resetState() - removalLength := len(s.Delimiter) - 1 // We remove the delimiter so it depends on the length - return i + 1, data[0 : i-removalLength], nil + if s.state.quoteChar == 0 && s.buf[i] == s.delimiter[s.state.numConsecutiveDelimiterMatches] { + s.state.numConsecutiveDelimiterMatches++ + if s.state.numConsecutiveDelimiterMatches == len(s.delimiter) { + s.state.end = s.i - len(s.delimiter) + 1 + s.i++ + return nil, true } - lastChar = data[i] + s.state.lastChar = s.buf[i] continue } else { - numConsecutiveDelimiterMatches = 0 + s.state.numConsecutiveDelimiterMatches = 0 } - switch data[i] { + switch s.buf[i] { case '\n': s.lineNum++ case backslash: - numConsecutiveBackslashes++ + s.state.numConsecutiveBackslashes++ case sQuote, dQuote, backtick: - prevNumConsecutiveBackslashes := numConsecutiveBackslashes - numConsecutiveBackslashes = 0 + prevNumConsecutiveBackslashes := s.state.numConsecutiveBackslashes + s.state.numConsecutiveBackslashes = 0 // escaped quote character - if lastChar == backslash && prevNumConsecutiveBackslashes%2 == 1 { + if s.state.lastChar == backslash && prevNumConsecutiveBackslashes%2 == 1 { break } // currently in a quoted string - if quoteChar != 0 { + if s.state.quoteChar != 0 { + if i+1 >= s.fill { + // require lookahead or EOF + if err := s.read(); err != nil { + return err, false + } + } // end quote or two consecutive quote characters (a form of escaping quote chars) - if quoteChar == data[i] { + if s.state.quoteChar == s.buf[i] { var nextChar byte = 0 - if i+1 < len(data) { - nextChar = data[i+1] + if i+1 < s.fill { + nextChar = s.buf[i+1] } - if nextChar == quoteChar { + if nextChar == s.state.quoteChar { // escaped quote. skip the next character - ignoreNextChar = true - break - } else if atEOF || i+1 < len(data) { - // end quote - quoteChar = 0 - break + s.state.ignoreNextChar = true } else { - // need more data to make a decision - return s.resetState() + // end quote + s.state.quoteChar = 0 } } @@ -145,29 +314,15 @@ func (s *statementScanner) scanStatements(data []byte, atEOF bool) (advance int, } // open quote - quoteChar = data[i] + s.state.quoteChar = s.buf[i] default: - numConsecutiveBackslashes = 0 + s.state.numConsecutiveBackslashes = 0 } } else { - ignoreNextChar = false + s.state.ignoreNextChar = false } - lastChar = data[i] - } - - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil + s.state.lastChar = s.buf[i] } - - // Request more data. - return s.resetState() -} - -// resetState resets the internal state of the scanner and returns the "more data" response for a split function -func (s *statementScanner) resetState() (advance int, token []byte, err error) { - // rewind the line number to where we started parsing this token - s.lineNum = s.startLineNum - return 0, nil, nil + return nil, false } diff --git a/go/cmd/dolt/commands/sql_statement_scanner_test.go b/go/cmd/dolt/commands/sql_statement_scanner_test.go index 507ef4047d..7217e7b97b 100755 --- a/go/cmd/dolt/commands/sql_statement_scanner_test.go +++ b/go/cmd/dolt/commands/sql_statement_scanner_test.go @@ -178,20 +178,29 @@ primary key (a))`, 1, 2, 6, }, }, + { + input: `DELIMITER | +insert into foo values (1,2,3)|`, + statements: []string{ + "", + "insert into foo values (1,2,3)", + }, + lineNums: []int{1, 2}, + }, } for _, tt := range testcases { t.Run(tt.input, func(t *testing.T) { reader := strings.NewReader(tt.input) - scanner := NewSqlStatementScanner(reader) + scanner := newStreamScanner(reader) var i int for scanner.Scan() { require.True(t, i < len(tt.statements)) assert.Equal(t, tt.statements[i], strings.TrimSpace(scanner.Text())) if tt.lineNums != nil { - assert.Equal(t, tt.lineNums[i], scanner.statementStartLine) + assert.Equal(t, tt.lineNums[i], scanner.state.statementStartLine) } else { - assert.Equal(t, 1, scanner.statementStartLine) + assert.Equal(t, 1, scanner.state.statementStartLine) } i++ } diff --git a/go/cmd/dolt/commands/sqlserver/queryist_utils.go b/go/cmd/dolt/commands/sqlserver/queryist_utils.go index 2c838203e5..b15672498f 100644 --- a/go/cmd/dolt/commands/sqlserver/queryist_utils.go +++ b/go/cmd/dolt/commands/sqlserver/queryist_utils.go @@ -22,6 +22,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" + querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/go-sql-driver/mysql" "github.com/gocraft/dbr/v2" "github.com/gocraft/dbr/v2/dialect" @@ -92,6 +94,10 @@ func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, s return rowIter.Schema(), rowIter, nil, nil } +func (c ConnectionQueryist) QueryWithBindings(ctx *sql.Context, query string, _ sqlparser.Statement, _ map[string]*querypb.BindVariable, _ *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + return c.Query(ctx, query) +} + type MysqlRowWrapper struct { rows *sql2.Rows schema sql.Schema