diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 70640c935b4..83703d4532c 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -16,7 +16,10 @@ limitations under the License. package evalengine -import "vitess.io/vitess/go/sqltypes" +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) type typeAggregation struct { double uint16 @@ -46,15 +49,25 @@ type typeAggregation struct { nullable bool } -func AggregateTypes(types []sqltypes.Type) sqltypes.Type { +func AggregateEvalTypes(types []Type, env *collations.Environment) (Type, error) { var typeAgg typeAggregation + var collAgg collationAggregation + var size, scale int32 for _, typ := range types { - var flag typeFlag - if typ == sqltypes.HexVal || typ == sqltypes.HexNum { - typ = sqltypes.Binary - flag = flagHex + typeAgg.addNullable(typ.typ, typ.nullable) + if err := collAgg.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil { + return Type{}, err } - typeAgg.add(typ, flag) + size = max(typ.size, size) + scale = max(typ.scale, scale) + } + return NewTypeEx(typeAgg.result(), collAgg.result().Collation, typeAgg.nullable, size, scale), nil +} + +func AggregateTypes(types []sqltypes.Type) sqltypes.Type { + var typeAgg typeAggregation + for _, typ := range types { + typeAgg.addNullable(typ, false) } return typeAgg.result() } @@ -75,6 +88,18 @@ func (ta *typeAggregation) addEval(e eval) { ta.add(t, f) } +func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) { + var flag typeFlag + if typ == sqltypes.HexVal || typ == sqltypes.HexNum { + typ = sqltypes.Binary + flag |= flagHex + } + if nullable { + flag |= flagNullable + } + ta.add(typ, flag) +} + func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { if f&flagNullable != 0 { ta.nullable = true @@ -128,6 +153,23 @@ func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { ta.total++ } +func nextSignedTypeForUnsigned(t sqltypes.Type) sqltypes.Type { + switch t { + case sqltypes.Uint8: + return sqltypes.Int16 + case sqltypes.Uint16: + return sqltypes.Int24 + case sqltypes.Uint24: + return sqltypes.Int32 + case sqltypes.Uint32: + return sqltypes.Int64 + case sqltypes.Uint64: + return sqltypes.Decimal + default: + panic("bad unsigned integer type") + } +} + func (ta *typeAggregation) result() sqltypes.Type { /* If all types are numeric, the aggregated type is also numeric: @@ -181,11 +223,14 @@ func (ta *typeAggregation) result() sqltypes.Type { if ta.unsigned == ta.total { return ta.unsignedMax } - if ta.unsignedMax == sqltypes.Uint64 && ta.signed > 0 { - return sqltypes.Decimal + if ta.signed == 0 { + panic("bad type aggregation for signed/unsigned types") + } + agtype := nextSignedTypeForUnsigned(ta.unsignedMax) + if sqltypes.IsSigned(agtype) { + return max(agtype, ta.signedMax) } - // TODO - return sqltypes.Uint64 + return agtype } if ta.char == ta.total { diff --git a/go/vt/vtgate/evalengine/api_type_aggregation_test.go b/go/vt/vtgate/evalengine/api_type_aggregation_test.go new file mode 100644 index 00000000000..1bf29eaffb3 --- /dev/null +++ b/go/vt/vtgate/evalengine/api_type_aggregation_test.go @@ -0,0 +1,78 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) + +var aggregationCases = []struct { + types []sqltypes.Type + result sqltypes.Type +}{ + {[]sqltypes.Type{sqltypes.Int64, sqltypes.Int32, sqltypes.Float64}, sqltypes.Float64}, + {[]sqltypes.Type{sqltypes.Int64, sqltypes.Decimal, sqltypes.Float64}, sqltypes.Float64}, + {[]sqltypes.Type{sqltypes.Int64, sqltypes.Int32, sqltypes.Decimal}, sqltypes.Decimal}, + {[]sqltypes.Type{sqltypes.Int64, sqltypes.Int32, sqltypes.Int64}, sqltypes.Int64}, + {[]sqltypes.Type{sqltypes.Int32, sqltypes.Int16, sqltypes.Int8}, sqltypes.Int32}, + {[]sqltypes.Type{sqltypes.Int32, sqltypes.Uint16, sqltypes.Uint8}, sqltypes.Int32}, + {[]sqltypes.Type{sqltypes.Int32, sqltypes.Uint16, sqltypes.Uint32}, sqltypes.Int64}, + {[]sqltypes.Type{sqltypes.Int32, sqltypes.Uint16, sqltypes.Uint64}, sqltypes.Decimal}, + {[]sqltypes.Type{sqltypes.Bit, sqltypes.Bit, sqltypes.Bit}, sqltypes.Bit}, + {[]sqltypes.Type{sqltypes.Bit, sqltypes.Int32, sqltypes.Float64}, sqltypes.Float64}, + {[]sqltypes.Type{sqltypes.Bit, sqltypes.Decimal, sqltypes.Float64}, sqltypes.Float64}, + {[]sqltypes.Type{sqltypes.Bit, sqltypes.Int32, sqltypes.Decimal}, sqltypes.Decimal}, + {[]sqltypes.Type{sqltypes.Bit, sqltypes.Int32, sqltypes.Int64}, sqltypes.Int64}, + {[]sqltypes.Type{sqltypes.Char, sqltypes.VarChar}, sqltypes.VarChar}, + {[]sqltypes.Type{sqltypes.Char, sqltypes.Char}, sqltypes.VarChar}, + {[]sqltypes.Type{sqltypes.Char, sqltypes.VarChar, sqltypes.VarBinary}, sqltypes.VarBinary}, + {[]sqltypes.Type{sqltypes.Char, sqltypes.Char, sqltypes.Set, sqltypes.Enum}, sqltypes.VarChar}, + {[]sqltypes.Type{sqltypes.TypeJSON, sqltypes.TypeJSON}, sqltypes.TypeJSON}, + {[]sqltypes.Type{sqltypes.Geometry, sqltypes.Geometry}, sqltypes.Geometry}, +} + +func TestTypeAggregations(t *testing.T) { + for i, tc := range aggregationCases { + t.Run(fmt.Sprintf("%d.%v", i, tc.result), func(t *testing.T) { + res := AggregateTypes(tc.types) + require.Equalf(t, tc.result, res, "expected aggregate(%v) = %v, got %v", tc.types, tc.result, res) + }) + } +} + +func TestEvalengineTypeAggregations(t *testing.T) { + for i, tc := range aggregationCases { + t.Run(fmt.Sprintf("%d.%v", i, tc.result), func(t *testing.T) { + var types []Type + for _, tt := range tc.types { + // this test only aggregates binary collations because textual collation + // aggregation is tested in the `mysql/collations` package + types = append(types, NewType(tt, collations.CollationBinaryID)) + } + + res, err := AggregateEvalTypes(types, collations.MySQL8()) + require.NoError(t, err) + require.Equalf(t, tc.result, res.Type(), "expected aggregate(%v) = %v, got %v", tc.types, tc.result, res) + }) + } +}