Skip to content

Commit

Permalink
evalengine: Internal cleanup and consistency fixes
Browse files Browse the repository at this point in the history
While working on vitessio#14841 and
running some tests, I ran into some other issues / consistency problems
in the evalengine that we can clean up and fix.

First we use `time.Duration` for intervals, which already provides
constants and we don't have to deal with nanoseconds then separately but
they are part of durations.

We clean up the tinyweight function which does a bunch of casting which
works but isn't as clear and would actually break on 32 bit (but we
don't support that anyway). It now also returns 0, 1 or -1 which is more
how other Go `Cmp` functions work.

We remove `int` from `dataOutOfRangeError` since the `evalengine` only
works with `int64` or `uint64` anyway, so any usage of `int` would
really be a bug (and we didn't deal with `uint` either so it was
inconsistent anyway).

The bit shift operations also need to operate on int64 explicitly, since
that's what the inputs are in the `evalengine`. So we should keep the
types consistent.

Next, we were missing a now possible optimization which is that we
have size for temporal times at compile time. This means we know if we
need to convert to integer or decimal. We don't hit the deoptimize path
anymore, and now also error hard if that happens since compilation is
broken in that case.

Lastly we were not dealing with underflow / overflow checks correctly in
`FROM_UNIXTIME` between the evaluator and compiler. We need to check
before conversions, because specifically float64 to int64 conversions
have badly defined behavior for large float64 values. It behaves
differently on amd64 vs arm64 vs i386 for example already. Some convert
large values to negative ints, others positive or even other values. By
checking before casting we avoid this and can behave consistently.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Dec 22, 2023
1 parent d807985 commit 8a75ccb
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 76 deletions.
40 changes: 16 additions & 24 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,12 @@ func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) {
return dt.Time, itv.precision(stradd), ok
}

func (t Time) toSeconds() int {
tsecs := t.Hour()*secondsPerHour + t.Minute()*secondsPerMinute + t.Second()
func (t Time) toDuration() time.Duration {
dur := time.Duration(t.hour)*time.Hour + time.Duration(t.minute)*time.Minute + time.Duration(t.second)*time.Second + time.Duration(t.nanosecond)*time.Nanosecond
if t.Neg() {
return -tsecs
return -dur
}
return tsecs
return dur
}

func (d Date) ToStdTime(loc *time.Location) (out time.Time) {
Expand Down Expand Up @@ -569,8 +569,8 @@ func (dt DateTime) Round(p int) (r DateTime) {
return r
}

func (dt DateTime) toSeconds() int {
return (dt.Date.Day()-1)*secondsPerDay + dt.Time.toSeconds()
func (dt DateTime) toDuration() time.Duration {
return time.Duration(dt.Date.Day()-1)*durationPerDay + dt.Time.toDuration()
}

func (dt *DateTime) addInterval(itv *Interval) bool {
Expand All @@ -580,29 +580,21 @@ func (dt *DateTime) addInterval(itv *Interval) bool {
return false
}

nsec := dt.Time.Nanosecond() + itv.nsec
sec := dt.toSeconds() + itv.toSeconds() + (nsec / int(time.Second))
nsec = nsec % int(time.Second)
dur := dt.toDuration() + itv.toDuration()
days := dur / durationPerDay
dur -= days * durationPerDay

if nsec < 0 {
nsec += int(time.Second)
sec--
}

days := sec / secondsPerDay
sec -= days * secondsPerDay

if sec < 0 {
sec += secondsPerDay
if dur < 0 {
dur += durationPerDay
days--
}

dt.Time.nanosecond = uint32(nsec)
dt.Time.second = uint8(sec % secondsPerMinute)
dt.Time.minute = uint8((sec / secondsPerMinute) % secondsPerMinute)
dt.Time.hour = uint16(sec / secondsPerHour)
dt.Time.nanosecond = uint32((dur % time.Second) / time.Nanosecond)
dt.Time.second = uint8((dur % time.Minute) / time.Second)
dt.Time.minute = uint8((dur % time.Hour) / time.Minute)
dt.Time.hour = uint16(dur / time.Hour)

daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + days
daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + int(days)
if daynum < 0 || daynum > maxDay {
return false
}
Expand Down
4 changes: 1 addition & 3 deletions go/mysql/datetime/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,5 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l
}

const (
secondsPerMinute = 60
secondsPerHour = 60 * secondsPerMinute
secondsPerDay = 24 * secondsPerHour
durationPerDay = 24 * time.Hour
)
4 changes: 2 additions & 2 deletions go/mysql/datetime/timeparts.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ func (tp *timeparts) isZero() bool {
return tp.year == 0 && tp.month == 0 && tp.day == 0 && tp.hour == 0 && tp.min == 0 && tp.sec == 0 && tp.nsec == 0
}

func (tp *timeparts) toSeconds() int {
return tp.day*secondsPerDay + tp.hour*3600 + tp.min*60 + tp.sec
func (tp *timeparts) toDuration() time.Duration {
return time.Duration(tp.day)*durationPerDay + time.Duration(tp.hour)*time.Hour + time.Duration(tp.min)*time.Minute + time.Duration(tp.sec)*time.Second + time.Duration(tp.nsec)*time.Nanosecond
}
8 changes: 7 additions & 1 deletion go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,13 @@ func (v Value) TinyWeightCmp(other Value) int {
if v.flags&other.flags&flagTinyWeight == 0 {
return 0
}
return int(int64(v.tinyweight) - int64(other.tinyweight))
if v.tinyweight == other.tinyweight {
return 0
}
if v.tinyweight < other.tinyweight {
return -1
}
return 1
}

func (v Value) TinyWeight() uint32 {
Expand Down
34 changes: 17 additions & 17 deletions go/vt/vtgate/evalengine/api_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ func TestArithmetics(t *testing.T) {
// testing for int64 overflow with min negative value
v1: NewInt64(math.MinInt64),
v2: NewInt64(1),
err: dataOutOfRangeError(math.MinInt64, 1, "BIGINT", "-").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), int64(1), "BIGINT", "-").Error(),
}, {
v1: NewUint64(4),
v2: NewInt64(5),
err: dataOutOfRangeError(4, 5, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(uint64(4), int64(5), "BIGINT UNSIGNED", "-").Error(),
}, {
// testing uint - int
v1: NewUint64(7),
Expand All @@ -103,7 +103,7 @@ func TestArithmetics(t *testing.T) {
// testing for int64 overflow
v1: NewInt64(math.MinInt64),
v2: NewUint64(0),
err: dataOutOfRangeError(math.MinInt64, 0, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), uint64(0), "BIGINT UNSIGNED", "-").Error(),
}, {
v1: TestValue(sqltypes.VarChar, "c"),
v2: NewInt64(1),
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestArithmetics(t *testing.T) {
}, {
v1: NewInt64(-1),
v2: NewUint64(2),
err: dataOutOfRangeError(-1, 2, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(int64(-1), int64(2), "BIGINT UNSIGNED", "-").Error(),
}, {
v1: NewInt64(2),
v2: NewUint64(1),
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestArithmetics(t *testing.T) {
// testing uint - uint if v2 > v1
v1: NewUint64(2),
v2: NewUint64(4),
err: dataOutOfRangeError(2, 4, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(uint64(2), uint64(4), "BIGINT UNSIGNED", "-").Error(),
}, {
// testing uint - (- int)
v1: NewUint64(1),
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestArithmetics(t *testing.T) {
}, {
v1: NewInt64(-2),
v2: NewUint64(1),
err: dataOutOfRangeError(1, -2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(uint64(1), int64(-2), "BIGINT UNSIGNED", "+").Error(),
}, {
v1: NewInt64(math.MaxInt64),
v2: NewInt64(-2),
Expand All @@ -221,12 +221,12 @@ func TestArithmetics(t *testing.T) {
// testing for overflow uint64
v1: NewUint64(maxUint64),
v2: NewUint64(2),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(maxUint64, uint64(2), "BIGINT UNSIGNED", "+").Error(),
}, {
// int64 underflow
v1: NewInt64(math.MinInt64),
v2: NewInt64(-2),
err: dataOutOfRangeError(math.MinInt64, -2, "BIGINT", "+").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), int64(-2), "BIGINT", "+").Error(),
}, {
// checking int64 max value can be returned
v1: NewInt64(math.MaxInt64),
Expand Down Expand Up @@ -261,7 +261,7 @@ func TestArithmetics(t *testing.T) {
// testing for uint64 overflow with max uint64 + int value
v1: NewUint64(maxUint64),
v2: NewInt64(2),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "+").Error(),
}, {
v1: sqltypes.NewHexNum([]byte("0x9")),
v2: NewInt64(1),
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestArithmetics(t *testing.T) {
// Lower bound for int64
v1: NewInt64(math.MinInt64),
v2: NewInt64(1),
out: NewDecimal(strconv.Itoa(math.MinInt64) + ".0000"),
out: NewDecimal(strconv.FormatInt(math.MinInt64, 10) + ".0000"),
}, {
// upper bound for uint64
v1: NewUint64(math.MaxUint64),
Expand Down Expand Up @@ -413,12 +413,12 @@ func TestArithmetics(t *testing.T) {
// testing for overflow of int64
v1: NewInt64(math.MaxInt64),
v2: NewInt64(2),
err: dataOutOfRangeError(math.MaxInt64, 2, "BIGINT", "*").Error(),
err: dataOutOfRangeError(int64(math.MaxInt64), int64(2), "BIGINT", "*").Error(),
}, {
// testing for underflow of uint64*max.uint64
v1: NewInt64(2),
v2: NewUint64(maxUint64),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "*").Error(),
err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "*").Error(),
}, {
v1: NewUint64(math.MaxUint64),
v2: NewUint64(1),
Expand All @@ -427,7 +427,7 @@ func TestArithmetics(t *testing.T) {
// Checking whether maxInt value can be passed as uint value
v1: NewUint64(math.MaxInt64),
v2: NewInt64(3),
err: dataOutOfRangeError(math.MaxInt64, 3, "BIGINT UNSIGNED", "*").Error(),
err: dataOutOfRangeError(uint64(math.MaxInt64), int64(3), "BIGINT UNSIGNED", "*").Error(),
}},
}}

Expand Down Expand Up @@ -492,7 +492,7 @@ func TestNullSafeAdd(t *testing.T) {
}, {
v1: NewInt64(-100),
v2: NewUint64(10),
err: dataOutOfRangeError(10, -100, "BIGINT UNSIGNED", "+"),
err: dataOutOfRangeError(uint64(10), int64(-100), "BIGINT UNSIGNED", "+"),
}, {
// Make sure underlying error is returned while converting.
v1: NewFloat64(1),
Expand Down Expand Up @@ -594,12 +594,12 @@ func TestAddNumeric(t *testing.T) {
// Int64 overflow.
v1: newEvalInt64(9223372036854775807),
v2: newEvalInt64(2),
err: dataOutOfRangeError(9223372036854775807, 2, "BIGINT", "+"),
err: dataOutOfRangeError(int64(9223372036854775807), int64(2), "BIGINT", "+"),
}, {
// Int64 underflow.
v1: newEvalInt64(-9223372036854775807),
v2: newEvalInt64(-2),
err: dataOutOfRangeError(-9223372036854775807, -2, "BIGINT", "+"),
err: dataOutOfRangeError(int64(-9223372036854775807), int64(-2), "BIGINT", "+"),
}, {
v1: newEvalInt64(-1),
v2: newEvalUint64(2),
Expand All @@ -608,7 +608,7 @@ func TestAddNumeric(t *testing.T) {
// Uint64 overflow.
v1: newEvalUint64(18446744073709551615),
v2: newEvalUint64(2),
err: dataOutOfRangeError(uint64(18446744073709551615), 2, "BIGINT UNSIGNED", "+"),
err: dataOutOfRangeError(uint64(18446744073709551615), uint64(2), "BIGINT UNSIGNED", "+"),
}}
for _, tcase := range tcases {
got, err := addNumericWithError(tcase.v1, tcase.v2)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

func dataOutOfRangeError[N1, N2 int | int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error {
func dataOutOfRangeError[N1, N2 int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error {
return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2)
}

Expand Down
8 changes: 6 additions & 2 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ func (c *compiler) compileToNumeric(ct ctype, offset int, fallback sqltypes.Type

if sqltypes.IsDateOrTime(ct.Type) {
if preciseDatetime {
c.asm.Convert_Ti(offset)
return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric}
if ct.Size == 0 {
c.asm.Convert_Ti(offset)
return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric}
}
c.asm.Convert_Td(offset)
return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Size: ct.Size}
}
c.asm.Convert_Tf(offset)
return ctype{Type: sqltypes.Float64, Flag: ct.Flag, Col: collationNumeric}
Expand Down
28 changes: 20 additions & 8 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,13 @@ func (asm *assembler) BitShiftLeft_bu() {
r := env.vm.stack[env.vm.sp-1].(*evalUint64)

var (
bits = int(r.u & 7)
bytes = int(r.u >> 3)
length = len(l.bytes)
bits = int64(r.u & 7)
bytes = int64(r.u >> 3)
length = int64(len(l.bytes))
out = make([]byte, length)
)

for i := 0; i < length; i++ {
for i := int64(0); i < length; i++ {
pos := i + bytes + 1
switch {
case pos < length:
Expand Down Expand Up @@ -332,9 +332,9 @@ func (asm *assembler) BitShiftRight_bu() {
r := env.vm.stack[env.vm.sp-1].(*evalUint64)

var (
bits = int(r.u & 7)
bytes = int(r.u >> 3)
length = len(l.bytes)
bits = int64(r.u & 7)
bytes = int64(r.u >> 3)
length = int64(len(l.bytes))
out = make([]byte, length)
)

Expand Down Expand Up @@ -904,7 +904,7 @@ func (asm *assembler) Convert_Ti(offset int) {
asm.emit(func(env *ExpressionEnv) int {
v := env.vm.stack[env.vm.sp-offset].(*evalTemporal)
if v.prec != 0 {
env.vm.err = errDeoptimize
env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with non-zero precision")
return 1
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalInt64(v.toInt64())
Expand All @@ -920,6 +920,18 @@ func (asm *assembler) Convert_Tf(offset int) {
}, "CONV SQLTIME(SP-%d), FLOAT64", offset)
}

func (asm *assembler) Convert_Td(offset int) {
asm.emit(func(env *ExpressionEnv) int {
v := env.vm.stack[env.vm.sp-offset].(*evalTemporal)
if v.prec == 0 {
env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with zero precision")
return 1
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalDecimalWithPrec(v.toDecimal(), int32(v.prec))
return 1
}, "CONV SQLTIME(SP-%d), DECIMAL", offset)
}

func (asm *assembler) Convert_iB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,14 @@ func TestCompilerSingle(t *testing.T) {
expression: `DAYOFMONTH(0)`,
result: `INT64(0)`,
},
{
expression: `FROM_UNIXTIME(time '10:04:58.5')`,
result: `DATETIME("1970-01-02 04:54:18.5")`,
},
{
expression: `0 = time '10:04:58.1'`,
result: `INT64(0)`,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/evalengine/expr_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ func (o opBitShr) numeric(num, shift uint64) uint64 { return num >> shift }

func (o opBitShr) binary(num []byte, shift uint64) []byte {
var (
bits = int(shift % 8)
bytes = int(shift / 8)
length = len(num)
bits = int64(shift % 8)
bytes = int64(shift / 8)
length = int64(len(num))
out = make([]byte, length)
)

Expand All @@ -127,13 +127,13 @@ func (o opBitShl) numeric(num, shift uint64) uint64 { return num << shift }

func (o opBitShl) binary(num []byte, shift uint64) []byte {
var (
bits = int(shift % 8)
bytes = int(shift / 8)
length = len(num)
bits = int64(shift % 8)
bytes = int64(shift / 8)
length = int64(len(num))
out = make([]byte, length)
)

for i := 0; i < length; i++ {
for i := int64(0); i < length; i++ {
pos := i + bytes + 1
switch {
case pos < length:
Expand Down
18 changes: 14 additions & 4 deletions go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,22 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {
c.asm.CmpDateString()
case compareAsDateAndNumeric(lt.Type, rt.Type):
if sqltypes.IsDateOrTime(lt.Type) {
c.asm.Convert_Ti(2)
lt.Type = sqltypes.Int64
if lt.Size == 0 {
c.asm.Convert_Ti(2)
lt.Type = sqltypes.Int64
} else {
c.asm.Convert_Tf(2)
lt.Type = sqltypes.Float64
}
}
if sqltypes.IsDateOrTime(rt.Type) {
c.asm.Convert_Ti(1)
rt.Type = sqltypes.Int64
if rt.Size == 0 {
c.asm.Convert_Ti(1)
rt.Type = sqltypes.Int64
} else {
c.asm.Convert_Tf(1)
rt.Type = sqltypes.Float64
}
}
swapped = c.compareNumericTypes(lt, rt)
case compareAsJSON(lt.Type, rt.Type):
Expand Down
Loading

0 comments on commit 8a75ccb

Please sign in to comment.