Skip to content

Commit

Permalink
evalengine: Internal cleanup and consistency fixes (#14854)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Dec 27, 2023
1 parent ab37170 commit d62a5c5
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 89 deletions.
50 changes: 25 additions & 25 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,12 @@ func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) {
return dt.Time, itv.precision(stradd), ok
}

func (t Time) toSeconds() int {
tsecs := t.Hour()*secondsPerHour + t.Minute()*secondsPerMinute + t.Second()
func (t Time) toDuration() time.Duration {
dur := time.Duration(t.hour)*time.Hour + time.Duration(t.minute)*time.Minute + time.Duration(t.second)*time.Second + time.Duration(t.nanosecond)*time.Nanosecond
if t.Neg() {
return -tsecs
return -dur
}
return tsecs
return dur
}

func (d Date) ToStdTime(loc *time.Location) (out time.Time) {
Expand Down Expand Up @@ -577,8 +577,12 @@ func (dt DateTime) Round(p int) (r DateTime) {
return r
}

func (dt DateTime) toSeconds() int {
return (dt.Date.Day()-1)*secondsPerDay + dt.Time.toSeconds()
func (dt DateTime) toDuration() time.Duration {
dur := dt.Time.toDuration()
if !dt.Date.IsZero() {
dur += time.Duration(dt.Date.Day()-1) * durationPerDay
}
return dur
}

func (dt *DateTime) addInterval(itv *Interval) bool {
Expand All @@ -588,29 +592,25 @@ func (dt *DateTime) addInterval(itv *Interval) bool {
return false
}

nsec := dt.Time.Nanosecond() + itv.nsec
sec := dt.toSeconds() + itv.toSeconds() + (nsec / int(time.Second))
nsec = nsec % int(time.Second)

if nsec < 0 {
nsec += int(time.Second)
sec--
}
dur := dt.toDuration()
dur += itv.toDuration()
days := time.Duration(0)
if !dt.Date.IsZero() {
days = dur / durationPerDay
dur -= days * durationPerDay

days := sec / secondsPerDay
sec -= days * secondsPerDay

if sec < 0 {
sec += secondsPerDay
days--
if dur < 0 {
dur += durationPerDay
days--
}
}

dt.Time.nanosecond = uint32(nsec)
dt.Time.second = uint8(sec % secondsPerMinute)
dt.Time.minute = uint8((sec / secondsPerMinute) % secondsPerMinute)
dt.Time.hour = uint16(sec / secondsPerHour)
dt.Time.nanosecond = uint32((dur % time.Second) / time.Nanosecond)
dt.Time.second = uint8((dur % time.Minute) / time.Second)
dt.Time.minute = uint8((dur % time.Hour) / time.Minute)
dt.Time.hour = uint16(dur / time.Hour)

daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + days
daynum := mysqlDayNumber(dt.Date.Year(), dt.Date.Month(), 1) + int(days)
if daynum < 0 || daynum > maxDay {
return false
}
Expand Down
4 changes: 1 addition & 3 deletions go/mysql/datetime/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,5 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l
}

const (
secondsPerMinute = 60
secondsPerHour = 60 * secondsPerMinute
secondsPerDay = 24 * secondsPerHour
durationPerDay = 24 * time.Hour
)
4 changes: 2 additions & 2 deletions go/mysql/datetime/timeparts.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ func (tp *timeparts) isZero() bool {
return tp.year == 0 && tp.month == 0 && tp.day == 0 && tp.hour == 0 && tp.min == 0 && tp.sec == 0 && tp.nsec == 0
}

func (tp *timeparts) toSeconds() int {
return tp.day*secondsPerDay + tp.hour*3600 + tp.min*60 + tp.sec
func (tp *timeparts) toDuration() time.Duration {
return time.Duration(tp.day)*durationPerDay + time.Duration(tp.hour)*time.Hour + time.Duration(tp.min)*time.Minute + time.Duration(tp.sec)*time.Second + time.Duration(tp.nsec)*time.Nanosecond
}
8 changes: 7 additions & 1 deletion go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,13 @@ func (v Value) TinyWeightCmp(other Value) int {
if v.flags&other.flags&flagTinyWeight == 0 {
return 0
}
return int(int64(v.tinyweight) - int64(other.tinyweight))
if v.tinyweight == other.tinyweight {
return 0
}
if v.tinyweight < other.tinyweight {
return -1
}
return 1
}

func (v Value) TinyWeight() uint32 {
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/onlineddl/vtgate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func VtgateExecQuery(t *testing.T, vtParams *mysql.ConnParams, query string, exp
require.Nil(t, err)
defer conn.Close()

qr, err := conn.ExecuteFetch(query, math.MaxInt64, true)
qr, err := conn.ExecuteFetch(query, math.MaxInt, true)
if expectError == "" {
require.NoError(t, err)
} else {
Expand Down
5 changes: 3 additions & 2 deletions go/vt/mysqlctl/azblobbackupstorage/azblob.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ func (bh *AZBlobBackupHandle) AddFile(ctx context.Context, filename string, file
return nil, fmt.Errorf("AddFile cannot be called on read-only backup")
}
// Error out if the file size it too large ( ~4.75 TB)
if filesize > azblob.BlockBlobMaxStageBlockBytes*azblob.BlockBlobMaxBlocks {
return nil, fmt.Errorf("filesize (%v) is too large to upload to az blob (max size %v)", filesize, azblob.BlockBlobMaxStageBlockBytes*azblob.BlockBlobMaxBlocks)
maxSize := int64(azblob.BlockBlobMaxStageBlockBytes * azblob.BlockBlobMaxBlocks)
if filesize > maxSize {
return nil, fmt.Errorf("filesize (%v) is too large to upload to az blob (max size %v)", filesize, maxSize)
}

obj := objName(bh.dir, bh.name, filename)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/throttler/throttler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ const (

// MaxRateModuleDisabled can be set in NewThrottler() to disable throttling
// by a fixed rate.
MaxRateModuleDisabled = math.MaxInt64
MaxRateModuleDisabled = int64(math.MaxInt64)

// InvalidMaxRate is a constant which will fail in a NewThrottler() call.
// It should be used when returning maxRate in an error case.
InvalidMaxRate = -1

// ReplicationLagModuleDisabled can be set in NewThrottler() to disable
// throttling based on the MySQL replication lag.
ReplicationLagModuleDisabled = math.MaxInt64
ReplicationLagModuleDisabled = int64(math.MaxInt64)

// InvalidMaxReplicationLag is a constant which will fail in a NewThrottler()
// call. It should be used when returning maxReplicationlag in an error case.
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (ms *MemorySort) NeedsTransaction() bool {

func (ms *MemorySort) fetchCount(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (int, error) {
if ms.UpperLimit == nil {
return math.MaxInt64, nil
return math.MaxInt, nil
}
env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
resolved, err := env.Evaluate(ms.UpperLimit)
Expand Down
34 changes: 17 additions & 17 deletions go/vt/vtgate/evalengine/api_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ func TestArithmetics(t *testing.T) {
// testing for int64 overflow with min negative value
v1: NewInt64(math.MinInt64),
v2: NewInt64(1),
err: dataOutOfRangeError(math.MinInt64, 1, "BIGINT", "-").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), int64(1), "BIGINT", "-").Error(),
}, {
v1: NewUint64(4),
v2: NewInt64(5),
err: dataOutOfRangeError(4, 5, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(uint64(4), int64(5), "BIGINT UNSIGNED", "-").Error(),
}, {
// testing uint - int
v1: NewUint64(7),
Expand All @@ -103,7 +103,7 @@ func TestArithmetics(t *testing.T) {
// testing for int64 overflow
v1: NewInt64(math.MinInt64),
v2: NewUint64(0),
err: dataOutOfRangeError(math.MinInt64, 0, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), uint64(0), "BIGINT UNSIGNED", "-").Error(),
}, {
v1: TestValue(sqltypes.VarChar, "c"),
v2: NewInt64(1),
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestArithmetics(t *testing.T) {
}, {
v1: NewInt64(-1),
v2: NewUint64(2),
err: dataOutOfRangeError(-1, 2, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(int64(-1), int64(2), "BIGINT UNSIGNED", "-").Error(),
}, {
v1: NewInt64(2),
v2: NewUint64(1),
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestArithmetics(t *testing.T) {
// testing uint - uint if v2 > v1
v1: NewUint64(2),
v2: NewUint64(4),
err: dataOutOfRangeError(2, 4, "BIGINT UNSIGNED", "-").Error(),
err: dataOutOfRangeError(uint64(2), uint64(4), "BIGINT UNSIGNED", "-").Error(),
}, {
// testing uint - (- int)
v1: NewUint64(1),
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestArithmetics(t *testing.T) {
}, {
v1: NewInt64(-2),
v2: NewUint64(1),
err: dataOutOfRangeError(1, -2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(uint64(1), int64(-2), "BIGINT UNSIGNED", "+").Error(),
}, {
v1: NewInt64(math.MaxInt64),
v2: NewInt64(-2),
Expand All @@ -221,12 +221,12 @@ func TestArithmetics(t *testing.T) {
// testing for overflow uint64
v1: NewUint64(maxUint64),
v2: NewUint64(2),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(maxUint64, uint64(2), "BIGINT UNSIGNED", "+").Error(),
}, {
// int64 underflow
v1: NewInt64(math.MinInt64),
v2: NewInt64(-2),
err: dataOutOfRangeError(math.MinInt64, -2, "BIGINT", "+").Error(),
err: dataOutOfRangeError(int64(math.MinInt64), int64(-2), "BIGINT", "+").Error(),
}, {
// checking int64 max value can be returned
v1: NewInt64(math.MaxInt64),
Expand Down Expand Up @@ -261,7 +261,7 @@ func TestArithmetics(t *testing.T) {
// testing for uint64 overflow with max uint64 + int value
v1: NewUint64(maxUint64),
v2: NewInt64(2),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "+").Error(),
err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "+").Error(),
}, {
v1: sqltypes.NewHexNum([]byte("0x9")),
v2: NewInt64(1),
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestArithmetics(t *testing.T) {
// Lower bound for int64
v1: NewInt64(math.MinInt64),
v2: NewInt64(1),
out: NewDecimal(strconv.Itoa(math.MinInt64) + ".0000"),
out: NewDecimal(strconv.FormatInt(math.MinInt64, 10) + ".0000"),
}, {
// upper bound for uint64
v1: NewUint64(math.MaxUint64),
Expand Down Expand Up @@ -413,12 +413,12 @@ func TestArithmetics(t *testing.T) {
// testing for overflow of int64
v1: NewInt64(math.MaxInt64),
v2: NewInt64(2),
err: dataOutOfRangeError(math.MaxInt64, 2, "BIGINT", "*").Error(),
err: dataOutOfRangeError(int64(math.MaxInt64), int64(2), "BIGINT", "*").Error(),
}, {
// testing for underflow of uint64*max.uint64
v1: NewInt64(2),
v2: NewUint64(maxUint64),
err: dataOutOfRangeError(maxUint64, 2, "BIGINT UNSIGNED", "*").Error(),
err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "*").Error(),
}, {
v1: NewUint64(math.MaxUint64),
v2: NewUint64(1),
Expand All @@ -427,7 +427,7 @@ func TestArithmetics(t *testing.T) {
// Checking whether maxInt value can be passed as uint value
v1: NewUint64(math.MaxInt64),
v2: NewInt64(3),
err: dataOutOfRangeError(math.MaxInt64, 3, "BIGINT UNSIGNED", "*").Error(),
err: dataOutOfRangeError(uint64(math.MaxInt64), int64(3), "BIGINT UNSIGNED", "*").Error(),
}},
}}

Expand Down Expand Up @@ -492,7 +492,7 @@ func TestNullSafeAdd(t *testing.T) {
}, {
v1: NewInt64(-100),
v2: NewUint64(10),
err: dataOutOfRangeError(10, -100, "BIGINT UNSIGNED", "+"),
err: dataOutOfRangeError(uint64(10), int64(-100), "BIGINT UNSIGNED", "+"),
}, {
// Make sure underlying error is returned while converting.
v1: NewFloat64(1),
Expand Down Expand Up @@ -594,12 +594,12 @@ func TestAddNumeric(t *testing.T) {
// Int64 overflow.
v1: newEvalInt64(9223372036854775807),
v2: newEvalInt64(2),
err: dataOutOfRangeError(9223372036854775807, 2, "BIGINT", "+"),
err: dataOutOfRangeError(int64(9223372036854775807), int64(2), "BIGINT", "+"),
}, {
// Int64 underflow.
v1: newEvalInt64(-9223372036854775807),
v2: newEvalInt64(-2),
err: dataOutOfRangeError(-9223372036854775807, -2, "BIGINT", "+"),
err: dataOutOfRangeError(int64(-9223372036854775807), int64(-2), "BIGINT", "+"),
}, {
v1: newEvalInt64(-1),
v2: newEvalUint64(2),
Expand All @@ -608,7 +608,7 @@ func TestAddNumeric(t *testing.T) {
// Uint64 overflow.
v1: newEvalUint64(18446744073709551615),
v2: newEvalUint64(2),
err: dataOutOfRangeError(uint64(18446744073709551615), 2, "BIGINT UNSIGNED", "+"),
err: dataOutOfRangeError(uint64(18446744073709551615), uint64(2), "BIGINT UNSIGNED", "+"),
}}
for _, tcase := range tcases {
got, err := addNumericWithError(tcase.v1, tcase.v2)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

func dataOutOfRangeError[N1, N2 int | int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error {
func dataOutOfRangeError[N1, N2 int64 | uint64 | float64](v1 N1, v2 N2, typ, sign string) error {
return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "%s value is out of range in '(%v %s %v)'", typ, v1, sign, v2)
}

Expand Down
8 changes: 6 additions & 2 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ func (c *compiler) compileToNumeric(ct ctype, offset int, fallback sqltypes.Type

if sqltypes.IsDateOrTime(ct.Type) {
if preciseDatetime {
c.asm.Convert_Ti(offset)
return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric}
if ct.Size == 0 {
c.asm.Convert_Ti(offset)
return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric}
}
c.asm.Convert_Td(offset)
return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Size: ct.Size}
}
c.asm.Convert_Tf(offset)
return ctype{Type: sqltypes.Float64, Flag: ct.Flag, Col: collationNumeric}
Expand Down
28 changes: 20 additions & 8 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,13 @@ func (asm *assembler) BitShiftLeft_bu() {
r := env.vm.stack[env.vm.sp-1].(*evalUint64)

var (
bits = int(r.u & 7)
bytes = int(r.u >> 3)
length = len(l.bytes)
bits = int64(r.u & 7)
bytes = int64(r.u >> 3)
length = int64(len(l.bytes))
out = make([]byte, length)
)

for i := 0; i < length; i++ {
for i := int64(0); i < length; i++ {
pos := i + bytes + 1
switch {
case pos < length:
Expand Down Expand Up @@ -332,9 +332,9 @@ func (asm *assembler) BitShiftRight_bu() {
r := env.vm.stack[env.vm.sp-1].(*evalUint64)

var (
bits = int(r.u & 7)
bytes = int(r.u >> 3)
length = len(l.bytes)
bits = int64(r.u & 7)
bytes = int64(r.u >> 3)
length = int64(len(l.bytes))
out = make([]byte, length)
)

Expand Down Expand Up @@ -904,7 +904,7 @@ func (asm *assembler) Convert_Ti(offset int) {
asm.emit(func(env *ExpressionEnv) int {
v := env.vm.stack[env.vm.sp-offset].(*evalTemporal)
if v.prec != 0 {
env.vm.err = errDeoptimize
env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with non-zero precision")
return 1
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalInt64(v.toInt64())
Expand All @@ -920,6 +920,18 @@ func (asm *assembler) Convert_Tf(offset int) {
}, "CONV SQLTIME(SP-%d), FLOAT64", offset)
}

func (asm *assembler) Convert_Td(offset int) {
asm.emit(func(env *ExpressionEnv) int {
v := env.vm.stack[env.vm.sp-offset].(*evalTemporal)
if v.prec == 0 {
env.vm.err = vterrors.NewErrorf(vtrpc.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "temporal type with zero precision")
return 1
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalDecimalWithPrec(v.toDecimal(), int32(v.prec))
return 1
}, "CONV SQLTIME(SP-%d), DECIMAL", offset)
}

func (asm *assembler) Convert_iB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
Expand Down
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,18 @@ func TestCompilerSingle(t *testing.T) {
expression: `week('2024-12-31', 5)`,
result: `INT64(53)`,
},
{
expression: `FROM_UNIXTIME(time '10:04:58.5')`,
result: `DATETIME("1970-01-02 04:54:18.5")`,
},
{
expression: `0 = time '10:04:58.1'`,
result: `INT64(0)`,
},
{
expression: `CAST(time '32:34:58.5' AS TIME)`,
result: `TIME("32:34:59")`,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down
Loading

0 comments on commit d62a5c5

Please sign in to comment.