diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index fed519ec58b..be547618b98 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1318,8 +1318,8 @@ func (node *Literal) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%#s", node.Val) case HexVal: buf.astPrintf(node, "X'%#s'", node.Val) - case BitVal: - buf.astPrintf(node, "B'%#s'", node.Val) + case BitNum: + buf.astPrintf(node, "0b%#s", node.Val) case DateVal: buf.astPrintf(node, "date'%#s'", node.Val) case TimeVal: diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 16d52cac9d2..10a2ef0a559 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1733,10 +1733,9 @@ func (node *Literal) FormatFast(buf *TrackedBuffer) { buf.WriteString("X'") buf.WriteString(node.Val) buf.WriteByte('\'') - case BitVal: - buf.WriteString("B'") + case BitNum: + buf.WriteString("0b") buf.WriteString(node.Val) - buf.WriteByte('\'') case DateVal: buf.WriteString("date'") buf.WriteString(node.Val) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index c9f778df0b6..d2a8a8008be 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -163,7 +163,7 @@ const ( FloatVal HexNum HexVal - BitVal + BitNum DateVal TimeVal TimestampVal @@ -515,9 +515,9 @@ func NewHexLiteral(in string) *Literal { return &Literal{Type: HexVal, Val: in} } -// NewBitLiteral builds a new BitVal containing a bit literal. +// NewBitLiteral builds a new BitNum containing a bit literal. func NewBitLiteral(in string) *Literal { - return &Literal{Type: BitVal, Val: in} + return &Literal{Type: BitNum, Val: in} } // NewDateLiteral builds a new Date. @@ -583,8 +583,8 @@ func (node *Literal) SQLType() sqltypes.Type { return sqltypes.HexNum case HexVal: return sqltypes.HexVal - case BitVal: - return sqltypes.HexNum + case BitNum: + return sqltypes.BitNum case DateVal: return sqltypes.Date case TimeVal: diff --git a/go/vt/sqlparser/literal.go b/go/vt/sqlparser/literal.go index 24613ff6e05..71fed3d7d16 100644 --- a/go/vt/sqlparser/literal.go +++ b/go/vt/sqlparser/literal.go @@ -77,7 +77,7 @@ func LiteralToValue(lit *Literal) (sqltypes.Value, error) { return parseHexLiteral(b[1:]) case HexVal: return parseHexLiteral(lit.Bytes()) - case BitVal: + case BitNum: return parseBitLiteral(lit.Bytes()) case DateVal: d, ok := datetime.ParseDate(lit.Val) diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index a9d70d5e190..464283c1e2b 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -18,10 +18,8 @@ package sqlparser import ( "bytes" - "math/big" "vitess.io/vitess/go/mysql/datetime" - "vitess.io/vitess/go/mysql/hex" "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -365,19 +363,11 @@ func SQLToBindvar(node SQLNode) *querypb.BindVariable { buf = append(buf, bytes.ToUpper(node.Bytes())...) buf = append(buf, '\'') v, err = sqltypes.NewValue(sqltypes.HexVal, buf) - case BitVal: - // Convert bit value to hex number in parameterized query format - var i big.Int - _, ok := i.SetString(string(node.Bytes()), 2) - if !ok { - return nil - } - - buf := i.Bytes() - out := make([]byte, 0, (len(buf)*2)+2) - out = append(out, '0', 'x') - out = append(out, hex.EncodeBytes(buf)...) - v, err = sqltypes.NewValue(sqltypes.HexNum, out) + case BitNum: + out := make([]byte, 0, len(node.Bytes())+2) + out = append(out, '0', 'b') + out = append(out, node.Bytes()...) + v, err = sqltypes.NewValue(sqltypes.BitNum, out) case DateVal: v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes()) case TimeVal: diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 2b0a4b52122..f4f5f89d99d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -191,23 +191,23 @@ func TestNormalize(t *testing.T) { }, { // Bin values work fine in: "select * from t where foo = b'11'", - outstmt: "select * from t where foo = :foo /* HEXNUM */", + outstmt: "select * from t where foo = :foo /* BITNUM */", outbv: map[string]*querypb.BindVariable{ - "foo": sqltypes.HexNumBindVariable([]byte("0x03")), + "foo": sqltypes.BitNumBindVariable([]byte("0b11")), }, }, { // Large bin values work fine in: "select * from t where foo = b'11101010100101010010101010101010101010101000100100100100100101001101010101010101000001'", - outstmt: "select * from t where foo = :foo /* HEXNUM */", + outstmt: "select * from t where foo = :foo /* BITNUM */", outbv: map[string]*querypb.BindVariable{ - "foo": sqltypes.HexNumBindVariable([]byte("0x3AA54AAAAAA24925355541")), + "foo": sqltypes.BitNumBindVariable([]byte("0b11101010100101010010101010101010101010101000100100100100100101001101010101010101000001")), }, }, { // Bin value does not convert for DMLs in: "update a set v1 = b'11'", - outstmt: "update a set v1 = :v1 /* HEXNUM */", + outstmt: "update a set v1 = :v1 /* BITNUM */", outbv: map[string]*querypb.BindVariable{ - "v1": sqltypes.HexNumBindVariable([]byte("0x03")), + "v1": sqltypes.BitNumBindVariable([]byte("0b11")), }, }, { // ORDER BY column_position @@ -308,14 +308,14 @@ func TestNormalize(t *testing.T) { "bv3": sqltypes.Int64BindVariable(3), }, }, { - // BitVal should also be normalized + // BitNum should also be normalized in: `select b'1', 0b01, b'1010', 0b1111111`, - outstmt: `select :bv1 /* HEXNUM */, :bv2 /* HEXNUM */, :bv3 /* HEXNUM */, :bv4 /* HEXNUM */ from dual`, + outstmt: `select :bv1 /* BITNUM */, :bv2 /* BITNUM */, :bv3 /* BITNUM */, :bv4 /* BITNUM */ from dual`, outbv: map[string]*querypb.BindVariable{ - "bv1": sqltypes.HexNumBindVariable([]byte("0x01")), - "bv2": sqltypes.HexNumBindVariable([]byte("0x01")), - "bv3": sqltypes.HexNumBindVariable([]byte("0x0A")), - "bv4": sqltypes.HexNumBindVariable([]byte("0x7F")), + "bv1": sqltypes.BitNumBindVariable([]byte("0b1")), + "bv2": sqltypes.BitNumBindVariable([]byte("0b01")), + "bv3": sqltypes.BitNumBindVariable([]byte("0b1010")), + "bv4": sqltypes.BitNumBindVariable([]byte("0b1111111")), }, }, { // DateVal should also be normalized diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index f4167249f66..28fa6fc4a0f 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -527,7 +527,7 @@ var ( output: "select `name`, numbers from (select * from users) as x(`name`, numbers)", }, { input: "select 0b010, 0b0111, b'0111', b'011'", - output: "select B'010', B'0111', B'0111', B'011' from dual", + output: "select 0b010, 0b0111, 0b0111, 0b011 from dual", }, { input: "select 0x010, 0x0111, x'0111'", output: "select 0x010, 0x0111, X'0111' from dual", @@ -1120,9 +1120,10 @@ var ( input: "select /* hex caps */ X'F0a1' from t", }, { input: "select /* bit literal */ b'0101' from t", - output: "select /* bit literal */ B'0101' from t", + output: "select /* bit literal */ 0b0101 from t", }, { - input: "select /* bit literal caps */ B'010011011010' from t", + input: "select /* bit literal caps */ B'010011011010' from t", + output: "select /* bit literal caps */ 0b010011011010 from t", }, { input: "select /* 0x */ 0xf0 from t", }, { @@ -5004,7 +5005,7 @@ func TestCreateTable(t *testing.T) { ` + "`" + `s3` + "`" + ` varchar default null, s4 timestamp default current_timestamp(), s41 timestamp default now(), - s5 bit(1) default B'0' + s5 bit(1) default 0b0 )`, }, { // test non_reserved word in column name diff --git a/go/vt/sqlparser/testdata/select_cases.txt b/go/vt/sqlparser/testdata/select_cases.txt index 1112593cd13..661045add7d 100644 --- a/go/vt/sqlparser/testdata/select_cases.txt +++ b/go/vt/sqlparser/testdata/select_cases.txt @@ -14654,7 +14654,7 @@ INPUT select hex(_utf8mb4 B'001111111111'); END OUTPUT -select hex(_utf8mb4 B'001111111111') from dual +select hex(_utf8mb4 0b001111111111) from dual END INPUT select NULLIF(1,NULL), NULLIF(1.0, NULL), NULLIF("test", NULL); @@ -18968,7 +18968,7 @@ INPUT select hex(_utf8 B'001111111111'); END OUTPUT -select hex(_utf8mb3 B'001111111111') from dual +select hex(_utf8mb3 0b001111111111) from dual END INPUT select right('hello', -18446744073709551615); diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 1f8ebd969d7..85aa64037a7 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -846,7 +846,7 @@ func inferColTypeFromExpr(node sqlparser.Expr, tableColumnMap map[sqlparser.Iden fallthrough case sqlparser.HexVal: fallthrough - case sqlparser.BitVal: + case sqlparser.BitNum: colTypes = append(colTypes, querypb.Type_INT32) case sqlparser.StrVal: colTypes = append(colTypes, querypb.Type_VARCHAR) diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 6b1390e3a41..e87a197adc2 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -28,6 +28,8 @@ import ( "vitess.io/vitess/go/mysql/fastparse" "vitess.io/vitess/go/mysql/hex" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) // NullExpr is just what you are lead to believe @@ -156,11 +158,18 @@ func parseHexNumber(val []byte) ([]byte, error) { return parseHexLiteral(val[1:]) } +func parseBitNum(val []byte) ([]byte, error) { + if val[0] != '0' || val[1] != 'b' { + return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (missing 0b prefix)", val) + } + return parseBitLiteral(val[2:]) +} + func parseBitLiteral(val []byte) ([]byte, error) { var i big.Int - _, ok := i.SetString(string(val), 2) + _, ok := i.SetString(hack.String(val), 2) if !ok { - panic("malformed bit literal from parser") + return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (not base 2)", val) } return i.Bytes(), nil } diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 37effb2ab12..c258dab1672 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -133,12 +133,21 @@ func integerDivideConvert(arg eval) evalNumeric { return dec } - if b1, ok := arg.(*evalBytes); ok && b1.isHexLiteral() { - hex, ok := b1.toNumericHex() - if !ok { - return newEvalDecimal(decimal.Zero, 0, 0) + if b1, ok := arg.(*evalBytes); ok { + if b1.isHexLiteral() { + hex, ok := b1.toNumericHex() + if !ok { + return newEvalDecimal(decimal.Zero, 0, 0) + } + return hex + } + if b1.isBitLiteral() { + bit, ok := b1.toNumericBit() + if !ok { + return newEvalDecimal(decimal.Zero, 0, 0) + } + return bit } - return hex } return evalToDecimal(arg, 0, 0) } diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 69c39249fb9..0f4aafadf0c 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1699,7 +1699,7 @@ func (cached *evalInt64) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(8) + size += int64(16) } return size } diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 84caa2d7690..32b83b2cd1b 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -91,9 +91,15 @@ func (c *compiler) compileToNumeric(ct ctype, offset int, fallback sqltypes.Type if sqltypes.IsNumber(ct.Type) { return ct } - if ct.Type == sqltypes.VarBinary && (ct.Flag&flagHex) != 0 { - c.asm.Convert_hex(offset) - return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} + if ct.Type == sqltypes.VarBinary { + if (ct.Flag & flagHex) != 0 { + c.asm.Convert_hex(offset) + return ctype{sqltypes.Uint64, ct.Flag, collationNumeric} + } + if (ct.Flag & flagBit) != 0 { + c.asm.Convert_bit(offset) + return ctype{sqltypes.Int64, ct.Flag, collationNumeric} + } } if sqltypes.IsDateOrTime(ct.Type) { diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 47a7e30d17b..5bf71e1a7c1 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -883,6 +883,17 @@ func (asm *assembler) Convert_hex(offset int) { }, "CONV VARBINARY(SP-%d), HEX", offset) } +func (asm *assembler) Convert_bit(offset int) { + asm.emit(func(env *ExpressionEnv) int { + var ok bool + env.vm.stack[env.vm.sp-offset], ok = env.vm.stack[env.vm.sp-offset].(*evalBytes).toNumericBit() + if !ok { + env.vm.err = errDeoptimize + } + return 1 + }, "CONV VARBINARY(SP-%d), BIT", offset) +} + func (asm *assembler) Convert_Ti(offset int) { asm.emit(func(env *ExpressionEnv) int { v := env.vm.stack[env.vm.sp-offset].(*evalTemporal) @@ -3094,6 +3105,14 @@ func (asm *assembler) Neg_hex() { }, "NEG HEX(SP-1)") } +func (asm *assembler) Neg_bit() { + asm.emit(func(env *ExpressionEnv) int { + arg := env.vm.stack[env.vm.sp-1].(*evalInt64) + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalFloat(-float64(arg.i)) + return 1 + }, "NEG BIT(SP-1)") +} + func (asm *assembler) Neg_i() { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-1].(*evalInt64) diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index 17537973215..395e6261b37 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -141,6 +141,26 @@ func (asm *assembler) PushBVar_f(key string) { }, "PUSH FLOAT64(:%q)", key) } +func push_bitnum(env *ExpressionEnv, raw []byte) int { + raw, env.vm.err = parseBitNum(raw) + env.vm.stack[env.vm.sp] = newEvalBytesBit(raw) + env.vm.sp++ + return 1 +} + +func (asm *assembler) PushBVar_bitnum(key string) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + var bvar *querypb.BindVariable + bvar, env.vm.err = env.lookupBindVar(key) + if env.vm.err != nil { + return 0 + } + return push_bitnum(env, bvar.Value) + }, "PUSH BITNUM(:%q)", key) +} + func push_hexnum(env *ExpressionEnv, raw []byte) int { raw, env.vm.err = parseHexNumber(raw) env.vm.stack[env.vm.sp] = newEvalBytesHex(raw) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index efbe3d0ed0c..25bf35129b6 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -455,6 +455,56 @@ func TestCompilerSingle(t *testing.T) { expression: `WEIGHT_STRING('foobar' as char(3))`, result: `VARBINARY("\x1c\xe5\x1d\xdd\x1d\xdd")`, }, + { + expression: `CAST(time '5 10:34:58' AS DATETIME)`, + result: `DATETIME("2023-10-29 10:34:58")`, + }, + { + expression: `CAST(time '130:34:58' AS DATETIME)`, + result: `DATETIME("2023-10-29 10:34:58")`, + }, + { + expression: `UNIX_TIMESTAMP(time '5 10:34:58')`, + result: `INT64(1698572098)`, + }, + { + expression: `CONV(-1, -1.5e0, 3.141592653589793)`, + result: `VARCHAR("11112220022122120101211020120210210211220")`, + }, + { + expression: `column0 between 10 and 20`, + values: []sqltypes.Value{sqltypes.NewInt16(15)}, + result: `INT64(1)`, + }, + { + expression: `column0 between 10 and 20`, + values: []sqltypes.Value{sqltypes.NULL}, + result: `NULL`, + }, + { + expression: `1 + 0b1001`, + result: `INT64(10)`, + }, + { + expression: `1 + 0x6`, + result: `UINT64(7)`, + }, + { + expression: `0 DIV 0b1001`, + result: `INT64(0)`, + }, + { + expression: `0 & 0b1001`, + result: `UINT64(0)`, + }, + { + expression: `CAST(0b1001 AS DECIMAL)`, + result: `DECIMAL(9)`, + }, + { + expression: `-0b1001`, + result: `FLOAT64(-9)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index d74014f772a..ee09f96cded 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -157,6 +157,14 @@ func evalIsTruthy(e eval) boolean { } return makeboolean(hex.u != 0) } + if e.isBitLiteral() { + bit, ok := e.toNumericBit() + if !ok { + // overflow + return makeboolean(true) + } + return makeboolean(bit.i != 0) + } f, _ := fastparse.ParseFloat64(e.string()) return makeboolean(f != 0.0) case *evalJSON: @@ -368,7 +376,7 @@ func valueToEvalNumeric(v sqltypes.Value) (eval, error) { if err != nil { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return &evalInt64{ival}, nil + return &evalInt64{i: ival}, nil case v.IsUnsigned(): var uval uint64 uval, err := v.ToUint64() @@ -383,7 +391,7 @@ func valueToEvalNumeric(v sqltypes.Value) (eval, error) { } ival, err := strconv.ParseInt(v.RawStr(), 10, 64) if err == nil { - return &evalInt64{ival}, nil + return &evalInt64{i: ival}, nil } return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", v.RawStr()) } @@ -418,6 +426,9 @@ func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eva hex := value.Raw() raw, err := parseHexLiteral(hex[2 : len(hex)-1]) return newEvalBytesHex(raw), wrap(err) + } else if tt == sqltypes.BitNum { + raw, err := parseBitNum(value.Raw()) + return newEvalBytesBit(raw), wrap(err) } else { return newEvalText(value.Raw(), collation), nil } diff --git a/go/vt/vtgate/evalengine/eval_bytes.go b/go/vt/vtgate/evalengine/eval_bytes.go index a3bcfeafa6d..caa516acbe4 100644 --- a/go/vt/vtgate/evalengine/eval_bytes.go +++ b/go/vt/vtgate/evalengine/eval_bytes.go @@ -178,22 +178,39 @@ func (e *evalBytes) toDateBestEffort() datetime.DateTime { return datetime.DateTime{} } -func (e *evalBytes) toNumericHex() (*evalUint64, bool) { +func (e *evalBytes) parseNumericBytes(number *[8]byte) bool { raw := e.bytes if l := len(raw); l > 8 { for _, b := range raw[:l-8] { if b != 0 { - return nil, false // overflow + return false // overflow } } raw = raw[l-8:] } - - var number [8]byte for i, b := range raw { number[8-len(raw)+i] = b } + return true +} + +func (e *evalBytes) toNumericHex() (*evalUint64, bool) { + var number [8]byte + if !e.parseNumericBytes(&number) { + return nil, false + } + hex := newEvalUint64(binary.BigEndian.Uint64(number[:])) hex.hexLiteral = true return hex, true } + +func (e *evalBytes) toNumericBit() (*evalInt64, bool) { + var number [8]byte + if !e.parseNumericBytes(&number) { + return nil, false + } + bit := newEvalInt64(int64(binary.BigEndian.Uint64(number[:]))) + bit.bitLiteral = true + return bit, true +} diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 5c2db2da71a..8584fa4a714 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -41,7 +41,8 @@ type ( } evalInt64 struct { - i int64 + i int64 + bitLiteral bool } evalUint64 struct { @@ -64,8 +65,8 @@ var _ evalNumeric = (*evalUint64)(nil) var _ evalNumeric = (*evalFloat)(nil) var _ evalNumeric = (*evalDecimal)(nil) -var evalBoolTrue = &evalInt64{1} -var evalBoolFalse = &evalInt64{0} +var evalBoolTrue = &evalInt64{i: 1} +var evalBoolFalse = &evalInt64{i: 0} func newEvalUint64(u uint64) *evalUint64 { return &evalUint64{u: u} @@ -110,6 +111,14 @@ func evalToNumeric(e eval, preciseDatetime bool) evalNumeric { } return hex } + if e.isBitLiteral() { + bit, ok := e.toNumericBit() + if !ok { + // overflow + return newEvalFloat(0) + } + return bit + } f, _ := fastparse.ParseFloat64(e.string()) return &evalFloat{f: f} case *evalJSON: @@ -160,6 +169,18 @@ func evalToFloat(e eval) (*evalFloat, bool) { } return f, true } + if e.isBitLiteral() { + bit, ok := e.toNumericBit() + if !ok { + // overflow + return newEvalFloat(0), false + } + f, ok := bit.toFloat() + if !ok { + return newEvalFloat(0), false + } + return f, true + } val, err := fastparse.ParseFloat64(e.string()) return &evalFloat{f: val}, err == nil case *evalJSON: @@ -198,6 +219,14 @@ func evalToDecimal(e eval, m, d int32) *evalDecimal { } return hex.toDecimal(m, d) } + if e.isBitLiteral() { + bit, ok := e.toNumericBit() + if !ok { + // overflow + return newEvalDecimal(decimal.Zero, m, d) + } + return bit.toDecimal(m, d) + } dec, _ := decimal.NewFromString(e.string()) return newEvalDecimal(dec, m, d) case *evalJSON: @@ -256,6 +285,14 @@ func evalToInt64(e eval) *evalInt64 { } return hex.toInt64() } + if e.isBitLiteral() { + bit, ok := e.toNumericBit() + if !ok { + // overflow + return newEvalInt64(0) + } + return bit + } i, _ := fastparse.ParseInt64(e.string(), 10) return newEvalInt64(i) case *evalJSON: @@ -314,6 +351,9 @@ func (e *evalInt64) ToRawBytes() []byte { } func (e *evalInt64) negate() evalNumeric { + if e.bitLiteral { + return newEvalFloat(-float64(e.i)) + } if e.i == math.MinInt64 { return newEvalDecimalWithPrec(decimal.NewFromInt(e.i).NegInPlace(), 0) } diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 50326c9eb3c..01fdb673f8d 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -583,8 +583,13 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Int64: - neg = sqltypes.Int64 - c.asm.Neg_i() + if arg.Flag&flagBit != 0 { + neg = sqltypes.Float64 + c.asm.Neg_bit() + } else { + neg = sqltypes.Int64 + c.asm.Neg_i() + } case sqltypes.Uint64: if arg.Flag&flagHex != 0 { neg = sqltypes.Float64 diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 9172f8abc3c..d75a290036d 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -57,7 +57,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { tuple := make([]eval, 0, len(bvar.Values)) for _, value := range bvar.Values { - e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.DefaultCollationForType(value.Type))) + e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.CollationForType(value.Type, bv.Collation.Collation))) if err != nil { return nil, err } @@ -73,7 +73,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { if bv.typed() { typ = bv.Type } - return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.DefaultCollationForType(typ))) + return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.CollationForType(typ, bv.Collation.Collation))) } } @@ -92,6 +92,8 @@ func (bv *BindVariable) typeof(env *ExpressionEnv, _ []*querypb.Field) (sqltypes return sqltypes.Null, flagNull | flagNullable case sqltypes.HexNum, sqltypes.HexVal: return sqltypes.VarBinary, flagHex + case sqltypes.BitNum: + return sqltypes.VarBinary, flagBit default: return tt, 0 } @@ -102,7 +104,13 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { return ctype{}, c.unsupported(bvar) } - switch tt := bvar.Type; { + typ := ctype{ + Type: bvar.Type, + Flag: 0, + Col: bvar.Collation, + } + + switch tt := typ.Type; { case sqltypes.IsSigned(tt): c.asm.PushBVar_i(bvar.Key) case sqltypes.IsUnsigned(tt): @@ -114,10 +122,18 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { case sqltypes.IsText(tt): if tt == sqltypes.HexNum { c.asm.PushBVar_hexnum(bvar.Key) + typ.Type = sqltypes.VarBinary + typ.Flag |= flagHex } else if tt == sqltypes.HexVal { c.asm.PushBVar_hexval(bvar.Key) + typ.Type = sqltypes.VarBinary + typ.Flag |= flagHex + } else if tt == sqltypes.BitNum { + c.asm.PushBVar_bitnum(bvar.Key) + typ.Type = sqltypes.VarBinary + typ.Flag |= flagBit } else { - c.asm.PushBVar_text(bvar.Key, bvar.Collation) + c.asm.PushBVar_text(bvar.Key, typ.Col) } case sqltypes.IsBinary(tt): c.asm.PushBVar_bin(bvar.Key) @@ -128,11 +144,7 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { default: return ctype{}, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Type is not supported: %s", tt) } - - return ctype{ - Type: bvar.Type, - Col: bvar.Collation, - }, nil + return typ, nil } func (bvar *BindVariable) typed() bool { diff --git a/go/vt/vtgate/evalengine/expr_literal.go b/go/vt/vtgate/evalengine/expr_literal.go index 392dcd25288..e17fe63ee32 100644 --- a/go/vt/vtgate/evalengine/expr_literal.go +++ b/go/vt/vtgate/evalengine/expr_literal.go @@ -51,6 +51,9 @@ func (l *Literal) typeof(*ExpressionEnv, []*querypb.Field) (sqltypes.Type, typeF if e == evalBoolTrue || e == evalBoolFalse { f |= flagIsBoolean } + if e.bitLiteral { + f |= flagBit + } case *evalUint64: if e.hexLiteral { f |= flagHex diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index f1a85f64778..6f34f5bcd12 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -1375,7 +1375,7 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { case *evalDecimal: prec = d.length case *evalBytes: - if d.isHexLiteral() { + if d.isHexOrBitLiteral() { return newEvalInt64(0) } prec = 6 diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 245318529c3..dd97ce0fbe6 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -91,6 +91,7 @@ var inputConversions = []string{ "0.0", "0.000", "1.5", "-1.5", "1.1", "1.7", "-1.1", "-1.7", "'1.5'", "'-1.5'", `'foobar'`, `_utf8 'foobar'`, `''`, `_binary 'foobar'`, `0x0`, `0x1`, `0xff`, `X'00'`, `X'01'`, `X'ff'`, + `0b1001`, `b'1001'`, `0x9`, `x'09'`, "NULL", "true", "false", "0xFF666F6F626172FF", "0x666F6F626172FF", "0xFF666F6F626172", "9223372036854775807", "-9223372036854775808", "18446744073709551615", diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 84d8a2c4f10..de58e061fc8 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -246,7 +246,7 @@ func translateLiteral(lit *sqlparser.Literal, collation collations.ID) (*Literal return NewLiteralBinaryFromHexNum(lit.Bytes()) case sqlparser.HexVal: return NewLiteralBinaryFromHex(lit.Bytes()) - case sqlparser.BitVal: + case sqlparser.BitNum: return NewLiteralBinaryFromBit(lit.Bytes()) case sqlparser.DateVal: return NewLiteralDateFromBytes(lit.Bytes()) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 9c8ae647d39..087bf2b14e4 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -4079,7 +4079,7 @@ func TestSelectHexAndBit(t *testing.T) { qr, err = executor.Execute(context.Background(), nil, "TestSelectHexAndBit", session, "select 1 + 0b1001, 1 + b'1001', 1 + 0x9, 1 + x'09'", nil) require.NoError(t, err) - require.Equal(t, `[[UINT64(10) UINT64(10) UINT64(10) UINT64(10)]]`, fmt.Sprintf("%v", qr.Rows)) + require.Equal(t, `[[INT64(10) INT64(10) UINT64(10) UINT64(10)]]`, fmt.Sprintf("%v", qr.Rows)) } // TestSelectCFC tests validates that cfc vindex plan gets cached and same plan is getting reused.