Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datetime: obey the evalengine's environment time #14358

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func (t Time) FormatDecimal() decimal.Decimal {
return dec
}

func (t Time) ToDateTime() (out DateTime) {
return NewDateTimeFromStd(t.ToStdTime(time.Local))
func (t Time) ToDateTime(now time.Time) (out DateTime) {
return NewDateTimeFromStd(t.ToStdTime(now))
}

func (t Time) IsZero() bool {
Expand Down Expand Up @@ -421,9 +421,9 @@ func (t Time) toStdTime(year int, month time.Month, day int, loc *time.Location)
return time.Date(year, month, day, hours, minutes, secs, nsecs, loc)
}

func (t Time) ToStdTime(loc *time.Location) (out time.Time) {
year, month, day := time.Now().Date()
return t.toStdTime(year, month, day, loc)
func (t Time) ToStdTime(now time.Time) (out time.Time) {
year, month, day := now.Date()
return t.toStdTime(year, month, day, now.Location())
}

func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) {
Expand All @@ -444,20 +444,20 @@ func (d Date) ToStdTime(loc *time.Location) (out time.Time) {
return time.Date(d.Year(), time.Month(d.Month()), d.Day(), 0, 0, 0, 0, loc)
}

func (dt DateTime) ToStdTime(loc *time.Location) time.Time {
func (dt DateTime) ToStdTime(now time.Time) time.Time {
zerodate := dt.Date.IsZero()
zerotime := dt.Time.IsZero()

switch {
case zerodate && zerotime:
return time.Time{}
case zerodate:
return dt.Time.ToStdTime(loc)
return dt.Time.ToStdTime(now)
case zerotime:
return dt.Date.ToStdTime(loc)
return dt.Date.ToStdTime(now.Location())
default:
year, month, day := dt.Date.Year(), time.Month(dt.Date.Month()), dt.Date.Day()
return dt.Time.toStdTime(year, month, day, loc)
return dt.Time.toStdTime(year, month, day, now.Location())
}
}

Expand Down Expand Up @@ -527,7 +527,10 @@ func (dt DateTime) Compare(dt2 DateTime) int {
// if we're comparing a time to a datetime, we need to normalize them
// both into datetimes; this normalization is not trivial because negative
// times result in a date change, so let the standard library handle this
return dt.ToStdTime(time.Local).Compare(dt2.ToStdTime(time.Local))

// Using the current time is OK here since the comparison is relative
now := time.Now()
return dt.ToStdTime(now).Compare(dt2.ToStdTime(now))
}
if cmp := dt.Date.Compare(dt2.Date); cmp != 0 {
return cmp
Expand Down Expand Up @@ -559,9 +562,10 @@ func (dt DateTime) Round(p int) (r DateTime) {
r = dt
if n == 1e9 {
r.Time.nanosecond = 0
return NewDateTimeFromStd(r.ToStdTime(time.Local).Add(time.Second))
r.addInterval(&Interval{timeparts: timeparts{sec: 1}, unit: IntervalSecond})
} else {
r.Time.nanosecond = uint32(n)
}
r.Time.nanosecond = uint32(n)
return r
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ func (v *Value) MarshalDate() string {

func (v *Value) MarshalDateTime() string {
if dt, ok := v.DateTime(); ok {
return dt.ToStdTime(time.Local).Format("2006-01-02 15:04:05.000000")
return dt.ToStdTime(time.Now()).Format("2006-01-02 15:04:05.000000")
}
return ""
}
Expand Down
16 changes: 8 additions & 8 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now)
goto done
}
}
if elseOffset != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now)
} else {
env.vm.stack[env.vm.sp-stackDepth] = nil
}
Expand Down Expand Up @@ -1110,7 +1110,7 @@ func (asm *assembler) Convert_xD(offset int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
d := evalToDate(env.vm.stack[env.vm.sp-offset])
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
if d == nil {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1125,7 +1125,7 @@ func (asm *assembler) Convert_xD_nz(offset int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
d := evalToDate(env.vm.stack[env.vm.sp-offset])
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
if d == nil || d.isZero() {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1140,7 +1140,7 @@ func (asm *assembler) Convert_xDT(offset, prec int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
if dt == nil {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1155,7 +1155,7 @@ func (asm *assembler) Convert_xDT_nz(offset, prec int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
if dt == nil || dt.isZero() {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand Down Expand Up @@ -4252,7 +4252,7 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
}

tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{})
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
Expand All @@ -4274,7 +4274,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col
goto baddate
}

env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col, env.now)
env.vm.sp--
return 1

Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ func TestCompilerSingle(t *testing.T) {
},
}

tz, _ := time.LoadLocation("Europe/Madrid")

for _, tc := range testCases {
t.Run(tc.expression, func(t *testing.T) {
expr, err := sqlparser.ParseExpr(tc.expression)
Expand All @@ -478,6 +480,7 @@ func TestCompilerSingle(t *testing.T) {
}

env := evalengine.EmptyExpressionEnv()
env.SetTime(time.Date(2023, 10, 24, 12, 0, 0, 0, tz))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbussink: the tests we wrote yesterday break if they're not run yesterday 😅

env.Row = tc.values

expected, err := env.Evaluate(evalengine.Deoptimize(converted))
Expand Down
11 changes: 6 additions & 5 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package evalengine

import (
"strconv"
"time"
"unicode/utf8"

"vitess.io/vitess/go/hack"
Expand Down Expand Up @@ -167,7 +168,7 @@ func evalIsTruthy(e eval) boolean {
}
}

func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) {
if e == nil {
return nil, nil
}
Expand Down Expand Up @@ -199,9 +200,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64:
return evalToInt64(e).toUint64(), nil
case sqltypes.Date:
return evalToDate(e), nil
return evalToDate(e, now), nil
case sqltypes.Datetime, sqltypes.Timestamp:
return evalToDateTime(e, -1), nil
return evalToDateTime(e, -1, now), nil
case sqltypes.Time:
return evalToTime(e, -1), nil
default:
Expand Down Expand Up @@ -329,7 +330,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
return nil, err
}
// Separate return here to avoid nil wrapped in interface type
d := evalToDate(e)
d := evalToDate(e, time.Now())
if d == nil {
return nil, nil
}
Expand All @@ -340,7 +341,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
return nil, err
}
// Separate return here to avoid nil wrapped in interface type
dt := evalToDateTime(e, -1)
dt := evalToDateTime(e, -1, time.Now())
if dt == nil {
return nil, nil
}
Expand Down
22 changes: 12 additions & 10 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package evalengine

import (
"time"

"vitess.io/vitess/go/hack"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/datetime"
Expand Down Expand Up @@ -92,12 +94,12 @@ func (e *evalTemporal) toJSON() *evalJSON {
}
}

func (e *evalTemporal) toDateTime(l int) *evalTemporal {
func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime, sqltypes.Date:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)}
case sqltypes.Time:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(), prec: uint8(l)}
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)}
default:
panic("unreachable")
}
Expand All @@ -118,15 +120,15 @@ func (e *evalTemporal) toTime(l int) *evalTemporal {
}
}

func (e *evalTemporal) toDate() *evalTemporal {
func (e *evalTemporal) toDate(now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime:
dt := datetime.DateTime{Date: e.dt.Date}
return &evalTemporal{t: sqltypes.Date, dt: dt}
case sqltypes.Date:
return e
case sqltypes.Time:
dt := e.dt.Time.ToDateTime()
dt := e.dt.Time.ToDateTime(now)
dt.Time = datetime.Time{}
return &evalTemporal{t: sqltypes.Date, dt: dt}
default:
Expand All @@ -138,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
return e.dt.IsZero()
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation) eval {
func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
var tmp *evalTemporal
var ok bool

Expand All @@ -150,7 +152,7 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
tmp = &evalTemporal{t: e.t}
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
tmp = e.toDateTime(int(e.prec))
tmp = e.toDateTime(int(e.prec), now)
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
}
if !ok {
Expand Down Expand Up @@ -324,10 +326,10 @@ func evalToTime(e eval, l int) *evalTemporal {
return nil
}

func evalToDateTime(e eval, l int) *evalTemporal {
func evalToDateTime(e eval, l int, now time.Time) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e.toDateTime(precision(l, int(e.prec)))
return e.toDateTime(precision(l, int(e.prec)), now)
case *evalBytes:
if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() {
return newEvalDateTime(t, l)
Expand Down Expand Up @@ -371,10 +373,10 @@ func evalToDateTime(e eval, l int) *evalTemporal {
return nil
}

func evalToDate(e eval) *evalTemporal {
func evalToDate(e eval, now time.Time) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e.toDate()
return e.toDate(now)
case *evalBytes:
if t, _ := datetime.ParseDate(e.string()); !t.IsZero() {
return newEvalDate(t)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/expr_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if dt := evalToDateTime(e, c.Length); dt != nil {
if dt := evalToDateTime(e, c.Length, env.now); dt != nil {
return dt, nil
}
return nil, nil
case "DATE":
if d := evalToDate(e); d != nil {
if d := evalToDate(e, env.now); d != nil {
return d, nil
}
return nil, nil
Expand Down
19 changes: 10 additions & 9 deletions go/vt/vtgate/evalengine/expr_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ func (env *ExpressionEnv) TypeOf(expr Expr, fields []*querypb.Field) (sqltypes.T
return ty, f, nil
}

func (env *ExpressionEnv) SetTime(now time.Time) {
// This function is called only once by NewExpressionEnv to ensure that all expressions in the same
// ExpressionEnv evaluate NOW() and similar SQL functions to the same value.
env.now = now
if tz := env.currentTimezone(); tz != nil {
env.now = env.now.In(tz)
}
}

// EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row
func EmptyExpressionEnv() *ExpressionEnv {
return NewExpressionEnv(context.Background(), nil, nil)
Expand All @@ -108,14 +117,6 @@ func EmptyExpressionEnv() *ExpressionEnv {
func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVariable, vc VCursor) *ExpressionEnv {
env := &ExpressionEnv{BindVars: bindVars, vc: vc}
env.user = callerid.ImmediateCallerIDFromContext(ctx)

// The current time for this ExpressionEnv is set only once, during creation.
// This is to ensure that all expressions in the same ExpressionEnv evaluate NOW()
// and similar SQL functions to the same value.
env.now = time.Now()

if tz := env.currentTimezone(); tz != nil {
env.now = env.now.In(tz)
}
env.SetTime(time.Now())
return env
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) {
return nil, nil
}
t, _ := c.typeof(env, nil)
return evalCoerce(result, t, ca.result().Collation)
return evalCoerce(result, t, ca.result().Collation, env.now)
}

func (c *CaseExpr) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
Expand Down
Loading
Loading