Skip to content

Commit

Permalink
coerce the fields and row values on concatenation
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Oct 11, 2023
1 parent 6c15016 commit 1365b7d
Showing 1 changed file with 101 additions and 66 deletions.
167 changes: 101 additions & 66 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"sync"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -88,7 +89,7 @@ var errWrongNumberOfColumnsInSelect = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRE

// TryExecute performs a non-streaming exec.
func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
res, err := c.execSources(ctx, vcursor, bindVars, wantfields)
res, err := c.execSources(ctx, vcursor, bindVars, true)
if err != nil {
return nil, err
}
Expand All @@ -98,35 +99,42 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
return nil, err
}

var rowsAffected uint64
var rows [][]sqltypes.Value

for _, r := range res {
rowsAffected += r.RowsAffected

if len(rows) > 0 &&
len(r.Rows) > 0 &&
len(rows[0]) != len(r.Rows[0]) {
return nil, errWrongNumberOfColumnsInSelect
}

for _, row := range r.Rows {
newRow, err := c.coerceValuesTo(row, fields)
if err != nil {
return nil, err
needsCoercion := false
for idx, field := range r.Fields {
if fields[idx].Type != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fields)
if err != nil {
return nil, err
}
rows = append(rows, row)
}
rows = append(rows, newRow)
} else {
rows = append(rows, r.Rows...)
}
}

return &sqltypes.Result{
Fields: fields,
RowsAffected: rowsAffected,
Rows: rows,
Fields: fields,
Rows: rows,
}, nil
}

func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) (sqltypes.Row, error) {
func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error {
if len(row) != len(fields) {
panic("wrong number of fields")
}
Expand All @@ -135,14 +143,15 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field)
if _, found := c.NoNeedToTypeCheck[i]; found {
continue
}

newValue, err := evalengine.CoerceTo(value, fields[i].Type)
if err != nil {
return nil, err
if fields[i].Type != value.Type() {
newValue, err := evalengine.CoerceTo(value, fields[i].Type)
if err != nil {
return err
}
row[i] = newValue
}
row[i] = newValue
}
return row, nil
return nil
}

func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) {
Expand All @@ -153,21 +162,24 @@ func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb
resultFields = res[0].Fields
columns := make([][]sqltypes.Type, len(resultFields))

addFields := func(fields []*querypb.Field) {
addFields := func(fields []*querypb.Field) error {
if len(fields) != len(columns) {
err = errWrongNumberOfColumnsInSelect
return
return errWrongNumberOfColumnsInSelect
}
for idx, field := range fields {
columns[idx] = append(columns[idx], field.Type)
}
return nil
}

for _, r := range res {
if r.Fields == nil {
if r == nil || r.Fields == nil {
continue
}
addFields(r.Fields)
err := addFields(r.Fields)
if err != nil {
return nil, err
}
}

// The resulting column types need to be the coercion of all the input columns
Expand Down Expand Up @@ -206,7 +218,7 @@ func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVar
wg.Add(1)
go func() {
defer wg.Done()
result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields)
result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true)
if err != nil {
outerErr = err
cancel()
Expand All @@ -223,7 +235,7 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV
for i, source := range c.Sources {
currIndex, currSource := i, source
vars := copyBindVars(bindVars)
result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields)
result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, true)
if err != nil {
return nil, err
}
Expand All @@ -235,72 +247,95 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV
// TryStreamExecute performs a streaming exec.
func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
if vcursor.Session().InTransaction() {
// as we are in a transaction, we need to execute all queries inside a single transaction
// therefore it needs a sequential execution.
// as we are in a transaction, we need to execute all queries inside a single connection,
// which holds the single transaction we have
return c.sequentialStreamExec(ctx, vcursor, bindVars, wantfields, callback)
}
// not in transaction, so execute in parallel.
return c.parallelStreamExec(ctx, vcursor, bindVars, wantfields, callback)
}

func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var seenFields []*querypb.Field
func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, in func(*sqltypes.Result) error) error {
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
var outerErr error

var fieldsSent bool
var cbMu, fieldsMu sync.Mutex
var wg, fieldSendWg sync.WaitGroup
fieldSendWg.Add(1)
var cbMu sync.Mutex
var wg, fieldMu sync.WaitGroup
var fieldRec atomic.Int64
fieldRec.Store(int64(len(c.Sources)))
fieldMu.Add(1)

rest := make([]*sqltypes.Result, len(c.Sources))
var fields []*querypb.Field
callback := func(res *sqltypes.Result, srcIdx int) error {
cbMu.Lock()
defer cbMu.Unlock()

needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, ok := c.NoNeedToTypeCheck[idx]
if !ok && fields[idx].Type != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range res.Rows {
err := c.coerceValuesTo(row, fields)
if err != nil {
return err
}
}
}
return in(res)
}

once := sync.Once{}

for i, source := range c.Sources {
wg.Add(1)
currIndex, currSource := i, source

go func() {
defer wg.Done()
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if currIndex == 0 {
fieldsMu.Lock()
if !fieldsSent {
defer fieldSendWg.Done()
defer fieldsMu.Unlock()
seenFields = resultChunk.Fields
fieldsSent = true
// No other call can happen before this call.
return callback(resultChunk)
if fieldRec.Load() > 0 && resultChunk.Fields != nil {
rest[currIndex] = resultChunk
res := fieldRec.Add(-1)
if res == 0 {
// We have received fields from all sources. We can now calculate the output types //WALKING THE DOGS! push and THANKS!
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
}
resultChunk.Fields = fields
defer once.Do(func() {
fieldMu.Done()
})

return callback(resultChunk, currIndex)
} else {
fieldMu.Wait()
}
fieldsMu.Unlock()
}
fieldSendWg.Wait()
if resultChunk.Fields != nil {
err := c.compareFields(seenFields, resultChunk.Fields)
if err != nil {
return err
}
}
// This to ensure only one send happens back to the client.
cbMu.Lock()
defer cbMu.Unlock()

// If we get here, all the fields have been received
select {
case <-ctx.Done():
return nil
default:
return callback(resultChunk)
return callback(resultChunk, currIndex)
}
})
// This is to ensure other streams complete if the first stream failed to unlock the wait.
if currIndex == 0 {
fieldsMu.Lock()
if !fieldsSent {
fieldsSent = true
fieldSendWg.Done()
}
fieldsMu.Unlock()
}
if err != nil {
outerErr = err
ctx.Done()
cancel()
once.Do(func() {
fieldMu.Done()
})
}
}()

Expand Down

0 comments on commit 1365b7d

Please sign in to comment.