diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 20a30cbc1c1..b8f05e02db3 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -51,6 +51,8 @@ var ( // ErrIncompatibleTypeCast indicates a casting problem ErrIncompatibleTypeCast = errors.New("Cannot convert value to desired type") + + ErrInvalidEncodedString = errors.New("invalid SQL encoded string") ) const ( @@ -861,6 +863,56 @@ var encodeRef = map[byte]byte{ '\\': '\\', } +// BufDecodeStringSQL decodes the string into a strings.Builder +func BufDecodeStringSQL(buf *strings.Builder, val string) error { + if len(val) < 2 || val[0] != '\'' || val[len(val)-1] != '\'' { + return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString) + } + in := hack.StringBytes(val[1 : len(val)-1]) + idx := 0 + for { + if idx >= len(in) { + return nil + } + ch := in[idx] + if ch == '\'' { + idx++ + if idx >= len(in) { + return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString) + } + if in[idx] != '\'' { + return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString) + } + buf.WriteByte(ch) + idx++ + continue + } + if ch == '\\' { + idx++ + if idx >= len(in) { + return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString) + } + decoded := SQLDecodeMap[in[idx]] + if decoded == DontEscape { + return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString) + } + buf.WriteByte(decoded) + idx++ + continue + } + + buf.WriteByte(ch) + idx++ + } +} + +// DecodeStringSQL encodes the string as a SQL string. +func DecodeStringSQL(val string) (string, error) { + var buf strings.Builder + err := BufDecodeStringSQL(&buf, val) + return buf.String(), err +} + func init() { for i := range SQLEncodeMap { SQLEncodeMap[i] = DontEscape diff --git a/go/sqltypes/value_test.go b/go/sqltypes/value_test.go index 10a46e09a9e..d6a9b510b9e 100644 --- a/go/sqltypes/value_test.go +++ b/go/sqltypes/value_test.go @@ -512,3 +512,60 @@ func TestHexAndBitToBytes(t *testing.T) { }) } } + +func TestEncodeStringSQL(t *testing.T) { + testcases := []struct { + in string + out string + }{ + { + in: "", + out: "''", + }, + { + in: "\x00'\"\b\n\r\t\x1A\\", + out: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", + }, + } + for _, tcase := range testcases { + out := EncodeStringSQL(tcase.in) + assert.Equal(t, tcase.out, out) + } +} + +func TestDecodeStringSQL(t *testing.T) { + testcases := []struct { + in string + out string + err string + }{ + { + in: "", + err: ": invalid SQL encoded string", + }, { + in: "''", + err: "", + }, + { + in: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", + out: "\x00'\"\b\n\r\t\x1A\\", + }, + { + in: "'light ''green\\r\\n, \\nfoo'", + out: "light 'green\r\n, \nfoo", + }, + { + in: "'foo \\\\ % _bar'", + out: "foo \\ % _bar", + }, + } + for _, tcase := range testcases { + out, err := DecodeStringSQL(tcase.in) + if tcase.err != "" { + assert.EqualError(t, err, tcase.err) + } else { + require.NoError(t, err) + assert.Equal(t, tcase.out, out) + } + } +} diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/alter b/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/alter new file mode 100644 index 00000000000..39c00aa3903 --- /dev/null +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/alter @@ -0,0 +1 @@ +change e e enum('red', 'light green', 'blue', 'orange', 'yellow') collate 'utf8_bin' null default null diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/create.sql b/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/create.sql new file mode 100644 index 00000000000..741b06e9040 --- /dev/null +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/enum-whitespace/create.sql @@ -0,0 +1,27 @@ +drop table if exists onlineddl_test; +create table onlineddl_test ( + id int auto_increment, + i int not null, + e enum('red', 'light green', 'blue', 'orange') null default null collate 'utf8_bin', + primary key(id) +) auto_increment=1; + +drop event if exists onlineddl_test; +delimiter ;; +create event onlineddl_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into onlineddl_test values (null, 11, 'red'); + insert into onlineddl_test values (null, 13, 'light green'); + insert into onlineddl_test values (null, 17, 'blue'); + set @last_insert_id := last_insert_id(); + update onlineddl_test set e='orange' where id = @last_insert_id; + insert into onlineddl_test values (null, 23, null); + set @last_insert_id := last_insert_id(); + update onlineddl_test set i=i+1, e=null where id = @last_insert_id; +end ;; diff --git a/go/vt/schema/parser.go b/go/vt/schema/parser.go index 6bf15057d13..7ed820a3687 100644 --- a/go/vt/schema/parser.go +++ b/go/vt/schema/parser.go @@ -22,7 +22,7 @@ import ( "strconv" "strings" - "vitess.io/vitess/go/textutil" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" ) @@ -113,22 +113,61 @@ func ParseSetValues(setColumnType string) string { // returns the (unquoted) text values // Expected input: `'x-small','small','medium','large','x-large'` // Unexpected input: `enum('x-small','small','medium','large','x-large')` -func parseEnumOrSetTokens(enumOrSetValues string) (tokens []string) { - if submatch := enumValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 { - // input should not contain `enum(...)` column definition, just the comma delimited list - return tokens - } - if submatch := setValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 { - // input should not contain `enum(...)` column definition, just the comma delimited list - return tokens - } - tokens = textutil.SplitDelimitedList(enumOrSetValues) - for i := range tokens { - if strings.HasPrefix(tokens[i], `'`) && strings.HasSuffix(tokens[i], `'`) { - tokens[i] = strings.Trim(tokens[i], `'`) +func parseEnumOrSetTokens(enumOrSetValues string) []string { + // We need to track both the start of the current value and current + // position, since there might be quoted quotes inside the value + // which we need to handle. + start := 0 + pos := 1 + var tokens []string + for { + // If the input does not start with a quote, it's not a valid enum/set definition + if enumOrSetValues[start] != '\'' { + return nil + } + i := strings.IndexByte(enumOrSetValues[pos:], '\'') + // If there's no closing quote, we have invalid input + if i < 0 { + return nil + } + // We're at the end here of the last quoted value, + // so we add the last token and return them. + if i == len(enumOrSetValues[pos:])-1 { + tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start:]) + if err != nil { + return nil + } + tokens = append(tokens, tok) + return tokens + } + // MySQL double quotes things as escape value, so if we see another + // single quote, we skip the character and remove it from the input. + if enumOrSetValues[pos+i+1] == '\'' { + pos = pos + i + 2 + continue } + // Next value needs to be a comma as a separator, otherwise + // the data is invalid so we return nil. + if enumOrSetValues[pos+i+1] != ',' { + return nil + } + // If we're at the end of the input here, it's invalid + // since we have a trailing comma which is not what MySQL + // returns. + if pos+i+1 == len(enumOrSetValues) { + return nil + } + + tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start : pos+i+1]) + if err != nil { + return nil + } + + tokens = append(tokens, tok) + // We add 2 to the position to skip the closing quote & comma + start = pos + i + 2 + pos = start + 1 } - return tokens } // ParseEnumOrSetTokensMap parses the comma delimited part of an enum column definition diff --git a/go/vt/schema/parser_test.go b/go/vt/schema/parser_test.go index 5154411e829..d251a195d1d 100644 --- a/go/vt/schema/parser_test.go +++ b/go/vt/schema/parser_test.go @@ -89,6 +89,19 @@ func TestParseEnumValues(t *testing.T) { assert.Equal(t, input, enumValues) } } + + { + inputs := []string{ + ``, + `abc`, + `func('x small','small','medium','large','x large')`, + `set('x small','small','medium','large','x large')`, + } + for _, input := range inputs { + enumValues := ParseEnumValues(input) + assert.Equal(t, input, enumValues) + } + } } func TestParseSetValues(t *testing.T) { @@ -125,6 +138,18 @@ func TestParseEnumTokens(t *testing.T) { expect := []string{"x-small", "small", "medium", "large", "x-large"} assert.Equal(t, expect, enumTokens) } + { + input := `'x small','small','medium','large','x large'` + enumTokens := parseEnumOrSetTokens(input) + expect := []string{"x small", "small", "medium", "large", "x large"} + assert.Equal(t, expect, enumTokens) + } + { + input := `'with '' quote','and \n newline'` + enumTokens := parseEnumOrSetTokens(input) + expect := []string{"with ' quote", "and \n newline"} + assert.Equal(t, expect, enumTokens) + } { input := `enum('x-small','small','medium','large','x-large')` enumTokens := parseEnumOrSetTokens(input) diff --git a/go/vt/vttablet/onlineddl/vrepl/columns_test.go b/go/vt/vttablet/onlineddl/vrepl/columns_test.go index 201ffe55201..32efd104cc1 100644 --- a/go/vt/vttablet/onlineddl/vrepl/columns_test.go +++ b/go/vt/vttablet/onlineddl/vrepl/columns_test.go @@ -333,11 +333,11 @@ func TestGetExpandedColumnNames(t *testing.T) { "expand enum", Column{ Type: EnumColumnType, - EnumValues: "'a', 'b'", + EnumValues: "'a','b'", }, Column{ Type: EnumColumnType, - EnumValues: "'a', 'x'", + EnumValues: "'a','x'", }, true, }, @@ -345,11 +345,11 @@ func TestGetExpandedColumnNames(t *testing.T) { "expand enum", Column{ Type: EnumColumnType, - EnumValues: "'a', 'b'", + EnumValues: "'a','b'", }, Column{ Type: EnumColumnType, - EnumValues: "'a', 'b', 'c'", + EnumValues: "'a','b','c'", }, true, }, @@ -357,11 +357,11 @@ func TestGetExpandedColumnNames(t *testing.T) { "reduce enum", Column{ Type: EnumColumnType, - EnumValues: "'a', 'b', 'c'", + EnumValues: "'a','b','c'", }, Column{ Type: EnumColumnType, - EnumValues: "'a', 'b'", + EnumValues: "'a','b'", }, false, },