diff --git a/go/test/endtoend/vtgate/queries/union/union_test.go b/go/test/endtoend/vtgate/queries/union/union_test.go index 52aa94bda51..2365e7a2fc9 100644 --- a/go/test/endtoend/vtgate/queries/union/union_test.go +++ b/go/test/endtoend/vtgate/queries/union/union_test.go @@ -118,6 +118,7 @@ func TestUnionAll(t *testing.T) { func TestUnion(t *testing.T) { mcmp, closer := start(t) defer closer() + mcmp.Exec("insert into t1(id1, id2) values(1, 1), (2, 2)") mcmp.AssertMatches(`SELECT 1 UNION SELECT 1 UNION SELECT 1`, `[[INT64(1)]]`) mcmp.AssertMatches(`SELECT 1,'a' UNION SELECT 1,'a' UNION SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`) @@ -126,4 +127,7 @@ func TestUnion(t *testing.T) { mcmp.AssertMatches(`(SELECT 1,'a') UNION ALL (SELECT 1,'a') UNION ALL (SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`) mcmp.AssertMatches(`(SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`) mcmp.AssertMatches(`(SELECT 1,'a' order by 1) union (SELECT 1,'a' ORDER BY 1)`, `[[INT64(1) VARCHAR("a")]]`) + if utils.BinaryIsAtVersion(19, "vtgate") { + mcmp.AssertMatches(`(SELECT id2,'a' from t1 where id1 = 1) union (SELECT 'a',id2 from t1 where id1 = 2)`, `[[VARCHAR("1") VARCHAR("a")] [VARCHAR("a") VARCHAR("2")]]`) + } } diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 904a44ccb85..4df774f9f5d 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -19,11 +19,13 @@ package engine import ( "context" "sync" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // Concatenate Primitive is used to concatenate results from multiple sources. @@ -86,8 +88,8 @@ func formatTwoOptionsNicely(a, b string) string { var errWrongNumberOfColumnsInSelect = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.WrongNumberOfColumnsInSelect, "The used SELECT statements have a different number of columns") // TryExecute performs a non-streaming exec. -func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - res, err := c.execSources(ctx, vcursor, bindVars, wantfields) +func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { + res, err := c.execSources(ctx, vcursor, bindVars, true) if err != nil { return nil, err } @@ -97,49 +99,79 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } - var rowsAffected uint64 var rows [][]sqltypes.Value - - for _, r := range res { - rowsAffected += r.RowsAffected - - if len(rows) > 0 && - len(r.Rows) > 0 && - len(rows[0]) != len(r.Rows[0]) { - return nil, errWrongNumberOfColumnsInSelect - } - - rows = append(rows, r.Rows...) + err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error { + rows = append(rows, result.Rows...) + return nil + }) + if err != nil { + return nil, err } return &sqltypes.Result{ - Fields: fields, - RowsAffected: rowsAffected, - Rows: rows, + Fields: fields, + Rows: rows, }, nil } -func (c *Concatenate) getFields(res []*sqltypes.Result) ([]*querypb.Field, error) { +func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error { + if len(row) != len(fields) { + return errWrongNumberOfColumnsInSelect + } + + for i, value := range row { + if _, found := c.NoNeedToTypeCheck[i]; found { + continue + } + if fields[i].Type != value.Type() { + newValue, err := evalengine.CoerceTo(value, fields[i].Type) + if err != nil { + return err + } + row[i] = newValue + } + } + return nil +} + +func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) { if len(res) == 0 { return nil, nil } - var fields []*querypb.Field - for _, r := range res { - if r.Fields == nil { - continue + resultFields = res[0].Fields + columns := make([][]sqltypes.Type, len(resultFields)) + + addFields := func(fields []*querypb.Field) error { + if len(fields) != len(columns) { + return errWrongNumberOfColumnsInSelect } - if fields == nil { - fields = r.Fields - continue + for idx, field := range fields { + columns[idx] = append(columns[idx], field.Type) } + return nil + } - err := c.compareFields(fields, r.Fields) + for _, r := range res { + if r == nil || r.Fields == nil { + continue + } + err := addFields(r.Fields) if err != nil { return nil, err } } - return fields, nil + + // The resulting column types need to be the coercion of all the input columns + for colIdx, t := range columns { + if _, found := c.NoNeedToTypeCheck[colIdx]; found { + continue + } + + resultFields[colIdx].Type = evalengine.AggregateTypes(t) + } + + return resultFields, nil } func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { @@ -152,7 +184,7 @@ func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars return c.parallelExec(ctx, vcursor, bindVars, wantfields) } -func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { +func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) ([]*sqltypes.Result, error) { results := make([]*sqltypes.Result, len(c.Sources)) var outerErr error @@ -166,7 +198,7 @@ func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVar wg.Add(1) go func() { defer wg.Done() - result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) + result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true) if err != nil { outerErr = err cancel() @@ -178,12 +210,12 @@ func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVar return results, outerErr } -func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { +func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) ([]*sqltypes.Result, error) { results := make([]*sqltypes.Result, len(c.Sources)) for i, source := range c.Sources { currIndex, currSource := i, source vars := copyBindVars(bindVars) - result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) + result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true) if err != nil { return nil, err } @@ -193,24 +225,53 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV } // TryStreamExecute performs a streaming exec. -func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { +func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { if vcursor.Session().InTransaction() { - // as we are in a transaction, we need to execute all queries inside a single transaction - // therefore it needs a sequential execution. - return c.sequentialStreamExec(ctx, vcursor, bindVars, wantfields, callback) + // as we are in a transaction, we need to execute all queries inside a single connection, + // which holds the single transaction we have + return c.sequentialStreamExec(ctx, vcursor, bindVars, callback) } // not in transaction, so execute in parallel. - return c.parallelStreamExec(ctx, vcursor, bindVars, wantfields, callback) + return c.parallelStreamExec(ctx, vcursor, bindVars, callback) } -func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - var seenFields []*querypb.Field +func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { + ctx, cancel := context.WithCancel(inCtx) + defer cancel() var outerErr error - var fieldsSent bool - var cbMu, fieldsMu sync.Mutex - var wg, fieldSendWg sync.WaitGroup - fieldSendWg.Add(1) + var cbMu sync.Mutex + var wg, fieldMu sync.WaitGroup + var fieldRec atomic.Int64 + fieldRec.Store(int64(len(c.Sources))) + fieldMu.Add(1) + + rest := make([]*sqltypes.Result, len(c.Sources)) + var fields []*querypb.Field + callback := func(res *sqltypes.Result, srcIdx int) error { + cbMu.Lock() + defer cbMu.Unlock() + + needsCoercion := false + for idx, field := range rest[srcIdx].Fields { + _, ok := c.NoNeedToTypeCheck[idx] + if !ok && fields[idx].Type != field.Type { + needsCoercion = true + break + } + } + if needsCoercion { + for _, row := range res.Rows { + err := c.coerceValuesTo(row, fields) + if err != nil { + return err + } + } + } + return in(res) + } + + once := sync.Once{} for i, source := range c.Sources { wg.Add(1) @@ -218,49 +279,43 @@ func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, b go func() { defer wg.Done() - err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { // if we have fields to compare, make sure all the fields are all the same - if currIndex == 0 { - fieldsMu.Lock() - if !fieldsSent { - defer fieldSendWg.Done() - defer fieldsMu.Unlock() - seenFields = resultChunk.Fields - fieldsSent = true - // No other call can happen before this call. - return callback(resultChunk) - } - fieldsMu.Unlock() - } - fieldSendWg.Wait() - if resultChunk.Fields != nil { - err := c.compareFields(seenFields, resultChunk.Fields) - if err != nil { - return err + if fieldRec.Load() > 0 && resultChunk.Fields != nil { + rest[currIndex] = resultChunk + res := fieldRec.Add(-1) + if res == 0 { + // We have received fields from all sources. We can now calculate the output types + var err error + fields, err = c.getFields(rest) + if err != nil { + return err + } + resultChunk.Fields = fields + defer once.Do(func() { + fieldMu.Done() + }) + + return callback(resultChunk, currIndex) + } else { + fieldMu.Wait() } } - // This to ensure only one send happens back to the client. - cbMu.Lock() - defer cbMu.Unlock() + + // If we get here, all the fields have been received select { case <-ctx.Done(): return nil default: - return callback(resultChunk) + return callback(resultChunk, currIndex) } }) - // This is to ensure other streams complete if the first stream failed to unlock the wait. - if currIndex == 0 { - fieldsMu.Lock() - if !fieldsSent { - fieldsSent = true - fieldSendWg.Done() - } - fieldsMu.Unlock() - } if err != nil { outerErr = err - ctx.Done() + cancel() + once.Do(func() { + fieldMu.Done() + }) } }() @@ -269,61 +324,110 @@ func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, b return outerErr } -func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { +func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { // all the below fields ensure that the fields are sent only once. - var seenFields []*querypb.Field - var fieldsMu sync.Mutex - var fieldsSent bool + results := make([][]*sqltypes.Result, len(c.Sources)) + var mu sync.Mutex for idx, source := range c.Sources { - err := vcursor.StreamExecutePrimitive(ctx, source, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { - // if we have fields to compare, make sure all the fields are all the same - if idx == 0 { - fieldsMu.Lock() - defer fieldsMu.Unlock() - if !fieldsSent { - fieldsSent = true - seenFields = resultChunk.Fields - return callback(resultChunk) - } - } - if resultChunk.Fields != nil { - err := c.compareFields(seenFields, resultChunk.Fields) - if err != nil { - return err - } - } + err := vcursor.StreamExecutePrimitive(ctx, source, bindVars, true, func(resultChunk *sqltypes.Result) error { // check if context has expired. if ctx.Err() != nil { return ctx.Err() } - return callback(resultChunk) + mu.Lock() + defer mu.Unlock() + // This visitor will just accumulate all the results into slices + results[idx] = append(results[idx], resultChunk) + + return nil }) if err != nil { return err } } + + firsts := make([]*sqltypes.Result, len(c.Sources)) + for i, result := range results { + firsts[i] = result[0] + } + + fields, err := c.getFields(firsts) + if err != nil { + return err + } + for _, res := range results { + if err = c.coerceAndVisitResults(res, fields, callback); err != nil { + return err + } + } + + return nil +} + +func (c *Concatenate) coerceAndVisitResults( + res []*sqltypes.Result, + fields []*querypb.Field, + callback func(*sqltypes.Result) error, +) error { + for _, r := range res { + if len(r.Rows) > 0 && + len(fields) != len(r.Rows[0]) { + return errWrongNumberOfColumnsInSelect + } + + needsCoercion := false + for idx, field := range r.Fields { + if fields[idx].Type != field.Type { + needsCoercion = true + break + } + } + if needsCoercion { + for _, row := range r.Rows { + err := c.coerceValuesTo(row, fields) + if err != nil { + return err + } + } + } + err := callback(r) + if err != nil { + return err + } + } return nil } // GetFields fetches the field info. func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - // TODO: type coercions res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars) if err != nil { return nil, err } + columns := make([][]sqltypes.Type, len(res.Fields)) + + addFields := func(fields []*querypb.Field) { + for idx, field := range fields { + columns[idx] = append(columns[idx], field.Type) + } + } + + addFields(res.Fields) + for i := 1; i < len(c.Sources); i++ { result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars) if err != nil { return nil, err } - err = c.compareFields(res.Fields, result.Fields) - if err != nil { - return nil, err - } + addFields(result.Fields) + } + + // The resulting column types need to be the coercion of all the input columns + for colIdx, t := range columns { + res.Fields[colIdx].Type = evalengine.AggregateTypes(t) } return res, nil @@ -347,19 +451,3 @@ func (c *Concatenate) Inputs() ([]Primitive, []map[string]any) { func (c *Concatenate) description() PrimitiveDescription { return PrimitiveDescription{OperatorType: c.RouteType()} } - -func (c *Concatenate) compareFields(fields1 []*querypb.Field, fields2 []*querypb.Field) error { - if len(fields1) != len(fields2) { - return errWrongNumberOfColumnsInSelect - } - for i, field1 := range fields1 { - if _, found := c.NoNeedToTypeCheck[i]; found { - continue - } - field2 := fields2[i] - if field1.Type != field2.Type { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "merging field of different types is not supported, name: (%v, %v) types: (%v, %v)", field1.Name, field2.Name, field1.Type, field2.Type) - } - } - return nil -} diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index 4a6305e3a0a..b886d1312af 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -19,8 +19,13 @@ package engine import ( "context" "errors" + "fmt" + "strings" "testing" + "vitess.io/vitess/go/test/utils" + + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -68,7 +73,7 @@ func TestConcatenate_NoErrors(t *testing.T) { r("id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1", "2|a2|b2"), r("id|col3|col4", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2"), }, - expectedError: "merging field of different types is not supported", + expectedResult: r("id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1", "2|a2|b2", "1|a1|b1", "2|a2|b2", "1|a1|b1", "2|a2|b2"), }, { testName: "ignored field types - ignored", inputs: []*sqltypes.Result{ @@ -95,35 +100,42 @@ func TestConcatenate_NoErrors(t *testing.T) { }} for _, tc := range testCases { - var sources []Primitive - for _, input := range tc.inputs { - // input is added twice, since the first one is used by execute and the next by stream execute - sources = append(sources, &fakePrimitive{results: []*sqltypes.Result{input, input}}) - } - - concatenate := NewConcatenate(sources, tc.ignoreTypes) - - t.Run(tc.testName+"-Execute", func(t *testing.T) { - qr, err := concatenate.TryExecute(context.Background(), &noopVCursor{}, nil, true) - if tc.expectedError == "" { - require.NoError(t, err) - require.Equal(t, tc.expectedResult, qr) - } else { - require.Error(t, err) - require.Contains(t, err.Error(), tc.expectedError) + for _, tx := range []bool{false, true} { + var sources []Primitive + for _, input := range tc.inputs { + // input is added twice, since the first one is used by execute and the next by stream execute + sources = append(sources, &fakePrimitive{results: []*sqltypes.Result{input, input}}) } - }) - t.Run(tc.testName+"-StreamExecute", func(t *testing.T) { - qr, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) - if tc.expectedError == "" { - require.NoError(t, err) - require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) - } else { - require.Error(t, err) - require.Contains(t, err.Error(), tc.expectedError) + concatenate := NewConcatenate(sources, tc.ignoreTypes) + vcursor := &noopVCursor{inTx: tx} + txStr := "InTx" + if !tx { + txStr = "NotInTx" } - }) + t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) { + qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true) + if tc.expectedError == "" { + require.NoError(t, err) + utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields") + utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedError) + } + }) + + t.Run(fmt.Sprintf("%s-%s-StreamExec", txStr, tc.testName), func(t *testing.T) { + qr, err := wrapStreamExecute(concatenate, vcursor, nil, true) + if tc.expectedError == "" { + require.NoError(t, err) + require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedError) + } + }) + } } } @@ -156,3 +168,36 @@ func TestConcatenate_WithErrors(t *testing.T) { _, err = wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) require.EqualError(t, err, strFailed) } + +func TestConcatenateTypes(t *testing.T) { + tests := []struct { + t1, t2, expected string + }{ + {t1: "int32", t2: "int64", expected: "int64"}, + {t1: "int32", t2: "int32", expected: "int32"}, + {t1: "int32", t2: "varchar", expected: "varchar"}, + {t1: "int32", t2: "decimal", expected: "decimal"}, + {t1: "hexval", t2: "uint64", expected: "varchar"}, + {t1: "varchar", t2: "varbinary", expected: "varbinary"}, + } + + for _, test := range tests { + name := fmt.Sprintf("%s - %s", test.t1, test.t2) + t.Run(name, func(t *testing.T) { + in1 := r("id", test.t1, "1") + in2 := r("id", test.t2, "1") + concatenate := NewConcatenate( + []Primitive{ + &fakePrimitive{results: []*sqltypes.Result{in1}}, + &fakePrimitive{results: []*sqltypes.Result{in2}}, + }, nil, + ) + + res, err := concatenate.GetFields(context.Background(), &noopVCursor{}, nil) + require.NoError(t, err) + + expected := fmt.Sprintf(`[name:"id" type:%s]`, test.expected) + assert.Equal(t, expected, strings.ToLower(fmt.Sprintf("%v", res.Fields))) + }) + } +} diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index b51eebf0eb8..c2418c73560 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -49,6 +49,7 @@ var _ SessionActions = (*noopVCursor)(nil) // noopVCursor is used to build other vcursors. type noopVCursor struct { + inTx bool } func (t *noopVCursor) Commit(ctx context.Context) error { @@ -61,7 +62,7 @@ func (t *noopVCursor) GetUDV(key string) *querypb.BindVariable { } func (t *noopVCursor) InTransaction() bool { - return false + return t.inTx } func (t *noopVCursor) SetCommitOrder(co vtgatepb.CommitOrder) { diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go new file mode 100644 index 00000000000..130727d8f31 --- /dev/null +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -0,0 +1,30 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) + +func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { + cast, err := valueToEvalCast(value, value.Type(), collations.Unknown) + if err != nil { + return sqltypes.Value{}, err + } + return evalToSQLValueWithType(cast, typ), nil +} diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go new file mode 100644 index 00000000000..a5cdc688858 --- /dev/null +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -0,0 +1,205 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import "vitess.io/vitess/go/sqltypes" + +type typeAggregation struct { + double uint16 + decimal uint16 + signed uint16 + unsigned uint16 + + signedMax sqltypes.Type + unsignedMax sqltypes.Type + + bit uint16 + year uint16 + char uint16 + binary uint16 + charother uint16 + json uint16 + + date uint16 + time uint16 + timestamp uint16 + datetime uint16 + + geometry uint16 + blob uint16 + total uint16 +} + +func AggregateTypes(types []sqltypes.Type) sqltypes.Type { + var typeAgg typeAggregation + for _, typ := range types { + var flag typeFlag + if typ == sqltypes.HexVal || typ == sqltypes.HexNum { + typ = sqltypes.Binary + flag = flagHex + } + typeAgg.add(typ, flag) + } + return typeAgg.result() +} + +func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { + switch tt { + case sqltypes.Float32, sqltypes.Float64: + ta.double++ + case sqltypes.Decimal: + ta.decimal++ + case sqltypes.Int8, sqltypes.Int16, sqltypes.Int24, sqltypes.Int32, sqltypes.Int64: + ta.signed++ + if tt > ta.signedMax { + ta.signedMax = tt + } + case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: + ta.unsigned++ + if tt > ta.unsignedMax { + ta.unsignedMax = tt + } + case sqltypes.Bit: + ta.bit++ + case sqltypes.Year: + ta.year++ + case sqltypes.Char, sqltypes.VarChar, sqltypes.Set, sqltypes.Enum: + if f&flagExplicitCollation != 0 { + ta.charother++ + } + ta.char++ + case sqltypes.Binary, sqltypes.VarBinary: + if f&flagHex != 0 { + ta.charother++ + } + ta.binary++ + case sqltypes.TypeJSON: + ta.json++ + case sqltypes.Date: + ta.date++ + case sqltypes.Datetime: + ta.datetime++ + case sqltypes.Time: + ta.time++ + case sqltypes.Timestamp: + ta.timestamp++ + case sqltypes.Geometry: + ta.geometry++ + case sqltypes.Blob: + ta.blob++ + default: + return + } + ta.total++ +} + +func (ta *typeAggregation) result() sqltypes.Type { + /* + If all types are numeric, the aggregated type is also numeric: + If at least one argument is double precision, the result is double precision. + Otherwise, if at least one argument is DECIMAL, the result is DECIMAL. + Otherwise, the result is an integer type (with one exception): + If all integer types are all signed or all unsigned, the result is the same sign and the precision is the highest of all specified integer types (that is, TINYINT, SMALLINT, MEDIUMINT, INT, or BIGINT). + If there is a combination of signed and unsigned integer types, the result is signed and the precision may be higher. For example, if the types are signed INT and unsigned INT, the result is signed BIGINT. + The exception is unsigned BIGINT combined with any signed integer type. The result is DECIMAL with sufficient precision and scale 0. + If all types are BIT, the result is BIT. Otherwise, BIT arguments are treated similar to BIGINT. + If all types are YEAR, the result is YEAR. Otherwise, YEAR arguments are treated similar to INT. + If all types are character string (CHAR or VARCHAR), the result is VARCHAR with maximum length determined by the longest character length of the operands. + If all types are character or binary string, the result is VARBINARY. + SET and ENUM are treated similar to VARCHAR; the result is VARCHAR. + If all types are JSON, the result is JSON. + If all types are temporal, the result is temporal: + If all temporal types are DATE, TIME, or TIMESTAMP, the result is DATE, TIME, or TIMESTAMP, respectively. + Otherwise, for a mix of temporal types, the result is DATETIME. + If all types are GEOMETRY, the result is GEOMETRY. + If any type is BLOB, the result is BLOB. + For all other type combinations, the result is VARCHAR. + Literal NULL operands are ignored for type aggregation. + */ + + if ta.bit == ta.total { + return sqltypes.Bit + } else if ta.bit > 0 { + ta.signed += ta.bit + ta.signedMax = sqltypes.Int64 + } + + if ta.year == ta.total { + return sqltypes.Year + } else if ta.year > 0 { + ta.signed += ta.year + if sqltypes.Int32 > ta.signedMax { + ta.signedMax = sqltypes.Int32 + } + } + + if ta.double+ta.decimal+ta.signed+ta.unsigned == ta.total { + if ta.double > 0 { + return sqltypes.Float64 + } + if ta.decimal > 0 { + return sqltypes.Decimal + } + if ta.signed == ta.total { + return ta.signedMax + } + if ta.unsigned == ta.total { + return ta.unsignedMax + } + if ta.unsignedMax == sqltypes.Uint64 && ta.signed > 0 { + return sqltypes.Decimal + } + // TODO + return sqltypes.Uint64 + } + + if ta.char == ta.total { + return sqltypes.VarChar + } + if ta.char+ta.binary == ta.total { + // HACK: this is not in the official documentation, but groups of strings where + // one of the strings is not directly a VARCHAR or VARBINARY (e.g. a hex literal, + // or a VARCHAR that has been explicitly collated) will result in VARCHAR when + // aggregated + if ta.charother > 0 { + return sqltypes.VarChar + } + return sqltypes.VarBinary + } + if ta.json == ta.total { + return sqltypes.TypeJSON + } + if ta.date+ta.time+ta.timestamp+ta.datetime == ta.total { + if ta.date == ta.total { + return sqltypes.Date + } + if ta.time == ta.total { + return sqltypes.Time + } + if ta.timestamp == ta.total { + return sqltypes.Timestamp + } + return sqltypes.Datetime + } + if ta.geometry == ta.total { + return sqltypes.Geometry + } + if ta.blob > 0 { + return sqltypes.Blob + } + return sqltypes.VarChar +} diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index ee4f61cb596..276e6caa5f1 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -423,176 +423,3 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { } return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") } - -type typeAggregation struct { - double uint16 - decimal uint16 - signed uint16 - unsigned uint16 - - signedMax sqltypes.Type - unsignedMax sqltypes.Type - - bit uint16 - year uint16 - char uint16 - binary uint16 - charother uint16 - json uint16 - - date uint16 - time uint16 - timestamp uint16 - datetime uint16 - - geometry uint16 - blob uint16 - total uint16 -} - -func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { - switch tt { - case sqltypes.Float32, sqltypes.Float64: - ta.double++ - case sqltypes.Decimal: - ta.decimal++ - case sqltypes.Int8, sqltypes.Int16, sqltypes.Int24, sqltypes.Int32, sqltypes.Int64: - ta.signed++ - if tt > ta.signedMax { - ta.signedMax = tt - } - case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ta.unsigned++ - if tt > ta.unsignedMax { - ta.unsignedMax = tt - } - case sqltypes.Bit: - ta.bit++ - case sqltypes.Year: - ta.year++ - case sqltypes.Char, sqltypes.VarChar, sqltypes.Set, sqltypes.Enum: - if f&flagExplicitCollation != 0 { - ta.charother++ - } - ta.char++ - case sqltypes.Binary, sqltypes.VarBinary: - if f&flagHex != 0 { - ta.charother++ - } - ta.binary++ - case sqltypes.TypeJSON: - ta.json++ - case sqltypes.Date: - ta.date++ - case sqltypes.Datetime: - ta.datetime++ - case sqltypes.Time: - ta.time++ - case sqltypes.Timestamp: - ta.timestamp++ - case sqltypes.Geometry: - ta.geometry++ - case sqltypes.Blob: - ta.blob++ - default: - return - } - ta.total++ -} - -func (ta *typeAggregation) result() sqltypes.Type { - /* - If all types are numeric, the aggregated type is also numeric: - If at least one argument is double precision, the result is double precision. - Otherwise, if at least one argument is DECIMAL, the result is DECIMAL. - Otherwise, the result is an integer type (with one exception): - If all integer types are all signed or all unsigned, the result is the same sign and the precision is the highest of all specified integer types (that is, TINYINT, SMALLINT, MEDIUMINT, INT, or BIGINT). - If there is a combination of signed and unsigned integer types, the result is signed and the precision may be higher. For example, if the types are signed INT and unsigned INT, the result is signed BIGINT. - The exception is unsigned BIGINT combined with any signed integer type. The result is DECIMAL with sufficient precision and scale 0. - If all types are BIT, the result is BIT. Otherwise, BIT arguments are treated similar to BIGINT. - If all types are YEAR, the result is YEAR. Otherwise, YEAR arguments are treated similar to INT. - If all types are character string (CHAR or VARCHAR), the result is VARCHAR with maximum length determined by the longest character length of the operands. - If all types are character or binary string, the result is VARBINARY. - SET and ENUM are treated similar to VARCHAR; the result is VARCHAR. - If all types are JSON, the result is JSON. - If all types are temporal, the result is temporal: - If all temporal types are DATE, TIME, or TIMESTAMP, the result is DATE, TIME, or TIMESTAMP, respectively. - Otherwise, for a mix of temporal types, the result is DATETIME. - If all types are GEOMETRY, the result is GEOMETRY. - If any type is BLOB, the result is BLOB. - For all other type combinations, the result is VARCHAR. - Literal NULL operands are ignored for type aggregation. - */ - - if ta.bit == ta.total { - return sqltypes.Bit - } else if ta.bit > 0 { - ta.signed += ta.bit - ta.signedMax = sqltypes.Int64 - } - - if ta.year == ta.total { - return sqltypes.Year - } else if ta.year > 0 { - ta.signed += ta.year - if sqltypes.Int32 > ta.signedMax { - ta.signedMax = sqltypes.Int32 - } - } - - if ta.double+ta.decimal+ta.signed+ta.unsigned == ta.total { - if ta.double > 0 { - return sqltypes.Float64 - } - if ta.decimal > 0 { - return sqltypes.Decimal - } - if ta.signed == ta.total { - return ta.signedMax - } - if ta.unsigned == ta.total { - return ta.unsignedMax - } - if ta.unsignedMax == sqltypes.Uint64 && ta.signed > 0 { - return sqltypes.Decimal - } - // TODO - return sqltypes.Uint64 - } - - if ta.char == ta.total { - return sqltypes.VarChar - } - if ta.char+ta.binary == ta.total { - // HACK: this is not in the official documentation, but groups of strings where - // one of the strings is not directly a VARCHAR or VARBINARY (e.g. a hex literal, - // or a VARCHAR that has been explicitly collated) will result in VARCHAR when - // aggregated - if ta.charother > 0 { - return sqltypes.VarChar - } - return sqltypes.VarBinary - } - if ta.json == ta.total { - return sqltypes.TypeJSON - } - if ta.date+ta.time+ta.timestamp+ta.datetime == ta.total { - if ta.date == ta.total { - return sqltypes.Date - } - if ta.time == ta.total { - return sqltypes.Time - } - if ta.timestamp == ta.total { - return sqltypes.Timestamp - } - return sqltypes.Datetime - } - if ta.geometry == ta.total { - return sqltypes.Geometry - } - if ta.blob > 0 { - return sqltypes.Blob - } - return sqltypes.VarChar -} diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 562a477c9d4..45ccb041ddd 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -249,8 +249,8 @@ func (a *Aggregator) ShortDescription() string { return fmt.Sprintf("%s%s group by %s", org, strings.Join(columns, ", "), strings.Join(grouping, ",")) } -func (a *Aggregator) GetOrdering() []ops.OrderBy { - return a.Source.GetOrdering() +func (a *Aggregator) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return a.Source.GetOrdering(ctx) } func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) { diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index af8be7adf32..138c17f2da7 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -175,8 +175,8 @@ func (aj *ApplyJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser. return transformColumnsToSelectExprs(ctx, aj) } -func (aj *ApplyJoin) GetOrdering() []ops.OrderBy { - return aj.LHS.GetOrdering() +func (aj *ApplyJoin) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return aj.LHS.GetOrdering(ctx) } func joinColumnToAliasedExpr(c JoinColumn) *sqlparser.AliasedExpr { diff --git a/go/vt/vtgate/planbuilder/operators/comments.go b/go/vt/vtgate/planbuilder/operators/comments.go index b480df9ee66..9ede4b9e0da 100644 --- a/go/vt/vtgate/planbuilder/operators/comments.go +++ b/go/vt/vtgate/planbuilder/operators/comments.go @@ -76,6 +76,6 @@ func (l *LockAndComment) ShortDescription() string { return strings.Join(s, " ") } -func (l *LockAndComment) GetOrdering() []ops.OrderBy { - return l.Source.GetOrdering() +func (l *LockAndComment) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return l.Source.GetOrdering(ctx) } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index 2bf5a65a893..2455bb958b5 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -62,7 +62,7 @@ func (d *Delete) TablesUsed() []string { return nil } -func (d *Delete) GetOrdering() []ops.OrderBy { +func (d *Delete) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index b141706e847..d6bbdff8088 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -114,8 +114,8 @@ func (d *Distinct) ShortDescription() string { return "Performance" } -func (d *Distinct) GetOrdering() []ops.OrderBy { - return d.Source.GetOrdering() +func (d *Distinct) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return d.Source.GetOrdering(ctx) } func (d *Distinct) setTruncateColumnCount(offset int) { diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index 6e531693752..ed43910b75d 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -101,8 +101,8 @@ func (f *Filter) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Sele return f.Source.GetSelectExprs(ctx) } -func (f *Filter) GetOrdering() []ops.OrderBy { - return f.Source.GetOrdering() +func (f *Filter) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return f.Source.GetOrdering(ctx) } func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { diff --git a/go/vt/vtgate/planbuilder/operators/fk_cascade.go b/go/vt/vtgate/planbuilder/operators/fk_cascade.go index e37280b0d08..90c797d55e8 100644 --- a/go/vt/vtgate/planbuilder/operators/fk_cascade.go +++ b/go/vt/vtgate/planbuilder/operators/fk_cascade.go @@ -20,6 +20,7 @@ import ( "slices" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) // FkChild is used to represent a foreign key child table operation @@ -96,7 +97,7 @@ func (fkc *FkCascade) Clone(inputs []ops.Operator) ops.Operator { } // GetOrdering implements the Operator interface -func (fkc *FkCascade) GetOrdering() []ops.OrderBy { +func (fkc *FkCascade) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/fk_verify.go b/go/vt/vtgate/planbuilder/operators/fk_verify.go index 1bac01ca7d0..39e1092c8d9 100644 --- a/go/vt/vtgate/planbuilder/operators/fk_verify.go +++ b/go/vt/vtgate/planbuilder/operators/fk_verify.go @@ -18,6 +18,7 @@ package operators import ( "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) // VerifyOp keeps the information about the foreign key verification operation. @@ -70,7 +71,7 @@ func (fkv *FkVerify) Clone(inputs []ops.Operator) ops.Operator { } // GetOrdering implements the Operator interface -func (fkv *FkVerify) GetOrdering() []ops.OrderBy { +func (fkv *FkVerify) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 3057553984b..919767d550f 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -175,9 +175,12 @@ func (h *Horizon) GetSelectExprs(*plancontext.PlanningContext) sqlparser.SelectE return sqlparser.GetFirstSelect(h.Query).SelectExprs } -func (h *Horizon) GetOrdering() []ops.OrderBy { +func (h *Horizon) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { if h.QP == nil { - panic(vterrors.VT13001("QP should already be here")) + _, err := h.getQP(ctx) + if err != nil { + panic(err) + } } return h.QP.OrderExprs } diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 3afeb79d88a..8bdee0a11a8 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -96,7 +96,7 @@ func (i *Insert) ShortDescription() string { return i.VTable.String() } -func (i *Insert) GetOrdering() []ops.OrderBy { +func (i *Insert) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 5574e859953..7cc29b42deb 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -48,7 +48,7 @@ func (j *Join) Clone(inputs []ops.Operator) ops.Operator { } } -func (j *Join) GetOrdering() []ops.OrderBy { +func (j *Join) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/limit.go b/go/vt/vtgate/planbuilder/operators/limit.go index 12929c69e7d..a6ea925b135 100644 --- a/go/vt/vtgate/planbuilder/operators/limit.go +++ b/go/vt/vtgate/planbuilder/operators/limit.go @@ -68,8 +68,8 @@ func (l *Limit) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Selec return l.Source.GetSelectExprs(ctx) } -func (l *Limit) GetOrdering() []ops.OrderBy { - return l.Source.GetOrdering() +func (l *Limit) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return l.Source.GetOrdering(ctx) } func (l *Limit) ShortDescription() string { diff --git a/go/vt/vtgate/planbuilder/operators/ops/op.go b/go/vt/vtgate/planbuilder/operators/ops/op.go index 379011d99d9..1117b947814 100644 --- a/go/vt/vtgate/planbuilder/operators/ops/op.go +++ b/go/vt/vtgate/planbuilder/operators/ops/op.go @@ -53,7 +53,7 @@ type ( ShortDescription() string - GetOrdering() []OrderBy + GetOrdering(ctx *plancontext.PlanningContext) []OrderBy } // OrderBy contains the expression to used in order by and also if ordering is needed at VTGate level then what the weight_string function expression to be sent down for evaluation. diff --git a/go/vt/vtgate/planbuilder/operators/ordering.go b/go/vt/vtgate/planbuilder/operators/ordering.go index 813f091acbe..b3d0310eadb 100644 --- a/go/vt/vtgate/planbuilder/operators/ordering.go +++ b/go/vt/vtgate/planbuilder/operators/ordering.go @@ -74,7 +74,7 @@ func (o *Ordering) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Se return o.Source.GetSelectExprs(ctx) } -func (o *Ordering) GetOrdering() []ops.OrderBy { +func (o *Ordering) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return o.Order } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 0dcc859055b..ba13a828d0b 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -165,7 +165,7 @@ func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) (bool, erro if len(requiredOrder) == 0 { return false, nil } - srcOrdering := in.Source.GetOrdering() + srcOrdering := in.Source.GetOrdering(ctx) if len(srcOrdering) < len(requiredOrder) { return true, nil } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 027c95a6e10..2d4630bd87a 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -412,8 +412,8 @@ func (p *Projection) GetSelectExprs(*plancontext.PlanningContext) sqlparser.Sele } } -func (p *Projection) GetOrdering() []ops.OrderBy { - return p.Source.GetOrdering() +func (p *Projection) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return p.Source.GetOrdering(ctx) } // AllOffsets returns a slice of integer offsets for all columns in the Projection diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index c30bfb8fd21..b0e6b4440be 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -176,7 +176,7 @@ func (qg *QueryGraph) Clone([]ops.Operator) ops.Operator { return result } -func (qg *QueryGraph) GetOrdering() []ops.OrderBy { +func (qg *QueryGraph) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 47b42f124d6..d4b2c43ecff 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -18,7 +18,6 @@ package operators import ( "fmt" - "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/slice" @@ -658,8 +657,8 @@ func (r *Route) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Selec return r.Source.GetSelectExprs(ctx) } -func (r *Route) GetOrdering() []ops.OrderBy { - return r.Source.GetOrdering() +func (r *Route) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return r.Source.GetOrdering(ctx) } // TablesUsed returns tables used by MergedWith routes, which are not included @@ -690,7 +689,7 @@ func (r *Route) planOffsets(ctx *plancontext.PlanningContext) { // if we are getting results from multiple shards, we need to do a merge-sort // between them to get the final output correctly sorted - ordering := r.Source.GetOrdering() + ordering := r.Source.GetOrdering(ctx) if len(ordering) == 0 { return } @@ -735,15 +734,6 @@ func (r *Route) ShortDescription() string { first += " " + info.extraInfo() } - orderBy := r.Source.GetOrdering() - ordering := "" - if len(orderBy) > 0 { - var oo []string - for _, o := range orderBy { - oo = append(oo, sqlparser.String(o.Inner)) - } - ordering = " order by " + strings.Join(oo, ",") - } comments := "" if r.Comments != nil { comments = " comments: " + sqlparser.String(r.Comments) @@ -752,7 +742,7 @@ func (r *Route) ShortDescription() string { if r.Lock != sqlparser.NoLock { lock = " lock: " + r.Lock.ToString() } - return first + ordering + comments + lock + return first + comments + lock } func (r *Route) setTruncateColumnCount(offset int) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 958b873e9e9..e06e595f689 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -125,8 +125,8 @@ func (sq *SubQuery) Clone(inputs []ops.Operator) ops.Operator { return &klone } -func (sq *SubQuery) GetOrdering() []ops.OrderBy { - return sq.Outer.GetOrdering() +func (sq *SubQuery) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return sq.Outer.GetOrdering(ctx) } // Inputs implements the Operator interface diff --git a/go/vt/vtgate/planbuilder/operators/subquery_container.go b/go/vt/vtgate/planbuilder/operators/subquery_container.go index fc2fc823fb4..ab8d1104623 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_container.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_container.go @@ -49,8 +49,8 @@ func (sqc *SubQueryContainer) Clone(inputs []ops.Operator) ops.Operator { return result } -func (sqc *SubQueryContainer) GetOrdering() []ops.OrderBy { - return sqc.Outer.GetOrdering() +func (sqc *SubQueryContainer) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return sqc.Outer.GetOrdering(ctx) } // Inputs implements the Operator interface diff --git a/go/vt/vtgate/planbuilder/operators/table.go b/go/vt/vtgate/planbuilder/operators/table.go index 4809a60c988..09a99170932 100644 --- a/go/vt/vtgate/planbuilder/operators/table.go +++ b/go/vt/vtgate/planbuilder/operators/table.go @@ -92,7 +92,7 @@ func (to *Table) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Sele return transformColumnsToSelectExprs(ctx, to) } -func (to *Table) GetOrdering() []ops.OrderBy { +func (to *Table) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index b926aefdd04..b3d866a00a3 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -58,7 +58,7 @@ func (u *Union) Clone(inputs []ops.Operator) ops.Operator { return &newOp } -func (u *Union) GetOrdering() []ops.OrderBy { +func (u *Union) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 3f049fd6257..c76baabe7dd 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -76,7 +76,7 @@ func (u *Update) Clone([]ops.Operator) ops.Operator { return &upd } -func (u *Update) GetOrdering() []ops.OrderBy { +func (u *Update) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/vindex.go b/go/vt/vtgate/planbuilder/operators/vindex.go index 49a0ffd1409..2fe2bf4d3e5 100644 --- a/go/vt/vtgate/planbuilder/operators/vindex.go +++ b/go/vt/vtgate/planbuilder/operators/vindex.go @@ -101,7 +101,7 @@ func (v *Vindex) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Sele return transformColumnsToSelectExprs(ctx, v) } -func (v *Vindex) GetOrdering() []ops.OrderBy { +func (v *Vindex) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return nil } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 6491fe8c493..d5bd132cfaa 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -5988,5 +5988,87 @@ "user.user" ] } + }, + { + "comment": "Add two counts together", + "query": "SELECT (select count(*) from user) + (select count(*) from user_extra)", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT (select count(*) from user) + (select count(*) from user_extra)", + "Instructions": { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq2" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from user_extra where 1 != 1", + "Query": "select count(*) from user_extra", + "Table": "user_extra" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from `user` where 1 != 1", + "Query": "select count(*) from `user`", + "Table": "`user`" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual where 1 != 1", + "Query": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual", + "Table": "dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user", + "user.user_extra" + ] + } } ]