From b0521ad8c2aefba47f929803741b734eb8106d8f Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 23 Apr 2024 09:10:24 +0200 Subject: [PATCH 1/3] evalengine: Add support for enum and set The evalengine currently doesn't handle enum and set types properly. The comparison function always returns 0 at the moment and we don't consider ordering of elements etc. Here we add two new native types to the evalengine for set and enum and instantiate those appropriately. We also ensure we can compare them correctly. In case we don't have the schema information with the values, we do a best effort case of depending on the string representation. This is not correct always of course, but at least makes equality comparison work for those cases and only ordering is off in that scenario. Signed-off-by: Dirkjan Bussink --- go/mysql/json/helpers.go | 4 + go/sqltypes/testing.go | 36 +++ go/sqltypes/type.go | 10 + go/sqltypes/value.go | 10 + .../endtoend/vtgate/queries/misc/misc_test.go | 15 +- .../endtoend/vtgate/queries/misc/schema.sql | 8 + .../endtoend/vtgate/queries/misc/vschema.json | 8 + go/vt/vtexplain/vtexplain_vttablet.go | 2 +- go/vt/vtgate/engine/aggregations.go | 9 +- go/vt/vtgate/engine/cached_size.go | 29 +- go/vt/vtgate/engine/distinct.go | 4 +- go/vt/vtgate/engine/distinct_test.go | 2 +- go/vt/vtgate/engine/hash_join.go | 13 +- go/vt/vtgate/engine/opcode/constants.go | 2 +- go/vt/vtgate/engine/ordered_aggregate.go | 4 +- go/vt/vtgate/evalengine/api_aggregation.go | 7 +- .../vtgate/evalengine/api_aggregation_test.go | 4 +- go/vt/vtgate/evalengine/api_coerce.go | 4 +- go/vt/vtgate/evalengine/api_compare.go | 14 +- go/vt/vtgate/evalengine/api_compare_test.go | 54 +++- go/vt/vtgate/evalengine/api_hash.go | 20 +- go/vt/vtgate/evalengine/api_hash_test.go | 24 +- go/vt/vtgate/evalengine/api_literal.go | 1 + .../vtgate/evalengine/api_type_aggregation.go | 2 +- go/vt/vtgate/evalengine/arena.go | 32 +++ go/vt/vtgate/evalengine/cached_size.go | 80 +++++- go/vt/vtgate/evalengine/compare.go | 37 +++ go/vt/vtgate/evalengine/compiler.go | 31 ++- go/vt/vtgate/evalengine/compiler_asm_push.go | 36 +++ go/vt/vtgate/evalengine/eval.go | 30 +- go/vt/vtgate/evalengine/eval_enum.go | 37 +++ go/vt/vtgate/evalengine/eval_numeric.go | 16 ++ go/vt/vtgate/evalengine/eval_set.go | 49 ++++ go/vt/vtgate/evalengine/expr_bvar.go | 4 +- go/vt/vtgate/evalengine/expr_column.go | 12 +- go/vt/vtgate/evalengine/expr_compare.go | 14 +- go/vt/vtgate/evalengine/expr_env.go | 2 +- go/vt/vtgate/evalengine/expr_tuple_bvar.go | 2 +- go/vt/vtgate/evalengine/translate.go | 4 +- go/vt/vtgate/evalengine/weights.go | 25 +- go/vt/vtgate/evalengine/weights_test.go | 38 +-- .../planbuilder/operator_transformers.go | 1 + go/vt/vtgate/semantics/semantic_state.go | 2 +- go/vt/vtgate/vindexes/consistent_lookup.go | 2 +- go/vt/vtgate/vindexes/vschema.go | 2 +- .../tabletmanager/vdiff/table_differ.go | 2 +- .../vreplication/replicator_plan.go | 2 +- .../tabletserver/schema/load_table.go | 2 +- .../tabletserver/vstreamer/planbuilder.go | 2 +- go/vt/wrangler/vdiff.go | 18 +- go/vt/wrangler/vdiff_test.go | 260 +++++++++++------- 51 files changed, 796 insertions(+), 232 deletions(-) create mode 100644 go/vt/vtgate/evalengine/eval_enum.go create mode 100644 go/vt/vtgate/evalengine/eval_set.go diff --git a/go/mysql/json/helpers.go b/go/mysql/json/helpers.go index 1df38b2d769..760d59c5624 100644 --- a/go/mysql/json/helpers.go +++ b/go/mysql/json/helpers.go @@ -106,6 +106,10 @@ func NewFromSQL(v sqltypes.Value) (*Value, error) { return NewDate(v.RawStr()), nil case v.IsTime(): return NewTime(v.RawStr()), nil + case v.IsEnum(): + return NewString(v.RawStr()), nil + case v.IsSet(): + return NewString(v.RawStr()), nil default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v) } diff --git a/go/sqltypes/testing.go b/go/sqltypes/testing.go index 2fd9ee9c2be..f67cd1c6deb 100644 --- a/go/sqltypes/testing.go +++ b/go/sqltypes/testing.go @@ -279,6 +279,12 @@ var RandomGenerators = map[Type]RandomGenerator{ } return v }, + Enum: func() Value { + return MakeTrusted(Enum, randEnum()) + }, + Set: func() Value { + return MakeTrusted(Set, randSet()) + }, } func randTime() time.Time { @@ -289,3 +295,33 @@ func randTime() time.Time { sec := rand.Int64N(delta) + min return time.Unix(sec, 0) } + +func randEnum() []byte { + enums := []string{ + "xxsmall", + "xsmall", + "small", + "medium", + "large", + "xlarge", + "xxlarge", + } + return []byte(enums[rand.IntN(len(enums))]) +} + +func randSet() []byte { + set := []string{ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + } + rand.Shuffle(len(set), func(i, j int) { + set[i], set[j] = set[j], set[i] + }) + set = set[:rand.IntN(len(set))] + return []byte(strings.Join(set, ",")) +} diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index 964dd6b5d83..4090dd0107a 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -119,6 +119,16 @@ func IsNull(t querypb.Type) bool { return t == Null } +// IsEnum returns true if the type is Enum type +func IsEnum(t querypb.Type) bool { + return t == Enum +} + +// IsSet returns true if the type is Set type +func IsSet(t querypb.Type) bool { + return t == Set +} + // Vitess data types. These are idiomatically named synonyms for the querypb.Type values. // Although these constants are interchangeable, they should be treated as different from querypb.Type. // Use the synonyms only to refer to the type in Value. For proto variables, use the querypb.Type constants instead. diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index b8f05e02db3..99a0a43828e 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -568,6 +568,16 @@ func (v Value) IsDecimal() bool { return IsDecimal(v.Type()) } +// IsEnum returns true if Value is time. +func (v Value) IsEnum() bool { + return v.Type() == querypb.Type_ENUM +} + +// IsSet returns true if Value is time. +func (v Value) IsSet() bool { + return v.Type() == querypb.Type_SET +} + // IsComparable returns true if the Value is null safe comparable without collation information. func (v *Value) IsComparable() bool { if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) { diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index c10cb4c9b71..d0c610084cd 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -37,7 +37,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { require.NoError(t, err) deleteAll := func() { - tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "uks.unsharded"} + tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "tbl_enum_set", "uks.unsharded"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -452,3 +452,16 @@ func TestStraightJoin(t *testing.T) { require.NoError(t, err) require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl") } + +func TestEnumSetVals(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema)) + + mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')") + + mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`) + mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`) +} diff --git a/go/test/endtoend/vtgate/queries/misc/schema.sql b/go/test/endtoend/vtgate/queries/misc/schema.sql index 6fd57b9183d..685500ec809 100644 --- a/go/test/endtoend/vtgate/queries/misc/schema.sql +++ b/go/test/endtoend/vtgate/queries/misc/schema.sql @@ -27,3 +27,11 @@ create table tbl primary key (id), unique (unq_col) ) Engine = InnoDB; + +create table tbl_enum_set +( + id bigint, + enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'), + set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'), + primary key (id) +) Engine = InnoDB; diff --git a/go/test/endtoend/vtgate/queries/misc/vschema.json b/go/test/endtoend/vtgate/queries/misc/vschema.json index f56b1fc1b36..d3d7c3b7935 100644 --- a/go/test/endtoend/vtgate/queries/misc/vschema.json +++ b/go/test/endtoend/vtgate/queries/misc/vschema.json @@ -53,6 +53,14 @@ } ] }, + "tbl_enum_set": { + "column_vindexes": [ + { + "column": "id", + "name": "hash" + } + ] + }, "unq_idx": { "column_vindexes": [ { diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 53e09445c17..6f28cd99ec0 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -755,7 +755,7 @@ func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap m // Check if we have a duplicate value isNewValue := true for _, v := range inVal { - result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset()) + result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset(), nil) if err != nil { return "", nil, 0, nil, err } diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index ea10267a7e6..b033e9fbb0e 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -107,6 +107,7 @@ type aggregatorDistinct struct { last sqltypes.Value coll collations.ID collationEnv *collations.Environment + values []string } func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { @@ -115,7 +116,7 @@ func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { next := row[a.column] if !last.IsNull() { if last.TinyWeightCmp(next) == 0 { - cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll) + cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll, a.values) if err != nil { return true, err } @@ -386,6 +387,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg column: distinct, coll: aggr.Type.Collation(), collationEnv: aggr.CollationEnv, + values: aggr.Type.Values(), }, } @@ -405,6 +407,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg column: distinct, coll: aggr.Type.Collation(), collationEnv: aggr.CollationEnv, + values: aggr.Type.Values(), }, } @@ -412,7 +415,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMin{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()), }, } @@ -420,7 +423,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg ag = &aggregatorMax{ aggregatorMinMax{ from: aggr.Col, - minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()), + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()), }, } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 5ff7a7c96ce..410f024149c 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -35,8 +35,10 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(112) + size += int64(128) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field Alias string size += hack.RuntimeAllocSize(int64(len(cached.Alias))) // field Expr vitess.io/vitess/go/vt/sqlparser.Expr @@ -69,10 +71,12 @@ func (cached *CheckCol) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } // field WsCol *int size += hack.RuntimeAllocSize(int64(8)) + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -235,7 +239,7 @@ func (cached *Distinct) CachedSize(alloc bool) int64 { } // field CheckCols []vitess.io/vitess/go/vt/vtgate/engine.CheckCol { - size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(40)) + size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(64)) for _, elem := range cached.CheckCols { size += elem.CachedSize(false) } @@ -382,12 +386,14 @@ func (cached *GroupByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(96) } // field Expr vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Expr.(cachedObject); ok { size += cc.CachedSize(true) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -398,7 +404,7 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(112) + size += int64(144) } // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Left.(cachedObject); ok { @@ -418,6 +424,13 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) + // field Values []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) + for _, elem := range cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } return size } func (cached *Insert) CachedSize(alloc bool) int64 { @@ -657,7 +670,7 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -687,7 +700,7 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -897,7 +910,7 @@ func (cached *Route) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.FieldQuery))) // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index c47cf6be8d1..189440611c3 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -74,14 +74,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error) return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) + err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values()) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { return vthash.Hash{}, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) + err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values()) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index cb414d8de28..d7fe8786158 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -90,7 +90,7 @@ func TestDistinct(t *testing.T) { } checkCols = append(checkCols, CheckCol{ Col: i, - Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0), + Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0, nil), CollationEnv: collations.MySQL8(), }) } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index f7c9d87e1fb..89dbf1190ae 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -67,6 +67,9 @@ type ( ComparisonType querypb.Type CollationEnv *collations.Environment + + // Values for enum and set types + Values []string } hashJoinProbeTable struct { @@ -78,6 +81,7 @@ type ( cols []int hasher vthash.Hasher sqlmode evalengine.SQLMode + values []string } probeTableEntry struct { @@ -94,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma return nil, err } - pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values) // build the probe table from the LHS result for _, row := range lresult.Rows { err := pt.addLeftRow(row) @@ -130,7 +134,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // TryStreamExecute implements the Primitive interface func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { // build the probe table from the LHS result - pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values) var lfields []*querypb.Field var mu sync.Mutex err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error { @@ -260,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription { } } -func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable { +func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values []string) *hashJoinProbeTable { return &hashJoinProbeTable{ innerMap: map[vthash.Hash]*probeTableEntry{}, coll: coll, @@ -269,6 +273,7 @@ func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey rhsKey: rhsKey, cols: cols, hasher: vthash.New(), + values: values, } } @@ -286,7 +291,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error { } func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) { - err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode) + err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode, pt.values) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 1bdbe61fd65..28c09de0fd6 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -180,7 +180,7 @@ func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Envir if code == AggregateAvg { scale += 4 } - return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale) + return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale, t.Values()) } func (code AggregateOpcode) NeedsComparableValues() bool { diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index ade8cd00299..5a72bdf4501 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -344,14 +344,14 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n return nextRow, true, nil } - cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation()) + cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation(), gb.Type.Values()) if err != nil { _, isCollationErr := err.(evalengine.UnsupportedCollationError) if !isCollationErr || gb.WeightStringCol == -1 { return nil, false, err } gb.KeyCol = gb.WeightStringCol - cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation()) + cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation(), gb.Type.Values()) if err != nil { return nil, false, err } diff --git a/go/vt/vtgate/evalengine/api_aggregation.go b/go/vt/vtgate/evalengine/api_aggregation.go index 0566f477a3c..8584261b654 100644 --- a/go/vt/vtgate/evalengine/api_aggregation.go +++ b/go/vt/vtgate/evalengine/api_aggregation.go @@ -448,6 +448,7 @@ type aggregationMinMax struct { current sqltypes.Value collation collations.ID collationEnv *collations.Environment + values []string } func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { @@ -458,7 +459,7 @@ func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { a.current = value return nil } - n, err := compare(a.current, value, a.collationEnv, a.collation) + n, err := compare(a.current, value, a.collationEnv, a.collation, a.values) if err != nil { return err } @@ -484,7 +485,7 @@ func (a *aggregationMinMax) Reset() { a.current = sqltypes.NULL } -func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID) MinMax { +func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID, values []string) MinMax { switch { case sqltypes.IsSigned(typ): return &aggregationInt{t: typ} @@ -495,6 +496,6 @@ func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environmen case sqltypes.IsDecimal(typ): return &aggregationDecimal{} default: - return &aggregationMinMax{collation: collation, collationEnv: collationEnv} + return &aggregationMinMax{collation: collation, collationEnv: collationEnv, values: values} } } diff --git a/go/vt/vtgate/evalengine/api_aggregation_test.go b/go/vt/vtgate/evalengine/api_aggregation_test.go index e5dae47017e..05884b4bb4b 100644 --- a/go/vt/vtgate/evalengine/api_aggregation_test.go +++ b/go/vt/vtgate/evalengine/api_aggregation_test.go @@ -137,7 +137,7 @@ func TestMinMax(t *testing.T) { for i, tcase := range tcases { t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run("Min", func(t *testing.T) { - agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll) + agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil) for _, v := range tcase.values { err := agg.Min(v) @@ -153,7 +153,7 @@ func TestMinMax(t *testing.T) { }) t.Run("Max", func(t *testing.T) { - agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll) + agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil) for _, v := range tcase.values { err := agg.Max(v) diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 907c578df8a..eef83c58422 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -24,7 +24,7 @@ import ( ) func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, error) { - cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode) + cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, typ.values, sqlmode) if err != nil { return sqltypes.Value{}, err } @@ -33,7 +33,7 @@ func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, // CoerceTypes takes two input types, and decides how they should be coerced before compared func CoerceTypes(v1, v2 Type, collationEnv *collations.Environment) (out Type, err error) { - if v1 == v2 { + if v1.Equal(&v2) { return v1, nil } if sqltypes.IsNull(v1.Type()) || sqltypes.IsNull(v2.Type()) { diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index c6278264a47..e890e7c83fd 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -43,7 +43,7 @@ func (err UnsupportedCollationError) Error() string { // UnsupportedCollationHashError is returned when we try to get the hash value and are missing the collation to use var UnsupportedCollationHashError = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") -func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID) (int, error) { +func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values []string) (int, error) { v1t := v1.Type() // We have a fast path here for the case where both values are @@ -115,7 +115,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat Collation: collationID, Coercibility: collations.CoerceImplicit, Repertoire: collations.RepertoireUnicode, - }) + }, values) if err != nil { return 0, err } @@ -124,7 +124,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat Collation: collationID, Coercibility: collations.CoerceImplicit, Repertoire: collations.RepertoireUnicode, - }) + }, values) if err != nil { return 0, err } @@ -147,7 +147,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat // numeric, then a numeric comparison is performed after // necessary conversions. If none are numeric, then it's // a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID) (int, error) { +func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values []string) (int, error) { // Based on the categorization defined for the types, // we're going to allow comparison of the following: // Null, isNumber, IsBinary. This will exclude IsQuoted @@ -161,7 +161,7 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment if v2.IsNull() { return 1, nil } - return compare(v1, v2, collationEnv, collationID) + return compare(v1, v2, collationEnv, collationID, values) } // OrderByParams specifies the parameters for ordering. @@ -213,7 +213,7 @@ func (obp *OrderByParams) Compare(r1, r2 []sqltypes.Value) int { if cmp == 0 { var err error - cmp, err = NullsafeCompare(v1, v2, obp.CollationEnv, obp.Type.Collation()) + cmp, err = NullsafeCompare(v1, v2, obp.CollationEnv, obp.Type.Collation(), obp.Type.values) if err != nil { _, isCollationErr := err.(UnsupportedCollationError) if !isCollationErr || obp.WeightStringCol == -1 { @@ -222,7 +222,7 @@ func (obp *OrderByParams) Compare(r1, r2 []sqltypes.Value) int { // in case of a comparison or collation error switch to using the weight string column for ordering obp.Col = obp.WeightStringCol obp.WeightStringCol = -1 - cmp, err = NullsafeCompare(r1[obp.Col], r2[obp.Col], obp.CollationEnv, obp.Type.Collation()) + cmp, err = NullsafeCompare(r1[obp.Col], r2[obp.Col], obp.CollationEnv, obp.Type.Collation(), obp.Type.values) if err != nil { panic(err) } diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index aa039537240..778a252e2d8 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -1109,11 +1109,12 @@ func TestNullComparisons(t *testing.T) { } func TestNullsafeCompare(t *testing.T) { - collation := collationEnv.LookupByName("utf8mb4_general_ci") + collation := collations.ID(collations.CollationUtf8mb4ID) tcases := []struct { v1, v2 sqltypes.Value out int err error + values []string }{ { v1: NULL, @@ -1140,23 +1141,60 @@ func TestNullsafeCompare(t *testing.T) { v2: TestValue(sqltypes.VarChar, " 6736380880502626304.000000 aa"), out: -1, }, + { + v1: TestValue(sqltypes.Enum, "foo"), + v2: TestValue(sqltypes.Enum, "bar"), + out: -1, + values: []string{"'foo'", "'bar'"}, + }, { v1: TestValue(sqltypes.Enum, "foo"), v2: TestValue(sqltypes.Enum, "bar"), out: 1, }, + { + v1: TestValue(sqltypes.Enum, "foo"), + v2: TestValue(sqltypes.VarChar, "bar"), + out: 1, + values: []string{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.VarChar, "foo"), + v2: TestValue(sqltypes.Enum, "bar"), + out: 1, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + values: []string{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + }, + { + v1: TestValue(sqltypes.VarChar, "bar"), + v2: TestValue(sqltypes.Set, "foo,bar"), + out: -1, + values: []string{"'foo'", "'bar'"}, + }, + { + v1: TestValue(sqltypes.Set, "bar"), + v2: TestValue(sqltypes.VarChar, "foo,bar"), + out: -1, + }, } for _, tcase := range tcases { t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := NullsafeCompare(tcase.v1, tcase.v2, collations.MySQL8(), collation) + got, err := NullsafeCompare(tcase.v1, tcase.v2, collations.MySQL8(), collation, tcase.values) if tcase.err != nil { require.EqualError(t, err, tcase.err.Error()) return } require.NoError(t, err) - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) - } + assert.Equal(t, tcase.out, got) }) } } @@ -1237,7 +1275,7 @@ func TestNullsafeCompareCollate(t *testing.T) { } for _, tcase := range tcases { t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), collations.MySQL8(), tcase.collation) + got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), collations.MySQL8(), tcase.collation, nil) if tcase.err == nil { require.NoError(t, err) } else { @@ -1288,7 +1326,7 @@ func BenchmarkNullSafeComparison(b *testing.B) { for i := 0; i < b.N; i++ { for _, lhs := range inputs { for _, rhs := range inputs { - _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collid) + _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collid, nil) } } } @@ -1318,7 +1356,7 @@ func BenchmarkNullSafeComparison(b *testing.B) { for i := 0; i < b.N; i++ { for _, lhs := range inputs { for _, rhs := range inputs { - _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collations.CollationUtf8mb4ID) + _, _ = NullsafeCompare(lhs, rhs, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) } } } diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 2d3bc2d3b56..0ed3e0c4146 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,8 +34,8 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode) (HashCode, error) { - e, err := valueToEvalCast(v, coerceType, collation, sqlmode) +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode, values []string) (HashCode, error) { + e, err := valueToEvalCast(v, coerceType, collation, values, sqlmode) if err != nil { return 0, err } @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, " // for two values that are considered equal by `NullsafeCompare`. // This can be used to avoid having to do comparison checks after a hash, // since we consider the 128 bits of entropy enough to guarantee uniqueness. -func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { +func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values []string) error { switch { case v.IsNull(), sqltypes.IsNull(coerceTo): hash.Write16(hashPrefixNil) @@ -97,7 +97,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case v.IsText(), v.IsBinary(): f, _ = fastparse.ParseFloat64(v.RawStr()) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -137,7 +137,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat } neg = i < 0 default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -180,7 +180,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat u, err = uint64(fval), nil } default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } if err != nil { return err @@ -223,20 +223,20 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } hash.Write16(hashPrefixDecimal) dec.Hash(hash) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode, values) } return nil } -func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values []string) error { // Slow path to handle all other types. This uses the generic // logic for value casting to ensure we match MySQL here. - e, err := valueToEvalCast(v, coerceTo, collation, sqlmode) + e, err := valueToEvalCast(v, coerceTo, collation, values, sqlmode) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 7a680892712..bb2652ec6f2 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -52,14 +52,14 @@ func TestHashCodes(t *testing.T) { for _, tc := range cases { t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID) + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.NoError(t, err) - h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.ErrorIs(t, err, tc.err) assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) @@ -77,14 +77,14 @@ func TestHashCodesRandom(t *testing.T) { for time.Now().Before(endTime) { tested++ v1, v2 := sqltypes.TestRandomValues() - cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation) + cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation, nil) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) - hash1, err := NullsafeHashcode(v1, collation, typ, 0) + hash1, err := NullsafeHashcode(v1, collation, typ, 0, nil) require.NoError(t, err) - hash2, err := NullsafeHashcode(v2, collation, typ, 0) + hash2, err := NullsafeHashcode(v2, collation, typ, 0, nil) require.NoError(t, err) if cmp == 0 { equal++ @@ -137,16 +137,16 @@ func TestHashCodes128(t *testing.T) { for _, tc := range cases { t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID) + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.MySQL8(), collations.CollationUtf8mb4ID, nil) require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) + err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0, nil) require.ErrorIs(t, err, tc.err) h1 := hasher1.Sum128() @@ -166,16 +166,16 @@ func TestHashCodesRandom128(t *testing.T) { for time.Now().Before(endTime) { tested++ v1, v2 := sqltypes.TestRandomValues() - cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation) + cmp, err := NullsafeCompare(v1, v2, collations.MySQL8(), collation, nil) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0) + err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0, nil) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0) + err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0, nil) require.NoError(t, err) if cmp == 0 { equal++ diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 64d0cf5c1c3..16897650362 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -228,6 +228,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column { Collation: typedCoercionCollation(typ.Type(), typ.Collation()), Original: original, Nullable: typ.nullable, + Values: typ.values, dynamicTypeOffset: -1, } } diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 326f1397369..04622e5a212 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -80,7 +80,7 @@ func (ta *TypeAggregator) Type() Type { if ta.invalid > 0 || ta.types.empty() { return Type{} } - return NewTypeEx(ta.types.result(), ta.collations.result().Collation, ta.types.nullable, ta.size, ta.scale) + return NewTypeEx(ta.types.result(), ta.collations.result().Collation, ta.types.nullable, ta.size, ta.scale, nil) } func (ta *TypeAggregator) Field(name string) *query.Field { diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index 590dc3b02c7..0b01a485dc3 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -17,6 +17,8 @@ limitations under the License. package evalengine import ( + "slices" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/mysql/decimal" @@ -32,6 +34,8 @@ type Arena struct { aFloat64 []evalFloat aDecimal []evalDecimal aBytes []evalBytes + aEnum []evalEnum + aSet []evalSet } func (a *Arena) reset() { @@ -40,6 +44,8 @@ func (a *Arena) reset() { a.aFloat64 = a.aFloat64[:0] a.aDecimal = a.aDecimal[:0] a.aBytes = a.aBytes[:0] + a.aEnum = a.aEnum[:0] + a.aSet = a.aSet[:0] } func (a *Arena) newEvalDecimalWithPrec(dec decimal.Decimal, prec int32) *evalDecimal { @@ -61,6 +67,32 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } +func (a *Arena) newEvalEnum(raw []byte, values []string) *evalEnum { + if cap(a.aEnum) > len(a.aEnum) { + a.aEnum = a.aEnum[:len(a.aEnum)+1] + } else { + a.aEnum = append(a.aEnum, evalEnum{}) + } + val := &a.aEnum[len(a.aInt64)-1] + s := string(raw) + val.string = s + val.value = slices.Index(values, s) + return val +} + +func (a *Arena) newEvalSet(raw []byte, values []string) *evalSet { + if cap(a.aSet) > len(a.aSet) { + a.aSet = a.aSet[:len(a.aSet)+1] + } else { + a.aSet = append(a.aSet, evalSet{}) + } + val := &a.aSet[len(a.aInt64)-1] + s := string(raw) + val.string = s + val.set = evalSetBits(values, s) + return val +} + func (a *Arena) newEvalBool(b bool) *evalInt64 { if b { return a.newEvalInt64(1) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index e7563e8f258..abe7bdc473f 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -159,12 +159,19 @@ func (cached *Column) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(80) } // field Original vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Original.(cachedObject); ok { size += cc.CachedSize(true) } + // field Values []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) + for _, elem := range cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } return size } func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { @@ -189,12 +196,14 @@ func (cached *CompiledExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(80) + size += int64(96) } // field code []vitess.io/vitess/go/vt/vtgate/evalengine.frame { size += hack.RuntimeAllocSize(int64(cap(cached.code)) * int64(8)) } + // field typed vitess.io/vitess/go/vt/vtgate/evalengine.ctype + size += cached.typed.CachedSize(false) // field ir vitess.io/vitess/go/vt/vtgate/evalengine.IR if cc, ok := cached.ir.(cachedObject); ok { size += cc.CachedSize(true) @@ -361,8 +370,10 @@ func (cached *OrderByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(80) } + // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type + size += cached.Type.CachedSize(false) // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) return size @@ -379,6 +390,23 @@ func (cached *TupleBindVariable) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Key))) return size } +func (cached *Type) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field values []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.values)) * int64(16)) + for _, elem := range cached.values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } + return size +} func (cached *UnaryExpr) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1899,6 +1927,23 @@ func (cached *builtinYearWeek) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *ctype) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field Values []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) + for _, elem := range cached.Values { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } + return size +} func (cached *evalBytes) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1925,6 +1970,18 @@ func (cached *evalDecimal) CachedSize(alloc bool) int64 { size += cached.dec.CachedSize(false) return size } +func (cached *evalEnum) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field string string + size += hack.RuntimeAllocSize(int64(len(cached.string))) + return size +} func (cached *evalFloat) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1945,6 +2002,18 @@ func (cached *evalInt64) CachedSize(alloc bool) int64 { } return size } +func (cached *evalSet) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field string string + size += hack.RuntimeAllocSize(int64(len(cached.string))) + return size +} func (cached *evalTemporal) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -1994,7 +2063,10 @@ func (cached *typedExpr) CachedSize(alloc bool) int64 { } // field types []vitess.io/vitess/go/vt/vtgate/evalengine.ctype { - size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(20)) + size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(48)) + for _, elem := range cached.types { + size += elem.CachedSize(false) + } } // field compiled *vitess.io/vitess/go/vt/vtgate/evalengine.CompiledExpr size += cached.compiled.CachedSize(true) diff --git a/go/vt/vtgate/evalengine/compare.go b/go/vt/vtgate/evalengine/compare.go index 102d6142321..836ca7c5043 100644 --- a/go/vt/vtgate/evalengine/compare.go +++ b/go/vt/vtgate/evalengine/compare.go @@ -18,6 +18,7 @@ package evalengine import ( "bytes" + "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/colldata" @@ -122,6 +123,42 @@ func compareDates(l, r *evalTemporal) int { return l.dt.Compare(r.dt) } +func compareEnums(l, r *evalEnum) int { + if l.value == -1 || r.value == -1 { + // If the values are equal normally the strings + // are equal too. In case we didn't find the proper + // value in the enum we return the string comparison. + // This is not always correct, but a best effort and still + // works for the cases where we only care about + // equality. + return strings.Compare(l.string, r.string) + } + if l.value == r.value { + return 0 + } + if l.value < r.value { + return -1 + } + return 1 +} + +func compareSets(l, r *evalSet) int { + if l.set == r.set { + if l.set == 0 && (len(l.string) != 0 || len(r.string) != 0) { + // In this case we didn't have the proper values passed + // in when creating the evalSet. We can't compare the set + // values then, but fall back to string comparison to at + // least compare something and to handle equality checks. + return strings.Compare(l.string, r.string) + } + return 0 + } + if l.set < r.set { + return -1 + } + return 1 +} + func compareDateAndString(l, r eval) int { if tt, ok := l.(*evalTemporal); ok { return tt.dt.Compare(r.(*evalBytes).toDateBestEffort()) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 21d13119804..344798f6abb 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -17,6 +17,8 @@ limitations under the License. package evalengine import ( + "slices" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" @@ -55,6 +57,7 @@ type ctype struct { Flag typeFlag Size, Scale int32 Col collations.TypedCollation + Values []string } type Type struct { @@ -63,14 +66,15 @@ type Type struct { nullable bool init bool size, scale int32 + values []string } func NewType(t sqltypes.Type, collation collations.ID) Type { // New types default to being nullable - return NewTypeEx(t, collation, true, 0, 0) + return NewTypeEx(t, collation, true, 0, 0, nil) } -func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32) Type { +func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32, values []string) Type { return Type{ typ: t, collation: collation, @@ -78,6 +82,7 @@ func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, sc init: true, size: size, scale: scale, + values: values, } } @@ -139,10 +144,32 @@ func (t *Type) Nullable() bool { return true // nullable by default for unknown types } +func (t *Type) Values() []string { + return t.values +} + func (t *Type) Valid() bool { return t.init } +func (t *Type) Equal(other *Type) bool { + return t.typ == other.typ && + t.collation == other.collation && + t.nullable == other.nullable && + t.size == other.size && + t.scale == other.scale && + slices.Equal(t.values, other.values) +} + +func (ct ctype) equal(other ctype) bool { + return ct.Type == other.Type && + ct.Flag == other.Flag && + ct.Size == other.Size && + ct.Scale == other.Scale && + ct.Col == other.Col && + slices.Equal(ct.Values, other.Values) +} + func (ct ctype) nullable() bool { return ct.Flag&flagNullable != 0 } diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index ab1371f1e11..ff8adb168ff 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -105,6 +105,18 @@ func push_d(env *ExpressionEnv, raw []byte) int { return 1 } +func push_enum(env *ExpressionEnv, raw []byte, values []string) int { + env.vm.stack[env.vm.sp] = env.vm.arena.newEvalEnum(raw, values) + env.vm.sp++ + return 1 +} + +func push_set(env *ExpressionEnv, raw []byte, values []string) int { + env.vm.stack[env.vm.sp] = env.vm.arena.newEvalSet(raw, values) + env.vm.sp++ + return 1 +} + func (asm *assembler) PushColumn_d(offset int) { asm.adjustStack(1) @@ -117,6 +129,30 @@ func (asm *assembler) PushColumn_d(offset int) { }, "PUSH DECIMAL(:%d)", offset) } +func (asm *assembler) PushColumn_enum(offset int, values []string) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_enum(env, col.Raw(), values) + }, "PUSH ENUM(:%d)", offset) +} + +func (asm *assembler) PushColumn_set(offset int, values []string) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_set(env, col.Raw(), values) + }, "PUSH SET(:%d)", offset) +} + func (asm *assembler) PushBVar_d(key string) { asm.adjustStack(1) diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 36ce482d967..eeefa351894 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -212,7 +212,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, sqlmode SQLMode) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, values []string, sqlmode SQLMode) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -232,7 +232,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) return newEvalFloat(fval), nil default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -259,7 +259,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -279,7 +279,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I i, err := fastparse.ParseInt64(v.RawStr(), 10) return newEvalInt64(i), err default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -298,7 +298,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I u, err := fastparse.ParseUint64(v.RawStr(), 10) return newEvalUint64(u), err default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -311,13 +311,13 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case v.IsText() || v.IsBinary(): return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil case sqltypes.IsText(typ): - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } return evalToVarchar(e, collation, true) default: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -327,7 +327,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case typ == sqltypes.TypeJSON: return json.NewFromSQL(v) case typ == sqltypes.Date: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return d, nil case typ == sqltypes.Datetime || typ == sqltypes.Timestamp: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -349,7 +349,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return dt, nil case typ == sqltypes.Time: - e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation), values) if err != nil { return nil, err } @@ -359,11 +359,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, nil } return t, nil + case typ == sqltypes.Enum: + return newEvalEnum(v.Raw(), values), nil + case typ == sqltypes.Set: + return newEvalSet(v.Raw(), values), nil } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) } -func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eval, error) { +func valueToEval(value sqltypes.Value, collation collations.TypedCollation, values []string) (eval, error) { wrap := func(err error) error { if err == nil { return nil @@ -384,6 +388,10 @@ func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eva case tt == sqltypes.Decimal: dec, err := decimal.NewFromMySQL(value.Raw()) return newEvalDecimal(dec, 0, 0), wrap(err) + case tt == sqltypes.Enum: + return newEvalEnum(value.Raw(), values), nil + case tt == sqltypes.Set: + return newEvalSet(value.Raw(), values), nil case sqltypes.IsText(tt): if tt == sqltypes.HexNum { raw, err := parseHexNumber(value.Raw()) diff --git a/go/vt/vtgate/evalengine/eval_enum.go b/go/vt/vtgate/evalengine/eval_enum.go new file mode 100644 index 00000000000..b89abd15429 --- /dev/null +++ b/go/vt/vtgate/evalengine/eval_enum.go @@ -0,0 +1,37 @@ +package evalengine + +import ( + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/sqltypes" +) + +type evalEnum struct { + value int + string string +} + +func newEvalEnum(val []byte, values []string) *evalEnum { + s := string(val) + return &evalEnum{ + value: valueIdx(values, s), + string: s, + } +} + +func (e *evalEnum) ToRawBytes() []byte { + return hack.StringBytes(e.string) +} + +func (e *evalEnum) SQLType() sqltypes.Type { + return sqltypes.Enum +} + +func valueIdx(values []string, value string) int { + for i, v := range values { + v, _ = sqltypes.DecodeStringSQL(v) + if v == value { + return i + } + } + return -1 +} diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index fb34caab85d..64f5477a3fc 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -149,6 +149,10 @@ func evalToNumeric(e eval, preciseDatetime bool) evalNumeric { return newEvalDecimalWithPrec(e.toDecimal(), int32(e.prec)) } return &evalFloat{f: e.toFloat()} + case *evalEnum: + return &evalFloat{f: float64(e.value)} + case *evalSet: + return &evalFloat{f: float64(e.set)} default: panic("unsupported") } @@ -205,6 +209,10 @@ func evalToFloat(e eval) (*evalFloat, bool) { } case *evalTemporal: return &evalFloat{f: e.toFloat()}, true + case *evalEnum: + return &evalFloat{f: float64(e.value)}, e.value != -1 + case *evalSet: + return &evalFloat{f: float64(e.set)}, true default: panic(fmt.Sprintf("unsupported type %T", e)) } @@ -269,6 +277,10 @@ func evalToDecimal(e eval, m, d int32) *evalDecimal { } case *evalTemporal: return newEvalDecimal(e.toDecimal(), m, d) + case *evalEnum: + return newEvalDecimal(decimal.NewFromInt(int64(e.value)), m, d) + case *evalSet: + return newEvalDecimal(decimal.NewFromUint(e.set), m, d) default: panic("unsupported") } @@ -332,6 +344,10 @@ func evalToInt64(e eval) *evalInt64 { } case *evalTemporal: return newEvalInt64(e.toInt64()) + case *evalEnum: + return newEvalInt64(int64(e.value)) + case *evalSet: + return newEvalInt64(int64(e.set)) default: panic(fmt.Sprintf("unsupported type: %T", e)) } diff --git a/go/vt/vtgate/evalengine/eval_set.go b/go/vt/vtgate/evalengine/eval_set.go new file mode 100644 index 00000000000..47fe29607df --- /dev/null +++ b/go/vt/vtgate/evalengine/eval_set.go @@ -0,0 +1,49 @@ +package evalengine + +import ( + "strings" + + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/sqltypes" +) + +type evalSet struct { + set uint64 + string string +} + +func newEvalSet(val []byte, values []string) *evalSet { + value := string(val) + + return &evalSet{ + set: evalSetBits(values, value), + string: value, + } +} + +func (e *evalSet) ToRawBytes() []byte { + return hack.StringBytes(e.string) +} + +func (e *evalSet) SQLType() sqltypes.Type { + return sqltypes.Set +} + +func evalSetBits(values []string, value string) uint64 { + if len(values) > 64 { + // This never would happen as MySQL limits SET + // to 64 elements. Safeguard here just in case though. + panic("too many values for set") + } + + set := uint64(0) + for _, val := range strings.Split(value, ",") { + idx := valueIdx(values, val) + if idx == -1 { + continue + } + set |= 1 << idx + } + + return set +} diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index b21ded90189..0fffe3140a2 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { tuple := make([]eval, 0, len(bvar.Values)) for _, value := range bvar.Values { - e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation))) + e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation)), nil) if err != nil { return nil, err } @@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { if bv.typed() { typ = bv.Type } - return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation))) + return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil) } } diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index 8663370f819..cbdb1775f88 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -34,6 +34,7 @@ type ( Collation collations.TypedCollation Original sqlparser.Expr Nullable bool + Values []string // For ENUM and SET types // dynamicTypeOffset is set when the type of this column cannot be calculated // at translation time. Since expressions with dynamic types cannot be compiled ahead of time, @@ -54,7 +55,7 @@ func (c *Column) IsExpr() {} // eval implements the expression interface func (c *Column) eval(env *ExpressionEnv) (eval, error) { - return valueToEval(env.Row[c.Offset], c.Collation) + return valueToEval(env.Row[c.Offset], c.Collation, c.Values) } func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { @@ -63,7 +64,7 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { if c.Nullable { nullable = flagNullable } - return ctype{Type: c.Type, Size: c.Size, Scale: c.Scale, Flag: nullable, Col: c.Collation}, nil + return ctype{Type: c.Type, Size: c.Size, Scale: c.Scale, Flag: nullable, Col: c.Collation, Values: c.Values}, nil } if c.Offset < len(env.Fields) { field := env.Fields[c.Offset] @@ -83,7 +84,7 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { } if c.Offset < len(env.Row) { value := env.Row[c.Offset] - return ctype{Type: value.Type(), Flag: 0, Col: c.Collation}, nil + return ctype{Type: value.Type(), Flag: 0, Col: c.Collation, Values: c.Values}, nil } return ctype{}, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "no column at offset %d", c.Offset) } @@ -99,6 +100,7 @@ func (column *Column) compile(c *compiler) (ctype, error) { } typ.Size = column.Size typ.Scale = column.Scale + typ.Values = column.Values } else if c.dynamicTypes != nil { typ = c.dynamicTypes[column.dynamicTypeOffset] } else { @@ -121,6 +123,10 @@ func (column *Column) compile(c *compiler) (ctype, error) { typ.Type = sqltypes.Float64 case sqltypes.IsDecimal(tt): c.asm.PushColumn_d(column.Offset) + case tt == sqltypes.Enum: + c.asm.PushColumn_enum(column.Offset, column.Values) + case tt == sqltypes.Set: + c.asm.PushColumn_set(column.Offset, column.Values) case sqltypes.IsText(tt): if tt == sqltypes.HexNum { c.asm.PushColumn_hexnum(column.Offset) diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index ca4cdd75f74..f3bd44588ee 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -114,7 +114,7 @@ func (compareNullSafeEQ) compare(collationEnv *collations.Environment, left, rig } func typeIsTextual(tt sqltypes.Type) bool { - return sqltypes.IsTextOrBinary(tt) || tt == sqltypes.Time + return sqltypes.IsTextOrBinary(tt) || tt == sqltypes.Time || tt == sqltypes.Enum || tt == sqltypes.Set } func compareAsStrings(l, r sqltypes.Type) bool { @@ -143,6 +143,14 @@ func compareAsDates(l, r sqltypes.Type) bool { return sqltypes.IsDateOrTime(l) && sqltypes.IsDateOrTime(r) } +func compareAsEnums(l, r sqltypes.Type) bool { + return sqltypes.IsEnum(l) && sqltypes.IsEnum(r) +} + +func compareAsSets(l, r sqltypes.Type) bool { + return sqltypes.IsSet(l) && sqltypes.IsSet(r) +} + func compareAsDateAndString(l, r sqltypes.Type) bool { return (sqltypes.IsDate(l) && typeIsTextual(r)) || (typeIsTextual(l) && sqltypes.IsDate(r)) } @@ -223,6 +231,10 @@ func evalCompare(left, right eval, collationEnv *collations.Environment) (comp i switch { case compareAsDates(lt, rt): return compareDates(left.(*evalTemporal), right.(*evalTemporal)), nil + case compareAsEnums(lt, rt): + return compareEnums(left.(*evalEnum), right.(*evalEnum)), nil + case compareAsSets(lt, rt): + return compareSets(left.(*evalSet), right.(*evalSet)), nil case compareAsStrings(lt, rt): return compareStrings(left, right, collationEnv) case compareAsSameNumericType(lt, rt) || compareAsDecimal(lt, rt): diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index 6e09b03cffb..38a65f9b4e0 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -104,7 +104,7 @@ func (env *ExpressionEnv) TypeOf(expr Expr) (Type, error) { if err != nil { return Type{}, err } - return NewTypeEx(ty.Type, ty.Col.Collation, ty.Flag&flagNullable != 0, ty.Size, ty.Scale), nil + return NewTypeEx(ty.Type, ty.Col.Collation, ty.Flag&flagNullable != 0, ty.Size, ty.Scale, ty.Values), nil } func (env *ExpressionEnv) SetTime(now time.Time) { diff --git a/go/vt/vtgate/evalengine/expr_tuple_bvar.go b/go/vt/vtgate/evalengine/expr_tuple_bvar.go index 3b2553f25ba..14cfbd95a8b 100644 --- a/go/vt/vtgate/evalengine/expr_tuple_bvar.go +++ b/go/vt/vtgate/evalengine/expr_tuple_bvar.go @@ -71,7 +71,7 @@ func (bv *TupleBindVariable) eval(env *ExpressionEnv) (eval, error) { return } found = true - e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), bv.Collation))) + e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), bv.Collation)), nil) if err != nil { evalErr = err return diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index d1c32b113c2..99ffd956513 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -686,7 +686,9 @@ func (u *UntypedExpr) loadTypedExpression(env *ExpressionEnv) (*typedExpr, error defer u.mu.Unlock() for _, typed := range u.typed { - if slices.Equal(typed.types, dynamicTypes) { + if slices.EqualFunc(typed.types, dynamicTypes, func(a, b ctype) bool { + return a.equal(b) + }) { return typed, nil } } diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index 2a9d6c9f93e..37286af7ddc 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -41,11 +41,11 @@ import ( // externally communicates with the `WEIGHT_STRING` function, so that we // can also use this to order / sort other types like Float and Decimal // as well. -func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values []string, sqlmode SQLMode) ([]byte, bool, error) { // We optimize here for the case where we already have the desired type. // Otherwise, we fall back to the general evalengine conversion logic. if v.Type() != coerceTo { - return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, values, sqlmode) } switch { @@ -116,13 +116,17 @@ func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col coll return dst, false, err } return j.WeightString(dst), false, nil + case coerceTo == sqltypes.Enum: + return evalWeightString(dst, newEvalEnum(v.Raw(), values), length, precision) + case coerceTo == sqltypes.Set: + return evalWeightString(dst, newEvalSet(v.Raw(), values), length, precision) default: - return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, values, sqlmode) } } -func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { - e, err := valueToEvalCast(v, coerceTo, col, sqlmode) +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values []string, sqlmode SQLMode) ([]byte, bool, error) { + e, err := valueToEvalCast(v, coerceTo, col, values, sqlmode) if err != nil { return dst, false, err } @@ -174,6 +178,14 @@ func evalWeightString(dst []byte, e eval, length, precision int) ([]byte, bool, return e.dt.WeightString(dst), true, nil case *evalJSON: return e.WeightString(dst), false, nil + case *evalEnum: + raw := uint64(e.value) + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil + case *evalSet: + raw := e.set + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil } return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", e.SQLType()) @@ -192,7 +204,7 @@ func TinyWeighter(f *querypb.Field, collation collations.ID) func(v *sqltypes.Va case sqltypes.IsNull(f.Type): return nil - case sqltypes.IsSigned(f.Type): + case sqltypes.IsSigned(f.Type), f.Type == sqltypes.Enum, f.Type == sqltypes.Set: return func(v *sqltypes.Value) { i, err := v.ToInt64() if err != nil { @@ -301,7 +313,6 @@ func TinyWeighter(f *querypb.Field, collation collations.ID) func(v *sqltypes.Va copy(w32[:4], j.WeightString(nil)) v.SetTinyWeight(binary.BigEndian.Uint32(w32[:4])) } - default: return nil } diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 9a34e6e9e81..b059142163a 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -32,11 +32,12 @@ func TestTinyWeightStrings(t *testing.T) { const Length = 10000 var cases = []struct { - typ sqltypes.Type - gen func() sqltypes.Value - col collations.ID - len int - prec int + typ sqltypes.Type + gen func() sqltypes.Value + col collations.ID + len int + prec int + values []string }{ {typ: sqltypes.Int32, gen: sqltypes.RandomGenerators[sqltypes.Int32], col: collations.CollationBinaryID}, {typ: sqltypes.Int64, gen: sqltypes.RandomGenerators[sqltypes.Int64], col: collations.CollationBinaryID}, @@ -47,6 +48,8 @@ func TestTinyWeightStrings(t *testing.T) { {typ: sqltypes.VarBinary, gen: sqltypes.RandomGenerators[sqltypes.VarBinary], col: collations.CollationBinaryID}, {typ: sqltypes.Decimal, gen: sqltypes.RandomGenerators[sqltypes.Decimal], col: collations.CollationBinaryID, len: 20, prec: 10}, {typ: sqltypes.TypeJSON, gen: sqltypes.RandomGenerators[sqltypes.TypeJSON], col: collations.CollationBinaryID}, + {typ: sqltypes.Enum, gen: sqltypes.RandomGenerators[sqltypes.Enum], col: collations.CollationBinaryID, values: []string{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {typ: sqltypes.Set, gen: sqltypes.RandomGenerators[sqltypes.Set], col: collations.CollationBinaryID, values: []string{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { @@ -77,7 +80,7 @@ func TestTinyWeightStrings(t *testing.T) { return cmp } - cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col) + cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col, tc.values) require.NoError(t, err) fullComparisons++ @@ -88,7 +91,7 @@ func TestTinyWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col) + cmp, err := NullsafeCompare(a, b, collations.MySQL8(), tc.col, tc.values) require.NoError(t, err) if cmp > 0 { @@ -110,12 +113,13 @@ func TestWeightStrings(t *testing.T) { } var cases = []struct { - name string - gen func() sqltypes.Value - types []sqltypes.Type - col collations.ID - len int - prec int + name string + gen func() sqltypes.Value + types []sqltypes.Type + col collations.ID + len int + prec int + values []string }{ {name: "int64", gen: sqltypes.RandomGenerators[sqltypes.Int64], types: []sqltypes.Type{sqltypes.Int64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "uint64", gen: sqltypes.RandomGenerators[sqltypes.Uint64], types: []sqltypes.Type{sqltypes.Uint64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, @@ -128,6 +132,8 @@ func TestWeightStrings(t *testing.T) { {name: "datetime", gen: sqltypes.RandomGenerators[sqltypes.Datetime], types: []sqltypes.Type{sqltypes.Datetime, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "timestamp", gen: sqltypes.RandomGenerators[sqltypes.Timestamp], types: []sqltypes.Type{sqltypes.Timestamp, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "time", gen: sqltypes.RandomGenerators[sqltypes.Time], types: []sqltypes.Type{sqltypes.Time, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "enum", gen: sqltypes.RandomGenerators[sqltypes.Enum], types: []sqltypes.Type{sqltypes.Enum, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: []string{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {name: "set", gen: sqltypes.RandomGenerators[sqltypes.Set], types: []sqltypes.Type{sqltypes.Set, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: []string{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { @@ -136,7 +142,7 @@ func TestWeightStrings(t *testing.T) { items := make([]item, 0, Length) for i := 0; i < Length; i++ { v := tc.gen() - w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, 0) + w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, tc.values, 0) require.NoError(t, err) items = append(items, item{value: v, weight: string(w)}) @@ -156,9 +162,9 @@ func TestWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - v1, err := valueToEvalCast(a.value, typ, tc.col, 0) + v1, err := valueToEvalCast(a.value, typ, tc.col, tc.values, 0) require.NoError(t, err) - v2, err := valueToEvalCast(b.value, typ, tc.col, 0) + v2, err := valueToEvalCast(b.value, typ, tc.col, tc.values, 0) require.NoError(t, err) cmp, err := evalCompareNullSafe(v1, v2, collations.MySQL8()) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 572afa42f72..2a7f37a258f 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -918,6 +918,7 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin) Collation: comparisonType.Collation(), ComparisonType: comparisonType.Type(), CollationEnv: ctx.VSchema.Environment().CollationEnv(), + Values: comparisonType.Values(), }, }, nil } diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 6c6e495b33d..6c89b2bb999 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -667,7 +667,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { ws, isWS := e.(*sqlparser.WeightStringFuncExpr) if isWS { wt, _ := st.TypeForExpr(ws.Expr) - return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0), true + return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true } return evalengine.Type{}, false diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index f32adc0f772..d231f358a37 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -412,7 +412,7 @@ func (lu *clCommon) Delete(ctx context.Context, vcursor VCursor, rowsColValues [ func (lu *clCommon) Update(ctx context.Context, vcursor VCursor, oldValues []sqltypes.Value, ksid []byte, newValues []sqltypes.Value) error { equal := true for i := range oldValues { - result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], vcursor.Environment().CollationEnv(), vcursor.ConnCollation()) + result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], vcursor.Environment().CollationEnv(), vcursor.ConnCollation(), nil) // errors from NullsafeCompare can be ignored. if they are real problems, we'll see them in the Create/Update if err != nil || result != 0 { equal = false diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 8dc889fc848..6506cdee09c 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -233,7 +233,7 @@ func (col *Column) ToEvalengineType(collationEnv *collations.Environment) evalen } else { collation = collations.CollationForType(col.Type, collationEnv.DefaultConnectionCharset()) } - return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale) + return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale, col.Values) } // KeyspaceSchema contains the schema(table) for a keyspace. diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go index 1b64662e551..20e795d2804 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_differ.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_differ.go @@ -701,7 +701,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com if collationID == collations.Unknown { collationID = collations.CollationBinaryID } - c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.wd.collationEnv, collationID) + c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.wd.collationEnv, collationID, nil) if err != nil { return 0, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go index 424daad4871..d4b733b4c0b 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go +++ b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go @@ -303,7 +303,7 @@ func (tp *TablePlan) isOutsidePKRange(bindvars map[string]*querypb.BindVariable, rowVal, _ := sqltypes.BindVariableToValue(bindvar) // TODO(king-11) make collation aware - result, err := evalengine.NullsafeCompare(rowVal, tp.Lastpk.Rows[0][0], tp.CollationEnv, collations.Unknown) + result, err := evalengine.NullsafeCompare(rowVal, tp.Lastpk.Rows[0][0], tp.CollationEnv, collations.Unknown, nil) // If rowVal is > last pk, transaction will be a noop, so don't apply this statement if err == nil && result > 0 { tp.Stats.NoopQueryCount.Add(stmtType, 1) diff --git a/go/vt/vttablet/tabletserver/schema/load_table.go b/go/vt/vttablet/tabletserver/schema/load_table.go index e4e464f3fce..6022f8724eb 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table.go +++ b/go/vt/vttablet/tabletserver/schema/load_table.go @@ -215,7 +215,7 @@ func getSpecifiedMessageFields(tableFields []*querypb.Field, specifiedCols []str fields := make([]*querypb.Field, 0, len(specifiedCols)) for _, col := range specifiedCols { for _, field := range tableFields { - if res, _ := evalengine.NullsafeCompare(sqltypes.NewVarChar(field.Name), sqltypes.NewVarChar(strings.TrimSpace(col)), collationEnv, collationEnv.DefaultConnectionCharset()); res == 0 { + if res, _ := evalengine.NullsafeCompare(sqltypes.NewVarChar(field.Name), sqltypes.NewVarChar(strings.TrimSpace(col)), collationEnv, collationEnv.DefaultConnectionCharset(), nil); res == 0 { fields = append(fields, field) break } diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index c3e1975c0a1..ad2f218f8d1 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -172,7 +172,7 @@ func compare(comparison Opcode, columnValue, filterValue sqltypes.Value, collati } // at this point neither values can be null // NullsafeCompare returns 0 if values match, -1 if columnValue < filterValue, 1 if columnValue > filterValue - result, err := evalengine.NullsafeCompare(columnValue, filterValue, collationEnv, charset) + result, err := evalengine.NullsafeCompare(columnValue, filterValue, collationEnv, charset, nil) if err != nil { return false, err } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 2196152b122..a698bad290a 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -118,6 +118,7 @@ type vdiff struct { type compareColInfo struct { colIndex int // index of the column in the filter's select collation collations.ID // is the collation of the column, if any + values []string // is the list of enum or set values for the column, if any isPK bool // is this column part of the primary key } @@ -492,7 +493,7 @@ func (df *vdiff) buildVDiffPlan(filter *binlogdatapb.Filter, schm *tabletmanager // findPKs identifies PKs, determines any collations to be used for // them, and removes them from the columns used for data comparison. func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, targetSelect *sqlparser.Select, td *tableDiffer) (sqlparser.OrderBy, error) { - columnCollations, err := getColumnCollations(env, table) + columnCollations, columnValues, err := getColumnCollations(env, table) if err != nil { return nil, err } @@ -513,6 +514,7 @@ func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, if strings.EqualFold(pk, colname) { td.compareCols[i].isPK = true td.compareCols[i].collation = columnCollations[strings.ToLower(colname)] + td.compareCols[i].values = columnValues[strings.ToLower(colname)] td.comparePKs = append(td.comparePKs, td.compareCols[i]) td.selectPks = append(td.selectPks, i) // We'll be comparing pks separately. So, remove them from compareCols. @@ -536,19 +538,19 @@ func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, // getColumnCollations determines the proper collation to use for each // column in the table definition leveraging MySQL's collation inheritance // rules. -func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, error) { +func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, map[string][]string, error) { createstmt, err := venv.Parser().Parse(table.Schema) if err != nil { - return nil, err + return nil, nil, err } createtable, ok := createstmt.(*sqlparser.CreateTable) if !ok { - return nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) + return nil, nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) } env := schemadiff.NewEnv(venv, venv.CollationEnv().DefaultConnectionCharset()) tableschema, err := schemadiff.NewCreateTableEntity(env, createtable) if err != nil { - return nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) + return nil, nil, vterrors.Wrapf(err, "invalid table schema %s for table %s", table.Schema, table.Name) } tableCharset := tableschema.GetCharset() tableCollation := tableschema.GetCollation() @@ -579,6 +581,7 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab } columnCollations := make(map[string]collations.ID) + columnValues := make(map[string][]string) for _, column := range tableschema.TableSpec.Columns { // If it's not a character based type then no collation is used. if !sqltypes.IsQuoted(column.Type.SQLType()) { @@ -586,8 +589,9 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab continue } columnCollations[column.Name.Lowered()] = getColumnCollation(column) + columnValues[column.Name.Lowered()] = column.Type.EnumValues } - return columnCollations, nil + return columnCollations, columnValues, nil } // If SourceTimeZone is defined in the BinlogSource, the VReplication workflow would have converted the datetime @@ -1318,7 +1322,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com if col.collation == collations.Unknown { collationID = collations.CollationBinaryID } - c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.collationEnv, collationID) + c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.collationEnv, collationID, col.values) if err != nil { return 0, err } diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 1b0071ebed7..5f98e11ed72 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -18,7 +18,6 @@ package wrangler import ( "context" - "reflect" "strings" "testing" "time" @@ -94,12 +93,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -113,12 +112,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -132,12 +131,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -151,12 +150,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c2, c1 from t1 order by c1 asc", targetExpression: "select c2, c1 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{1, collations.Unknown, nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -170,12 +169,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c0 as c1, c2 from t2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -190,12 +189,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "nonpktext", sourceExpression: "select c1, textcol from nonpktext order by c1 asc", targetExpression: "select c1, textcol from nonpktext order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -210,12 +209,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "nonpktext", sourceExpression: "select textcol, c1 from nonpktext order by c1 asc", targetExpression: "select textcol, c1 from nonpktext order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{1, collations.Unknown, nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -230,12 +229,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select textcol, c2 from pktext order by textcol asc", targetExpression: "select textcol, c2 from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -250,12 +249,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select c2, textcol from pktext order by textcol asc", targetExpression: "select c2, textcol from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collationEnv.DefaultConnectionCharset(), true}}, - comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collationEnv.DefaultConnectionCharset(), nil, true}}, + comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -270,12 +269,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "pktext", sourceExpression: "select c2, a + b as textcol from pktext order by textcol asc", targetExpression: "select c2, textcol from pktext order by textcol asc", - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collationEnv.DefaultConnectionCharset(), true}}, - comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collationEnv.DefaultConnectionCharset(), nil, true}}, + comparePKs: []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, true}}, pkCols: []int{1}, selectPks: []int{1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), false}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -288,12 +287,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "multipk", sourceExpression: "select c1, c2 from multipk order by c1 asc, c2 asc", targetExpression: "select c1, c2 from multipk order by c1 asc, c2 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, pkCols: []int{0, 1}, selectPks: []int{0, 1}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -308,12 +307,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -329,12 +328,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -350,12 +349,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -371,12 +370,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 and c1 = 1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -392,12 +391,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 where c2 = 2 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -412,12 +411,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "t1", sourceExpression: "select c1, c2 from t1 group by c1 order by c1 asc", targetExpression: "select c1, c2 from t1 order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -432,8 +431,8 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "aggr", sourceExpression: "select c1, c2, count(*) as c3, sum(c4) as c4 from t1 group by c1 order by c1 asc", targetExpression: "select c1, c2, c3, c4 from aggr order by c1 asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ @@ -442,10 +441,10 @@ func TestVDiffPlanSuccess(t *testing.T) { engine.NewAggregateParam(opcode.AggregateSum, 3, "", collationEnv), }, GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1, CollationEnv: collations.MySQL8()}}, - Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), CollationEnv: collationEnv, }, - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -459,12 +458,12 @@ func TestVDiffPlanSuccess(t *testing.T) { targetTable: "datze", sourceExpression: "select id, dt from datze order by id asc", targetExpression: "select id, convert_tz(dt, 'UTC', 'US/Pacific') as dt from datze order by id asc", - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, - sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), - targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, true}}, collationEnv), + sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), + targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), collationEnv: collationEnv, parser: parser, }, @@ -1078,13 +1077,13 @@ func TestVDiffFindPKs(t *testing.T) { }, }, tdIn: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, false}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, false}}, comparePKs: []compareColInfo{}, pkCols: []int{}, }, tdOut: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}}, pkCols: []int{0}, selectPks: []int{0}, }, @@ -1106,13 +1105,13 @@ func TestVDiffFindPKs(t *testing.T) { }, }, tdIn: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, false}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, false}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, false}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, false}}, comparePKs: []compareColInfo{}, pkCols: []int{}, }, tdOut: &tableDiffer{ - compareCols: []compareColInfo{{0, collations.Unknown, true}, {1, collations.Unknown, false}, {2, collations.Unknown, false}, {3, collations.Unknown, true}}, - comparePKs: []compareColInfo{{0, collations.Unknown, true}, {3, collations.Unknown, true}}, + compareCols: []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, false}, {2, collations.Unknown, nil, false}, {3, collations.Unknown, nil, true}}, + comparePKs: []compareColInfo{{0, collations.Unknown, nil, true}, {3, collations.Unknown, nil, true}}, pkCols: []int{0, 3}, selectPks: []int{0, 3}, }, @@ -1184,10 +1183,11 @@ func TestVDiffPlanInclude(t *testing.T) { func TestGetColumnCollations(t *testing.T) { collationEnv := collations.MySQL8() tests := []struct { - name string - table *tabletmanagerdatapb.TableDefinition - want map[string]collations.ID - wantErr bool + name string + table *tabletmanagerdatapb.TableDefinition + wantCols map[string]collations.ID + wantValues map[string][]string + wantErr bool }{ { name: "invalid schema", @@ -1201,94 +1201,152 @@ func TestGetColumnCollations(t *testing.T) { table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 int, name varchar(10), primary key(c1))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string][]string{ + "name": nil, + }, }, { name: "char pk with global default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultConnectionCharset(), "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, }, { name: "compound char int pk with global default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 int, name varchar(10), primary key(c1, name))", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, + wantValues: map[string][]string{ + "name": nil, + }, }, { name: "char pk with table default charset", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1)) default character set ucs2", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultCollationForCharset("ucs2"), "name": collationEnv.DefaultCollationForCharset("ucs2"), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, }, { name: "char pk with table default collation", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10), name varchar(10), primary key(c1)) charset=utf32 collate=utf32_icelandic_ci", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("utf32_icelandic_ci"), "name": collationEnv.LookupByName("utf32_icelandic_ci"), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, }, { name: "char pk with column charset override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) charset sjis, name varchar(10), primary key(c1)) character set=utf8", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.DefaultCollationForCharset("sjis"), "name": collationEnv.DefaultCollationForCharset("utf8mb3"), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, }, { name: "char pk with column collation override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) collate hebrew_bin, name varchar(10), primary key(c1)) charset=hebrew", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("hebrew_bin"), "name": collationEnv.DefaultCollationForCharset("hebrew"), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, }, { name: "compound char int pk with column collation override", table: &tabletmanagerdatapb.TableDefinition{ Schema: "create table t1 (c1 varchar(10) collate utf16_turkish_ci, c2 int, name varchar(10), primary key(c1, c2)) charset=utf16 collate=utf16_icelandic_ci", }, - want: map[string]collations.ID{ + wantCols: map[string]collations.ID{ "c1": collationEnv.LookupByName("utf16_turkish_ci"), "c2": collations.Unknown, "name": collationEnv.LookupByName("utf16_icelandic_ci"), }, + wantValues: map[string][]string{ + "c1": nil, + "name": nil, + }, + }, + { + name: "col with enum values", + table: &tabletmanagerdatapb.TableDefinition{ + Schema: "create table t1 (c1 varchar(10), size enum('small', 'medium', 'large'), primary key(c1))", + }, + wantCols: map[string]collations.ID{ + "c1": collationEnv.DefaultConnectionCharset(), + "size": collationEnv.DefaultConnectionCharset(), + }, + wantValues: map[string][]string{ + "c1": nil, + "size": {"'small'", "'medium'", "'large'"}, + }, + }, + { + name: "col with set values", + table: &tabletmanagerdatapb.TableDefinition{ + Schema: "create table t1 (c1 varchar(10), size set('small', 'medium', 'large'), primary key(c1))", + }, + wantCols: map[string]collations.ID{ + "c1": collationEnv.DefaultConnectionCharset(), + "size": collationEnv.DefaultConnectionCharset(), + }, + wantValues: map[string][]string{ + "c1": nil, + "size": {"'small'", "'medium'", "'large'"}, + }, }, } env := vtenv.NewTestEnv() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getColumnCollations(env, tt.table) - if (err != nil) != tt.wantErr { - t.Errorf("getColumnCollations() error = %v, wantErr = %t", err, tt.wantErr) + gotCols, gotValues, err := getColumnCollations(env, tt.table) + if tt.wantErr { + require.Error(t, err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getColumnCollations() = %+v, want %+v", got, tt.want) - } + require.NoError(t, err) + require.Equal(t, tt.wantCols, gotCols) + require.Equal(t, tt.wantValues, gotValues) }) } } From 595ec8fd9cb708a239ca72038f7185d0b3d9473f Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 23 Apr 2024 16:40:24 +0200 Subject: [PATCH 2/3] Introduce type and use pointer to reduce size Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/engine/aggregations.go | 2 +- go/vt/vtgate/engine/cached_size.go | 25 ++++++----- go/vt/vtgate/engine/hash_join.go | 6 +-- go/vt/vtgate/evalengine/api_aggregation.go | 4 +- go/vt/vtgate/evalengine/api_compare.go | 4 +- go/vt/vtgate/evalengine/api_compare_test.go | 16 +++---- go/vt/vtgate/evalengine/api_hash.go | 6 +-- go/vt/vtgate/evalengine/arena.go | 8 ++-- go/vt/vtgate/evalengine/cached_size.go | 39 ++++++++-------- go/vt/vtgate/evalengine/compiler.go | 32 ++++++++----- go/vt/vtgate/evalengine/compiler_asm_push.go | 8 ++-- go/vt/vtgate/evalengine/eval.go | 4 +- go/vt/vtgate/evalengine/eval_enum.go | 9 ++-- go/vt/vtgate/evalengine/eval_set.go | 6 +-- go/vt/vtgate/evalengine/expr_column.go | 2 +- go/vt/vtgate/evalengine/weights.go | 4 +- go/vt/vtgate/evalengine/weights_test.go | 12 ++--- go/vt/vtgate/vindexes/vschema.go | 3 +- go/vt/wrangler/vdiff.go | 18 +++++--- go/vt/wrangler/vdiff_test.go | 47 +++++--------------- 20 files changed, 126 insertions(+), 129 deletions(-) diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index b033e9fbb0e..4673a2717e5 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -107,7 +107,7 @@ type aggregatorDistinct struct { last sqltypes.Value coll collations.ID collationEnv *collations.Environment - values []string + values *evalengine.EnumSetValues } func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 410f024149c..22b3a38a990 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -35,7 +35,7 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(128) + size += int64(112) } // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type size += cached.Type.CachedSize(false) @@ -71,7 +71,7 @@ func (cached *CheckCol) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(48) } // field WsCol *int size += hack.RuntimeAllocSize(int64(8)) @@ -239,7 +239,7 @@ func (cached *Distinct) CachedSize(alloc bool) int64 { } // field CheckCols []vitess.io/vitess/go/vt/vtgate/engine.CheckCol { - size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(64)) + size += hack.RuntimeAllocSize(int64(cap(cached.CheckCols)) * int64(48)) for _, elem := range cached.CheckCols { size += elem.CachedSize(false) } @@ -386,7 +386,7 @@ func (cached *GroupByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(96) + size += int64(80) } // field Expr vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Expr.(cachedObject); ok { @@ -404,7 +404,7 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(144) + size += int64(128) } // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Left.(cachedObject); ok { @@ -424,10 +424,11 @@ func (cached *HashJoin) CachedSize(alloc bool) int64 { } // field CollationEnv *vitess.io/vitess/go/mysql/collations.Environment size += cached.CollationEnv.CachedSize(true) - // field Values []string - { - size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) - for _, elem := range cached.Values { + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { size += hack.RuntimeAllocSize(int64(len(elem))) } } @@ -670,7 +671,7 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -700,7 +701,7 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } @@ -910,7 +911,7 @@ func (cached *Route) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.FieldQuery))) // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(72)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(56)) for _, elem := range cached.OrderBy { size += elem.CachedSize(false) } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index 89dbf1190ae..6ac34e1ab79 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -69,7 +69,7 @@ type ( CollationEnv *collations.Environment // Values for enum and set types - Values []string + Values *evalengine.EnumSetValues } hashJoinProbeTable struct { @@ -81,7 +81,7 @@ type ( cols []int hasher vthash.Hasher sqlmode evalengine.SQLMode - values []string + values *evalengine.EnumSetValues } probeTableEntry struct { @@ -264,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription { } } -func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values []string) *hashJoinProbeTable { +func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values *evalengine.EnumSetValues) *hashJoinProbeTable { return &hashJoinProbeTable{ innerMap: map[vthash.Hash]*probeTableEntry{}, coll: coll, diff --git a/go/vt/vtgate/evalengine/api_aggregation.go b/go/vt/vtgate/evalengine/api_aggregation.go index 8584261b654..78ab8335d6d 100644 --- a/go/vt/vtgate/evalengine/api_aggregation.go +++ b/go/vt/vtgate/evalengine/api_aggregation.go @@ -448,7 +448,7 @@ type aggregationMinMax struct { current sqltypes.Value collation collations.ID collationEnv *collations.Environment - values []string + values *EnumSetValues } func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { @@ -485,7 +485,7 @@ func (a *aggregationMinMax) Reset() { a.current = sqltypes.NULL } -func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID, values []string) MinMax { +func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID, values *EnumSetValues) MinMax { switch { case sqltypes.IsSigned(typ): return &aggregationInt{t: typ} diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index e890e7c83fd..6873ad40143 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -43,7 +43,7 @@ func (err UnsupportedCollationError) Error() string { // UnsupportedCollationHashError is returned when we try to get the hash value and are missing the collation to use var UnsupportedCollationHashError = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") -func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values []string) (int, error) { +func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values *EnumSetValues) (int, error) { v1t := v1.Type() // We have a fast path here for the case where both values are @@ -147,7 +147,7 @@ func compare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collat // numeric, then a numeric comparison is performed after // necessary conversions. If none are numeric, then it's // a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values []string) (int, error) { +func NullsafeCompare(v1, v2 sqltypes.Value, collationEnv *collations.Environment, collationID collations.ID, values *EnumSetValues) (int, error) { // Based on the categorization defined for the types, // we're going to allow comparison of the following: // Null, isNumber, IsBinary. This will exclude IsQuoted diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index 778a252e2d8..106b111cafc 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -30,14 +30,12 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" - - "vitess.io/vitess/go/sqltypes" - - querypb "vitess.io/vitess/go/vt/proto/query" ) type testCase struct { @@ -1114,7 +1112,7 @@ func TestNullsafeCompare(t *testing.T) { v1, v2 sqltypes.Value out int err error - values []string + values *EnumSetValues }{ { v1: NULL, @@ -1145,7 +1143,7 @@ func TestNullsafeCompare(t *testing.T) { v1: TestValue(sqltypes.Enum, "foo"), v2: TestValue(sqltypes.Enum, "bar"), out: -1, - values: []string{"'foo'", "'bar'"}, + values: &EnumSetValues{"'foo'", "'bar'"}, }, { v1: TestValue(sqltypes.Enum, "foo"), @@ -1156,7 +1154,7 @@ func TestNullsafeCompare(t *testing.T) { v1: TestValue(sqltypes.Enum, "foo"), v2: TestValue(sqltypes.VarChar, "bar"), out: 1, - values: []string{"'foo'", "'bar'"}, + values: &EnumSetValues{"'foo'", "'bar'"}, }, { v1: TestValue(sqltypes.VarChar, "foo"), @@ -1167,7 +1165,7 @@ func TestNullsafeCompare(t *testing.T) { v1: TestValue(sqltypes.Set, "bar"), v2: TestValue(sqltypes.Set, "foo,bar"), out: -1, - values: []string{"'foo'", "'bar'"}, + values: &EnumSetValues{"'foo'", "'bar'"}, }, { v1: TestValue(sqltypes.Set, "bar"), @@ -1178,7 +1176,7 @@ func TestNullsafeCompare(t *testing.T) { v1: TestValue(sqltypes.VarChar, "bar"), v2: TestValue(sqltypes.Set, "foo,bar"), out: -1, - values: []string{"'foo'", "'bar'"}, + values: &EnumSetValues{"'foo'", "'bar'"}, }, { v1: TestValue(sqltypes.Set, "bar"), diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 0ed3e0c4146..a5e5d1778dd 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,7 +34,7 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode, values []string) (HashCode, error) { +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) (HashCode, error) { e, err := valueToEvalCast(v, coerceType, collation, values, sqlmode) if err != nil { return 0, err @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, " // for two values that are considered equal by `NullsafeCompare`. // This can be used to avoid having to do comparison checks after a hash, // since we consider the 128 bits of entropy enough to guarantee uniqueness. -func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values []string) error { +func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) error { switch { case v.IsNull(), sqltypes.IsNull(coerceTo): hash.Write16(hashPrefixNil) @@ -233,7 +233,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat return nil } -func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values []string) error { +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode, values *EnumSetValues) error { // Slow path to handle all other types. This uses the generic // logic for value casting to ensure we match MySQL here. e, err := valueToEvalCast(v, coerceTo, collation, values, sqlmode) diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index 0b01a485dc3..ccfe63f514f 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -17,8 +17,6 @@ limitations under the License. package evalengine import ( - "slices" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/mysql/decimal" @@ -67,7 +65,7 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } -func (a *Arena) newEvalEnum(raw []byte, values []string) *evalEnum { +func (a *Arena) newEvalEnum(raw []byte, values *EnumSetValues) *evalEnum { if cap(a.aEnum) > len(a.aEnum) { a.aEnum = a.aEnum[:len(a.aEnum)+1] } else { @@ -76,11 +74,11 @@ func (a *Arena) newEvalEnum(raw []byte, values []string) *evalEnum { val := &a.aEnum[len(a.aInt64)-1] s := string(raw) val.string = s - val.value = slices.Index(values, s) + val.value = valueIdx(values, s) return val } -func (a *Arena) newEvalSet(raw []byte, values []string) *evalSet { +func (a *Arena) newEvalSet(raw []byte, values *EnumSetValues) *evalSet { if cap(a.aSet) > len(a.aSet) { a.aSet = a.aSet[:len(a.aSet)+1] } else { diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index abe7bdc473f..4854795779e 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -159,16 +159,17 @@ func (cached *Column) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(80) + size += int64(64) } // field Original vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Original.(cachedObject); ok { size += cc.CachedSize(true) } - // field Values []string - { - size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) - for _, elem := range cached.Values { + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { size += hack.RuntimeAllocSize(int64(len(elem))) } } @@ -196,7 +197,7 @@ func (cached *CompiledExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(96) + size += int64(80) } // field code []vitess.io/vitess/go/vt/vtgate/evalengine.frame { @@ -370,7 +371,7 @@ func (cached *OrderByParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(80) + size += int64(64) } // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type size += cached.Type.CachedSize(false) @@ -396,12 +397,13 @@ func (cached *Type) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(24) } - // field values []string - { - size += hack.RuntimeAllocSize(int64(cap(cached.values)) * int64(16)) - for _, elem := range cached.values { + // field values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.values)) * int64(16)) + for _, elem := range *cached.values { size += hack.RuntimeAllocSize(int64(len(elem))) } } @@ -1933,12 +1935,13 @@ func (cached *ctype) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(32) } - // field Values []string - { - size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(16)) - for _, elem := range cached.Values { + // field Values *vitess.io/vitess/go/vt/vtgate/evalengine.EnumSetValues + if cached.Values != nil { + size += int64(24) + size += hack.RuntimeAllocSize(int64(cap(*cached.Values)) * int64(16)) + for _, elem := range *cached.Values { size += hack.RuntimeAllocSize(int64(len(elem))) } } @@ -2063,7 +2066,7 @@ func (cached *typedExpr) CachedSize(alloc bool) int64 { } // field types []vitess.io/vitess/go/vt/vtgate/evalengine.ctype { - size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(48)) + size += hack.RuntimeAllocSize(int64(cap(cached.types)) * int64(32)) for _, elem := range cached.types { size += elem.CachedSize(false) } diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 344798f6abb..d9de15aa571 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -52,12 +52,14 @@ type compiledCoercion struct { right colldata.Coercion } +type EnumSetValues []string + type ctype struct { Type sqltypes.Type Flag typeFlag Size, Scale int32 Col collations.TypedCollation - Values []string + Values *EnumSetValues } type Type struct { @@ -66,7 +68,17 @@ type Type struct { nullable bool init bool size, scale int32 - values []string + values *EnumSetValues +} + +func (v *EnumSetValues) Equal(other *EnumSetValues) bool { + if v == nil && other == nil { + return true + } + if v == nil || other == nil { + return false + } + return slices.Equal(*v, *other) } func NewType(t sqltypes.Type, collation collations.ID) Type { @@ -74,7 +86,7 @@ func NewType(t sqltypes.Type, collation collations.ID) Type { return NewTypeEx(t, collation, true, 0, 0, nil) } -func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32, values []string) Type { +func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, scale int32, values *EnumSetValues) Type { return Type{ typ: t, collation: collation, @@ -144,7 +156,7 @@ func (t *Type) Nullable() bool { return true // nullable by default for unknown types } -func (t *Type) Values() []string { +func (t *Type) Values() *EnumSetValues { return t.values } @@ -158,27 +170,27 @@ func (t *Type) Equal(other *Type) bool { t.nullable == other.nullable && t.size == other.size && t.scale == other.scale && - slices.Equal(t.values, other.values) + t.values.Equal(other.values) } -func (ct ctype) equal(other ctype) bool { +func (ct *ctype) equal(other ctype) bool { return ct.Type == other.Type && ct.Flag == other.Flag && ct.Size == other.Size && ct.Scale == other.Scale && ct.Col == other.Col && - slices.Equal(ct.Values, other.Values) + ct.Values.Equal(other.Values) } -func (ct ctype) nullable() bool { +func (ct *ctype) nullable() bool { return ct.Flag&flagNullable != 0 } -func (ct ctype) isTextual() bool { +func (ct *ctype) isTextual() bool { return sqltypes.IsTextOrBinary(ct.Type) } -func (ct ctype) isHexOrBitLiteral() bool { +func (ct *ctype) isHexOrBitLiteral() bool { return ct.Flag&flagBit != 0 || ct.Flag&flagHex != 0 } diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index ff8adb168ff..87d2ee9af9b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -105,13 +105,13 @@ func push_d(env *ExpressionEnv, raw []byte) int { return 1 } -func push_enum(env *ExpressionEnv, raw []byte, values []string) int { +func push_enum(env *ExpressionEnv, raw []byte, values *EnumSetValues) int { env.vm.stack[env.vm.sp] = env.vm.arena.newEvalEnum(raw, values) env.vm.sp++ return 1 } -func push_set(env *ExpressionEnv, raw []byte, values []string) int { +func push_set(env *ExpressionEnv, raw []byte, values *EnumSetValues) int { env.vm.stack[env.vm.sp] = env.vm.arena.newEvalSet(raw, values) env.vm.sp++ return 1 @@ -129,7 +129,7 @@ func (asm *assembler) PushColumn_d(offset int) { }, "PUSH DECIMAL(:%d)", offset) } -func (asm *assembler) PushColumn_enum(offset int, values []string) { +func (asm *assembler) PushColumn_enum(offset int, values *EnumSetValues) { asm.adjustStack(1) asm.emit(func(env *ExpressionEnv) int { @@ -141,7 +141,7 @@ func (asm *assembler) PushColumn_enum(offset int, values []string) { }, "PUSH ENUM(:%d)", offset) } -func (asm *assembler) PushColumn_set(offset int, values []string) { +func (asm *assembler) PushColumn_set(offset int, values *EnumSetValues) { asm.adjustStack(1) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index eeefa351894..90b1add541a 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -212,7 +212,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, values []string, sqlmode SQLMode) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, values *EnumSetValues, sqlmode SQLMode) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -367,7 +367,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) } -func valueToEval(value sqltypes.Value, collation collations.TypedCollation, values []string) (eval, error) { +func valueToEval(value sqltypes.Value, collation collations.TypedCollation, values *EnumSetValues) (eval, error) { wrap := func(err error) error { if err == nil { return nil diff --git a/go/vt/vtgate/evalengine/eval_enum.go b/go/vt/vtgate/evalengine/eval_enum.go index b89abd15429..a0d349314da 100644 --- a/go/vt/vtgate/evalengine/eval_enum.go +++ b/go/vt/vtgate/evalengine/eval_enum.go @@ -10,7 +10,7 @@ type evalEnum struct { string string } -func newEvalEnum(val []byte, values []string) *evalEnum { +func newEvalEnum(val []byte, values *EnumSetValues) *evalEnum { s := string(val) return &evalEnum{ value: valueIdx(values, s), @@ -26,8 +26,11 @@ func (e *evalEnum) SQLType() sqltypes.Type { return sqltypes.Enum } -func valueIdx(values []string, value string) int { - for i, v := range values { +func valueIdx(values *EnumSetValues, value string) int { + if values == nil { + return -1 + } + for i, v := range *values { v, _ = sqltypes.DecodeStringSQL(v) if v == value { return i diff --git a/go/vt/vtgate/evalengine/eval_set.go b/go/vt/vtgate/evalengine/eval_set.go index 47fe29607df..6a9de2eff14 100644 --- a/go/vt/vtgate/evalengine/eval_set.go +++ b/go/vt/vtgate/evalengine/eval_set.go @@ -12,7 +12,7 @@ type evalSet struct { string string } -func newEvalSet(val []byte, values []string) *evalSet { +func newEvalSet(val []byte, values *EnumSetValues) *evalSet { value := string(val) return &evalSet{ @@ -29,8 +29,8 @@ func (e *evalSet) SQLType() sqltypes.Type { return sqltypes.Set } -func evalSetBits(values []string, value string) uint64 { - if len(values) > 64 { +func evalSetBits(values *EnumSetValues, value string) uint64 { + if values != nil && len(*values) > 64 { // This never would happen as MySQL limits SET // to 64 elements. Safeguard here just in case though. panic("too many values for set") diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index cbdb1775f88..d53585ceb8b 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -34,7 +34,7 @@ type ( Collation collations.TypedCollation Original sqlparser.Expr Nullable bool - Values []string // For ENUM and SET types + Values *EnumSetValues // For ENUM and SET types // dynamicTypeOffset is set when the type of this column cannot be calculated // at translation time. Since expressions with dynamic types cannot be compiled ahead of time, diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index 37286af7ddc..3eb9aa290c5 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -41,7 +41,7 @@ import ( // externally communicates with the `WEIGHT_STRING` function, so that we // can also use this to order / sort other types like Float and Decimal // as well. -func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values []string, sqlmode SQLMode) ([]byte, bool, error) { +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values *EnumSetValues, sqlmode SQLMode) ([]byte, bool, error) { // We optimize here for the case where we already have the desired type. // Otherwise, we fall back to the general evalengine conversion logic. if v.Type() != coerceTo { @@ -125,7 +125,7 @@ func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col coll } } -func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values []string, sqlmode SQLMode) ([]byte, bool, error) { +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, values *EnumSetValues, sqlmode SQLMode) ([]byte, bool, error) { e, err := valueToEvalCast(v, coerceTo, col, values, sqlmode) if err != nil { return dst, false, err diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index b059142163a..95764d3c3a4 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -37,7 +37,7 @@ func TestTinyWeightStrings(t *testing.T) { col collations.ID len int prec int - values []string + values *EnumSetValues }{ {typ: sqltypes.Int32, gen: sqltypes.RandomGenerators[sqltypes.Int32], col: collations.CollationBinaryID}, {typ: sqltypes.Int64, gen: sqltypes.RandomGenerators[sqltypes.Int64], col: collations.CollationBinaryID}, @@ -48,8 +48,8 @@ func TestTinyWeightStrings(t *testing.T) { {typ: sqltypes.VarBinary, gen: sqltypes.RandomGenerators[sqltypes.VarBinary], col: collations.CollationBinaryID}, {typ: sqltypes.Decimal, gen: sqltypes.RandomGenerators[sqltypes.Decimal], col: collations.CollationBinaryID, len: 20, prec: 10}, {typ: sqltypes.TypeJSON, gen: sqltypes.RandomGenerators[sqltypes.TypeJSON], col: collations.CollationBinaryID}, - {typ: sqltypes.Enum, gen: sqltypes.RandomGenerators[sqltypes.Enum], col: collations.CollationBinaryID, values: []string{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, - {typ: sqltypes.Set, gen: sqltypes.RandomGenerators[sqltypes.Set], col: collations.CollationBinaryID, values: []string{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, + {typ: sqltypes.Enum, gen: sqltypes.RandomGenerators[sqltypes.Enum], col: collations.CollationBinaryID, values: &EnumSetValues{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {typ: sqltypes.Set, gen: sqltypes.RandomGenerators[sqltypes.Set], col: collations.CollationBinaryID, values: &EnumSetValues{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { @@ -119,7 +119,7 @@ func TestWeightStrings(t *testing.T) { col collations.ID len int prec int - values []string + values *EnumSetValues }{ {name: "int64", gen: sqltypes.RandomGenerators[sqltypes.Int64], types: []sqltypes.Type{sqltypes.Int64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "uint64", gen: sqltypes.RandomGenerators[sqltypes.Uint64], types: []sqltypes.Type{sqltypes.Uint64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, @@ -132,8 +132,8 @@ func TestWeightStrings(t *testing.T) { {name: "datetime", gen: sqltypes.RandomGenerators[sqltypes.Datetime], types: []sqltypes.Type{sqltypes.Datetime, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "timestamp", gen: sqltypes.RandomGenerators[sqltypes.Timestamp], types: []sqltypes.Type{sqltypes.Timestamp, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, {name: "time", gen: sqltypes.RandomGenerators[sqltypes.Time], types: []sqltypes.Type{sqltypes.Time, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "enum", gen: sqltypes.RandomGenerators[sqltypes.Enum], types: []sqltypes.Type{sqltypes.Enum, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: []string{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, - {name: "set", gen: sqltypes.RandomGenerators[sqltypes.Set], types: []sqltypes.Type{sqltypes.Set, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: []string{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, + {name: "enum", gen: sqltypes.RandomGenerators[sqltypes.Enum], types: []sqltypes.Type{sqltypes.Enum, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: &EnumSetValues{"'xxsmall'", "'xsmall'", "'small'", "'medium'", "'large'", "'xlarge'", "'xxlarge'"}}, + {name: "set", gen: sqltypes.RandomGenerators[sqltypes.Set], types: []sqltypes.Type{sqltypes.Set, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, values: &EnumSetValues{"'a'", "'b'", "'c'", "'d'", "'e'", "'f'", "'g'"}}, } for _, tc := range cases { diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 6506cdee09c..8e5e8b547a6 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/topotools" "vitess.io/vitess/go/json2" @@ -233,7 +234,7 @@ func (col *Column) ToEvalengineType(collationEnv *collations.Environment) evalen } else { collation = collations.CollationForType(col.Type, collationEnv.DefaultConnectionCharset()) } - return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale, col.Values) + return evalengine.NewTypeEx(col.Type, collation, col.Nullable, col.Size, col.Scale, ptr.Of(evalengine.EnumSetValues(col.Values))) } // KeyspaceSchema contains the schema(table) for a keyspace. diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index a698bad290a..95c5c1fb32c 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/mysql/collations" @@ -116,10 +117,10 @@ type vdiff struct { // compareColInfo contains the metadata for a column of the table being diffed type compareColInfo struct { - colIndex int // index of the column in the filter's select - collation collations.ID // is the collation of the column, if any - values []string // is the list of enum or set values for the column, if any - isPK bool // is this column part of the primary key + colIndex int // index of the column in the filter's select + collation collations.ID // is the collation of the column, if any + values *evalengine.EnumSetValues // is the list of enum or set values for the column, if any + isPK bool // is this column part of the primary key } // tableDiffer performs a diff for one table in the workflow. @@ -538,7 +539,7 @@ func findPKs(env *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition, // getColumnCollations determines the proper collation to use for each // column in the table definition leveraging MySQL's collation inheritance // rules. -func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, map[string][]string, error) { +func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.TableDefinition) (map[string]collations.ID, map[string]*evalengine.EnumSetValues, error) { createstmt, err := venv.Parser().Parse(table.Schema) if err != nil { return nil, nil, err @@ -581,7 +582,7 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab } columnCollations := make(map[string]collations.ID) - columnValues := make(map[string][]string) + columnValues := make(map[string]*evalengine.EnumSetValues) for _, column := range tableschema.TableSpec.Columns { // If it's not a character based type then no collation is used. if !sqltypes.IsQuoted(column.Type.SQLType()) { @@ -589,7 +590,10 @@ func getColumnCollations(venv *vtenv.Environment, table *tabletmanagerdatapb.Tab continue } columnCollations[column.Name.Lowered()] = getColumnCollation(column) - columnValues[column.Name.Lowered()] = column.Type.EnumValues + if len(column.Type.EnumValues) == 0 { + continue + } + columnValues[column.Name.Lowered()] = ptr.Of(evalengine.EnumSetValues(column.Type.EnumValues)) } return columnCollations, columnValues, nil } diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 5f98e11ed72..87988c5fd7e 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -34,6 +34,7 @@ import ( "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) func TestVDiffPlanSuccess(t *testing.T) { @@ -1186,7 +1187,7 @@ func TestGetColumnCollations(t *testing.T) { name string table *tabletmanagerdatapb.TableDefinition wantCols map[string]collations.ID - wantValues map[string][]string + wantValues map[string]*evalengine.EnumSetValues wantErr bool }{ { @@ -1205,9 +1206,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, - wantValues: map[string][]string{ - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with global default collation", @@ -1218,10 +1217,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.DefaultConnectionCharset(), "name": collationEnv.DefaultConnectionCharset(), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "compound char int pk with global default collation", @@ -1232,9 +1228,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collations.Unknown, "name": collationEnv.DefaultConnectionCharset(), }, - wantValues: map[string][]string{ - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with table default charset", @@ -1245,10 +1239,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.DefaultCollationForCharset("ucs2"), "name": collationEnv.DefaultCollationForCharset("ucs2"), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with table default collation", @@ -1259,10 +1250,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.LookupByName("utf32_icelandic_ci"), "name": collationEnv.LookupByName("utf32_icelandic_ci"), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with column charset override", @@ -1273,10 +1261,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.DefaultCollationForCharset("sjis"), "name": collationEnv.DefaultCollationForCharset("utf8mb3"), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "char pk with column collation override", @@ -1287,10 +1272,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.LookupByName("hebrew_bin"), "name": collationEnv.DefaultCollationForCharset("hebrew"), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "compound char int pk with column collation override", @@ -1302,10 +1284,7 @@ func TestGetColumnCollations(t *testing.T) { "c2": collations.Unknown, "name": collationEnv.LookupByName("utf16_icelandic_ci"), }, - wantValues: map[string][]string{ - "c1": nil, - "name": nil, - }, + wantValues: map[string]*evalengine.EnumSetValues{}, }, { name: "col with enum values", @@ -1316,8 +1295,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.DefaultConnectionCharset(), "size": collationEnv.DefaultConnectionCharset(), }, - wantValues: map[string][]string{ - "c1": nil, + wantValues: map[string]*evalengine.EnumSetValues{ "size": {"'small'", "'medium'", "'large'"}, }, }, @@ -1330,8 +1308,7 @@ func TestGetColumnCollations(t *testing.T) { "c1": collationEnv.DefaultConnectionCharset(), "size": collationEnv.DefaultConnectionCharset(), }, - wantValues: map[string][]string{ - "c1": nil, + wantValues: map[string]*evalengine.EnumSetValues{ "size": {"'small'", "'medium'", "'large'"}, }, }, From 89ffe520b60dce513519759f75b7f779fe56745b Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 24 Apr 2024 09:02:07 +0200 Subject: [PATCH 3/3] Fix comments Signed-off-by: Dirkjan Bussink --- go/sqltypes/value.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 99a0a43828e..bb4e26d15e3 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -568,12 +568,12 @@ func (v Value) IsDecimal() bool { return IsDecimal(v.Type()) } -// IsEnum returns true if Value is time. +// IsEnum returns true if Value is enum. func (v Value) IsEnum() bool { return v.Type() == querypb.Type_ENUM } -// IsSet returns true if Value is time. +// IsSet returns true if Value is set. func (v Value) IsSet() bool { return v.Type() == querypb.Type_SET }