Skip to content

Commit

Permalink
feat: add failing case test and fix it
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 committed Apr 24, 2024
1 parent c1bddd7 commit f422372
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 8 deletions.
17 changes: 15 additions & 2 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package union
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"github.com/stretchr/testify/require"
)

func start(t *testing.T) (utils.MySQLCompare, func()) {
Expand Down Expand Up @@ -161,6 +161,19 @@ group by
order by
value desc;`,
},
{
name: "Q14 without decimal literal",
query: `select sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
}

for _, testcase := range testcases {
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
result := &sqltypes.Result{}
if len(lresult.Rows) == 0 && wantfields {
for k, col := range jn.Vars {
joinVars[k] = bindvarForType(lresult.Fields[col].Type)
joinVars[k] = bindvarForType(lresult.Fields[col])
}
rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
if err != nil {
Expand Down Expand Up @@ -95,19 +95,19 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func bindvarForType(t querypb.Type) *querypb.BindVariable {
func bindvarForType(field *querypb.Field) *querypb.BindVariable {
bv := &querypb.BindVariable{
Type: t,
Type: field.Type,
Value: nil,
}
switch t {
switch field.Type {
case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16,
querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64:
bv.Value = []byte("0")
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
bv.Value = []byte("0e0")
case querypb.Type_DECIMAL:
bv.Value = []byte("0.0")
bv.Value = []byte(fmt.Sprintf("%s.%s", strings.Repeat("0", max(1, int(field.ColumnLength-field.Decimals))), strings.Repeat("0", max(1, int(field.Decimals)))))
default:
return sqltypes.NullBindVariable
}
Expand Down
17 changes: 17 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) {
values []sqltypes.Value
result string
collation collations.ID
typeWanted evalengine.Type
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -675,6 +677,15 @@ func TestCompilerSingle(t *testing.T) {
expression: `1 * unix_timestamp(time('1.0000'))`,
result: `DECIMAL(1698098401.0000)`,
},
{
expression: `(case
when 'PROMOTION' like 'PROMO%'
then 0.01
else 0
end) * 0.01`,
result: `DECIMAL(0.0001)`,
typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4),
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -715,6 +726,12 @@ func TestCompilerSingle(t *testing.T) {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

if tc.typeWanted.Type() != sqltypes.Unknown {
typ, err := env.TypeOf(converted)
require.NoError(t, err)
require.EqualValues(t, tc.typeWanted, typ)
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
res, err := env.Evaluate(converted)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error {
func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
var ca collationAggregation
var ta typeAggregation
var scale, size int32

for _, wt := range cs.cases {
when, err := wt.when.compile(c)
Expand All @@ -691,6 +692,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}

ta.add(then.Type, then.Flag)
scale = max(scale, then.Scale)
size = max(size, then.Size)
if err := ca.add(then.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -703,6 +706,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}

ta.add(els.Type, els.Flag)
scale = max(scale, els.Scale)
size = max(size, els.Size)
if err := ca.add(els.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -712,7 +717,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate())
return ct, nil
}
Expand Down

0 comments on commit f422372

Please sign in to comment.