Skip to content

Commit

Permalink
fix output type for DateAdd() and DateSub() functions (#2609)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Jul 31, 2024
1 parent 5dedf37 commit 6ad97cc
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 49 deletions.
18 changes: 16 additions & 2 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5302,14 +5302,28 @@ func TestPrepared(t *testing.T, harness Harness) {
},
{
Query: "SELECT DATE_ADD(TIMESTAMP(?), INTERVAL 1 DAY);",
Expected: []sql.Row{{time.Date(2022, time.October, 27, 13, 14, 15, 0, time.UTC)}},
Expected: []sql.Row{{"2022-10-27 13:14:15"}},
Bindings: map[string]*query.BindVariable{
"v1": mustBuildBindVariable(time.Date(2022, time.October, 26, 13, 14, 15, 0, time.UTC)),
},
},
{
Query: "SELECT DATE_ADD(TIMESTAMP(?), INTERVAL 1 DAY);",
Expected: []sql.Row{{"2022-10-27 13:14:15"}},
Bindings: map[string]*query.BindVariable{
"v1": mustBuildBindVariable("2022-10-26 13:14:15"),
},
},
{
Query: "SELECT DATE_ADD(?, INTERVAL 1 DAY);",
Expected: []sql.Row{{time.Date(2022, time.October, 27, 13, 14, 15, 0, time.UTC)}},
Expected: []sql.Row{{"2022-10-27 13:14:15"}},
Bindings: map[string]*query.BindVariable{
"v1": mustBuildBindVariable(time.Date(2022, time.October, 26, 13, 14, 15, 0, time.UTC)),
},
},
{
Query: "SELECT DATE_ADD(?, INTERVAL 1 DAY);",
Expected: []sql.Row{{"2022-10-27 13:14:15"}},
Bindings: map[string]*query.BindVariable{
"v1": mustBuildBindVariable("2022-10-26 13:14:15"),
},
Expand Down
4 changes: 2 additions & 2 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -6309,7 +6309,7 @@ Select * from (
},
{
Query: "SELECT DATE_ADD('2018-05-02', INTERVAL 1 day)",
Expected: []sql.Row{{time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC)}},
Expected: []sql.Row{{"2018-05-03"}},
},
{
Query: "SELECT DATE_ADD(DATE('2018-05-02'), INTERVAL 1 day)",
Expand All @@ -6321,7 +6321,7 @@ Select * from (
},
{
Query: "SELECT DATE_SUB('2018-05-02', INTERVAL 1 DAY)",
Expected: []sql.Row{{time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC)}},
Expected: []sql.Row{{"2018-05-01"}},
},
{
Query: "SELECT DATE_SUB(DATE('2018-05-02'), INTERVAL 1 DAY)",
Expand Down
105 changes: 78 additions & 27 deletions sql/expression/function/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,43 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, nil
}

date, _, err = types.DatetimeMaxPrecision.Convert(date)
var dateVal interface{}
dateVal, _, err = types.DatetimeMaxPrecision.Convert(date)
if err != nil {
ctx.Warn(1292, err.Error())
return nil, nil
}

// return appropriate type
res := types.ValidateTime(delta.Add(date.(time.Time)))
res := types.ValidateTime(delta.Add(dateVal.(time.Time)))
if res == nil {
return nil, nil
}

resType := d.Type()
if types.IsText(resType) {
return res, nil
// If the input is a properly formatted date/datetime string, the output should also be a string
if dateStr, isStr := date.(string); isStr {
if res.(time.Time).Nanosecond() > 0 {
return res.(time.Time).Format(sql.DatetimeLayoutNoTrim), nil
}
if isHmsInterval(d.Interval) {
return res.(time.Time).Format(sql.TimestampDatetimeLayout), nil
}
for _, layout := range types.DateOnlyLayouts {
if _, pErr := time.Parse(layout, dateStr); pErr != nil {
continue
}
return res.(time.Time).Format(sql.DateLayout), nil
}
}
}

ret, _, err := resType.Convert(res)
return ret, err
if err != nil {
return nil, err
}
return ret, nil
}

func (d *DateAdd) String() string {
Expand Down Expand Up @@ -256,20 +279,43 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, nil
}

date, _, err = types.DatetimeMaxPrecision.Convert(date)
var dateVal interface{}
dateVal, _, err = types.DatetimeMaxPrecision.Convert(date)
if err != nil {
ctx.Warn(1292, err.Error())
return nil, nil
}

// return appropriate type
res := types.ValidateTime(delta.Sub(date.(time.Time)))
res := types.ValidateTime(delta.Sub(dateVal.(time.Time)))
if res == nil {
return nil, nil
}

resType := d.Type()
if types.IsText(resType) {
return res, nil
// If the input is a properly formatted date/datetime string, the output should also be a string
if dateStr, isStr := date.(string); isStr {
if res.(time.Time).Nanosecond() > 0 {
return res.(time.Time).Format(sql.DatetimeLayoutNoTrim), nil
}
if isHmsInterval(d.Interval) {
return res.(time.Time).Format(sql.TimestampDatetimeLayout), nil
}
for _, layout := range types.DateOnlyLayouts {
if _, pErr := time.Parse(layout, dateStr); pErr != nil {
continue
}
return res.(time.Time).Format(sql.DateLayout), nil
}
}
}

ret, _, err := resType.Convert(res)
return ret, err
if err != nil {
return nil, err
}
return ret, nil
}

func (d *DateSub) String() string {
Expand Down Expand Up @@ -734,6 +780,20 @@ func (c CurrDate) WithChildren(children ...sql.Expression) (sql.Expression, erro
return NoArgFuncWithChildren(c, children)
}

func isYmdInterval(interval *expression.Interval) bool {
return strings.Contains(interval.Unit, "YEAR") ||
strings.Contains(interval.Unit, "QUARTER") ||
strings.Contains(interval.Unit, "MONTH") ||
strings.Contains(interval.Unit, "WEEK") ||
strings.Contains(interval.Unit, "DAY")
}

func isHmsInterval(interval *expression.Interval) bool {
return strings.Contains(interval.Unit, "HOUR") ||
strings.Contains(interval.Unit, "MINUTE") ||
strings.Contains(interval.Unit, "SECOND")
}

// Determines the return type of a DateAdd/DateSub expression
// Logic is based on https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-add
func dateOffsetType(input sql.Expression, interval *expression.Interval) sql.Type {
Expand All @@ -747,31 +807,22 @@ func dateOffsetType(input sql.Expression, interval *expression.Interval) sql.Typ
return types.Null
}

if types.IsDatetimeType(inputType) || types.IsTimestampType(inputType) {
return types.DatetimeMaxPrecision
}

// set type flags
isInputDate := inputType == types.Date
isInputTime := inputType == types.Time
isInputDatetime := types.IsDatetimeType(inputType) || types.IsTimestampType(inputType)

// result is Datetime if expression is Datetime or Timestamp
if isInputDatetime {
return types.DatetimeMaxPrecision
}

// determine what kind of interval we're dealing with
isYmdInterval := strings.Contains(interval.Unit, "YEAR") ||
strings.Contains(interval.Unit, "QUARTER") ||
strings.Contains(interval.Unit, "MONTH") ||
strings.Contains(interval.Unit, "WEEK") ||
strings.Contains(interval.Unit, "DAY")

isHmsInterval := strings.Contains(interval.Unit, "HOUR") ||
strings.Contains(interval.Unit, "MINUTE") ||
strings.Contains(interval.Unit, "SECOND")
isMixedInterval := isYmdInterval && isHmsInterval
isYmd := isYmdInterval(interval)
isHms := isHmsInterval(interval)
isMixed := isYmd && isHms

// handle input of Date type
if isInputDate {
if isHmsInterval || isMixedInterval {
if isHms || isMixed {
// if interval contains time components, result is Datetime
return types.DatetimeMaxPrecision
} else {
Expand All @@ -782,7 +833,7 @@ func dateOffsetType(input sql.Expression, interval *expression.Interval) sql.Typ

// handle input of Time type
if isInputTime {
if isYmdInterval || isMixedInterval {
if isYmd || isMixed {
// if interval contains date components, result is Datetime
return types.DatetimeMaxPrecision
} else {
Expand All @@ -793,7 +844,7 @@ func dateOffsetType(input sql.Expression, interval *expression.Interval) sql.Typ

// handle dynamic input type
if types.IsDeferredType(inputType) {
if isYmdInterval && !isHmsInterval {
if isYmd && !isHms {
// if interval contains only date components, result is Date
return types.Date
} else {
Expand Down
Loading

0 comments on commit 6ad97cc

Please sign in to comment.