From 1365b7d002041d43556eda7ecd96324cc80788f2 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Wed, 11 Oct 2023 14:54:30 +0530 Subject: [PATCH] coerce the fields and row values on concatenation Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 167 +++++++++++++++++------------ 1 file changed, 101 insertions(+), 66 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index deee4657afe..70eeb7eee91 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -19,6 +19,7 @@ package engine import ( "context" "sync" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -88,7 +89,7 @@ var errWrongNumberOfColumnsInSelect = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRE // TryExecute performs a non-streaming exec. func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - res, err := c.execSources(ctx, vcursor, bindVars, wantfields) + res, err := c.execSources(ctx, vcursor, bindVars, true) if err != nil { return nil, err } @@ -98,35 +99,42 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } - var rowsAffected uint64 var rows [][]sqltypes.Value for _, r := range res { - rowsAffected += r.RowsAffected - if len(rows) > 0 && len(r.Rows) > 0 && len(rows[0]) != len(r.Rows[0]) { return nil, errWrongNumberOfColumnsInSelect } - for _, row := range r.Rows { - newRow, err := c.coerceValuesTo(row, fields) - if err != nil { - return nil, err + needsCoercion := false + for idx, field := range r.Fields { + if fields[idx].Type != field.Type { + needsCoercion = true + break + } + } + if needsCoercion { + for _, row := range r.Rows { + err := c.coerceValuesTo(row, fields) + if err != nil { + return nil, err + } + rows = append(rows, row) } - rows = append(rows, newRow) + } else { + rows = append(rows, r.Rows...) } } return &sqltypes.Result{ - Fields: fields, - RowsAffected: rowsAffected, - Rows: rows, + Fields: fields, + Rows: rows, }, nil } -func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) (sqltypes.Row, error) { +func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error { if len(row) != len(fields) { panic("wrong number of fields") } @@ -135,14 +143,15 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) if _, found := c.NoNeedToTypeCheck[i]; found { continue } - - newValue, err := evalengine.CoerceTo(value, fields[i].Type) - if err != nil { - return nil, err + if fields[i].Type != value.Type() { + newValue, err := evalengine.CoerceTo(value, fields[i].Type) + if err != nil { + return err + } + row[i] = newValue } - row[i] = newValue } - return row, nil + return nil } func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) { @@ -153,21 +162,24 @@ func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb resultFields = res[0].Fields columns := make([][]sqltypes.Type, len(resultFields)) - addFields := func(fields []*querypb.Field) { + addFields := func(fields []*querypb.Field) error { if len(fields) != len(columns) { - err = errWrongNumberOfColumnsInSelect - return + return errWrongNumberOfColumnsInSelect } for idx, field := range fields { columns[idx] = append(columns[idx], field.Type) } + return nil } for _, r := range res { - if r.Fields == nil { + if r == nil || r.Fields == nil { continue } - addFields(r.Fields) + err := addFields(r.Fields) + if err != nil { + return nil, err + } } // The resulting column types need to be the coercion of all the input columns @@ -206,7 +218,7 @@ func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVar wg.Add(1) go func() { defer wg.Done() - result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) + result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true) if err != nil { outerErr = err cancel() @@ -223,7 +235,7 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV for i, source := range c.Sources { currIndex, currSource := i, source vars := copyBindVars(bindVars) - result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) + result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true) if err != nil { return nil, err } @@ -235,22 +247,51 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV // TryStreamExecute performs a streaming exec. func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { if vcursor.Session().InTransaction() { - // as we are in a transaction, we need to execute all queries inside a single transaction - // therefore it needs a sequential execution. + // as we are in a transaction, we need to execute all queries inside a single connection, + // which holds the single transaction we have return c.sequentialStreamExec(ctx, vcursor, bindVars, wantfields, callback) } // not in transaction, so execute in parallel. return c.parallelStreamExec(ctx, vcursor, bindVars, wantfields, callback) } -func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - var seenFields []*querypb.Field +func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, in func(*sqltypes.Result) error) error { + ctx, cancel := context.WithCancel(inCtx) + defer cancel() var outerErr error - var fieldsSent bool - var cbMu, fieldsMu sync.Mutex - var wg, fieldSendWg sync.WaitGroup - fieldSendWg.Add(1) + var cbMu sync.Mutex + var wg, fieldMu sync.WaitGroup + var fieldRec atomic.Int64 + fieldRec.Store(int64(len(c.Sources))) + fieldMu.Add(1) + + rest := make([]*sqltypes.Result, len(c.Sources)) + var fields []*querypb.Field + callback := func(res *sqltypes.Result, srcIdx int) error { + cbMu.Lock() + defer cbMu.Unlock() + + needsCoercion := false + for idx, field := range rest[srcIdx].Fields { + _, ok := c.NoNeedToTypeCheck[idx] + if !ok && fields[idx].Type != field.Type { + needsCoercion = true + break + } + } + if needsCoercion { + for _, row := range res.Rows { + err := c.coerceValuesTo(row, fields) + if err != nil { + return err + } + } + } + return in(res) + } + + once := sync.Once{} for i, source := range c.Sources { wg.Add(1) @@ -258,49 +299,43 @@ func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, b go func() { defer wg.Done() - err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { // if we have fields to compare, make sure all the fields are all the same - if currIndex == 0 { - fieldsMu.Lock() - if !fieldsSent { - defer fieldSendWg.Done() - defer fieldsMu.Unlock() - seenFields = resultChunk.Fields - fieldsSent = true - // No other call can happen before this call. - return callback(resultChunk) + if fieldRec.Load() > 0 && resultChunk.Fields != nil { + rest[currIndex] = resultChunk + res := fieldRec.Add(-1) + if res == 0 { + // We have received fields from all sources. We can now calculate the output types //WALKING THE DOGS! push and THANKS! + var err error + fields, err = c.getFields(rest) + if err != nil { + return err + } + resultChunk.Fields = fields + defer once.Do(func() { + fieldMu.Done() + }) + + return callback(resultChunk, currIndex) + } else { + fieldMu.Wait() } - fieldsMu.Unlock() } - fieldSendWg.Wait() - if resultChunk.Fields != nil { - err := c.compareFields(seenFields, resultChunk.Fields) - if err != nil { - return err - } - } - // This to ensure only one send happens back to the client. - cbMu.Lock() - defer cbMu.Unlock() + + // If we get here, all the fields have been received select { case <-ctx.Done(): return nil default: - return callback(resultChunk) + return callback(resultChunk, currIndex) } }) - // This is to ensure other streams complete if the first stream failed to unlock the wait. - if currIndex == 0 { - fieldsMu.Lock() - if !fieldsSent { - fieldsSent = true - fieldSendWg.Done() - } - fieldsMu.Unlock() - } if err != nil { outerErr = err - ctx.Done() + cancel() + once.Do(func() { + fieldMu.Done() + }) } }()