From 7ca2b8116b91c818b6932572a0918075fadd2e26 Mon Sep 17 00:00:00 2001 From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Date: Thu, 25 Apr 2024 13:14:48 +0530 Subject: [PATCH] Fix Scale and length handling in `CASE` and JOIN bind variables (#15787) Signed-off-by: Manan Gupta Signed-off-by: Dirkjan Bussink Co-authored-by: Dirkjan Bussink --- go/mysql/json/parser.go | 8 +++++ .../endtoend/vtgate/queries/tpch/tpch_test.go | 17 ++++++++-- go/vt/vtgate/engine/join.go | 13 +++++--- .../vtgate/evalengine/api_type_aggregation.go | 20 ++++++++---- go/vt/vtgate/evalengine/compiler_asm.go | 6 ++-- go/vt/vtgate/evalengine/compiler_test.go | 30 +++++++++++++++++ go/vt/vtgate/evalengine/eval.go | 10 +++--- go/vt/vtgate/evalengine/eval_bytes.go | 8 +++++ go/vt/vtgate/evalengine/eval_enum.go | 8 +++++ go/vt/vtgate/evalengine/eval_numeric.go | 32 +++++++++++++++++++ go/vt/vtgate/evalengine/eval_set.go | 8 +++++ go/vt/vtgate/evalengine/eval_temporal.go | 8 +++++ go/vt/vtgate/evalengine/eval_tuple.go | 8 +++++ go/vt/vtgate/evalengine/expr_logical.go | 10 +++--- go/vt/vtgate/evalengine/fn_compare.go | 2 +- 15 files changed, 162 insertions(+), 26 deletions(-) diff --git a/go/mysql/json/parser.go b/go/mysql/json/parser.go index 707d890df93..b7a87c25756 100644 --- a/go/mysql/json/parser.go +++ b/go/mysql/json/parser.go @@ -669,6 +669,14 @@ type Value struct { n NumberType } +func (v *Value) Size() int32 { + return 0 +} + +func (v *Value) Scale() int32 { + return 0 +} + func (v *Value) MarshalDate() string { if d, ok := v.Date(); ok { return d.ToStdTime(time.Local).Format("2006-01-02") diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 513aea94a86..70e0c5e1edd 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -19,10 +19,10 @@ package union import ( "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" - - "github.com/stretchr/testify/require" ) func start(t *testing.T) (utils.MySQLCompare, func()) { @@ -161,6 +161,19 @@ group by order by value desc;`, }, + { + name: "Q14 without decimal literal", + query: `select sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, } for _, testcase := range testcases { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 45b0d182dd7..dc952673cfe 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -17,6 +17,7 @@ limitations under the License. package engine import ( + "bytes" "context" "fmt" "strings" @@ -61,7 +62,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st result := &sqltypes.Result{} if len(lresult.Rows) == 0 && wantfields { for k, col := range jn.Vars { - joinVars[k] = bindvarForType(lresult.Fields[col].Type) + joinVars[k] = bindvarForType(lresult.Fields[col]) } rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) if err != nil { @@ -95,19 +96,21 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } -func bindvarForType(t querypb.Type) *querypb.BindVariable { +func bindvarForType(field *querypb.Field) *querypb.BindVariable { bv := &querypb.BindVariable{ - Type: t, + Type: field.Type, Value: nil, } - switch t { + switch field.Type { case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16, querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64: bv.Value = []byte("0") case querypb.Type_FLOAT32, querypb.Type_FLOAT64: bv.Value = []byte("0e0") case querypb.Type_DECIMAL: - bv.Value = []byte("0.0") + size := max(1, int(field.ColumnLength-field.Decimals)) + scale := max(1, int(field.Decimals)) + bv.Value = append(append(bytes.Repeat([]byte{'0'}, size), byte('.')), bytes.Repeat([]byte{'0'}, scale)...) default: return sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 04622e5a212..45c0377bca4 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -47,7 +47,8 @@ type typeAggregation struct { blob uint16 total uint16 - nullable bool + nullable bool + scale, size int32 } type TypeAggregator struct { @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error { return nil } - ta.types.addNullable(typ.typ, typ.nullable) + ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale) if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil { return err } @@ -95,6 +96,7 @@ func (ta *typeAggregation) empty() bool { func (ta *typeAggregation) addEval(e eval) { var t sqltypes.Type var f typeFlag + var size, scale int32 switch e := e.(type) { case nil: t = sqltypes.Null @@ -102,13 +104,17 @@ func (ta *typeAggregation) addEval(e eval) { case *evalBytes: t = sqltypes.Type(e.tt) f = e.flag + size = e.Size() + scale = e.Scale() default: t = e.SQLType() + size = e.Size() + scale = e.Scale() } - ta.add(t, f) + ta.add(t, f, size, scale) } -func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) { +func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) { var flag typeFlag if typ == sqltypes.HexVal || typ == sqltypes.HexNum { typ = sqltypes.Binary @@ -117,13 +123,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) { if nullable { flag |= flagNullable } - ta.add(typ, flag) + ta.add(typ, flag, size, scale) } -func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { +func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) { if f&flagNullable != 0 { ta.nullable = true } + ta.size = max(ta.size, size) + ta.scale = max(ta.scale, scale) switch tt { case sqltypes.Float32, sqltypes.Float64: ta.double++ diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 07c302ac6ec..2cda3ecb348 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() { }, "CMPFLAG NE [NULL]") } -func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) { +func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, size, scale int32, cc collations.TypedCollation, allowZeroDate bool) { elseOffset := 0 if hasElse { elseOffset = 1 @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, size, scale, cc.Collation, env.now, allowZeroDate) goto done } } if elseOffset != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, size, scale, cc.Collation, env.now, allowZeroDate) } else { env.vm.stack[env.vm.sp-stackDepth] = nil } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 3d5283db415..b2d4ff0c2f0 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/olekukonko/tablewriter" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) { values []sqltypes.Value result string collation collations.ID + typeWanted evalengine.Type }{ { expression: "1 + column0", @@ -675,6 +677,28 @@ func TestCompilerSingle(t *testing.T) { expression: `1 * unix_timestamp(time('1.0000'))`, result: `DECIMAL(1698098401.0000)`, }, + { + expression: `(case + when 'PROMOTION' like 'PROMO%' + then 0.01 + else 0 + end) * 0.01`, + result: `DECIMAL(0.0001)`, + typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil), + }, + { + expression: `case when true then 0.02 else 1.000 end`, + result: `DECIMAL(0.02)`, + }, + { + expression: `case + when false + then timestamp'2023-10-24 12:00:00.123456' + else timestamp'2023-10-24 12:00:00' + end`, + result: `DATETIME("2023-10-24 12:00:00.000000")`, + typeWanted: evalengine.NewTypeEx(sqltypes.Datetime, collations.CollationBinaryID, false, 6, 0, nil), + }, } tz, _ := time.LoadLocation("Europe/Madrid") @@ -715,6 +739,12 @@ func TestCompilerSingle(t *testing.T) { t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation) } + if tc.typeWanted.Type() != sqltypes.Unknown { + typ, err := env.TypeOf(converted) + require.NoError(t, err) + require.True(t, tc.typeWanted.Equal(&typ)) + } + // re-run the same evaluation multiple times to ensure results are always consistent for i := 0; i < 8; i++ { res, err := env.Evaluate(converted) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 90b1add541a..49423979379 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool { type eval interface { ToRawBytes() []byte SQLType() sqltypes.Type + Size() int32 + Scale() int32 } type hashable interface { @@ -170,7 +172,7 @@ func evalIsTruthy(e eval) boolean { } } -func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) { +func evalCoerce(e eval, typ sqltypes.Type, size, scale int32, col collations.ID, now time.Time, allowZero bool) (eval, error) { if e == nil { return nil, nil } @@ -181,7 +183,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all // if we have an explicit VARCHAR coercion, always force it so the collation is replaced in the target return evalToVarchar(e, col, false) } - if e.SQLType() == typ { + if e.SQLType() == typ && e.Size() == size && e.Scale() == scale { // nothing to be done here return e, nil } @@ -204,9 +206,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all case sqltypes.Date: return evalToDate(e, now, allowZero), nil case sqltypes.Datetime, sqltypes.Timestamp: - return evalToDateTime(e, -1, now, allowZero), nil + return evalToDateTime(e, int(size), now, allowZero), nil case sqltypes.Time: - return evalToTime(e, -1), nil + return evalToTime(e, int(size)), nil default: return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String()) } diff --git a/go/vt/vtgate/evalengine/eval_bytes.go b/go/vt/vtgate/evalengine/eval_bytes.go index caa516acbe4..027c4bb652d 100644 --- a/go/vt/vtgate/evalengine/eval_bytes.go +++ b/go/vt/vtgate/evalengine/eval_bytes.go @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type { return sqltypes.Type(e.tt) } +func (e *evalBytes) Size() int32 { + return 0 +} + +func (e *evalBytes) Scale() int32 { + return 0 +} + func (e *evalBytes) ToRawBytes() []byte { return e.bytes } diff --git a/go/vt/vtgate/evalengine/eval_enum.go b/go/vt/vtgate/evalengine/eval_enum.go index a0d349314da..fa9675d7c0e 100644 --- a/go/vt/vtgate/evalengine/eval_enum.go +++ b/go/vt/vtgate/evalengine/eval_enum.go @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type { return sqltypes.Enum } +func (e *evalEnum) Size() int32 { + return 0 +} + +func (e *evalEnum) Scale() int32 { + return 0 +} + func valueIdx(values *EnumSetValues, value string) int { if values == nil { return -1 diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 64f5477a3fc..04f844566b1 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type { return sqltypes.Int64 } +func (e *evalInt64) Size() int32 { + return 0 +} + +func (e *evalInt64) Scale() int32 { + return 0 +} + func (e *evalInt64) ToRawBytes() []byte { return strconv.AppendInt(nil, e.i, 10) } @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type { return sqltypes.Uint64 } +func (e *evalUint64) Size() int32 { + return 0 +} + +func (e *evalUint64) Scale() int32 { + return 0 +} + func (e *evalUint64) ToRawBytes() []byte { return strconv.AppendUint(nil, e.u, 10) } @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type { return sqltypes.Float64 } +func (e *evalFloat) Size() int32 { + return 0 +} + +func (e *evalFloat) Scale() int32 { + return 0 +} + func (e *evalFloat) ToRawBytes() []byte { return format.FormatFloat(e.f) } @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type { return sqltypes.Decimal } +func (e *evalDecimal) Size() int32 { + return e.length +} + +func (e *evalDecimal) Scale() int32 { + return -e.dec.Exponent() +} + func (e *evalDecimal) ToRawBytes() []byte { return e.dec.FormatMySQL(e.length) } diff --git a/go/vt/vtgate/evalengine/eval_set.go b/go/vt/vtgate/evalengine/eval_set.go index 6a9de2eff14..bc75a527edc 100644 --- a/go/vt/vtgate/evalengine/eval_set.go +++ b/go/vt/vtgate/evalengine/eval_set.go @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type { return sqltypes.Set } +func (e *evalSet) Size() int32 { + return 0 +} + +func (e *evalSet) Scale() int32 { + return 0 +} + func evalSetBits(values *EnumSetValues, value string) uint64 { if values != nil && len(*values) > 64 { // This never would happen as MySQL limits SET diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index 7706ec36e64..d73485441c3 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type { return e.t } +func (e *evalTemporal) Size() int32 { + return int32(e.prec) +} + +func (e *evalTemporal) Scale() int32 { + return 0 +} + func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: diff --git a/go/vt/vtgate/evalengine/eval_tuple.go b/go/vt/vtgate/evalengine/eval_tuple.go index 73e7fcc2051..81fa3317977 100644 --- a/go/vt/vtgate/evalengine/eval_tuple.go +++ b/go/vt/vtgate/evalengine/eval_tuple.go @@ -33,3 +33,11 @@ func (e *evalTuple) ToRawBytes() []byte { func (e *evalTuple) SQLType() sqltypes.Type { return sqltypes.Tuple } + +func (e *evalTuple) Size() int32 { + return 0 +} + +func (e *evalTuple) Scale() int32 { + return 0 +} diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index ef59616b97c..561915f600c 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -631,7 +631,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) { if !matched { return nil, nil } - return evalCoerce(result, ta.result(), ca.result().Collation, env.now, env.sqlmode.AllowZeroDate()) + return evalCoerce(result, ta.result(), ta.size, ta.scale, ca.result().Collation, env.now, env.sqlmode.AllowZeroDate()) } func (c *CaseExpr) constant() bool { @@ -690,7 +690,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { return ctype{}, err } - ta.add(then.Type, then.Flag) + ta.add(then.Type, then.Flag, then.Size, then.Scale) if err := ca.add(then.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -702,7 +702,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { return ctype{}, err } - ta.add(els.Type, els.Flag) + ta.add(els.Type, els.Flag, els.Size, els.Scale) if err := ca.add(els.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -712,8 +712,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { if ta.nullable { f |= flagNullable } - ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} - c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.scale, Size: ta.size} + c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Size, ct.Scale, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index c102f5e5ef5..1deec6752ef 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -71,7 +71,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) { if !tt.nullable() { f = 0 } - ta.add(tt.Type, tt.Flag) + ta.add(tt.Type, tt.Flag, tt.Size, tt.Scale) if err := ca.add(tt.Col, c.env.CollationEnv()); err != nil { return ctype{}, err }