Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalengine: Proper support for bit literals #14374

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/mysql/collations/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ func Default() ID {
}

func DefaultCollationForType(t sqltypes.Type) ID {
return CollationForType(t, Default())
}

func CollationForType(t sqltypes.Type, fallback ID) ID {
switch {
case sqltypes.IsText(t):
return Default()
return fallback
case t == sqltypes.TypeJSON:
return CollationUtf8mb4ID
default:
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func TestBitVals(t *testing.T) {

mcmp.AssertMatches(`select b'1001', 0x9, B'010011011010'`, `[[VARBINARY("\t") VARBINARY("\t") VARBINARY("\x04\xda")]]`)
mcmp.AssertMatches(`select b'1001', 0x9, B'010011011010' from t1`, `[[VARBINARY("\t") VARBINARY("\t") VARBINARY("\x04\xda")]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010'`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[UINT64(10) UINT64(11) UINT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010' from t1`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[UINT64(10) UINT64(11) UINT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010'`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[INT64(10) UINT64(11) INT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010' from t1`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[INT64(10) UINT64(11) INT64(1245)]]`)
}

func TestHexVals(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestNormalizeAllFields(t *testing.T) {
defer conn.Close()

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
qr := utils.Exec(t, conn, selectQuery)
Expand Down
4 changes: 1 addition & 3 deletions go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,10 @@ func (node *Literal) Format(buf *TrackedBuffer) {
switch node.Type {
case StrVal:
sqltypes.MakeTrusted(sqltypes.VarBinary, node.Bytes()).EncodeSQL(buf)
case IntVal, FloatVal, DecimalVal, HexNum:
case IntVal, FloatVal, DecimalVal, HexNum, BitNum:
vmg marked this conversation as resolved.
Show resolved Hide resolved
buf.astPrintf(node, "%#s", node.Val)
case HexVal:
buf.astPrintf(node, "X'%#s'", node.Val)
case BitVal:
buf.astPrintf(node, "B'%#s'", node.Val)
case DateVal:
buf.astPrintf(node, "date'%#s'", node.Val)
case TimeVal:
Expand Down
6 changes: 1 addition & 5 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ const (
FloatVal
HexNum
HexVal
BitVal
BitNum
DateVal
TimeVal
TimestampVal
Expand Down Expand Up @@ -515,9 +515,9 @@ func NewHexLiteral(in string) *Literal {
return &Literal{Type: HexVal, Val: in}
}

// NewBitLiteral builds a new BitVal containing a bit literal.
// NewBitLiteral builds a new BitNum containing a bit literal.
func NewBitLiteral(in string) *Literal {
return &Literal{Type: BitVal, Val: in}
return &Literal{Type: BitNum, Val: in}
}

// NewDateLiteral builds a new Date.
Expand Down Expand Up @@ -583,8 +583,8 @@ func (node *Literal) SQLType() sqltypes.Type {
return sqltypes.HexNum
case HexVal:
return sqltypes.HexVal
case BitVal:
return sqltypes.HexNum
case BitNum:
return sqltypes.BitNum
case DateVal:
return sqltypes.Date
case TimeVal:
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func LiteralToValue(lit *Literal) (sqltypes.Value, error) {
return parseHexLiteral(b[1:])
case HexVal:
return parseHexLiteral(lit.Bytes())
case BitVal:
case BitNum:
return parseBitLiteral(lit.Bytes())
case DateVal:
d, ok := datetime.ParseDate(lit.Val)
Expand Down
20 changes: 5 additions & 15 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ package sqlparser

import (
"bytes"
"math/big"

"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/hex"
"vitess.io/vitess/go/sqltypes"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -365,19 +363,11 @@ func SQLToBindvar(node SQLNode) *querypb.BindVariable {
buf = append(buf, bytes.ToUpper(node.Bytes())...)
buf = append(buf, '\'')
v, err = sqltypes.NewValue(sqltypes.HexVal, buf)
case BitVal:
// Convert bit value to hex number in parameterized query format
var i big.Int
_, ok := i.SetString(string(node.Bytes()), 2)
if !ok {
return nil
}

buf := i.Bytes()
out := make([]byte, 0, (len(buf)*2)+2)
out = append(out, '0', 'x')
out = append(out, hex.EncodeBytes(buf)...)
v, err = sqltypes.NewValue(sqltypes.HexNum, out)
case BitNum:
out := make([]byte, 0, len(node.Bytes())+2)
out = append(out, '0', 'b')
out = append(out, node.Bytes()[2:]...)
v, err = sqltypes.NewValue(sqltypes.BitNum, out)
case DateVal:
v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes())
case TimeVal:
Expand Down
24 changes: 12 additions & 12 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,23 @@ func TestNormalize(t *testing.T) {
}, {
// Bin values work fine
in: "select * from t where foo = b'11'",
outstmt: "select * from t where foo = :foo /* HEXNUM */",
outstmt: "select * from t where foo = :foo /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.HexNumBindVariable([]byte("0x03")),
"foo": sqltypes.BitNumBindVariable([]byte("0b11")),
},
}, {
// Large bin values work fine
in: "select * from t where foo = b'11101010100101010010101010101010101010101000100100100100100101001101010101010101000001'",
outstmt: "select * from t where foo = :foo /* HEXNUM */",
outstmt: "select * from t where foo = :foo /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.HexNumBindVariable([]byte("0x3AA54AAAAAA24925355541")),
"foo": sqltypes.BitNumBindVariable([]byte("0b11101010100101010010101010101010101010101000100100100100100101001101010101010101000001")),
},
}, {
// Bin value does not convert for DMLs
in: "update a set v1 = b'11'",
outstmt: "update a set v1 = :v1 /* HEXNUM */",
outstmt: "update a set v1 = :v1 /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"v1": sqltypes.HexNumBindVariable([]byte("0x03")),
"v1": sqltypes.BitNumBindVariable([]byte("0b11")),
},
}, {
// ORDER BY column_position
Expand Down Expand Up @@ -308,14 +308,14 @@ func TestNormalize(t *testing.T) {
"bv3": sqltypes.Int64BindVariable(3),
},
}, {
// BitVal should also be normalized
// BitNum should also be normalized
in: `select b'1', 0b01, b'1010', 0b1111111`,
outstmt: `select :bv1 /* HEXNUM */, :bv2 /* HEXNUM */, :bv3 /* HEXNUM */, :bv4 /* HEXNUM */ from dual`,
outstmt: `select :bv1 /* BITNUM */, :bv2 /* BITNUM */, :bv3 /* BITNUM */, :bv4 /* BITNUM */ from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.HexNumBindVariable([]byte("0x01")),
"bv2": sqltypes.HexNumBindVariable([]byte("0x01")),
"bv3": sqltypes.HexNumBindVariable([]byte("0x0A")),
"bv4": sqltypes.HexNumBindVariable([]byte("0x7F")),
"bv1": sqltypes.BitNumBindVariable([]byte("0b1")),
"bv2": sqltypes.BitNumBindVariable([]byte("0b01")),
"bv3": sqltypes.BitNumBindVariable([]byte("0b1010")),
"bv4": sqltypes.BitNumBindVariable([]byte("0b1111111")),
},
}, {
// DateVal should also be normalized
Expand Down
9 changes: 5 additions & 4 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ var (
output: "select `name`, numbers from (select * from users) as x(`name`, numbers)",
}, {
input: "select 0b010, 0b0111, b'0111', b'011'",
output: "select B'010', B'0111', B'0111', B'011' from dual",
output: "select 0b010, 0b0111, 0b0111, 0b011 from dual",
}, {
input: "select 0x010, 0x0111, x'0111'",
output: "select 0x010, 0x0111, X'0111' from dual",
Expand Down Expand Up @@ -1120,9 +1120,10 @@ var (
input: "select /* hex caps */ X'F0a1' from t",
}, {
input: "select /* bit literal */ b'0101' from t",
output: "select /* bit literal */ B'0101' from t",
output: "select /* bit literal */ 0b0101 from t",
}, {
input: "select /* bit literal caps */ B'010011011010' from t",
input: "select /* bit literal caps */ B'010011011010' from t",
output: "select /* bit literal caps */ 0b010011011010 from t",
}, {
input: "select /* 0x */ 0xf0 from t",
}, {
Expand Down Expand Up @@ -5004,7 +5005,7 @@ func TestCreateTable(t *testing.T) {
` + "`" + `s3` + "`" + ` varchar default null,
s4 timestamp default current_timestamp(),
s41 timestamp default now(),
s5 bit(1) default B'0'
s5 bit(1) default 0b0
)`,
}, {
// test non_reserved word in column name
Expand Down
8 changes: 4 additions & 4 deletions go/vt/sqlparser/sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions go/vt/sqlparser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -1689,27 +1689,27 @@ text_literal
}
| BITNUM
{
$$ = NewBitLiteral($1[2:])
$$ = NewBitLiteral($1)
}
| BIT_LITERAL
{
$$ = NewBitLiteral($1)
$$ = NewBitLiteral("0b" + $1)
}
| VALUE_ARG
{
$$ = parseBindVariable(yylex, $1[1:])
}
| underscore_charsets BIT_LITERAL %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2)}
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral("0b" + $2)}
}
| underscore_charsets HEXNUM %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewHexNumLiteral($2)}
}
| underscore_charsets BITNUM %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2[2:])}
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2)}
}
| underscore_charsets HEX %prec UNARY
{
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/testdata/select_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14654,7 +14654,7 @@ INPUT
select hex(_utf8mb4 B'001111111111');
END
OUTPUT
select hex(_utf8mb4 B'001111111111') from dual
select hex(_utf8mb4 0b001111111111) from dual
END
INPUT
select NULLIF(1,NULL), NULLIF(1.0, NULL), NULLIF("test", NULL);
Expand Down Expand Up @@ -18968,7 +18968,7 @@ INPUT
select hex(_utf8 B'001111111111');
END
OUTPUT
select hex(_utf8mb3 B'001111111111') from dual
select hex(_utf8mb3 0b001111111111) from dual
END
INPUT
select right('hello', -18446744073709551615);
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ func inferColTypeFromExpr(node sqlparser.Expr, tableColumnMap map[sqlparser.Iden
fallthrough
case sqlparser.HexVal:
fallthrough
case sqlparser.BitVal:
case sqlparser.BitNum:
colTypes = append(colTypes, querypb.Type_INT32)
case sqlparser.StrVal:
colTypes = append(colTypes, querypb.Type_VARCHAR)
Expand Down
13 changes: 9 additions & 4 deletions go/vt/vtgate/evalengine/api_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"vitess.io/vitess/go/mysql/fastparse"
"vitess.io/vitess/go/mysql/hex"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// NullExpr is just what you are lead to believe
Expand Down Expand Up @@ -156,11 +158,14 @@ func parseHexNumber(val []byte) ([]byte, error) {
return parseHexLiteral(val[1:])
}

func parseBitLiteral(val []byte) ([]byte, error) {
func parseBitNum(val []byte) ([]byte, error) {
if val[0] != '0' || val[1] != 'b' {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (missing 0b prefix)", val)
}
var i big.Int
_, ok := i.SetString(string(val), 2)
_, ok := i.SetString(hack.String(val)[2:], 2)
if !ok {
panic("malformed bit literal from parser")
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (not base 2)", val)
}
return i.Bytes(), nil
}
Expand All @@ -186,7 +191,7 @@ func NewLiteralBinaryFromHexNum(val []byte) (*Literal, error) {
}

func NewLiteralBinaryFromBit(val []byte) (*Literal, error) {
raw, err := parseBitLiteral(val)
raw, err := parseBitNum(val)
if err != nil {
return nil, err
}
Expand Down
15 changes: 15 additions & 0 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ func AggregateTypes(types []sqltypes.Type) sqltypes.Type {
return typeAgg.result()
}

func (ta *typeAggregation) addEval(e eval) {
var t sqltypes.Type
var f typeFlag
switch e := e.(type) {
case nil:
t = sqltypes.Null
case *evalBytes:
t = sqltypes.Type(e.tt)
f = e.flag
default:
t = e.SQLType()
vmg marked this conversation as resolved.
Show resolved Hide resolved
}
ta.add(t, f)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
switch tt {
case sqltypes.Float32, sqltypes.Float64:
Expand Down
Loading
Loading