diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 513aea94a86..70e0c5e1edd 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -19,10 +19,10 @@ package union import ( "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" - - "github.com/stretchr/testify/require" ) func start(t *testing.T) (utils.MySQLCompare, func()) { @@ -161,6 +161,19 @@ group by order by value desc;`, }, + { + name: "Q14 without decimal literal", + query: `select sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, } for _, testcase := range testcases { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 45b0d182dd7..5c5259dfa83 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -61,7 +61,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st result := &sqltypes.Result{} if len(lresult.Rows) == 0 && wantfields { for k, col := range jn.Vars { - joinVars[k] = bindvarForType(lresult.Fields[col].Type) + joinVars[k] = bindvarForType(lresult.Fields[col]) } rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) if err != nil { @@ -95,19 +95,19 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } -func bindvarForType(t querypb.Type) *querypb.BindVariable { +func bindvarForType(field *querypb.Field) *querypb.BindVariable { bv := &querypb.BindVariable{ - Type: t, + Type: field.Type, Value: nil, } - switch t { + switch field.Type { case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16, querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64: bv.Value = []byte("0") case querypb.Type_FLOAT32, querypb.Type_FLOAT64: bv.Value = []byte("0e0") case querypb.Type_DECIMAL: - bv.Value = []byte("0.0") + bv.Value = []byte(fmt.Sprintf("%s.%s", strings.Repeat("0", max(1, int(field.ColumnLength-field.Decimals))), strings.Repeat("0", max(1, int(field.Decimals))))) default: return sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 3d5283db415..ebcfbf84a7c 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/olekukonko/tablewriter" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) { values []sqltypes.Value result string collation collations.ID + typeWanted evalengine.Type }{ { expression: "1 + column0", @@ -675,6 +677,15 @@ func TestCompilerSingle(t *testing.T) { expression: `1 * unix_timestamp(time('1.0000'))`, result: `DECIMAL(1698098401.0000)`, }, + { + expression: `(case + when 'PROMOTION' like 'PROMO%' + then 0.01 + else 0 + end) * 0.01`, + result: `DECIMAL(0.0001)`, + typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4), + }, } tz, _ := time.LoadLocation("Europe/Madrid") @@ -715,6 +726,12 @@ func TestCompilerSingle(t *testing.T) { t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation) } + if tc.typeWanted.Type() != sqltypes.Unknown { + typ, err := env.TypeOf(converted) + require.NoError(t, err) + require.EqualValues(t, tc.typeWanted, typ) + } + // re-run the same evaluation multiple times to ensure results are always consistent for i := 0; i < 8; i++ { res, err := env.Evaluate(converted) diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index ef59616b97c..c4ddba4ce0c 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -674,6 +674,7 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error { func (cs *CaseExpr) compile(c *compiler) (ctype, error) { var ca collationAggregation var ta typeAggregation + var scale, size int32 for _, wt := range cs.cases { when, err := wt.when.compile(c) @@ -691,6 +692,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } ta.add(then.Type, then.Flag) + scale = max(scale, then.Scale) + size = max(size, then.Size) if err := ca.add(then.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -703,6 +706,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } ta.add(els.Type, els.Flag) + scale = max(scale, els.Scale) + size = max(size, els.Size) if err := ca.add(els.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -712,7 +717,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { if ta.nullable { f |= flagNullable } - ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size} c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil }