From d62a5c5652de99c9b4f156157878c52c86c71e66 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 27 Dec 2023 08:36:59 +0100 Subject: [PATCH] evalengine: Internal cleanup and consistency fixes (#14854) Signed-off-by: Dirkjan Bussink --- go/mysql/datetime/datetime.go | 50 +++++++++---------- go/mysql/datetime/helpers.go | 4 +- go/mysql/datetime/timeparts.go | 4 +- go/sqltypes/value.go | 8 ++- go/test/endtoend/onlineddl/vtgate_util.go | 2 +- go/vt/mysqlctl/azblobbackupstorage/azblob.go | 5 +- go/vt/throttler/throttler.go | 4 +- go/vt/vtgate/engine/memory_sort.go | 2 +- .../vtgate/evalengine/api_arithmetic_test.go | 34 ++++++------- go/vt/vtgate/evalengine/arithmetic.go | 2 +- go/vt/vtgate/evalengine/compiler.go | 8 ++- go/vt/vtgate/evalengine/compiler_asm.go | 28 ++++++++--- go/vt/vtgate/evalengine/compiler_test.go | 12 +++++ go/vt/vtgate/evalengine/expr_bit.go | 14 +++--- go/vt/vtgate/evalengine/expr_compare.go | 18 +++++-- go/vt/vtgate/evalengine/fn_time.go | 46 ++++++++++++++--- go/vt/vtgate/evalengine/testcases/inputs.go | 3 +- go/vt/vttablet/onlineddl/vrepl.go | 12 ++--- 18 files changed, 167 insertions(+), 89 deletions(-) diff --git a/go/mysql/datetime/datetime.go b/go/mysql/datetime/datetime.go index debc21cff6d..67191e5c48e 100644 --- a/go/mysql/datetime/datetime.go +++ b/go/mysql/datetime/datetime.go @@ -440,12 +440,12 @@ func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) { return dt.Time, itv.precision(stradd), ok } -func (t Time) toSeconds() int { - tsecs := t.Hour()*secondsPerHour + t.Minute()*secondsPerMinute + t.Second() +func (t Time) toDuration() time.Duration { + dur := time.Duration(t.hour)*time.Hour + time.Duration(t.minute)*time.Minute + time.Duration(t.second)*time.Second + time.Duration(t.nanosecond)*time.Nanosecond if t.Neg() { - return -tsecs + return -dur } - return tsecs + return dur } func (d Date) ToStdTime(loc *time.Location) (out time.Time) { @@ -577,8 +577,12 @@ func (dt DateTime) Round(p int) (r DateTime) { return r } -func (dt DateTime) toSeconds() int { - return (dt.Date.Day()-1)*secondsPerDay + dt.Time.toSeconds() +func (dt DateTime) toDuration() time.Duration { + dur := dt.Time.toDuration() + if !dt.Date.IsZero() { + dur += time.Duration(dt.Date.Day()-1) * durationPerDay + } + return dur } func (dt *DateTime) addInterval(itv *Interval) bool { @@ -588,29 +592,25 @@ func (dt *DateTime) addInterval(itv *Interval) bool { return false } - nsec := dt.Time.Nanosecond() + itv.nsec - sec := dt.toSeconds() + itv.toSeconds() + (nsec / int(time.Second)) - nsec = nsec % int(time.Second) - - if nsec < 0 { - nsec += int(time.Second) - sec-- - } + dur := dt.toDuration() + dur += itv.toDuration() + days := time.Duration(0) + if !dt.Date.IsZero() { + days = dur / durationPerDay + dur -= days * durationPerDay - days := sec / secondsPerDay - sec -= days * secondsPerDay - - if sec < 0 { - sec += secondsPerDay - days-- + if dur < 0 { + dur += durationPerDay + days-- + } } - dt.Time.nanosecond = uint32(nsec) - dt.Time.second = uint8(sec % secondsPerMinute) - dt.Time.minute = uint8((sec / secondsPerMinute) % secondsPerMinute) - dt.Time.hour = uint16(sec / secondsPerHour) + dt.Time.nanosecond = uint32((dur % time.Second) / time.Nanosecond) + dt.Time.second = uint8((dur % time.Minute) / time.Second) + dt.Time.minute = uint8((dur % time.Hour) / time.Minute) + dt.Time.hour = uint16(dur / time.Hour) - daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + days + daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + int(days) if daynum < 0 || daynum > maxDay { return false } diff --git a/go/mysql/datetime/helpers.go b/go/mysql/datetime/helpers.go index 33d673782fc..c199844df19 100644 --- a/go/mysql/datetime/helpers.go +++ b/go/mysql/datetime/helpers.go @@ -285,7 +285,5 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l } const ( - secondsPerMinute = 60 - secondsPerHour = 60 * secondsPerMinute - secondsPerDay = 24 * secondsPerHour + durationPerDay = 24 * time.Hour ) diff --git a/go/mysql/datetime/timeparts.go b/go/mysql/datetime/timeparts.go index a774099a93a..ccc0d0a3640 100644 --- a/go/mysql/datetime/timeparts.go +++ b/go/mysql/datetime/timeparts.go @@ -87,6 +87,6 @@ func (tp *timeparts) isZero() bool { return tp.year == 0 && tp.month == 0 && tp.day == 0 && tp.hour == 0 && tp.min == 0 && tp.sec == 0 && tp.nsec == 0 } -func (tp *timeparts) toSeconds() int { - return tp.day*secondsPerDay + tp.hour*3600 + tp.min*60 + tp.sec +func (tp *timeparts) toDuration() time.Duration { + return time.Duration(tp.day)*durationPerDay + time.Duration(tp.hour)*time.Hour + time.Duration(tp.min)*time.Minute + time.Duration(tp.sec)*time.Second + time.Duration(tp.nsec)*time.Nanosecond } diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 45415814700..20a30cbc1c1 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -733,7 +733,13 @@ func (v Value) TinyWeightCmp(other Value) int { if v.flags&other.flags&flagTinyWeight == 0 { return 0 } - return int(int64(v.tinyweight) - int64(other.tinyweight)) + if v.tinyweight == other.tinyweight { + return 0 + } + if v.tinyweight < other.tinyweight { + return -1 + } + return 1 } func (v Value) TinyWeight() uint32 { diff --git a/go/test/endtoend/onlineddl/vtgate_util.go b/go/test/endtoend/onlineddl/vtgate_util.go index 7d51f3365ba..0e5c8af5bd9 100644 --- a/go/test/endtoend/onlineddl/vtgate_util.go +++ b/go/test/endtoend/onlineddl/vtgate_util.go @@ -57,7 +57,7 @@ func VtgateExecQuery(t *testing.T, vtParams *mysql.ConnParams, query string, exp require.Nil(t, err) defer conn.Close() - qr, err := conn.ExecuteFetch(query, math.MaxInt64, true) + qr, err := conn.ExecuteFetch(query, math.MaxInt, true) if expectError == "" { require.NoError(t, err) } else { diff --git a/go/vt/mysqlctl/azblobbackupstorage/azblob.go b/go/vt/mysqlctl/azblobbackupstorage/azblob.go index 7058745d6c6..3ba6b187a2f 100644 --- a/go/vt/mysqlctl/azblobbackupstorage/azblob.go +++ b/go/vt/mysqlctl/azblobbackupstorage/azblob.go @@ -239,8 +239,9 @@ func (bh *AZBlobBackupHandle) AddFile(ctx context.Context, filename string, file return nil, fmt.Errorf("AddFile cannot be called on read-only backup") } // Error out if the file size it too large ( ~4.75 TB) - if filesize > azblob.BlockBlobMaxStageBlockBytes*azblob.BlockBlobMaxBlocks { - return nil, fmt.Errorf("filesize (%v) is too large to upload to az blob (max size %v)", filesize, azblob.BlockBlobMaxStageBlockBytes*azblob.BlockBlobMaxBlocks) + maxSize := int64(azblob.BlockBlobMaxStageBlockBytes * azblob.BlockBlobMaxBlocks) + if filesize > maxSize { + return nil, fmt.Errorf("filesize (%v) is too large to upload to az blob (max size %v)", filesize, maxSize) } obj := objName(bh.dir, bh.name, filename) diff --git a/go/vt/throttler/throttler.go b/go/vt/throttler/throttler.go index 83a1c52225e..3e81ed5b902 100644 --- a/go/vt/throttler/throttler.go +++ b/go/vt/throttler/throttler.go @@ -50,7 +50,7 @@ const ( // MaxRateModuleDisabled can be set in NewThrottler() to disable throttling // by a fixed rate. - MaxRateModuleDisabled = math.MaxInt64 + MaxRateModuleDisabled = int64(math.MaxInt64) // InvalidMaxRate is a constant which will fail in a NewThrottler() call. // It should be used when returning maxRate in an error case. @@ -58,7 +58,7 @@ const ( // ReplicationLagModuleDisabled can be set in NewThrottler() to disable // throttling based on the MySQL replication lag. - ReplicationLagModuleDisabled = math.MaxInt64 + ReplicationLagModuleDisabled = int64(math.MaxInt64) // InvalidMaxReplicationLag is a constant which will fail in a NewThrottler() // call. It should be used when returning maxReplicationlag in an error case. diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index 8a4cd9188ac..4e222498f26 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -143,7 +143,7 @@ func (ms *MemorySort) NeedsTransaction() bool { func (ms *MemorySort) fetchCount(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (int, error) { if ms.UpperLimit == nil { - return math.MaxInt64, nil + return math.MaxInt, nil } env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) resolved, err := env.Evaluate(ms.UpperLimit) diff --git a/go/vt/vtgate/evalengine/api_arithmetic_test.go b/go/vt/vtgate/evalengine/api_arithmetic_test.go index 40373423aa5..28f6c6b55d0 100644 --- a/go/vt/vtgate/evalengine/api_arithmetic_test.go +++ b/go/vt/vtgate/evalengine/api_arithmetic_test.go @@ -85,11 +85,11 @@ func TestArithmetics(t *testing.T) { // testing for int64 overflow with min negative value v1: NewInt64(math.MinInt64), v2: NewInt64(1), - err: dataOutOfRangeError(math.MinInt64, 1, "BIGINT", "-").Error(), + err: dataOutOfRangeError(int64(math.MinInt64), int64(1), "BIGINT", "-").Error(), }, { v1: NewUint64(4), v2: NewInt64(5), - err: dataOutOfRangeError(4, 5, "BIGINT UNSIGNED", "-").Error(), + err: dataOutOfRangeError(uint64(4), int64(5), "BIGINT UNSIGNED", "-").Error(), }, { // testing uint - int v1: NewUint64(7), @@ -103,7 +103,7 @@ func TestArithmetics(t *testing.T) { // testing for int64 overflow v1: NewInt64(math.MinInt64), v2: NewUint64(0), - err: dataOutOfRangeError(math.MinInt64, 0, "BIGINT UNSIGNED", "-").Error(), + err: dataOutOfRangeError(int64(math.MinInt64), uint64(0), "BIGINT UNSIGNED", "-").Error(), }, { v1: TestValue(sqltypes.VarChar, "c"), v2: NewInt64(1), @@ -140,7 +140,7 @@ func TestArithmetics(t *testing.T) { }, { v1: NewInt64(-1), v2: NewUint64(2), - err: dataOutOfRangeError(-1, 2, "BIGINT UNSIGNED", "-").Error(), + err: dataOutOfRangeError(int64(-1), int64(2), "BIGINT UNSIGNED", "-").Error(), }, { v1: NewInt64(2), v2: NewUint64(1), @@ -169,7 +169,7 @@ func TestArithmetics(t *testing.T) { // testing uint - uint if v2 > v1 v1: NewUint64(2), v2: NewUint64(4), - err: dataOutOfRangeError(2, 4, "BIGINT UNSIGNED", "-").Error(), + err: dataOutOfRangeError(uint64(2), uint64(4), "BIGINT UNSIGNED", "-").Error(), }, { // testing uint - (- int) v1: NewUint64(1), @@ -207,7 +207,7 @@ func TestArithmetics(t *testing.T) { }, { v1: NewInt64(-2), v2: NewUint64(1), - err: dataOutOfRangeError(1, -2, "BIGINT UNSIGNED", "+").Error(), + err: dataOutOfRangeError(uint64(1), int64(-2), "BIGINT UNSIGNED", "+").Error(), }, { v1: NewInt64(math.MaxInt64), v2: NewInt64(-2), @@ -221,12 +221,12 @@ func TestArithmetics(t *testing.T) { // testing for overflow uint64 v1: NewUint64(maxUint64), v2: NewUint64(2), - err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(), + err: dataOutOfRangeError(maxUint64, uint64(2), "BIGINT UNSIGNED", "+").Error(), }, { // int64 underflow v1: NewInt64(math.MinInt64), v2: NewInt64(-2), - err: dataOutOfRangeError(math.MinInt64, -2, "BIGINT", "+").Error(), + err: dataOutOfRangeError(int64(math.MinInt64), int64(-2), "BIGINT", "+").Error(), }, { // checking int64 max value can be returned v1: NewInt64(math.MaxInt64), @@ -261,7 +261,7 @@ func TestArithmetics(t *testing.T) { // testing for uint64 overflow with max uint64 + int value v1: NewUint64(maxUint64), v2: NewInt64(2), - err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(), + err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "+").Error(), }, { v1: sqltypes.NewHexNum([]byte("0x9")), v2: NewInt64(1), @@ -309,7 +309,7 @@ func TestArithmetics(t *testing.T) { // Lower bound for int64 v1: NewInt64(math.MinInt64), v2: NewInt64(1), - out: NewDecimal(strconv.Itoa(math.MinInt64) + ".0000"), + out: NewDecimal(strconv.FormatInt(math.MinInt64, 10) + ".0000"), }, { // upper bound for uint64 v1: NewUint64(math.MaxUint64), @@ -413,12 +413,12 @@ func TestArithmetics(t *testing.T) { // testing for overflow of int64 v1: NewInt64(math.MaxInt64), v2: NewInt64(2), - err: dataOutOfRangeError(math.MaxInt64, 2, "BIGINT", "*").Error(), + err: dataOutOfRangeError(int64(math.MaxInt64), int64(2), "BIGINT", "*").Error(), }, { // testing for underflow of uint64*max.uint64 v1: NewInt64(2), v2: NewUint64(maxUint64), - err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "*").Error(), + err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "*").Error(), }, { v1: NewUint64(math.MaxUint64), v2: NewUint64(1), @@ -427,7 +427,7 @@ func TestArithmetics(t *testing.T) { // Checking whether maxInt value can be passed as uint value v1: NewUint64(math.MaxInt64), v2: NewInt64(3), - err: dataOutOfRangeError(math.MaxInt64, 3, "BIGINT UNSIGNED", "*").Error(), + err: dataOutOfRangeError(uint64(math.MaxInt64), int64(3), "BIGINT UNSIGNED", "*").Error(), }}, }} @@ -492,7 +492,7 @@ func TestNullSafeAdd(t *testing.T) { }, { v1: NewInt64(-100), v2: NewUint64(10), - err: dataOutOfRangeError(10, -100, "BIGINT UNSIGNED", "+"), + err: dataOutOfRangeError(uint64(10), int64(-100), "BIGINT UNSIGNED", "+"), }, { // Make sure underlying error is returned while converting. v1: NewFloat64(1), @@ -594,12 +594,12 @@ func TestAddNumeric(t *testing.T) { // Int64 overflow. v1: newEvalInt64(9223372036854775807), v2: newEvalInt64(2), - err: dataOutOfRangeError(9223372036854775807, 2, "BIGINT", "+"), + err: dataOutOfRangeError(int64(9223372036854775807), int64(2), "BIGINT", "+"), }, { // Int64 underflow. v1: newEvalInt64(-9223372036854775807), v2: newEvalInt64(-2), - err: dataOutOfRangeError(-9223372036854775807, -2, "BIGINT", "+"), + err: dataOutOfRangeError(int64(-9223372036854775807), int64(-2), "BIGINT", "+"), }, { v1: newEvalInt64(-1), v2: newEvalUint64(2), @@ -608,7 +608,7 @@ func TestAddNumeric(t *testing.T) { // Uint64 overflow. v1: newEvalUint64(18446744073709551615), v2: newEvalUint64(2), - err: dataOutOfRangeError(uint64(18446744073709551615), 2, "BIGINT UNSIGNED", "+"), + err: dataOutOfRangeError(uint64(18446744073709551615), uint64(2), "BIGINT UNSIGNED", "+"), }} for _, tcase := range tcases { got, err := addNumericWithError(tcase.v1, tcase.v2) diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index c258dab1672..031b387d275 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -25,7 +25,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func dataOutOfRangeError[N1, N2 int | int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error { +func dataOutOfRangeError[N1, N2 int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error { return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2) } diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 6df13dc4ccf..011e1641cb4 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -159,8 +159,12 @@ func (c *compiler) compileToNumeric(ct ctype, offset int, fallback sqltypes.Type if sqltypes.IsDateOrTime(ct.Type) { if preciseDatetime { - c.asm.Convert_Ti(offset) - return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric} + if ct.Size == 0 { + c.asm.Convert_Ti(offset) + return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric} + } + c.asm.Convert_Td(offset) + return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Size: ct.Size} } c.asm.Convert_Tf(offset) return ctype{Type: sqltypes.Float64, Flag: ct.Flag, Col: collationNumeric} diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index affbee664a8..3d8eb0023bf 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -288,13 +288,13 @@ func (asm *assembler) BitShiftLeft_bu() { r := env.vm.stack[env.vm.sp-1].(*evalUint64) var ( - bits = int(r.u & 7) - bytes = int(r.u >> 3) - length = len(l.bytes) + bits = int64(r.u & 7) + bytes = int64(r.u >> 3) + length = int64(len(l.bytes)) out = make([]byte, length) ) - for i := 0; i < length; i++ { + for i := int64(0); i < length; i++ { pos := i + bytes + 1 switch { case pos < length: @@ -332,9 +332,9 @@ func (asm *assembler) BitShiftRight_bu() { r := env.vm.stack[env.vm.sp-1].(*evalUint64) var ( - bits = int(r.u & 7) - bytes = int(r.u >> 3) - length = len(l.bytes) + bits = int64(r.u & 7) + bytes = int64(r.u >> 3) + length = int64(len(l.bytes)) out = make([]byte, length) ) @@ -904,7 +904,7 @@ func (asm *assembler) Convert_Ti(offset int) { asm.emit(func(env *ExpressionEnv) int { v := env.vm.stack[env.vm.sp-offset].(*evalTemporal) if v.prec != 0 { - env.vm.err = errDeoptimize + env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with non-zero precision") return 1 } env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalInt64(v.toInt64()) @@ -920,6 +920,18 @@ func (asm *assembler) Convert_Tf(offset int) { }, "CONV SQLTIME(SP-%d), FLOAT64", offset) } +func (asm *assembler) Convert_Td(offset int) { + asm.emit(func(env *ExpressionEnv) int { + v := env.vm.stack[env.vm.sp-offset].(*evalTemporal) + if v.prec == 0 { + env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with zero precision") + return 1 + } + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalDecimalWithPrec(v.toDecimal(), int32(v.prec)) + return 1 + }, "CONV SQLTIME(SP-%d), DECIMAL", offset) +} + func (asm *assembler) Convert_iB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 0e9219c648a..6cd27b043a3 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -590,6 +590,18 @@ func TestCompilerSingle(t *testing.T) { expression: `week('2024-12-31', 5)`, result: `INT64(53)`, }, + { + expression: `FROM_UNIXTIME(time '10:04:58.5')`, + result: `DATETIME("1970-01-02 04:54:18.5")`, + }, + { + expression: `0 = time '10:04:58.1'`, + result: `INT64(0)`, + }, + { + expression: `CAST(time '32:34:58.5' AS TIME)`, + result: `TIME("32:34:59")`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/expr_bit.go b/go/vt/vtgate/evalengine/expr_bit.go index 9c4dbafe2a6..6200875d1fc 100644 --- a/go/vt/vtgate/evalengine/expr_bit.go +++ b/go/vt/vtgate/evalengine/expr_bit.go @@ -104,9 +104,9 @@ func (o opBitShr) numeric(num, shift uint64) uint64 { return num >> shift } func (o opBitShr) binary(num []byte, shift uint64) []byte { var ( - bits = int(shift % 8) - bytes = int(shift / 8) - length = len(num) + bits = int64(shift % 8) + bytes = int64(shift / 8) + length = int64(len(num)) out = make([]byte, length) ) @@ -127,13 +127,13 @@ func (o opBitShl) numeric(num, shift uint64) uint64 { return num << shift } func (o opBitShl) binary(num []byte, shift uint64) []byte { var ( - bits = int(shift % 8) - bytes = int(shift / 8) - length = len(num) + bits = int64(shift % 8) + bytes = int64(shift / 8) + length = int64(len(num)) out = make([]byte, length) ) - for i := 0; i < length; i++ { + for i := int64(0); i < length; i++ { pos := i + bytes + 1 switch { case pos < length: diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index 84f40abb9c0..91c7d9f6c42 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -389,12 +389,22 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) { c.asm.CmpDateString() case compareAsDateAndNumeric(lt.Type, rt.Type): if sqltypes.IsDateOrTime(lt.Type) { - c.asm.Convert_Ti(2) - lt.Type = sqltypes.Int64 + if lt.Size == 0 { + c.asm.Convert_Ti(2) + lt.Type = sqltypes.Int64 + } else { + c.asm.Convert_Tf(2) + lt.Type = sqltypes.Float64 + } } if sqltypes.IsDateOrTime(rt.Type) { - c.asm.Convert_Ti(1) - rt.Type = sqltypes.Int64 + if rt.Size == 0 { + c.asm.Convert_Ti(1) + rt.Type = sqltypes.Int64 + } else { + c.asm.Convert_Tf(1) + rt.Type = sqltypes.Float64 + } } swapped = c.compareNumericTypes(lt, rt) case compareAsJSON(lt.Type, rt.Type): diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 49a328a852f..8236f104499 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -559,35 +559,65 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) { switch ts := ts.(type) { case *evalInt64: + if ts.i < 0 || ts.i >= maxUnixtime { + return nil, nil + } sec = ts.i case *evalUint64: + if ts.u >= maxUnixtime { + return nil, nil + } sec = int64(ts.u) case *evalFloat: + if ts.f < 0 || ts.f >= maxUnixtime { + return nil, nil + } sf, ff := math.Modf(ts.f) sec = int64(sf) frac = int64(ff * 1e9) prec = maxTimePrec case *evalDecimal: + if ts.dec.Sign() < 0 { + return nil, nil + } sd, fd := ts.dec.QuoRem(decimal.New(1, 0), 0) sec, _ = sd.Int64() + if sec >= maxUnixtime { + return nil, nil + } frac, _ = fd.Mul(decimal.New(1, 9)).Int64() prec = int(ts.length) case *evalTemporal: if ts.prec == 0 { sec = ts.toInt64() + if sec < 0 || sec >= maxUnixtime { + return nil, nil + } } else { dec := ts.toDecimal() + if dec.Sign() < 0 { + return nil, nil + } sd, fd := dec.QuoRem(decimal.New(1, 0), 0) sec, _ = sd.Int64() + if sec >= maxUnixtime { + return nil, nil + } frac, _ = fd.Mul(decimal.New(1, 9)).Int64() prec = int(ts.prec) } case *evalBytes: if ts.isHexOrBitLiteral() { u, _ := ts.toNumericHex() + if u.u >= maxUnixtime { + return nil, nil + } sec = int64(u.u) } else { f, _ := evalToFloat(ts) + if f.f < 0 || f.f >= maxUnixtime { + return nil, nil + } sf, ff := math.Modf(f.f) sec = int64(sf) frac = int64(ff * 1e9) @@ -595,16 +625,15 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) { } default: f, _ := evalToFloat(ts) + if f.f < 0 || f.f >= maxUnixtime { + return nil, nil + } sf, ff := math.Modf(f.f) sec = int64(sf) frac = int64(ff * 1e9) prec = maxTimePrec } - if sec < 0 || sec >= maxUnixtime { - return nil, nil - } - t := time.Unix(sec, frac) if tz := env.currentTimezone(); tz != nil { t = t.In(tz) @@ -645,8 +674,13 @@ func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { case sqltypes.Decimal: c.asm.Fn_FROM_UNIXTIME_d() case sqltypes.Datetime, sqltypes.Date, sqltypes.Time: - c.asm.Convert_Ti(1) - c.asm.Fn_FROM_UNIXTIME_i() + if arg.Size == 0 { + c.asm.Convert_Ti(1) + c.asm.Fn_FROM_UNIXTIME_i() + } else { + c.asm.Convert_Td(1) + c.asm.Fn_FROM_UNIXTIME_d() + } case sqltypes.VarChar, sqltypes.VarBinary: if arg.isHexOrBitLiteral() { c.asm.Convert_xu(1) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index afdb5d6e225..f5fa75854e0 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -100,7 +100,8 @@ var inputConversions = []string{ "18446744073709540000e0", "-18446744073709540000e0", "JSON_OBJECT()", "JSON_ARRAY()", - "time '10:04:58'", "time '31:34:58'", "time '32:34:58'", "time '130:34:58'", "time '5 10:34:58'", "date '2000-01-01'", + "time '10:04:58'", "time '31:34:58'", "time '32:34:58'", "time '130:34:58'", "time '5 10:34:58'", + "time '10:04:58.1'", "time '31:34:58.4'", "time '32:34:58.5'", "time '130:34:58.6'", "time '5 10:34:58.9'", "date '2000-01-01'", "timestamp '2000-01-01 10:34:58'", "timestamp '2000-01-01 10:34:58.123456'", "timestamp '2000-01-01 10:34:58.978654'", "20000101103458", "20000101103458.1234", "20000101103458.123456", "20000101", "103458", "103458.123456", "'20000101103458'", "'20000101103458.1234'", "'20000101103458.123456'", "'20000101'", "'103458'", "'103458.123456'", diff --git a/go/vt/vttablet/onlineddl/vrepl.go b/go/vt/vttablet/onlineddl/vrepl.go index 1f9b422563d..c4e3075ff38 100644 --- a/go/vt/vttablet/onlineddl/vrepl.go +++ b/go/vt/vttablet/onlineddl/vrepl.go @@ -185,7 +185,7 @@ func (v *VRepl) readAutoIncrement(ctx context.Context, conn *dbconnpool.DBConnec return 0, err } - rs, err := conn.ExecuteFetch(query, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(query, math.MaxInt, true) if err != nil { return 0, err } @@ -199,7 +199,7 @@ func (v *VRepl) readAutoIncrement(ctx context.Context, conn *dbconnpool.DBConnec // readTableColumns reads column list from given table func (v *VRepl) readTableColumns(ctx context.Context, conn *dbconnpool.DBConnection, tableName string) (columns *vrepl.ColumnList, virtualColumns *vrepl.ColumnList, pkColumns *vrepl.ColumnList, err error) { parsed := sqlparser.BuildParsedQuery(sqlShowColumnsFrom, tableName) - rs, err := conn.ExecuteFetch(parsed.Query, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(parsed.Query, math.MaxInt, true) if err != nil { return nil, nil, nil, err } @@ -237,7 +237,7 @@ func (v *VRepl) readTableUniqueKeys(ctx context.Context, conn *dbconnpool.DBConn if err != nil { return nil, err } - rs, err := conn.ExecuteFetch(query, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(query, math.MaxInt, true) if err != nil { return nil, err } @@ -260,7 +260,7 @@ func (v *VRepl) readTableUniqueKeys(ctx context.Context, conn *dbconnpool.DBConn // When `fast_analyze_table=1`, an `ANALYZE TABLE` command only analyzes the clustering index (normally the `PRIMARY KEY`). // This is useful when you want to get a better estimate of the number of table rows, as fast as possible. func (v *VRepl) isFastAnalyzeTableSupported(ctx context.Context, conn *dbconnpool.DBConnection) (isSupported bool, err error) { - rs, err := conn.ExecuteFetch(sqlShowVariablesLikeFastAnalyzeTable, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(sqlShowVariablesLikeFastAnalyzeTable, math.MaxInt, true) if err != nil { return false, err } @@ -295,7 +295,7 @@ func (v *VRepl) executeAnalyzeTable(ctx context.Context, conn *dbconnpool.DBConn // readTableStatus reads table status information func (v *VRepl) readTableStatus(ctx context.Context, conn *dbconnpool.DBConnection, tableName string) (tableRows int64, err error) { parsed := sqlparser.BuildParsedQuery(sqlShowTableStatus, tableName) - rs, err := conn.ExecuteFetch(parsed.Query, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(parsed.Query, math.MaxInt, true) if err != nil { return 0, err } @@ -316,7 +316,7 @@ func (v *VRepl) applyColumnTypes(ctx context.Context, conn *dbconnpool.DBConnect if err != nil { return err } - rs, err := conn.ExecuteFetch(query, math.MaxInt64, true) + rs, err := conn.ExecuteFetch(query, math.MaxInt, true) if err != nil { return err }