Skip to content

Commit

Permalink
Add parser for enum / set values
Browse files Browse the repository at this point in the history
This handles cases like nested quotes etc. in the input.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Mar 15, 2024
1 parent 8dcf432 commit 53b09f7
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 14 deletions.
52 changes: 52 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ var (

// ErrIncompatibleTypeCast indicates a casting problem
ErrIncompatibleTypeCast = errors.New("Cannot convert value to desired type")

ErrInvalidEncodedString = errors.New("invalid SQL encoded string")
)

const (
Expand Down Expand Up @@ -861,6 +863,56 @@ var encodeRef = map[byte]byte{
'\\': '\\',
}

// BufDecodeStringSQL decodes the string into a strings.Builder
func BufDecodeStringSQL(buf *strings.Builder, val string) error {
if len(val) < 2 || val[0] != '\'' || val[len(val)-1] != '\'' {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
in := hack.StringBytes(val[1 : len(val)-1])
idx := 0
for {
if idx >= len(in) {
return nil
}
ch := in[idx]
if ch == '\'' {
idx++
if idx >= len(in) {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
if in[idx] != '\'' {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
buf.WriteByte(ch)
idx++
continue
}
if ch == '\\' {
idx++
if idx >= len(in) {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
decoded := SQLDecodeMap[in[idx]]
if decoded == DontEscape {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
buf.WriteByte(decoded)
idx++
continue
}

buf.WriteByte(ch)
idx++
}
}

// DecodeStringSQL encodes the string as a SQL string.
func DecodeStringSQL(val string) (string, error) {
var buf strings.Builder
err := BufDecodeStringSQL(&buf, val)
return buf.String(), err
}

func init() {
for i := range SQLEncodeMap {
SQLEncodeMap[i] = DontEscape
Expand Down
57 changes: 57 additions & 0 deletions go/sqltypes/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,60 @@ func TestHexAndBitToBytes(t *testing.T) {
})
}
}

func TestEncodeStringSQL(t *testing.T) {
testcases := []struct {
in string
out string
}{
{
in: "",
out: "''",
},
{
in: "\x00'\"\b\n\r\t\x1A\\",
out: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'",
},
}
for _, tcase := range testcases {
out := EncodeStringSQL(tcase.in)
assert.Equal(t, tcase.out, out)
}
}

func TestDecodeStringSQL(t *testing.T) {
testcases := []struct {
in string
out string
err string
}{
{
in: "",
err: ": invalid SQL encoded string",
}, {
in: "''",
err: "",
},
{
in: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'",
out: "\x00'\"\b\n\r\t\x1A\\",
},
{
in: "'light ''green\\r\\n, \\nfoo'",
out: "light 'green\r\n, \nfoo",
},
{
in: "'foo \\\\ % _bar'",
out: "foo \\ % _bar",
},
}
for _, tcase := range testcases {
out, err := DecodeStringSQL(tcase.in)
if tcase.err != "" {
assert.EqualError(t, err, tcase.err)
} else {
require.NoError(t, err)
assert.Equal(t, tcase.out, out)
}
}
}
68 changes: 54 additions & 14 deletions go/vt/schema/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"strconv"
"strings"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
)

Expand Down Expand Up @@ -112,22 +113,61 @@ func ParseSetValues(setColumnType string) string {
// returns the (unquoted) text values
// Expected input: `'x-small','small','medium','large','x-large'`
// Unexpected input: `enum('x-small','small','medium','large','x-large')`
func parseEnumOrSetTokens(enumOrSetValues string) (tokens []string) {
if submatch := enumValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 {
// input should not contain `enum(...)` column definition, just the comma delimited list
return tokens
}
if submatch := setValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 {
// input should not contain `enum(...)` column definition, just the comma delimited list
return tokens
}
tokens = strings.Split(enumOrSetValues, ",")
for i := range tokens {
if strings.HasPrefix(tokens[i], `'`) && strings.HasSuffix(tokens[i], `'`) {
tokens[i] = strings.Trim(tokens[i], `'`)
func parseEnumOrSetTokens(enumOrSetValues string) []string {
// We need to track both the start of the current value and current
// position, since there might be quoted quotes inside the value
// which we need to handle.
start := 0
pos := 1
var tokens []string
for {
// If the input does not start with a quote, it's not a valid enum/set definition
if enumOrSetValues[start] != '\'' {
return nil
}
i := strings.IndexByte(enumOrSetValues[pos:], '\'')
// If there's no closing quote, we have invalid input
if i < 0 {
return nil
}
// We're at the end here of the last quoted value,
// so we add the last token and return them.
if i == len(enumOrSetValues[pos:])-1 {
tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start:])
if err != nil {
return nil
}
tokens = append(tokens, tok)
return tokens
}
// MySQL double quotes things as escape value, so if we see another
// single quote, we skip the character and remove it from the input.
if enumOrSetValues[pos+i+1] == '\'' {
pos = pos + i + 2
continue
}
// Next value needs to be a comma as a separator, otherwise
// the data is invalid so we return nil.
if enumOrSetValues[pos+i+1] != ',' {
return nil
}
// If we're at the end of the input here, it's invalid
// since we have a trailing comma which is not what MySQL
// returns.
if pos+i+1 == len(enumOrSetValues) {
return nil
}

tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start : pos+i+1])
if err != nil {
return nil
}

tokens = append(tokens, tok)
// We add 2 to the position to skip the closing quote & comma
start = pos + i + 2
pos = start + 1
}
return tokens
}

// ParseEnumOrSetTokensMap parses the comma delimited part of an enum column definition
Expand Down
6 changes: 6 additions & 0 deletions go/vt/schema/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ func TestParseEnumTokens(t *testing.T) {
expect := []string{"x small", "small", "medium", "large", "x large"}
assert.Equal(t, expect, enumTokens)
}
{
input := `'with '' quote','and \n newline'`
enumTokens := parseEnumOrSetTokens(input)
expect := []string{"with ' quote", "and \n newline"}
assert.Equal(t, expect, enumTokens)
}
{
input := `enum('x-small','small','medium','large','x-large')`
enumTokens := parseEnumOrSetTokens(input)
Expand Down

0 comments on commit 53b09f7

Please sign in to comment.