From 8cade461e393abf73bdaff3406409a5be4dceadc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 19 Oct 2023 15:46:28 +0200 Subject: [PATCH] refactor: introduce evalengine type and use it (#14292) Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/aggregations.go | 21 +- go/vt/vtgate/engine/cached_size.go | 8 +- go/vt/vtgate/engine/comparer.go | 2 +- go/vt/vtgate/engine/delete_test.go | 5 +- go/vt/vtgate/engine/distinct.go | 26 +-- go/vt/vtgate/engine/distinct_test.go | 26 ++- go/vt/vtgate/engine/limit_test.go | 10 +- go/vt/vtgate/engine/memory_sort_test.go | 24 +- go/vt/vtgate/engine/merge_sort_test.go | 7 +- go/vt/vtgate/engine/ordered_aggregate.go | 11 +- go/vt/vtgate/engine/ordered_aggregate_test.go | 8 +- go/vt/vtgate/engine/route.go | 9 +- go/vt/vtgate/engine/route_test.go | 14 +- go/vt/vtgate/engine/set_test.go | 3 +- go/vt/vtgate/engine/update_test.go | 3 +- go/vt/vtgate/evalengine/api_compare_test.go | 206 +++++++++--------- go/vt/vtgate/evalengine/api_literal.go | 12 +- go/vt/vtgate/evalengine/compiler.go | 10 + go/vt/vtgate/evalengine/translate.go | 32 ++- go/vt/vtgate/planbuilder/collations_test.go | 12 +- .../planbuilder/expression_converter.go | 3 +- .../planbuilder/expression_converter_test.go | 6 +- .../planbuilder/operator_transformers.go | 19 +- .../vtgate/planbuilder/operators/distinct.go | 9 +- .../planbuilder/operators/queryprojection.go | 13 +- go/vt/vtgate/semantics/analyzer.go | 13 +- go/vt/vtgate/semantics/analyzer_test.go | 4 +- go/vt/vtgate/semantics/binder.go | 9 +- go/vt/vtgate/semantics/dependencies.go | 8 +- go/vt/vtgate/semantics/derived_table.go | 7 +- go/vt/vtgate/semantics/real_table.go | 9 +- go/vt/vtgate/semantics/semantic_state.go | 19 +- go/vt/vtgate/semantics/table_collector.go | 5 +- go/vt/vtgate/semantics/typer.go | 33 +-- go/vt/vttablet/tabletmanager/vdiff/utils.go | 17 +- go/vt/wrangler/vdiff.go | 24 +- go/vt/wrangler/vdiff_test.go | 4 +- 37 files changed, 330 insertions(+), 321 deletions(-) diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 8037dda37a9..33b80faab55 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -38,10 +38,9 @@ type AggregateParams struct { Col int // These are used only for distinct opcodes. - KeyCol int - WCol int - Type sqltypes.Type - CollationID collations.ID + KeyCol int + WCol int + Type evalengine.Type Alias string `json:",omitempty"` Expr sqlparser.Expr @@ -58,7 +57,7 @@ func NewAggregateParam(opcode AggregateOpcode, col int, alias string) *Aggregate Col: col, Alias: alias, WCol: -1, - Type: sqltypes.Unknown, + Type: evalengine.UnknownType(), } if opcode.NeedsComparableValues() { out.KeyCol = col @@ -75,8 +74,8 @@ func (ap *AggregateParams) String() string { if ap.WAssigned() { keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) } - if sqltypes.IsText(ap.Type) && ap.CollationID != collations.Unknown { - keyCol += " COLLATE " + collations.Local().LookupName(ap.CollationID) + if sqltypes.IsText(ap.Type.Type) && ap.Type.Coll != collations.Unknown { + keyCol += " COLLATE " + collations.Local().LookupName(ap.Type.Coll) } dispOrigOp := "" if ap.OrigOpcode != AggregateUnassigned && ap.OrigOpcode != ap.Opcode { @@ -378,7 +377,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg from: aggr.Col, distinct: aggregatorDistinct{ column: distinct, - coll: aggr.CollationID, + coll: aggr.Type.Coll, }, } @@ -396,7 +395,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg sum: sum, distinct: aggregatorDistinct{ column: distinct, - coll: aggr.CollationID, + coll: aggr.Type.Coll, }, } @@ -404,7 +403,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMin{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.Type.Coll), }, } @@ -412,7 +411,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMax{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.Type.Coll), }, } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 10d862ea3df..99b03124611 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -199,7 +199,7 @@ func (cached *Distinct) CachedSize(alloc bool) int64 { } // field CheckCols []vitess.io/vitess/go/vt/vtgate/engine.CheckCol { - size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(22)) + size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(23)) for _, elem := range cached.CheckCols { size += elem.CachedSize(false) } @@ -581,7 +581,7 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { } // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(38)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Input.(cachedObject); ok { @@ -608,7 +608,7 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(38)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) } return size } @@ -801,7 +801,7 @@ func (cached *Route) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.FieldQuery))) // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(38)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) } // field RoutingParameters *vitess.io/vitess/go/vt/vtgate/engine.RoutingParameters size += cached.RoutingParameters.CachedSize(true) diff --git a/go/vt/vtgate/engine/comparer.go b/go/vt/vtgate/engine/comparer.go index f7728eb7f89..591b1cf2be0 100644 --- a/go/vt/vtgate/engine/comparer.go +++ b/go/vt/vtgate/engine/comparer.go @@ -71,7 +71,7 @@ func extractSlices(input []OrderByParams) []*comparer { weightString: order.WeightStringCol, desc: order.Desc, starColFixedIndex: order.StarColFixedIndex, - collationID: order.CollationID, + collationID: order.Type.Coll, }) } return result diff --git a/go/vt/vtgate/engine/delete_test.go b/go/vt/vtgate/engine/delete_test.go index 7312b4bd010..be67c7fc9e6 100644 --- a/go/vt/vtgate/engine/delete_test.go +++ b/go/vt/vtgate/engine/delete_test.go @@ -21,7 +21,6 @@ import ( "errors" "testing" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" "github.com/stretchr/testify/require" @@ -90,7 +89,7 @@ func TestDeleteEqual(t *testing.T) { }) // Failure case - expr := evalengine.NewBindVar("aa", sqltypes.Unknown, collations.Unknown) + expr := evalengine.NewBindVar("aa", evalengine.UnknownType()) del.Values = []evalengine.Expr{expr} _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.EqualError(t, err, "query arguments missing for aa") @@ -122,7 +121,7 @@ func TestDeleteEqualMultiCol(t *testing.T) { }) // Failure case - expr := evalengine.NewBindVar("aa", sqltypes.Unknown, collations.Unknown) + expr := evalengine.NewBindVar("aa", evalengine.UnknownType()) del.Values = []evalengine.Expr{expr} _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.EqualError(t, err, "query arguments missing for aa") diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 8608aec0d98..477e0803c1b 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -38,10 +38,9 @@ type ( Truncate int } CheckCol struct { - Col int - WsCol *int - Type sqltypes.Type - Collation collations.ID + Col int + WsCol *int + Type evalengine.Type } probeTable struct { seenRows map[evalengine.HashCode][]sqltypes.Row @@ -119,14 +118,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Collation, col.Type()) + hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Coll, col.Type()) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { return 0, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Collation, col.Type()) + hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Coll, col.Type()) if err != nil { return 0, err } @@ -138,7 +137,7 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode func (pt *probeTable) equal(a, b sqltypes.Row) (bool, error) { for i, checkCol := range pt.checkCols { - cmp, err := evalengine.NullsafeCompare(a[i], b[i], checkCol.Collation) + cmp, err := evalengine.NullsafeCompare(a[i], b[i], checkCol.Type.Coll) if err != nil { _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) if !isComparisonErr || checkCol.WsCol == nil { @@ -146,7 +145,7 @@ func (pt *probeTable) equal(a, b sqltypes.Row) (bool, error) { } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - cmp, err = evalengine.NullsafeCompare(a[i], b[i], checkCol.Collation) + cmp, err = evalengine.NullsafeCompare(a[i], b[i], checkCol.Type.Coll) if err != nil { return false, err } @@ -273,17 +272,16 @@ func (d *Distinct) description() PrimitiveDescription { // SwitchToWeightString returns a new CheckCol that works on the weight string column instead func (cc CheckCol) SwitchToWeightString() CheckCol { return CheckCol{ - Col: *cc.WsCol, - WsCol: nil, - Type: sqltypes.VarBinary, - Collation: collations.CollationBinaryID, + Col: *cc.WsCol, + WsCol: nil, + Type: evalengine.Type{Type: sqltypes.VarBinary, Coll: collations.CollationBinaryID}, } } func (cc CheckCol) String() string { var collation string - if sqltypes.IsText(cc.Type) && cc.Collation != collations.Unknown { - collation = ": " + collations.Local().LookupName(cc.Collation) + if sqltypes.IsText(cc.Type.Type) && cc.Type.Coll != collations.Unknown { + collation = ": " + collations.Local().LookupName(cc.Type.Coll) } var column string diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index e120c60bd3e..65f8e5d430c 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" @@ -86,10 +88,14 @@ func TestDistinct(t *testing.T) { if sqltypes.IsNumber(tc.inputs.Fields[i].Type) { collID = collations.CollationBinaryID } + t := evalengine.Type{ + Type: tc.inputs.Fields[i].Type, + Coll: collID, + Nullable: false, + } checkCols = append(checkCols, CheckCol{ - Col: i, - Type: tc.inputs.Fields[i].Type, - Collation: collID, + Col: i, + Type: t, }) } } @@ -132,10 +138,9 @@ func TestDistinct(t *testing.T) { func TestWeightStringFallBack(t *testing.T) { offsetOne := 1 checkCols := []CheckCol{{ - Col: 0, - WsCol: &offsetOne, - Type: sqltypes.Unknown, - Collation: collations.Unknown, + Col: 0, + WsCol: &offsetOne, + Type: evalengine.UnknownType(), }} input := r("myid|weightstring(myid)", "varchar|varbinary", @@ -158,9 +163,8 @@ func TestWeightStringFallBack(t *testing.T) { // the primitive must not change just because one run needed weight strings utils.MustMatch(t, []CheckCol{{ - Col: 0, - WsCol: &offsetOne, - Type: sqltypes.Unknown, - Collation: collations.Unknown, + Col: 0, + WsCol: &offsetOne, + Type: evalengine.UnknownType(), }}, distinct.CheckCols, "checkCols should not be updated") } diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index ba15306685a..d5c6602f820 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -130,7 +130,7 @@ func TestLimitExecute(t *testing.T) { results: []*sqltypes.Result{inputResult}, } l = &Limit{ - Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID), + Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), Input: fp, } @@ -343,8 +343,8 @@ func TestLimitOffsetExecute(t *testing.T) { } l = &Limit{ - Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID), - Offset: evalengine.NewBindVar("o", sqltypes.Int64, collations.CollationBinaryID), + Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), + Offset: evalengine.NewBindVar("o", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), Input: fp, } result, err = l.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(1), "o": sqltypes.Int64BindVariable(1)}, false) @@ -396,7 +396,7 @@ func TestLimitStreamExecute(t *testing.T) { // Test with bind vars. fp.rewind() - l.Count = evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID) + l.Count = evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) results = nil err = l.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, true, func(qr *sqltypes.Result) error { results = append(results, qr) @@ -540,7 +540,7 @@ func TestLimitInputFail(t *testing.T) { func TestLimitInvalidCount(t *testing.T) { l := &Limit{ - Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID), + Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), } _, _, err := l.getCountAndOffset(context.Background(), &noopVCursor{}, nil) assert.EqualError(t, err, "query arguments missing for l") diff --git a/go/vt/vtgate/engine/memory_sort_test.go b/go/vt/vtgate/engine/memory_sort_test.go index 3b53ef11250..2c73d49e74b 100644 --- a/go/vt/vtgate/engine/memory_sort_test.go +++ b/go/vt/vtgate/engine/memory_sort_test.go @@ -75,7 +75,7 @@ func TestMemorySortExecute(t *testing.T) { utils.MustMatch(t, wantResult, result) fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false) @@ -136,7 +136,7 @@ func TestMemorySortStreamExecuteWeightString(t *testing.T) { t.Run("Limit test", func(t *testing.T) { fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} results = nil @@ -194,7 +194,7 @@ func TestMemorySortExecuteWeightString(t *testing.T) { utils.MustMatch(t, wantResult, result) fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false) @@ -228,9 +228,8 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ OrderBy: []OrderByParams{{ - Col: 0, - Type: sqltypes.VarChar, - CollationID: collationID, + Col: 0, + Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }}, Input: fp, } @@ -278,7 +277,7 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) { t.Run("Limit test", func(t *testing.T) { fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} results = nil @@ -317,9 +316,8 @@ func TestMemorySortExecuteCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ OrderBy: []OrderByParams{{ - Col: 0, - Type: sqltypes.VarChar, - CollationID: collationID, + Col: 0, + Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }}, Input: fp, } @@ -338,7 +336,7 @@ func TestMemorySortExecuteCollation(t *testing.T) { utils.MustMatch(t, wantResult, result) fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false) @@ -395,7 +393,7 @@ func TestMemorySortStreamExecute(t *testing.T) { utils.MustMatch(t, wantResults, results) fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} results = nil @@ -554,7 +552,7 @@ func TestMemorySortMultiColumn(t *testing.T) { utils.MustMatch(t, wantResult, result) fp.rewind() - ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID) + ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}) bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)} result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false) diff --git a/go/vt/vtgate/engine/merge_sort_test.go b/go/vt/vtgate/engine/merge_sort_test.go index e8823e9e6d5..be370c0e86b 100644 --- a/go/vt/vtgate/engine/merge_sort_test.go +++ b/go/vt/vtgate/engine/merge_sort_test.go @@ -21,6 +21,8 @@ import ( "errors" "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" @@ -179,9 +181,8 @@ func TestMergeSortCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") orderBy := []OrderByParams{{ - Col: 0, - Type: sqltypes.VarChar, - CollationID: collationID, + Col: 0, + Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }} var results []*sqltypes.Result diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index acb958199d0..1982328a8a6 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -66,8 +66,7 @@ type GroupByParams struct { WeightStringCol int Expr sqlparser.Expr FromGroupBy bool - Type sqltypes.Type - CollationID collations.ID + Type evalengine.Type } // String returns a string. Used for plan descriptions @@ -79,8 +78,8 @@ func (gbp GroupByParams) String() string { out = fmt.Sprintf("(%d|%d)", gbp.KeyCol, gbp.WeightStringCol) } - if sqltypes.IsText(gbp.Type) && gbp.CollationID != collations.Unknown { - out += " COLLATE " + collations.Local().LookupName(gbp.CollationID) + if sqltypes.IsText(gbp.Type.Type) && gbp.Type.Coll != collations.Unknown { + out += " COLLATE " + collations.Local().LookupName(gbp.Type.Coll) } return out @@ -255,7 +254,7 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n } for _, gb := range oa.GroupByKeys { - cmp, err := evalengine.NullsafeCompare(currentKey[gb.KeyCol], nextRow[gb.KeyCol], gb.CollationID) + cmp, err := evalengine.NullsafeCompare(currentKey[gb.KeyCol], nextRow[gb.KeyCol], gb.Type.Coll) if err != nil { _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) _, isCollationErr := err.(evalengine.UnsupportedCollationError) @@ -263,7 +262,7 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n return nil, false, err } gb.KeyCol = gb.WeightStringCol - cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], gb.CollationID) + cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], gb.Type.Coll) if err != nil { return nil, false, err } diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 8aa0bf3c3b4..2eca4fc7ba9 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -22,6 +22,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -905,7 +907,7 @@ func TestOrderedAggregateCollate(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_0900_ai_ci") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, - GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.Type{Coll: collationID}}}, Input: fp, } @@ -943,7 +945,7 @@ func TestOrderedAggregateCollateAS(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_0900_as_ci") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, - GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.Type{Coll: collationID}}}, Input: fp, } @@ -983,7 +985,7 @@ func TestOrderedAggregateCollateKS(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_ja_0900_as_cs_ks") oa := &OrderedAggregate{ Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, - GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.Type{Coll: collationID}}}, Input: fp, } diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 1f806867b70..312e04f98f2 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -119,10 +119,9 @@ type OrderByParams struct { WeightStringCol int Desc bool StarColFixedIndex int + // Type for knowing if the collation is relevant - Type querypb.Type - // Collation ID for comparison using collation - CollationID collations.ID + Type evalengine.Type } // String returns a string. Used for plan descriptions @@ -140,8 +139,8 @@ func (obp OrderByParams) String() string { val += " ASC" } - if sqltypes.IsText(obp.Type) && obp.CollationID != collations.Unknown { - val += " COLLATE " + collations.Local().LookupName(obp.CollationID) + if sqltypes.IsText(obp.Type.Type) && obp.Type.Coll != collations.Unknown { + val += " COLLATE " + collations.Local().LookupName(obp.Type.Coll) } return val } diff --git a/go/vt/vtgate/engine/route_test.go b/go/vt/vtgate/engine/route_test.go index 13fb0be656b..58e6fb4a9f1 100644 --- a/go/vt/vtgate/engine/route_test.go +++ b/go/vt/vtgate/engine/route_test.go @@ -1076,9 +1076,8 @@ func TestRouteSortCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") sel.OrderBy = []OrderByParams{{ - Col: 0, - Type: sqltypes.VarChar, - CollationID: collationID, + Col: 0, + Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }} vc := &loggingVCursor{ @@ -1143,9 +1142,8 @@ func TestRouteSortCollation(t *testing.T) { t.Run("Error when Unknown Collation", func(t *testing.T) { sel.OrderBy = []OrderByParams{{ - Col: 0, - Type: sqltypes.Unknown, - CollationID: collations.Unknown, + Col: 0, + Type: evalengine.UnknownType(), }} vc := &loggingVCursor{ @@ -1170,8 +1168,8 @@ func TestRouteSortCollation(t *testing.T) { t.Run("Error when Unsupported Collation", func(t *testing.T) { sel.OrderBy = []OrderByParams{{ - Col: 0, - CollationID: 1111, + Col: 0, + Type: evalengine.Type{Coll: 1111}, }} vc := &loggingVCursor{ diff --git a/go/vt/vtgate/engine/set_test.go b/go/vt/vtgate/engine/set_test.go index 62ffa42b8d6..e9a5ef1a85e 100644 --- a/go/vt/vtgate/engine/set_test.go +++ b/go/vt/vtgate/engine/set_test.go @@ -22,7 +22,6 @@ import ( "fmt" "testing" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/srvtopo" @@ -108,7 +107,7 @@ func TestSetTable(t *testing.T) { setOps: []SetOp{ &UserDefinedVariable{ Name: "x", - Expr: evalengine.NewColumn(0, sqltypes.Unknown, collations.Unknown), + Expr: evalengine.NewColumn(0, evalengine.UnknownType()), }, }, qr: []*sqltypes.Result{sqltypes.MakeTestResult( diff --git a/go/vt/vtgate/engine/update_test.go b/go/vt/vtgate/engine/update_test.go index 313602668bc..22c2b90d60e 100644 --- a/go/vt/vtgate/engine/update_test.go +++ b/go/vt/vtgate/engine/update_test.go @@ -21,7 +21,6 @@ import ( "errors" "testing" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -94,7 +93,7 @@ func TestUpdateEqual(t *testing.T) { }) // Failure case - upd.Values = []evalengine.Expr{evalengine.NewBindVar("aa", sqltypes.Unknown, collations.Unknown)} + upd.Values = []evalengine.Expr{evalengine.NewBindVar("aa", evalengine.UnknownType())} _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.EqualError(t, err, `query arguments missing for aa`) } diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index bd87363b7e8..6de805ac8f1 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -107,7 +107,7 @@ func TestCompareIntegers(t *testing.T) { tests := []testCase{ { name: "integers are equal (1)", - v1: NewColumn(0, sqltypes.Int64, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewInt64(18)}, }, @@ -128,25 +128,25 @@ func TestCompareIntegers(t *testing.T) { }, { name: "integers are not equal (3)", - v1: NewColumn(0, sqltypes.Int64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewInt64(18), sqltypes.NewInt64(98)}, }, { name: "unsigned integers are equal", - v1: NewColumn(0, sqltypes.Uint64, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Uint64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewUint64(18)}, }, { name: "unsigned integer and integer are equal", - v1: NewColumn(0, sqltypes.Uint64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewUint64(18), sqltypes.NewInt64(18)}, }, { name: "unsigned integer and integer are not equal", - v1: NewColumn(0, sqltypes.Uint64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewUint64(18), sqltypes.NewInt64(42)}, }, @@ -204,7 +204,7 @@ func TestCompareFloats(t *testing.T) { tests := []testCase{ { name: "floats are equal (1)", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(18)}, }, @@ -225,7 +225,7 @@ func TestCompareFloats(t *testing.T) { }, { name: "floats are not equal (3)", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(16516.84), sqltypes.NewFloat64(219541.01)}, }, @@ -283,37 +283,37 @@ func TestCompareDecimals(t *testing.T) { tests := []testCase{ { name: "decimals are equal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("12.9019")}, }, { name: "decimals are not equal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("12.9019"), sqltypes.NewDecimal("489.156849")}, }, { name: "decimal is greater than decimal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewDecimal("192.128")}, }, { name: "decimal is not greater than decimal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.128"), sqltypes.NewDecimal("192.129")}, }, { name: "decimal is less than decimal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.128"), sqltypes.NewDecimal("192.129")}, }, { name: "decimal is not less than decimal", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewDecimal("192.128")}, }, @@ -331,151 +331,151 @@ func TestCompareNumerics(t *testing.T) { tests := []testCase{ { name: "decimal and float are equal", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(189.6), sqltypes.NewDecimal("189.6")}, }, { name: "decimal and float with negative values are equal", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(-98.1839), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and float with negative values are not equal (1)", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(-98.9381), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and float with negative values are not equal (2)", - v1: NewColumn(0, sqltypes.Float64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewFloat64(-98.9381), sqltypes.NewDecimal("-98.1839")}, }, { name: "decimal and integer are equal (1)", - v1: NewColumn(0, sqltypes.Int64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewInt64(8979), sqltypes.NewDecimal("8979")}, }, { name: "decimal and integer are equal (2)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("8979.0000"), sqltypes.NewInt64(8979)}, }, { name: "decimal and unsigned integer are equal (1)", - v1: NewColumn(0, sqltypes.Uint64, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Decimal, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewUint64(901), sqltypes.NewDecimal("901")}, }, { name: "decimal and unsigned integer are equal (2)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Uint64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("901.00"), sqltypes.NewUint64(901)}, }, { name: "decimal and unsigned integer are not equal (1)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Uint64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewUint64(192)}, }, { name: "decimal and unsigned integer are not equal (2)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Uint64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Uint64, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("192.129"), sqltypes.NewUint64(192)}, }, { name: "decimal is greater than integer", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("1.01"), sqltypes.NewInt64(1)}, }, { name: "decimal is greater-equal to integer", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("1.00"), sqltypes.NewInt64(1)}, }, { name: "decimal is less than integer", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDecimal(".99"), sqltypes.NewInt64(1)}, }, { name: "decimal is less-equal to integer", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("1.00"), sqltypes.NewInt64(1)}, }, { name: "decimal is greater than float", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("849.896"), sqltypes.NewFloat64(86.568)}, }, { name: "decimal is not greater than float", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("15.23"), sqltypes.NewFloat64(8689.5)}, }, { name: "decimal is greater-equal to float (1)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("65"), sqltypes.NewFloat64(65)}, }, { name: "decimal is greater-equal to float (2)", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("65"), sqltypes.NewFloat64(60)}, }, { name: "decimal is less than float", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDecimal("0.998"), sqltypes.NewFloat64(0.999)}, }, { name: "decimal is less-equal to float", - v1: NewColumn(0, sqltypes.Decimal, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Float64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Decimal, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Float64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDecimal("1.000101"), sqltypes.NewFloat64(1.00101)}, }, { name: "different int types are equal for 8 bit", - v1: NewColumn(0, sqltypes.Int8, collations.CollationBinaryID), v2: NewLiteralInt(0), + v1: NewColumn(0, Type{Type: sqltypes.Int8, Coll: collations.CollationBinaryID}), v2: NewLiteralInt(0), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewInt8(0)}, }, { name: "different int types are equal for 32 bit", - v1: NewColumn(0, sqltypes.Int32, collations.CollationBinaryID), v2: NewLiteralInt(0), + v1: NewColumn(0, Type{Type: sqltypes.Int32, Coll: collations.CollationBinaryID}), v2: NewLiteralInt(0), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewInt32(0)}, }, { name: "different int types are equal for float32 bit", - v1: NewColumn(0, sqltypes.Float32, collations.CollationBinaryID), v2: NewLiteralFloat(1.0), + v1: NewColumn(0, Type{Type: sqltypes.Float32, Coll: collations.CollationBinaryID}), v2: NewLiteralFloat(1.0), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Float32, []byte("1.0"))}, }, { name: "different unsigned int types are equal for 8 bit", - v1: NewColumn(0, sqltypes.Uint8, collations.CollationBinaryID), v2: NewLiteralInt(0), + v1: NewColumn(0, Type{Type: sqltypes.Uint8, Coll: collations.CollationBinaryID}), v2: NewLiteralInt(0), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Uint8, []byte("0"))}, }, { name: "different unsigned int types are equal for 32 bit", - v1: NewColumn(0, sqltypes.Uint32, collations.CollationBinaryID), v2: NewLiteralInt(0), + v1: NewColumn(0, Type{Type: sqltypes.Uint32, Coll: collations.CollationBinaryID}), v2: NewLiteralInt(0), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewUint32(0)}, }, @@ -493,73 +493,73 @@ func TestCompareDatetime(t *testing.T) { tests := []testCase{ { name: "datetimes are equal", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00")}, }, { name: "datetimes are not equal (1)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00"), sqltypes.NewDatetime("2020-10-22 12:00:00")}, }, { name: "datetimes are not equal (2)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-22 12:00:00"), sqltypes.NewDatetime("2021-10-22 10:23:56")}, }, { name: "datetimes are not equal (3)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 00:00:00"), sqltypes.NewDatetime("2021-02-01 00:00:00")}, }, { name: "datetime is greater than datetime", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is not greater than datetime", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is less than datetime", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is not less than datetime", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is greater-equal to datetime (1)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is greater-equal to datetime (2)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-01 13:10:02")}, }, { name: "datetime is less-equal to datetime (1)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-30 10:42:50"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, { name: "datetime is less-equal to datetime (2)", - v1: NewColumn(0, sqltypes.Datetime, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDatetime("2021-10-01 13:10:02"), sqltypes.NewDatetime("2021-10-30 10:42:50")}, }, @@ -577,73 +577,73 @@ func TestCompareTimestamp(t *testing.T) { tests := []testCase{ { name: "timestamps are equal", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00")}, }, { name: "timestamps are not equal (1)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00"), sqltypes.NewTimestamp("2020-10-22 12:00:00")}, }, { name: "timestamps are not equal (2)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-22 12:00:00"), sqltypes.NewTimestamp("2021-10-22 10:23:56")}, }, { name: "timestamps are not equal (3)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 00:00:00"), sqltypes.NewTimestamp("2021-02-01 00:00:00")}, }, { name: "timestamp is greater than timestamp", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is not greater than timestamp", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is less than timestamp", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is not less than timestamp", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is greater-equal to timestamp (1)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is greater-equal to timestamp (2)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-01 13:10:02")}, }, { name: "timestamp is less-equal to timestamp (1)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-30 10:42:50"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, { name: "timestamp is less-equal to timestamp (2)", - v1: NewColumn(0, sqltypes.Timestamp, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewTimestamp("2021-10-01 13:10:02"), sqltypes.NewTimestamp("2021-10-30 10:42:50")}, }, @@ -661,67 +661,67 @@ func TestCompareDate(t *testing.T) { tests := []testCase{ { name: "dates are equal", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22")}, }, { name: "dates are not equal (1)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDate("2020-10-21")}, }, { name: "dates are not equal (2)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-02-01")}, }, { name: "date is greater than date", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is not greater than date", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is less than date", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is not less than date", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is greater-equal to date (1)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is greater-equal to date (2)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-01")}, }, { name: "date is less-equal to date (1)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-30"), sqltypes.NewDate("2021-10-30")}, }, { name: "date is less-equal to date (2)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-01"), sqltypes.NewDate("2021-10-30")}, }, @@ -739,79 +739,79 @@ func TestCompareTime(t *testing.T) { tests := []testCase{ { name: "times are equal", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewTime("12:00:00")}, }, { name: "times are not equal (1)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewTime("12:00:00"), sqltypes.NewTime("10:23:56")}, }, { name: "times are not equal (2)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewTime("00:00:00"), sqltypes.NewTime("10:15:00")}, }, { name: "time is greater than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTime("18:14:35"), sqltypes.NewTime("13:01:38")}, }, { name: "time is not greater than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTime("02:46:02"), sqltypes.NewTime("10:42:50")}, }, { name: "time is greater than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTime("101:14:35"), sqltypes.NewTime("13:01:38")}, }, { name: "time is not greater than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTime("24:46:02"), sqltypes.NewTime("101:42:50")}, }, { name: "time is less than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewTime("04:30:00"), sqltypes.NewTime("09:23:48")}, }, { name: "time is not less than time", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.LessThanOp, row: []sqltypes.Value{sqltypes.NewTime("15:21:00"), sqltypes.NewTime("10:00:00")}, }, { name: "time is greater-equal to time (1)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewTime("10:42:50"), sqltypes.NewTime("10:42:50")}, }, { name: "time is greater-equal to time (2)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.GreaterEqualOp, row: []sqltypes.Value{sqltypes.NewTime("19:42:50"), sqltypes.NewTime("13:10:02")}, }, { name: "time is less-equal to time (1)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewTime("10:42:50"), sqltypes.NewTime("10:42:50")}, }, { name: "time is less-equal to time (2)", - v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.LessEqualOp, row: []sqltypes.Value{sqltypes.NewTime("10:10:02"), sqltypes.NewTime("10:42:50")}, }, @@ -829,13 +829,13 @@ func TestCompareDates(t *testing.T) { tests := []testCase{ { name: "date equal datetime", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDatetime("2021-10-22 00:00:00")}, }, { name: "date equal datetime through bind variables", - v1: NewBindVar("k1", sqltypes.Date, collations.CollationBinaryID), v2: NewBindVar("k2", sqltypes.Datetime, collations.CollationBinaryID), + v1: NewBindVar("k1", Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewBindVar("k2", Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, bv: map[string]*querypb.BindVariable{ "k1": {Type: sqltypes.Date, Value: []byte("2021-10-22")}, @@ -844,7 +844,7 @@ func TestCompareDates(t *testing.T) { }, { name: "date not equal datetime through bind variables", - v1: NewBindVar("k1", sqltypes.Date, collations.CollationBinaryID), v2: NewBindVar("k2", sqltypes.Datetime, collations.CollationBinaryID), + v1: NewBindVar("k1", Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewBindVar("k2", Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, bv: map[string]*querypb.BindVariable{ "k1": {Type: sqltypes.Date, Value: []byte("2021-02-20")}, @@ -853,73 +853,73 @@ func TestCompareDates(t *testing.T) { }, { name: "date not equal datetime", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewDatetime("2021-10-20 00:06:00")}, }, { name: "date equal timestamp", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewTimestamp("2021-10-22 00:00:00")}, }, { name: "date not equal timestamp", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-10-22"), sqltypes.NewTimestamp("2021-10-22 16:00:00")}, }, { name: "date equal time", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &F, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewDate(time.Now().Format("2006-01-02")), sqltypes.NewTime("00:00:00")}, }, { name: "date not equal time", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDate(time.Now().Format("2006-01-02")), sqltypes.NewTime("12:00:00")}, }, { name: "string equal datetime", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Datetime, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Datetime, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22"), sqltypes.NewDatetime("2021-10-22 00:00:00")}, }, { name: "string equal timestamp", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22 00:00:00"), sqltypes.NewTimestamp("2021-10-22 00:00:00")}, }, { name: "string not equal timestamp", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Timestamp, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Timestamp, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("2021-10-22 06:00:30"), sqltypes.NewTimestamp("2021-10-20 15:02:10")}, }, { name: "string equal time", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Time, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("00:05:12"), sqltypes.NewTime("00:05:12")}, }, { name: "string equal date", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("2021-02-22"), sqltypes.NewDate("2021-02-22")}, }, { name: "string not equal date (1, date on the RHS)", - v1: NewColumn(0, sqltypes.VarChar, collations.CollationUtf8mb4ID), v2: NewColumn(1, sqltypes.Date, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), v2: NewColumn(1, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("2021-02-20"), sqltypes.NewDate("2021-03-30")}, }, { name: "string not equal date (2, date on the LHS)", - v1: NewColumn(0, sqltypes.Date, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.VarChar, collations.CollationUtf8mb4ID), + v1: NewColumn(0, Type{Type: sqltypes.Date, Coll: collations.CollationBinaryID}), v2: NewColumn(1, Type{Type: sqltypes.VarChar, Coll: collations.CollationUtf8mb4ID}), out: &T, op: sqlparser.NotEqualOp, row: []sqltypes.Value{sqltypes.NewDate("2021-03-30"), sqltypes.NewVarChar("2021-02-20")}, }, @@ -937,13 +937,13 @@ func TestCompareStrings(t *testing.T) { tests := []testCase{ { name: "string equal string", - v1: NewColumn(0, sqltypes.VarChar, collations.Default()), v2: NewColumn(1, sqltypes.VarChar, collations.Default()), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.Default()}), v2: NewColumn(1, Type{Type: sqltypes.VarChar, Coll: collations.Default()}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("toto"), sqltypes.NewVarChar("toto")}, }, { name: "string equal number", - v1: NewColumn(0, sqltypes.VarChar, collations.Default()), v2: NewColumn(1, sqltypes.Int64, collations.CollationBinaryID), + v1: NewColumn(0, Type{Type: sqltypes.VarChar, Coll: collations.Default()}), v2: NewColumn(1, Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}), out: &T, op: sqlparser.EqualOp, row: []sqltypes.Value{sqltypes.NewVarChar("1"), sqltypes.NewInt64(1)}, }, diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 1b2ba6e2da2..6b1390e3a41 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -194,11 +194,11 @@ func NewLiteralBinaryFromBit(val []byte) (*Literal, error) { } // NewBindVar returns a bind variable -func NewBindVar(key string, typ sqltypes.Type, col collations.ID) *BindVariable { +func NewBindVar(key string, typ Type) *BindVariable { return &BindVariable{ Key: key, - Type: typ, - Collation: defaultCoercionCollation(col), + Type: typ.Type, + Collation: defaultCoercionCollation(typ.Coll), } } @@ -212,11 +212,11 @@ func NewBindVarTuple(key string, col collations.ID) *BindVariable { } // NewColumn returns a column expression -func NewColumn(offset int, typ sqltypes.Type, col collations.ID) *Column { +func NewColumn(offset int, typ Type) *Column { return &Column{ Offset: offset, - Type: typ, - Collation: defaultCoercionCollation(col), + Type: typ.Type, + Collation: defaultCoercionCollation(typ.Coll), } } diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 23f7a9f10aa..84caa2d7690 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -50,6 +50,16 @@ type ctype struct { Col collations.TypedCollation } +type Type struct { + Type sqltypes.Type + Coll collations.ID + Nullable bool +} + +func UnknownType() Type { + return Type{Type: sqltypes.Unknown, Coll: collations.Unknown} +} + func (ct ctype) nullable() bool { return ct.Flag&flagNullable != 0 } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 3af97a183e3..82d22083039 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -184,7 +184,7 @@ func (ast *astCompiler) translateIsExpr(left sqlparser.Expr, op sqlparser.IsExpr } func (ast *astCompiler) translateBindVar(arg *sqlparser.Argument) (Expr, error) { - bvar := NewBindVar(arg.Name, arg.Type, ast.cfg.Collation) + bvar := NewBindVar(arg.Name, Type{Type: arg.Type, Coll: ast.cfg.Collation}) if !bvar.typed() { ast.untyped++ @@ -193,16 +193,15 @@ func (ast *astCompiler) translateBindVar(arg *sqlparser.Argument) (Expr, error) } func (ast *astCompiler) translateColOffset(col *sqlparser.Offset) (Expr, error) { - var typ sqltypes.Type = sqltypes.Unknown - var coll collations.ID + var typ Type if ast.cfg.ResolveType != nil { - typ, coll, _ = ast.cfg.ResolveType(col.Original) + typ, _ = ast.cfg.ResolveType(col.Original) } - if coll == collations.Unknown { - coll = ast.cfg.Collation + if typ.Coll == collations.Unknown { + typ.Coll = ast.cfg.Collation } - column := NewColumn(col.V, typ, coll) + column := NewColumn(col.V, typ) if !column.typed() { ast.untyped++ } @@ -217,16 +216,15 @@ func (ast *astCompiler) translateColName(colname *sqlparser.ColName) (Expr, erro if err != nil { return nil, err } - var typ sqltypes.Type = sqltypes.Unknown - var coll collations.ID + var typ Type if ast.cfg.ResolveType != nil { - typ, coll, _ = ast.cfg.ResolveType(colname) + typ, _ = ast.cfg.ResolveType(colname) } - if coll == collations.Unknown { - coll = ast.cfg.Collation + if typ.Coll == collations.Unknown { + typ.Coll = ast.cfg.Collation } - column := NewColumn(idx, typ, coll) + column := NewColumn(idx, typ) if !column.typed() { ast.untyped++ @@ -550,7 +548,7 @@ type astCompiler struct { } type ColumnResolver func(name *sqlparser.ColName) (int, error) -type TypeResolver func(expr sqlparser.Expr) (sqltypes.Type, collations.ID, bool) +type TypeResolver func(expr sqlparser.Expr) (Type, bool) type OptimizationLevel int8 @@ -622,15 +620,15 @@ func (fields FieldResolver) Column(col *sqlparser.ColName) (int, error) { return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unknown column: %q", sqlparser.String(col)) } -func (fields FieldResolver) Type(expr sqlparser.Expr) (sqltypes.Type, collations.ID, bool) { +func (fields FieldResolver) Type(expr sqlparser.Expr) (Type, bool) { switch expr := expr.(type) { case *sqlparser.ColName: name := expr.CompliantName() for _, f := range fields { if f.Name == name { - return f.Type, collations.ID(f.Charset), true + return Type{Type: f.Type, Coll: collations.ID(f.Charset)}, true } } } - return sqltypes.Unknown, collations.Unknown, false + return UnknownType(), false } diff --git a/go/vt/vtgate/planbuilder/collations_test.go b/go/vt/vtgate/planbuilder/collations_test.go index 24fb038b4c2..7eaf3968f74 100644 --- a/go/vt/vtgate/planbuilder/collations_test.go +++ b/go/vt/vtgate/planbuilder/collations_test.go @@ -76,7 +76,7 @@ func TestOrderedAggregateCollations(t *testing.T) { check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { oa, isOA := primitive.(*engine.OrderedAggregate) require.True(t, isOA, "should be an OrderedAggregate") - require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].CollationID) + require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].Type.Coll) }, }, { @@ -85,7 +85,7 @@ func TestOrderedAggregateCollations(t *testing.T) { check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { distinct, isDistinct := primitive.(*engine.Distinct) require.True(t, isDistinct, "should be a distinct") - require.Equal(t, collid(colls[0].collationName), distinct.CheckCols[0].Collation) + require.Equal(t, collid(colls[0].collationName), distinct.CheckCols[0].Type.Coll) }, }, { @@ -97,8 +97,8 @@ func TestOrderedAggregateCollations(t *testing.T) { check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { oa, isOA := primitive.(*engine.OrderedAggregate) require.True(t, isOA, "should be an OrderedAggregate") - require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].CollationID) - require.Equal(t, collid(colls[1].collationName), oa.GroupByKeys[1].CollationID) + require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].Type.Coll) + require.Equal(t, collid(colls[1].collationName), oa.GroupByKeys[1].Type.Coll) }, }, { @@ -109,7 +109,7 @@ func TestOrderedAggregateCollations(t *testing.T) { check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { oa, isOA := primitive.(*engine.OrderedAggregate) require.True(t, isOA, "should be an OrderedAggregate") - require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].CollationID) + require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].Type.Coll) }, }, { @@ -122,7 +122,7 @@ func TestOrderedAggregateCollations(t *testing.T) { require.True(t, isMemSort, "should be a MemorySort") oa, isOA := memSort.Input.(*engine.OrderedAggregate) require.True(t, isOA, "should be an OrderedAggregate") - require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].CollationID) + require.Equal(t, collid(colls[0].collationName), oa.GroupByKeys[0].Type.Coll) }, }, } diff --git a/go/vt/vtgate/planbuilder/expression_converter.go b/go/vt/vtgate/planbuilder/expression_converter.go index 7a9dc374ea6..3720cfe7c24 100644 --- a/go/vt/vtgate/planbuilder/expression_converter.go +++ b/go/vt/vtgate/planbuilder/expression_converter.go @@ -20,7 +20,6 @@ import ( "fmt" "strings" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/mysql/collations" @@ -87,7 +86,7 @@ func (ec *expressionConverter) convert(astExpr sqlparser.Expr, boolean, identifi if !strings.Contains(err.Error(), evalengine.ErrTranslateExprNotSupported) { return nil, err } - evalExpr = evalengine.NewColumn(len(ec.tabletExpressions), sqltypes.Unknown, collations.Unknown) + evalExpr = evalengine.NewColumn(len(ec.tabletExpressions), evalengine.UnknownType()) ec.tabletExpressions = append(ec.tabletExpressions, astExpr) } return evalExpr, nil diff --git a/go/vt/vtgate/planbuilder/expression_converter_test.go b/go/vt/vtgate/planbuilder/expression_converter_test.go index e59df3c7fd1..9259d35becc 100644 --- a/go/vt/vtgate/planbuilder/expression_converter_test.go +++ b/go/vt/vtgate/planbuilder/expression_converter_test.go @@ -21,10 +21,6 @@ import ( "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" - - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -44,7 +40,7 @@ func TestConversion(t *testing.T) { expressionsOut: e(evalengine.NewLiteralInt(1)), }, { expressionsIn: "@@foo", - expressionsOut: e(evalengine.NewColumn(0, sqltypes.Unknown, collations.Unknown)), + expressionsOut: e(evalengine.NewColumn(0, evalengine.UnknownType())), }} for _, tc := range queries { diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 381aeb704c5..8f8ba90a1d6 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -221,17 +221,16 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega aggrParam.Original = aggr.Original aggrParam.OrigOpcode = aggr.OriginalOpCode aggrParam.WCol = aggr.WSOffset - aggrParam.Type, aggrParam.CollationID = aggr.GetTypeCollation(ctx) + aggrParam.Type = aggr.GetTypeCollation(ctx) oa.aggregates = append(oa.aggregates, aggrParam) } for _, groupBy := range op.Grouping { - typ, col, _ := ctx.SemTable.TypeForExpr(groupBy.SimplifiedExpr) + typ, _ := ctx.SemTable.TypeForExpr(groupBy.SimplifiedExpr) oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ KeyCol: groupBy.ColOffset, WeightStringCol: groupBy.WSOffset, Expr: groupBy.AsAliasedExpr().Expr, Type: typ, - CollationID: col, }) } @@ -269,14 +268,13 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin } for idx, order := range ordering.Order { - typ, collationID, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) + typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderByParams{ Col: ordering.Offset[idx], WeightStringCol: ordering.WOffset[idx], Desc: order.Inner.Direction == sqlparser.DescOrder, StarColFixedIndex: ordering.Offset[idx], Type: typ, - CollationID: collationID, }) } @@ -327,8 +325,8 @@ func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr case *operators.EvalEngine: return e.EExpr, nil case operators.Offset: - typ, col, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr) - return evalengine.NewColumn(int(e), typ, col), nil + typ, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr) + return evalengine.NewColumn(int(e), typ), nil default: return nil, vterrors.VT13001("project not planned for: %s", pe.String()) } @@ -501,13 +499,12 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route eroute, err := routeToEngineRoute(ctx, op, hints) for _, order := range op.Ordering { - typ, collation, _ := ctx.SemTable.TypeForExpr(order.AST) + typ, _ := ctx.SemTable.TypeForExpr(order.AST) eroute.OrderBy = append(eroute.OrderBy, engine.OrderByParams{ Col: order.Offset, WeightStringCol: order.WOffset, Desc: order.Direction == sqlparser.DescOrder, Type: typ, - CollationID: collation, }) } if err != nil { @@ -526,9 +523,7 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route } func buildInsertLogicalPlan( - rb *operators.Route, - op ops.Operator, - stmt *sqlparser.Insert, + rb *operators.Route, op ops.Operator, stmt *sqlparser.Insert, hints *queryHints, ) (logicalPlan, error) { ins := op.(*operators.Insert) diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index d6bbdff8088..d9121586185 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -51,7 +51,7 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) { for idx, col := range columns { e := d.QP.GetSimplifiedExpr(col.Expr) var wsCol *int - typ, coll, _ := ctx.SemTable.TypeForExpr(e) + typ, _ := ctx.SemTable.TypeForExpr(e) if ctx.SemTable.NeedsWeightString(e) { offset := d.Source.AddColumn(ctx, true, false, aeWrap(weightStringFor(e))) @@ -59,10 +59,9 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) { } d.Columns = append(d.Columns, engine.CheckCol{ - Col: idx, - WsCol: wsCol, - Type: typ, - Collation: coll, + Col: idx, + WsCol: wsCol, + Type: typ, }) } } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 50bbf3e1720..161245ab2fd 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -24,11 +24,10 @@ import ( "sort" "strings" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" @@ -114,17 +113,17 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { return aggr.OpCode.NeedsComparableValues() && ctx.SemTable.NeedsWeightString(aggr.Func.GetArg()) } -func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) (sqltypes.Type, collations.ID) { +func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type { if aggr.Func == nil { - return sqltypes.Unknown, collations.Unknown + return evalengine.UnknownType() } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: - typ, col, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg()) - return typ, col + typ, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg()) + return typ } - return sqltypes.Unknown, collations.Unknown + return evalengine.UnknownType() } // NewGroupBy creates a new group by from the given fields. diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 20bf4e50580..ba3344d70c0 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -21,6 +21,7 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -115,7 +116,7 @@ func (a *analyzer) newSemTable(statement sqlparser.Statement, coll collations.ID return &SemTable{ Recursive: a.binder.recursive, Direct: a.binder.direct, - ExprTypes: a.typer.exprTypes, + ExprTypes: a.typer.m, Tables: a.tables.Tables, NotSingleRouteErr: a.projErr, NotUnshardedErr: a.unshardedErr, @@ -271,17 +272,13 @@ func isParentSelectStatement(cursor *sqlparser.Cursor) bool { type originable interface { tableSetFor(t *sqlparser.AliasedTableExpr) TableSet - depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ *Type) + depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) } -func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ *Type) { +func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) { recursive = a.binder.recursive.dependencies(expr) direct = a.binder.direct.dependencies(expr) - qt, isFound := a.typer.exprTypes[expr] - if !isFound { - return - } - typ = &qt + typ = a.typer.exprType(expr) return } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index fdf795114a8..dfcc143d073 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -409,9 +409,9 @@ func TestUnknownColumnMap2(t *testing.T) { } else { require.NoError(t, err) require.NoError(t, tbl.NotSingleRouteErr) - typ, _, found := tbl.TypeForExpr(expr) + typ, found := tbl.TypeForExpr(expr) assert.True(t, found) - assert.Equal(t, test.typ, typ) + assert.Equal(t, test.typ, typ.Type) } }) } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index e3fed7e5a68..0d70816488a 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -19,6 +19,7 @@ package semantics import ( "strings" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" ) @@ -87,8 +88,8 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } b.recursive[node] = deps.recursive b.direct[node] = deps.direct - if deps.typ != nil { - b.typer.setTypeFor(node, *deps.typ) + if deps.typ.Type != sqltypes.Unknown { + b.typer.setTypeFor(node, deps.typ) } case *sqlparser.CountStar: b.bindCountStar(node) @@ -102,8 +103,8 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { for i, expr := range info.exprs { ae := expr.(*sqlparser.AliasedExpr) b.recursive[ae.Expr] = info.recursive[i] - if t := info.types[i]; t != nil { - b.typer.exprTypes[ae.Expr] = *t + if t := info.types[i]; t.Type != sqltypes.Unknown { + b.typer.m[ae.Expr] = t } } } diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index 8e5a481e17d..d93d895c8e3 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -20,6 +20,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) type ( @@ -33,7 +34,7 @@ type ( dependency struct { direct TableSet recursive TableSet - typ *Type + typ evalengine.Type } nothing struct{} certain struct { @@ -48,14 +49,15 @@ type ( var ambigousErr = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ambiguous") -func createCertain(direct TableSet, recursive TableSet, qt *Type) *certain { +func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *certain { c := &certain{ dependency: dependency{ direct: direct, recursive: recursive, + typ: evalengine.UnknownType(), }, } - if qt != nil && qt.Type != querypb.Type_NULL_TYPE { + if qt.Type != querypb.Type_NULL_TYPE { c.typ = qt } return c diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index a88f39cf8af..0498a26a429 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -22,6 +22,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -35,13 +36,13 @@ type DerivedTable struct { isAuthoritative bool recursive []TableSet - types []*Type + types []evalengine.Type } type unionInfo struct { isAuthoritative bool recursive []TableSet - types []*Type + types []evalengine.Type exprs sqlparser.SelectExprs } @@ -54,7 +55,7 @@ func createDerivedTableForExpressions( org originable, expanded bool, recursiveDeps []TableSet, - types []*Type, + types []evalengine.Type, ) *DerivedTable { vTbl := &DerivedTable{isAuthoritative: expanded, recursive: recursiveDeps, types: types} for i, selectExpr := range expressions { diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index bd57ab81474..9952e041378 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -24,6 +24,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -42,7 +43,7 @@ func (r *RealTable) dependencies(colName string, org originable) (dependencies, ts := org.tableSetFor(r.ASTNode) for _, info := range r.getColumns() { if strings.EqualFold(info.Name, colName) { - return createCertain(ts, ts, &info.Type), nil + return createCertain(ts, ts, info.Type), nil } } @@ -114,9 +115,9 @@ func vindexTableToColumnInfo(tbl *vindexes.Table) []ColumnInfo { cols = append(cols, ColumnInfo{ Name: col.Name.String(), - Type: Type{ - Type: col.Type, - Collation: collation, + Type: evalengine.Type{ + Type: col.Type, + Coll: collation, }, }) nameMap[col.Name.String()] = nil diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 4e31d7ebd8e..af37a9b34d1 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -27,6 +27,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -62,7 +63,7 @@ type ( // ColumnInfo contains information about columns ColumnInfo struct { Name string - Type Type + Type evalengine.Type } // ExprDependencies stores the tables that an expression depends on as a map @@ -88,7 +89,7 @@ type ( // from the connection's default collation. Collation collations.ID // ExprTypes maps expressions to their respective types in the query. - ExprTypes map[sqlparser.Expr]Type + ExprTypes map[sqlparser.Expr]evalengine.Type // NotSingleRouteErr stores errors related to missing schema information. // This typically occurs when a column's existence is uncertain. @@ -492,9 +493,9 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel } // TypeForExpr returns the type of expressions in the query -func (st *SemTable) TypeForExpr(e sqlparser.Expr) (sqltypes.Type, collations.ID, bool) { +func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { if typ, found := st.ExprTypes[e]; found { - return typ.Type, typ.Collation, true + return typ, true } // We add a lot of WeightString() expressions to queries at late stages of the planning, @@ -502,10 +503,14 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (sqltypes.Type, collations.ID, // are VarBinary, since that's the only type that WeightString() can return. _, isWS := e.(*sqlparser.WeightStringFuncExpr) if isWS { - return sqltypes.VarBinary, collations.CollationBinaryID, true + return evalengine.Type{ + Type: sqltypes.VarBinary, + Coll: collations.CollationBinaryID, + Nullable: false, // TODO: we should check if the argument is nullable + }, true } - return sqltypes.Unknown, collations.Unknown, false + return evalengine.UnknownType(), false } // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons @@ -518,7 +523,7 @@ func (st *SemTable) NeedsWeightString(e sqlparser.Expr) bool { if !found { return true } - return typ.Collation == collations.Unknown && !sqltypes.IsNumber(typ.Type) + return typ.Coll == collations.Unknown && !sqltypes.IsNumber(typ.Type) } } diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index d6fd4c6efd6..3aa46c87bfa 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -21,6 +21,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -62,7 +63,7 @@ func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { size := len(firstSelect.SelectExprs) info.recursive = make([]TableSet, size) - info.types = make([]*Type, size) + info.types = make([]evalengine.Type, size) _ = sqlparser.VisitAllSelects(node, func(s *sqlparser.Select, idx int) error { for i, expr := range s.SelectExprs { @@ -126,7 +127,7 @@ func (tc *tableCollector) addSelectDerivedTable(sel *sqlparser.Select, node *sql tables := tc.scoper.wScope[sel] size := len(sel.SelectExprs) deps := make([]TableSet, size) - types := make([]*Type, size) + types := make([]evalengine.Type, size) expanded := true for i, expr := range sel.SelectExprs { ae, ok := expr.(*sqlparser.AliasedExpr) diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 6652f1a476b..b43ea49c4d1 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -19,36 +19,39 @@ package semantics import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" - querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // typer is responsible for setting the type for expressions // it does it's work after visiting the children (up), since the children types is often needed to type a node. type typer struct { - exprTypes map[sqlparser.Expr]Type -} - -// Type is the normal querypb.Type with collation -type Type struct { - Type querypb.Type - Collation collations.ID + m map[sqlparser.Expr]evalengine.Type } func newTyper() *typer { return &typer{ - exprTypes: map[sqlparser.Expr]Type{}, + m: map[sqlparser.Expr]evalengine.Type{}, + } +} + +func (t *typer) exprType(expr sqlparser.Expr) evalengine.Type { + res, ok := t.m[expr] + if ok { + return res } + + return evalengine.UnknownType() } func (t *typer) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.Literal: - t.exprTypes[node] = Type{Type: node.SQLType(), Collation: collations.DefaultCollationForType(node.SQLType())} + t.m[node] = evalengine.Type{Type: node.SQLType(), Coll: collations.DefaultCollationForType(node.SQLType())} case *sqlparser.Argument: if node.Type >= 0 { - t.exprTypes[node] = Type{Type: node.Type, Collation: collations.DefaultCollationForType(node.Type)} + t.m[node] = evalengine.Type{Type: node.Type, Coll: collations.DefaultCollationForType(node.Type)} } case sqlparser.AggrFunc: code, ok := opcode.SupportedAggregates[node.AggrName()] @@ -57,17 +60,17 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { } var inputType sqltypes.Type if arg := node.GetArg(); arg != nil { - t, ok := t.exprTypes[arg] + t, ok := t.m[arg] if ok { inputType = t.Type } } type_ := code.Type(inputType) - t.exprTypes[node] = Type{Type: type_, Collation: collations.DefaultCollationForType(type_)} + t.m[node] = evalengine.Type{Type: type_, Coll: collations.DefaultCollationForType(type_)} } return nil } -func (t *typer) setTypeFor(node *sqlparser.ColName, typ Type) { - t.exprTypes[node] = typ +func (t *typer) setTypeFor(node *sqlparser.ColName, typ evalengine.Type) { + t.m[node] = typ } diff --git a/go/vt/vttablet/tabletmanager/vdiff/utils.go b/go/vt/vttablet/tabletmanager/vdiff/utils.go index 12ea1e8a68c..d756e6f6984 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/utils.go +++ b/go/vt/vttablet/tabletmanager/vdiff/utils.go @@ -21,6 +21,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/binlog/binlogplayer" "vitess.io/vitess/go/vt/log" @@ -40,11 +42,14 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare for i, cpk := range comparePKs { weightStringCol := -1 // if the collation is nil or unknown, use binary collation to compare as bytes - if cpk.collation == collations.Unknown { - ob[i] = engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: sqltypes.Unknown, CollationID: collations.CollationBinaryID} - } else { - ob[i] = engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: sqltypes.Unknown, CollationID: cpk.collation} + t := evalengine.Type{ + Type: sqltypes.Unknown, + Coll: collations.CollationBinaryID, + } + if cpk.collation != collations.Unknown { + t.Coll = cpk.collation } + ob[i] = engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t} } return &engine.MergeSort{ Primitives: prims, @@ -52,7 +57,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare } } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // Utility functions func encodeString(in string) string { @@ -64,7 +69,7 @@ func encodeString(in string) string { func pkColsToGroupByParams(pkCols []int) []*engine.GroupByParams { var res []*engine.GroupByParams for _, col := range pkCols { - res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1, Type: sqltypes.Unknown}) + res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1, Type: evalengine.UnknownType()}) } return res } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 85c82bb3574..3311d376431 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -499,8 +499,8 @@ func findPKs(table *tabletmanagerdatapb.TableDefinition, targetSelect *sqlparser switch ct := expr.(type) { case *sqlparser.ColName: colname = ct.Name.String() - case *sqlparser.FuncExpr: //eg. weight_string() - //no-op + case *sqlparser.FuncExpr: // eg. weight_string() + // no-op default: log.Warningf("Not considering column %v for PK, type %v not handled", selExpr, ct) } @@ -769,7 +769,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer func pkColsToGroupByParams(pkCols []int) []*engine.GroupByParams { var res []*engine.GroupByParams for _, col := range pkCols { - res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1, Type: sqltypes.Unknown}) + res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1, Type: evalengine.UnknownType()}) } return res } @@ -784,11 +784,11 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare for _, cpk := range comparePKs { weightStringCol := -1 // if the collation is nil or unknown, use binary collation to compare as bytes - if cpk.collation == collations.Unknown { - ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: sqltypes.Unknown, CollationID: collations.CollationBinaryID}) - } else { - ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: sqltypes.Unknown, CollationID: cpk.collation}) + t := evalengine.Type{Type: sqltypes.Unknown, Coll: collations.CollationBinaryID} + if cpk.collation != collations.Unknown { + t.Coll = cpk.collation } + ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t}) } return &engine.MergeSort{ Primitives: prims, @@ -1058,7 +1058,7 @@ func (df *vdiff) forAll(participants map[string]*shardStreamer, f func(string, * return allErrors.AggrError(vterrors.Aggregate) } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // primitiveExecutor // primitiveExecutor starts execution on the top level primitive @@ -1118,7 +1118,7 @@ func (pe *primitiveExecutor) drain(ctx context.Context) (int, error) { } } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // shardStreamer func (sm *shardStreamer) StreamExecute(ctx context.Context, vcursor engine.VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { @@ -1153,7 +1153,7 @@ func humanInt(n int64) string { // nolint return fmt.Sprintf("%s%s", s, unit) } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // tableDiffer func (td *tableDiffer) diff(ctx context.Context, rowsToCompare *int64, debug, onlyPks bool, maxExtraRowsToCompare int) (*DiffReport, error) { @@ -1375,7 +1375,7 @@ func (td *tableDiffer) genDebugQueryDiff(sel *sqlparser.Select, row []sqltypes.V return buf.String() } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // contextVCursor // contextVCursor satisfies VCursor interface @@ -1395,7 +1395,7 @@ func (vc *contextVCursor) StreamExecutePrimitive(ctx context.Context, primitive return primitive.TryStreamExecute(ctx, vc, bindVars, wantfields, callback) } -//----------------------------------------------------------------- +// ----------------------------------------------------------------- // Utility functions func removeKeyrange(where *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index ac57c9bcf68..28422b6cd4d 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -23,6 +23,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -404,7 +406,7 @@ func TestVDiffPlanSuccess(t *testing.T) { engine.NewAggregateParam(opcode.AggregateSum, 2, ""), engine.NewAggregateParam(opcode.AggregateSum, 3, ""), }, - GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1, Type: sqltypes.Unknown}}, + GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1, Type: evalengine.UnknownType()}}, Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}), }, targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}),