From 277d7863e4adc4585af9eac863525035130fa0c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Wed, 29 Nov 2023 18:15:12 +0100 Subject: [PATCH] refactor the INSERT engine primitive (#14606) Signed-off-by: Andres Taylor Signed-off-by: Harshit Gangal Co-authored-by: Harshit Gangal --- go/vt/vtgate/engine/cached_size.go | 64 +- go/vt/vtgate/engine/insert.go | 889 ++---------------- go/vt/vtgate/engine/insert_common.go | 480 ++++++++++ go/vt/vtgate/engine/insert_select.go | 423 +++++++++ go/vt/vtgate/engine/insert_test.go | 273 +++--- go/vt/vtgate/planbuilder/insert.go | 29 +- .../planbuilder/operator_transformers.go | 56 +- .../planbuilder/testdata/dml_cases.json | 4 +- 8 files changed, 1224 insertions(+), 994 deletions(-) create mode 100644 go/vt/vtgate/engine/insert_common.go create mode 100644 go/vt/vtgate/engine/insert_select.go diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 657e5323f0..807c760441 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -386,10 +386,10 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(240) + size += int64(192) } - // field Keyspace *vitess.io/vitess/go/vt/vtgate/vindexes.Keyspace - size += cached.Keyspace.CachedSize(true) + // field InsertCommon vitess.io/vitess/go/vt/vtgate/engine.InsertCommon + size += cached.InsertCommon.CachedSize(false) // field Query string size += hack.RuntimeAllocSize(int64(len(cached.Query))) // field VindexValues [][][]vitess.io/vitess/go/vt/vtgate/evalengine.Expr @@ -411,19 +411,6 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } } } - // field ColVindexes []*vitess.io/vitess/go/vt/vtgate/vindexes.ColumnVindex - { - size += hack.RuntimeAllocSize(int64(cap(cached.ColVindexes)) * int64(8)) - for _, elem := range cached.ColVindexes { - size += elem.CachedSize(true) - } - } - // field TableName string - size += hack.RuntimeAllocSize(int64(len(cached.TableName))) - // field Generate *vitess.io/vitess/go/vt/vtgate/engine.Generate - size += cached.Generate.CachedSize(true) - // field Prefix string - size += hack.RuntimeAllocSize(int64(len(cached.Prefix))) // field Mid vitess.io/vitess/go/vt/sqlparser.Values { size += hack.RuntimeAllocSize(int64(cap(cached.Mid)) * int64(24)) @@ -438,8 +425,49 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } } } + return size +} +func (cached *InsertCommon) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(128) + } + // field Keyspace *vitess.io/vitess/go/vt/vtgate/vindexes.Keyspace + size += cached.Keyspace.CachedSize(true) + // field TableName string + size += hack.RuntimeAllocSize(int64(len(cached.TableName))) + // field Generate *vitess.io/vitess/go/vt/vtgate/engine.Generate + size += cached.Generate.CachedSize(true) + // field ColVindexes []*vitess.io/vitess/go/vt/vtgate/vindexes.ColumnVindex + { + size += hack.RuntimeAllocSize(int64(cap(cached.ColVindexes)) * int64(8)) + for _, elem := range cached.ColVindexes { + size += elem.CachedSize(true) + } + } + // field Prefix string + size += hack.RuntimeAllocSize(int64(len(cached.Prefix))) // field Suffix string size += hack.RuntimeAllocSize(int64(len(cached.Suffix))) + return size +} +func (cached *InsertSelect) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(176) + } + // field InsertCommon vitess.io/vitess/go/vt/vtgate/engine.InsertCommon + size += cached.InsertCommon.CachedSize(false) + // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Input.(cachedObject); ok { + size += cc.CachedSize(true) + } // field VindexValueOffset [][]int { size += hack.RuntimeAllocSize(int64(cap(cached.VindexValueOffset)) * int64(24)) @@ -449,10 +477,6 @@ func (cached *Insert) CachedSize(alloc bool) int64 { } } } - // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive - if cc, ok := cached.Input.(cachedObject); ok { - size += cc.CachedSize(true) - } return size } diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 4fa8eeadce..4604341373 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -18,14 +18,10 @@ package engine import ( "context" - "encoding/json" "fmt" "strconv" "strings" - "sync" - "time" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" @@ -40,95 +36,41 @@ import ( var _ Primitive = (*Insert)(nil) -type ( - // Insert represents the instructions to perform an insert operation. - Insert struct { - // Opcode is the execution opcode. - Opcode InsertOpcode +// Insert represents the instructions to perform an insert operation. +type Insert struct { + InsertCommon - // Ignore is for INSERT IGNORE and INSERT...ON DUPLICATE KEY constructs - // for sharded cases. - Ignore bool + // Query specifies the query to be executed. + // For InsertSharded plans, this value is unused, + // and Prefix, Mid and Suffix are used instead. + Query string - // Keyspace specifies the keyspace to send the query to. - Keyspace *vindexes.Keyspace + // VindexValues specifies values for all the vindex columns. + // This is a three-dimensional data structure: + // Insert.Values[i] represents the values to be inserted for the i'th colvindex (i < len(Insert.Table.ColumnVindexes)) + // Insert.Values[i].Values[j] represents values for the j'th column of the given colVindex (j < len(colVindex[i].Columns) + // Insert.Values[i].Values[j].Values[k] represents the value pulled from row k for that column: (k < len(ins.rows)) + VindexValues [][][]evalengine.Expr - // Query specifies the query to be executed. - // For InsertSharded plans, this value is unused, - // and Prefix, Mid and Suffix are used instead. - Query string + // Mid is the row values for the sharded insert plans. + Mid sqlparser.Values - // VindexValues specifies values for all the vindex columns. - // This is a three-dimensional data structure: - // Insert.Values[i] represents the values to be inserted for the i'th colvindex (i < len(Insert.Table.ColumnVindexes)) - // Insert.Values[i].Values[j] represents values for the j'th column of the given colVindex (j < len(colVindex[i].Columns) - // Insert.Values[i].Values[j].Values[k] represents the value pulled from row k for that column: (k < len(ins.rows)) - VindexValues [][][]evalengine.Expr - - // ColVindexes are the vindexes that will use the VindexValues - ColVindexes []*vindexes.ColumnVindex - - // TableName is the name of the table on which row will be inserted. - TableName string - - // Generate is only set for inserts where a sequence must be generated. - Generate *Generate - - // Prefix, Mid and Suffix are for sharded insert plans. - Prefix string - Mid sqlparser.Values - Suffix string - - // Option to override the standard behavior and allow a multi-shard insert - // to use single round trip autocommit. - // - // This is a clear violation of the SQL semantics since it means the statement - // is not atomic in the presence of PK conflicts on one shard and not another. - // However some application use cases would prefer that the statement partially - // succeed in order to get the performance benefits of autocommit. - MultiShardAutocommit bool - - // QueryTimeout contains the optional timeout (in milliseconds) to apply to this query - QueryTimeout int - - // VindexValueOffset stores the offset for each column in the ColumnVindex - // that will appear in the result set of the select query. - VindexValueOffset [][]int - - // Input is a select query plan to retrieve results for inserting data. - Input Primitive `json:",omitempty"` - - // ForceNonStreaming is true when the insert table and select table are same. - // This will avoid locking by the select table. - ForceNonStreaming bool - - PreventAutoCommit bool - - // Insert needs tx handling - txNeeded - } - - ksID = []byte -) - -func (ins *Insert) Inputs() ([]Primitive, []map[string]any) { - if ins.Input == nil { - return nil, nil - } - return []Primitive{ins.Input}, nil + noInputs } -// NewQueryInsert creates an Insert with a query string. -func NewQueryInsert(opcode InsertOpcode, keyspace *vindexes.Keyspace, query string) *Insert { +// newQueryInsert creates an Insert with a query string. +func newQueryInsert(opcode InsertOpcode, keyspace *vindexes.Keyspace, query string) *Insert { return &Insert{ - Opcode: opcode, - Keyspace: keyspace, - Query: query, + InsertCommon: InsertCommon{ + Opcode: opcode, + Keyspace: keyspace, + }, + Query: query, } } -// NewInsert creates a new Insert. -func NewInsert( +// newInsert creates a new Insert. +func newInsert( opcode InsertOpcode, ignore bool, keyspace *vindexes.Keyspace, @@ -139,13 +81,15 @@ func NewInsert( suffix string, ) *Insert { ins := &Insert{ - Opcode: opcode, - Ignore: ignore, - Keyspace: keyspace, + InsertCommon: InsertCommon{ + Opcode: opcode, + Keyspace: keyspace, + Ignore: ignore, + Prefix: prefix, + Suffix: suffix, + }, VindexValues: vindexValues, - Prefix: prefix, Mid: mid, - Suffix: suffix, } if table != nil { ins.TableName = table.Name.String() @@ -159,208 +103,59 @@ func NewInsert( return ins } -// Generate represents the instruction to generate -// a value from a sequence. -type Generate struct { - Keyspace *vindexes.Keyspace - Query string - // Values are the supplied values for the column, which - // will be stored as a list within the expression. New - // values will be generated based on how many were not - // supplied (NULL). - Values evalengine.Expr - // Insert using Select, offset for auto increment column - Offset int -} - -// InsertOpcode is a number representing the opcode -// for the Insert primitive. -type InsertOpcode int - -const ( - // InsertUnsharded is for routing an insert statement - // to an unsharded keyspace. - InsertUnsharded = InsertOpcode(iota) - // InsertSharded is for routing an insert statement - // to individual shards. Requires: A list of Values, one - // for each ColVindex. If the table has an Autoinc column, - // A Generate subplan must be created. - InsertSharded - // InsertSelect is for routing an insert statement - // based on rows returned from the select statement. - InsertSelect -) - -var insName = map[InsertOpcode]string{ - InsertUnsharded: "InsertUnsharded", - InsertSharded: "InsertSharded", - InsertSelect: "InsertSelect", -} - -// String returns the opcode -func (code InsertOpcode) String() string { - return strings.ReplaceAll(insName[code], "Insert", "") -} - -// MarshalJSON serializes the InsertOpcode as a JSON string. -// It's used for testing and diagnostics. -func (code InsertOpcode) MarshalJSON() ([]byte, error) { - return json.Marshal(insName[code]) -} - // RouteType returns a description of the query routing type used by the primitive func (ins *Insert) RouteType() string { return insName[ins.Opcode] } -// GetKeyspaceName specifies the Keyspace that this primitive routes to. -func (ins *Insert) GetKeyspaceName() string { - return ins.Keyspace.Name -} - -// GetTableName specifies the table that this primitive routes to. -func (ins *Insert) GetTableName() string { - return ins.TableName -} - // TryExecute performs a non-streaming exec. -func (ins *Insert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { +func (ins *Insert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { ctx, cancelFunc := addQueryTimeout(ctx, vcursor, ins.QueryTimeout) defer cancelFunc() switch ins.Opcode { case InsertUnsharded: - return ins.execInsertUnsharded(ctx, vcursor, bindVars) + return ins.insertIntoUnshardedTable(ctx, vcursor, bindVars) case InsertSharded: - return ins.execInsertSharded(ctx, vcursor, bindVars) - case InsertSelect: - return ins.execInsertFromSelect(ctx, vcursor, bindVars) + return ins.insertIntoShardedTable(ctx, vcursor, bindVars) default: - // Unreachable. - return nil, fmt.Errorf("unsupported query route: %v", ins) + return nil, vterrors.VT13001("unexpected query route: %v", ins.Opcode) } } // TryStreamExecute performs a streaming exec. func (ins *Insert) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - if ins.Input == nil || ins.ForceNonStreaming { - res, err := ins.TryExecute(ctx, vcursor, bindVars, wantfields) - if err != nil { - return err - } - return callback(res) - } - if ins.QueryTimeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(ins.QueryTimeout)*time.Millisecond) - defer cancel() - } - - unsharded := ins.Opcode == InsertUnsharded - var mu sync.Mutex - output := &sqltypes.Result{} - - err := vcursor.StreamExecutePrimitiveStandalone(ctx, ins.Input, bindVars, false, func(result *sqltypes.Result) error { - if len(result.Rows) == 0 { - return nil - } - - // should process only one chunk at a time. - // as parallel chunk insert will try to use the same transaction in the vttablet - // this will cause transaction in use error. - mu.Lock() - defer mu.Unlock() - - var insertID int64 - var qr *sqltypes.Result - var err error - if unsharded { - insertID, qr, err = ins.insertIntoUnshardedTable(ctx, vcursor, bindVars, result) - } else { - insertID, qr, err = ins.insertIntoShardedTable(ctx, vcursor, bindVars, result) - } - if err != nil { - return err - } - - output.RowsAffected += qr.RowsAffected - // InsertID needs to be updated to the least insertID value in sqltypes.Result - if output.InsertID == 0 || output.InsertID > uint64(insertID) { - output.InsertID = uint64(insertID) - } - return nil - }) + res, err := ins.TryExecute(ctx, vcursor, bindVars, wantfields) if err != nil { return err } - return callback(output) + return callback(res) } -func (ins *Insert) insertIntoShardedTable(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, result *sqltypes.Result) (int64, *sqltypes.Result, error) { - insertID, err := ins.processGenerateFromRows(ctx, vcursor, result.Rows) - if err != nil { - return 0, nil, err - } - - rss, queries, err := ins.getInsertSelectQueries(ctx, vcursor, bindVars, result.Rows) +func (ins *Insert) insertIntoUnshardedTable(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + insertID, err := ins.processGenerateFromValues(ctx, vcursor, ins, bindVars) if err != nil { - return 0, nil, err - } - - qr, err := ins.executeInsertQueries(ctx, vcursor, rss, queries, insertID) - if err != nil { - return 0, nil, err - } - return insertID, qr, nil -} - -// GetFields fetches the field info. -func (ins *Insert) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unreachable code for %q", ins.Query) -} - -func (ins *Insert) execInsertUnsharded(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - query := ins.Query - if ins.Input != nil { - result, err := vcursor.ExecutePrimitive(ctx, ins.Input, bindVars, false) - if err != nil { - return nil, err - } - if len(result.Rows) == 0 { - return &sqltypes.Result{}, nil - } - query = ins.getInsertQueryForUnsharded(result, bindVars) + return nil, err } - _, qr, err := ins.executeUnshardedTableQuery(ctx, vcursor, bindVars, query) - return qr, err -} - -func (ins *Insert) getInsertQueryForUnsharded(result *sqltypes.Result, bindVars map[string]*querypb.BindVariable) string { - var mids sqlparser.Values - for r, inputRow := range result.Rows { - row := sqlparser.ValTuple{} - for c, value := range inputRow { - bvName := insertVarOffset(r, c) - bindVars[bvName] = sqltypes.ValueBindVariable(value) - row = append(row, sqlparser.NewArgument(bvName)) - } - mids = append(mids, row) - } - return ins.Prefix + sqlparser.String(mids) + ins.Suffix + return ins.executeUnshardedTableQuery(ctx, vcursor, ins, bindVars, ins.Query, uint64(insertID)) } -func (ins *Insert) execInsertSharded(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - insertID, err := ins.processGenerateFromValues(ctx, vcursor, bindVars) +func (ins *Insert) insertIntoShardedTable( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, +) (*sqltypes.Result, error) { + insertID, err := ins.processGenerateFromValues(ctx, vcursor, ins, bindVars) if err != nil { return nil, err } - rss, queries, err := ins.getInsertShardedRoute(ctx, vcursor, bindVars) + rss, queries, err := ins.getInsertShardedQueries(ctx, vcursor, bindVars) if err != nil { return nil, err } - return ins.executeInsertQueries(ctx, vcursor, rss, queries, insertID) + return ins.executeInsertQueries(ctx, vcursor, rss, queries, uint64(insertID)) } func (ins *Insert) executeInsertQueries( @@ -368,7 +163,7 @@ func (ins *Insert) executeInsertQueries( vcursor VCursor, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, - insertID int64, + insertID uint64, ) (*sqltypes.Result, error) { autocommit := (len(rss) == 1 || ins.MultiShardAutocommit) && vcursor.AutocommitApproval() err := allowOnlyPrimary(rss...) @@ -381,347 +176,51 @@ func (ins *Insert) executeInsertQueries( } if insertID != 0 { - result.InsertID = uint64(insertID) + result.InsertID = insertID } return result, nil } -func (ins *Insert) getInsertSelectQueries( - ctx context.Context, - vcursor VCursor, - bindVars map[string]*querypb.BindVariable, - rows []sqltypes.Row, -) ([]*srvtopo.ResolvedShard, []*querypb.BoundQuery, error) { - colVindexes := ins.ColVindexes - if len(colVindexes) != len(ins.VindexValueOffset) { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "vindex value offsets and vindex info do not match") - } - - // Here we go over the incoming rows and extract values for the vindexes we need to update - shardingCols := make([][]sqltypes.Row, len(colVindexes)) - for _, inputRow := range rows { - for colIdx := range colVindexes { - offsets := ins.VindexValueOffset[colIdx] - row := make(sqltypes.Row, 0, len(offsets)) - for _, offset := range offsets { - if offset == -1 { // value not provided from select query - row = append(row, sqltypes.NULL) - continue - } - row = append(row, inputRow[offset]) - } - shardingCols[colIdx] = append(shardingCols[colIdx], row) - } - } - - keyspaceIDs, err := ins.processPrimary(ctx, vcursor, shardingCols[0], colVindexes[0]) - if err != nil { - return nil, nil, err - } - - for vIdx := 1; vIdx < len(colVindexes); vIdx++ { - colVindex := colVindexes[vIdx] - var err error - if colVindex.Owned { - err = ins.processOwned(ctx, vcursor, shardingCols[vIdx], colVindex, keyspaceIDs) - } else { - err = ins.processUnowned(ctx, vcursor, shardingCols[vIdx], colVindex, keyspaceIDs) - } - if err != nil { - return nil, nil, err - } - } - - var indexes []*querypb.Value - var destinations []key.Destination - for i, ksid := range keyspaceIDs { - if ksid != nil { - indexes = append(indexes, &querypb.Value{ - Value: strconv.AppendInt(nil, int64(i), 10), - }) - destinations = append(destinations, key.DestinationKeyspaceID(ksid)) - } - } - if len(destinations) == 0 { - // In this case, all we have is nil KeyspaceIds, we don't do - // anything at all. - return nil, nil, nil - } - - rss, indexesPerRss, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, indexes, destinations) - if err != nil { - return nil, nil, err - } - - queries := make([]*querypb.BoundQuery, len(rss)) - for i := range rss { - bvs := sqltypes.CopyBindVariables(bindVars) // we don't want to create one huge bindvars for all values - var mids sqlparser.Values - for _, indexValue := range indexesPerRss[i] { - index, _ := strconv.Atoi(string(indexValue.Value)) - if keyspaceIDs[index] != nil { - row := sqlparser.ValTuple{} - for colOffset, value := range rows[index] { - bvName := insertVarOffset(index, colOffset) - bvs[bvName] = sqltypes.ValueBindVariable(value) - row = append(row, sqlparser.NewArgument(bvName)) - } - mids = append(mids, row) - } - } - rewritten := ins.Prefix + sqlparser.String(mids) + ins.Suffix - queries[i] = &querypb.BoundQuery{ - Sql: rewritten, - BindVariables: bvs, - } - } - - return rss, queries, nil -} - -func (ins *Insert) execInsertFromSelect(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - // run the SELECT query - if ins.Input == nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "something went wrong planning INSERT SELECT") - } - - result, err := vcursor.ExecutePrimitive(ctx, ins.Input, bindVars, false) - if err != nil { - return nil, err - } - if len(result.Rows) == 0 { - return &sqltypes.Result{}, nil - } - - _, qr, err := ins.insertIntoShardedTable(ctx, vcursor, bindVars, result) - return qr, err -} - -// shouldGenerate determines if a sequence value should be generated for a given value -func shouldGenerate(v sqltypes.Value) bool { - if v.IsNull() { - return true - } - - // Unless the NO_AUTO_VALUE_ON_ZERO sql mode is active in mysql, it also - // treats 0 as a value that should generate a new sequence. - n, err := v.ToCastUint64() - if err == nil && n == 0 { - return true - } - - return false -} - -// processGenerateFromValues generates new values using a sequence if necessary. -// If no value was generated, it returns 0. Values are generated only -// for cases where none are supplied. -func (ins *Insert) processGenerateFromValues( - ctx context.Context, - vcursor VCursor, - bindVars map[string]*querypb.BindVariable, -) (insertID int64, err error) { - if ins.Generate == nil { - return 0, nil - } - - // Scan input values to compute the number of values to generate, and - // keep track of where they should be filled. - env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) - resolved, err := env.Evaluate(ins.Generate.Values) - if err != nil { - return 0, err - } - count := int64(0) - values := resolved.TupleValues() - for _, val := range values { - if shouldGenerate(val) { - count++ - } - } - - // If generation is needed, generate the requested number of values (as one call). - if count != 0 { - rss, _, err := vcursor.ResolveDestinations(ctx, ins.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) - if err != nil { - return 0, err - } - if len(rss) != 1 { - return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto sequence generation can happen through single shard only, it is getting routed to %d shards", len(rss)) - } - bindVars := map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(count)} - qr, err := vcursor.ExecuteStandalone(ctx, ins, ins.Generate.Query, bindVars, rss[0]) - if err != nil { - return 0, err - } - // If no rows are returned, it's an internal error, and the code - // must panic, which will be caught and reported. - insertID, err = qr.Rows[0][0].ToCastInt64() - if err != nil { - return 0, err - } - } - - // Fill the holes where no value was supplied. - cur := insertID - for i, v := range values { - if shouldGenerate(v) { - bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.Int64BindVariable(cur) - cur++ - } else { - bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.ValueBindVariable(v) - } - } - return insertID, nil -} - -// processGenerateFromRows generates new values using a sequence if necessary. -// If no value was generated, it returns 0. Values are generated only -// for cases where none are supplied. -func (ins *Insert) processGenerateFromRows( - ctx context.Context, - vcursor VCursor, - rows []sqltypes.Row, -) (insertID int64, err error) { - if ins.Generate == nil { - return 0, nil - } - var count int64 - offset := ins.Generate.Offset - genColPresent := offset < len(rows[0]) - if genColPresent { - for _, val := range rows { - if val[offset].IsNull() { - count++ - } - } - } else { - count = int64(len(rows)) - } - - if count == 0 { - return 0, nil - } - - // If generation is needed, generate the requested number of values (as one call). - rss, _, err := vcursor.ResolveDestinations(ctx, ins.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) - if err != nil { - return 0, err - } - if len(rss) != 1 { - return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto sequence generation can happen through single shard only, it is getting routed to %d shards", len(rss)) - } - bindVars := map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(count)} - qr, err := vcursor.ExecuteStandalone(ctx, ins, ins.Generate.Query, bindVars, rss[0]) - if err != nil { - return 0, err - } - // If no rows are returned, it's an internal error, and the code - // must panic, which will be caught and reported. - insertID, err = qr.Rows[0][0].ToCastInt64() - if err != nil { - return 0, err - } - - used := insertID - for idx, val := range rows { - if genColPresent { - if val[offset].IsNull() { - val[offset] = sqltypes.NewInt64(used) - used++ - } - } else { - rows[idx] = append(val, sqltypes.NewInt64(used)) - used++ - } - } - - return insertID, nil -} - -// getInsertShardedRoute performs all the vindex related work +// getInsertShardedQueries performs all the vindex related work // and returns a map of shard to queries. // Using the primary vindex, it computes the target keyspace ids. // For owned vindexes, it creates entries. // For unowned vindexes with no input values, it reverse maps. // For unowned vindexes with values, it validates. // If it's an IGNORE or ON DUPLICATE key insert, it drops unroutable rows. -func (ins *Insert) getInsertShardedRoute( +func (ins *Insert) getInsertShardedQueries( ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, ) ([]*srvtopo.ResolvedShard, []*querypb.BoundQuery, error) { + // vindexRowsValues builds the values of all vindex columns. // the 3-d structure indexes are colVindex, row, col. Note that // ins.Values indexes are colVindex, col, row. So, the conversion // involves a transpose. // The reason we need to transpose is that all the Vindex APIs // require inputs in that format. - vindexRowsValues := make([][]sqltypes.Row, len(ins.VindexValues)) - rowCount := 0 - env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) - colVindexes := ins.ColVindexes - for vIdx, vColValues := range ins.VindexValues { - if len(vColValues) != len(colVindexes[vIdx].Columns) { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] supplied vindex column values don't match vschema: %v %v", vColValues, colVindexes[vIdx].Columns) - } - for colIdx, colValues := range vColValues { - rowsResolvedValues := make(sqltypes.Row, 0, len(colValues)) - for _, colValue := range colValues { - result, err := env.Evaluate(colValue) - if err != nil { - return nil, nil, err - } - rowsResolvedValues = append(rowsResolvedValues, result.Value(vcursor.ConnCollation())) - } - // This is the first iteration: allocate for transpose. - if colIdx == 0 { - if len(rowsResolvedValues) == 0 { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] rowcount is zero for inserts: %v", rowsResolvedValues) - } - if rowCount == 0 { - rowCount = len(rowsResolvedValues) - } - if rowCount != len(rowsResolvedValues) { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] uneven row values for inserts: %d %d", rowCount, len(rowsResolvedValues)) - } - vindexRowsValues[vIdx] = make([]sqltypes.Row, rowCount) - } - // Perform the transpose. - for rowNum, colVal := range rowsResolvedValues { - vindexRowsValues[vIdx][rowNum] = append(vindexRowsValues[vIdx][rowNum], colVal) - } - } + vindexRowsValues, err := ins.buildVindexRowsValues(ctx, vcursor, bindVars) + if err != nil { + return nil, nil, err } // The output from the following 'process' functions is a list of // keyspace ids. For regular inserts, a failure to find a route // results in an error. For 'ignore' type inserts, the keyspace // id is returned as nil, which is used later to drop the corresponding rows. - if len(vindexRowsValues) == 0 || len(colVindexes) == 0 { + if len(vindexRowsValues) == 0 || len(ins.ColVindexes) == 0 { return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.RequiresPrimaryKey, vterrors.PrimaryVindexNotSet, ins.TableName) } - keyspaceIDs, err := ins.processPrimary(ctx, vcursor, vindexRowsValues[0], colVindexes[0]) + + keyspaceIDs, err := ins.processVindexes(ctx, vcursor, vindexRowsValues) if err != nil { return nil, nil, err } - for vIdx := 1; vIdx < len(colVindexes); vIdx++ { - colVindex := colVindexes[vIdx] - var err error - if colVindex.Owned { - err = ins.processOwned(ctx, vcursor, vindexRowsValues[vIdx], colVindex, keyspaceIDs) - } else { - err = ins.processUnowned(ctx, vcursor, vindexRowsValues[vIdx], colVindex, keyspaceIDs) - } - if err != nil { - return nil, nil, err - } - } - // Build 3-d bindvars. Skip rows with nil keyspace ids in case // we're executing an insert ignore. - for vIdx, colVindex := range colVindexes { + for vIdx, colVindex := range ins.ColVindexes { for rowNum, rowColumnKeys := range vindexRowsValues[vIdx] { if keyspaceIDs[rowNum] == nil { // InsertIgnore: skip the row. @@ -794,174 +293,50 @@ func (ins *Insert) getInsertShardedRoute( return rss, queries, nil } -// processPrimary maps the primary vindex values to the keyspace ids. -func (ins *Insert) processPrimary(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex) ([]ksID, error) { - destinations, err := vindexes.Map(ctx, colVindex.Vindex, vcursor, vindexColumnsKeys) - if err != nil { - return nil, err - } - - keyspaceIDs := make([]ksID, len(destinations)) - for i, destination := range destinations { - switch d := destination.(type) { - case key.DestinationKeyspaceID: - // This is a single keyspace id, we're good. - keyspaceIDs[i] = d - case key.DestinationNone: - // No valid keyspace id, we may return an error. - if !ins.Ignore { - return nil, fmt.Errorf("could not map %v to a keyspace id", vindexColumnsKeys[i]) - } - default: - return nil, fmt.Errorf("could not map %v to a unique keyspace id: %v", vindexColumnsKeys[i], destination) - } - } - - return keyspaceIDs, nil -} - -// processOwned creates vindex entries for the values of an owned column. -func (ins *Insert) processOwned(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex, ksids []ksID) error { - if !ins.Ignore { - return colVindex.Vindex.(vindexes.Lookup).Create(ctx, vcursor, vindexColumnsKeys, ksids, false /* ignoreMode */) - } - - // InsertIgnore - var createIndexes []int - var createKeys []sqltypes.Row - var createKsids []ksID - - for rowNum, rowColumnKeys := range vindexColumnsKeys { - if ksids[rowNum] == nil { - continue - } - createIndexes = append(createIndexes, rowNum) - createKeys = append(createKeys, rowColumnKeys) - createKsids = append(createKsids, ksids[rowNum]) - } - if createKeys == nil { - return nil - } - - err := colVindex.Vindex.(vindexes.Lookup).Create(ctx, vcursor, createKeys, createKsids, true) - if err != nil { - return err - } - // After creation, verify that the keys map to the keyspace ids. If not, remove - // those that don't map. - verified, err := vindexes.Verify(ctx, colVindex.Vindex, vcursor, createKeys, createKsids) - if err != nil { - return err - } - for i, v := range verified { - if !v { - ksids[createIndexes[i]] = nil - } - } - return nil -} - -// processUnowned either reverse maps or validates the values for an unowned column. -func (ins *Insert) processUnowned(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex, ksids []ksID) error { - var reverseIndexes []int - var reverseKsids []ksID - - var verifyIndexes []int - var verifyKeys []sqltypes.Row - var verifyKsids []ksID - - // Check if this VIndex is reversible or not. - reversibleVindex, isReversible := colVindex.Vindex.(vindexes.Reversible) - - for rowNum, rowColumnKeys := range vindexColumnsKeys { - // If we weren't able to determine a keyspace id from the primary VIndex, skip this row - if ksids[rowNum] == nil { - continue +func (ins *Insert) buildVindexRowsValues(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([][]sqltypes.Row, error) { + vindexRowsValues := make([][]sqltypes.Row, len(ins.VindexValues)) + rowCount := 0 + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) + colVindexes := ins.ColVindexes + for vIdx, vColValues := range ins.VindexValues { + if len(vColValues) != len(colVindexes[vIdx].Columns) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] supplied vindex column values don't match vschema: %v %v", vColValues, colVindexes[vIdx].Columns) } - - if rowColumnKeys[0].IsNull() { - // If the value of the column is `NULL`, but this is a reversible VIndex, - // we will try to generate the value from the keyspace id generated by the primary VIndex. - if isReversible { - reverseIndexes = append(reverseIndexes, rowNum) - reverseKsids = append(reverseKsids, ksids[rowNum]) + for colIdx, colValues := range vColValues { + rowsResolvedValues := make(sqltypes.Row, 0, len(colValues)) + for _, colValue := range colValues { + result, err := env.Evaluate(colValue) + if err != nil { + return nil, err + } + rowsResolvedValues = append(rowsResolvedValues, result.Value(vcursor.ConnCollation())) } - - // Otherwise, don't do anything. Whether `NULL` is a valid value for this column will be - // handled by MySQL. - } else { - // If a value for this column was specified, the keyspace id values from the - // secondary VIndex need to be verified against the keyspace id from the primary VIndex - verifyIndexes = append(verifyIndexes, rowNum) - verifyKeys = append(verifyKeys, rowColumnKeys) - verifyKsids = append(verifyKsids, ksids[rowNum]) - } - } - - // Reverse map values for secondary VIndex columns from the primary VIndex's keyspace id. - if reverseKsids != nil { - reverseKeys, err := reversibleVindex.ReverseMap(vcursor, reverseKsids) - if err != nil { - return err - } - - for i, reverseKey := range reverseKeys { - // Fill the first column with the reverse-mapped value. - vindexColumnsKeys[reverseIndexes[i]][0] = reverseKey - } - } - - // Verify that the keyspace ids generated by the primary and secondary VIndexes match - if verifyIndexes != nil { - // If values were supplied, we validate against keyspace id. - verified, err := vindexes.Verify(ctx, colVindex.Vindex, vcursor, verifyKeys, verifyKsids) - if err != nil { - return err - } - - var mismatchVindexKeys []sqltypes.Row - for i, v := range verified { - rowNum := verifyIndexes[i] - if !v { - if !ins.Ignore { - mismatchVindexKeys = append(mismatchVindexKeys, vindexColumnsKeys[rowNum]) - continue + // This is the first iteration: allocate for transpose. + if colIdx == 0 { + if len(rowsResolvedValues) == 0 { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] rowcount is zero for inserts: %v", rowsResolvedValues) } - - // Skip the whole row if this is a `INSERT IGNORE` or `INSERT ... ON DUPLICATE KEY ...` statement - // but the keyspace ids didn't match. - ksids[verifyIndexes[i]] = nil + if rowCount == 0 { + rowCount = len(rowsResolvedValues) + } + if rowCount != len(rowsResolvedValues) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] uneven row values for inserts: %d %d", rowCount, len(rowsResolvedValues)) + } + vindexRowsValues[vIdx] = make([]sqltypes.Row, rowCount) + } + // Perform the transpose. + for rowNum, colVal := range rowsResolvedValues { + vindexRowsValues[vIdx][rowNum] = append(vindexRowsValues[vIdx][rowNum], colVal) } - } - - if mismatchVindexKeys != nil { - return fmt.Errorf("values %v for column %v does not map to keyspace ids", mismatchVindexKeys, colVindex.Columns) } } - - return nil -} - -// InsertVarName returns a name for the bind var for this column. This method is used by the planner and engine, -// to make sure they both produce the same names -func InsertVarName(col sqlparser.IdentifierCI, rowNum int) string { - return fmt.Sprintf("_%s_%d", col.CompliantName(), rowNum) -} - -func insertVarOffset(rowNum, colOffset int) string { - return fmt.Sprintf("_c%d_%d", rowNum, colOffset) + return vindexRowsValues, nil } func (ins *Insert) description() PrimitiveDescription { - other := map[string]any{ - "Query": ins.Query, - "TableName": ins.GetTableName(), - "MultiShardAutocommit": ins.MultiShardAutocommit, - "QueryTimeout": ins.QueryTimeout, - "InsertIgnore": ins.Ignore, - "InputAsNonStreaming": ins.ForceNonStreaming, - "NoAutoCommit": ins.PreventAutoCommit, - } + other := ins.commonDesc() + other["Query"] = ins.Query + other["TableName"] = ins.GetTableName() if len(ins.VindexValues) > 0 { valuesOffsets := map[string]string{} @@ -984,35 +359,6 @@ func (ins *Insert) description() PrimitiveDescription { other["VindexValues"] = valuesOffsets } - if ins.Generate != nil { - if ins.Generate.Values == nil { - other["AutoIncrement"] = fmt.Sprintf("%s:Offset(%d)", ins.Generate.Query, ins.Generate.Offset) - } else { - other["AutoIncrement"] = fmt.Sprintf("%s:Values::%s", ins.Generate.Query, sqlparser.String(ins.Generate.Values)) - } - } - - if len(ins.VindexValueOffset) > 0 { - valuesOffsets := map[string]string{} - for idx, ints := range ins.VindexValueOffset { - if len(ins.ColVindexes) < idx { - panic("ins.ColVindexes and ins.VindexValueOffset do not line up") - } - vindex := ins.ColVindexes[idx] - marshal, _ := json.Marshal(ints) - valuesOffsets[vindex.Name] = string(marshal) - } - other["VindexOffsetFromSelect"] = valuesOffsets - } - if len(ins.Mid) > 0 { - mids := slice.Map(ins.Mid, func(from sqlparser.ValTuple) string { - return sqlparser.String(from) - }) - shardQuery := fmt.Sprintf("%s%s%s", ins.Prefix, strings.Join(mids, ", "), ins.Suffix) - if shardQuery != ins.Query { - other["ShardedQuery"] = shardQuery - } - } return PrimitiveDescription{ OperatorType: "Insert", Keyspace: ins.Keyspace, @@ -1022,41 +368,8 @@ func (ins *Insert) description() PrimitiveDescription { } } -func (ins *Insert) insertIntoUnshardedTable(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, result *sqltypes.Result) (int64, *sqltypes.Result, error) { - query := ins.getInsertQueryForUnsharded(result, bindVars) - return ins.executeUnshardedTableQuery(ctx, vcursor, bindVars, query) -} - -func (ins *Insert) executeUnshardedTableQuery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, query string) (int64, *sqltypes.Result, error) { - insertID, err := ins.processGenerateFromValues(ctx, vcursor, bindVars) - if err != nil { - return 0, nil, err - } - - rss, _, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) - if err != nil { - return 0, nil, err - } - if len(rss) != 1 { - return 0, nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "Keyspace does not have exactly one shard: %v", rss) - } - err = allowOnlyPrimary(rss...) - if err != nil { - return 0, nil, err - } - qr, err := execShard(ctx, ins, vcursor, query, bindVars, rss[0], true, !ins.PreventAutoCommit /* canAutocommit */) - if err != nil { - return 0, nil, err - } - - // If processGenerateFromValues generated new values, it supersedes - // any ids that MySQL might have generated. If both generated - // values, we don't return an error because this behavior - // is required to support migration. - if insertID != 0 { - qr.InsertID = uint64(insertID) - } else { - insertID = int64(qr.InsertID) - } - return insertID, qr, nil +// InsertVarName returns a name for the bind var for this column. This method is used by the planner and engine, +// to make sure they both produce the same names +func InsertVarName(col sqlparser.IdentifierCI, rowNum int) string { + return fmt.Sprintf("_%s_%d", col.CompliantName(), rowNum) } diff --git a/go/vt/vtgate/engine/insert_common.go b/go/vt/vtgate/engine/insert_common.go new file mode 100644 index 0000000000..d0a14feec2 --- /dev/null +++ b/go/vt/vtgate/engine/insert_common.go @@ -0,0 +1,480 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/key" + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +type ( + InsertCommon struct { + // Opcode is the execution opcode. + Opcode InsertOpcode + + // Keyspace specifies the keyspace to send the query to. + Keyspace *vindexes.Keyspace + + // Ignore is for INSERT IGNORE and INSERT...ON DUPLICATE KEY constructs + // for sharded cases. + Ignore bool + + // TableName is the name of the table on which row will be inserted. + TableName string + + // Option to override the standard behavior and allow a multi-shard insert + // to use single round trip autocommit. + // + // This is a clear violation of the SQL semantics since it means the statement + // is not atomic in the presence of PK conflicts on one shard and not another. + // However, some application use cases would prefer that the statement partially + // succeed in order to get the performance benefits of autocommit. + MultiShardAutocommit bool + + // QueryTimeout contains the optional timeout (in milliseconds) to apply to this query + QueryTimeout int + + // ForceNonStreaming is true when the insert table and select table are same. + // This will avoid locking by the select table. + ForceNonStreaming bool + + PreventAutoCommit bool + + // Generate is only set for inserts where a sequence must be generated. + Generate *Generate + + // ColVindexes are the vindexes that will use the VindexValues + ColVindexes []*vindexes.ColumnVindex + + // Prefix, Suffix are for sharded insert plans. + Prefix string + Suffix string + + // Insert needs tx handling + txNeeded + } + + ksID = []byte + + // Generate represents the instruction to generate + // a value from a sequence. + Generate struct { + Keyspace *vindexes.Keyspace + Query string + // Values are the supplied values for the column, which + // will be stored as a list within the expression. New + // values will be generated based on how many were not + // supplied (NULL). + Values evalengine.Expr + // Insert using Select, offset for auto increment column + Offset int + } + + // InsertOpcode is a number representing the opcode + // for the Insert primitive. + InsertOpcode int +) + +const nextValBV = "n" + +const ( + // InsertUnsharded is for routing an insert statement + // to an unsharded keyspace. + InsertUnsharded = InsertOpcode(iota) + // InsertSharded is for routing an insert statement + // to individual shards. Requires: A list of Values, one + // for each ColVindex. If the table has an Autoinc column, + // A Generate subplan must be created. + InsertSharded +) + +var insName = map[InsertOpcode]string{ + InsertUnsharded: "InsertUnsharded", + InsertSharded: "InsertSharded", +} + +// String returns the opcode +func (code InsertOpcode) String() string { + return strings.ReplaceAll(insName[code], "Insert", "") +} + +// MarshalJSON serializes the InsertOpcode as a JSON string. +// It's used for testing and diagnostics. +func (code InsertOpcode) MarshalJSON() ([]byte, error) { + return json.Marshal(insName[code]) +} + +// GetKeyspaceName specifies the Keyspace that this primitive routes to. +func (ic *InsertCommon) GetKeyspaceName() string { + return ic.Keyspace.Name +} + +// GetTableName specifies the table that this primitive routes to. +func (ic *InsertCommon) GetTableName() string { + return ic.TableName +} + +// GetFields fetches the field info. +func (ic *InsertCommon) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return nil, vterrors.VT13001("unexpected fields call for insert query") +} + +func (ins *InsertCommon) executeUnshardedTableQuery(ctx context.Context, vcursor VCursor, loggingPrimitive Primitive, bindVars map[string]*querypb.BindVariable, query string, insertID uint64) (*sqltypes.Result, error) { + rss, _, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return nil, err + } + if len(rss) != 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "Keyspace does not have exactly one shard: %v", rss) + } + err = allowOnlyPrimary(rss...) + if err != nil { + return nil, err + } + qr, err := execShard(ctx, loggingPrimitive, vcursor, query, bindVars, rss[0], true, !ins.PreventAutoCommit /* canAutocommit */) + if err != nil { + return nil, err + } + + // If processGenerateFromValues generated new values, it supersedes + // any ids that MySQL might have generated. If both generated + // values, we don't return an error because this behavior + // is required to support migration. + if insertID != 0 { + qr.InsertID = insertID + } + return qr, nil +} + +func (ins *InsertCommon) processVindexes(ctx context.Context, vcursor VCursor, vindexRowsValues [][]sqltypes.Row) ([]ksID, error) { + colVindexes := ins.ColVindexes + keyspaceIDs, err := ins.processPrimary(ctx, vcursor, vindexRowsValues[0], colVindexes[0]) + if err != nil { + return nil, err + } + + for vIdx := 1; vIdx < len(colVindexes); vIdx++ { + colVindex := colVindexes[vIdx] + if colVindex.Owned { + err = ins.processOwned(ctx, vcursor, vindexRowsValues[vIdx], colVindex, keyspaceIDs) + } else { + err = ins.processUnowned(ctx, vcursor, vindexRowsValues[vIdx], colVindex, keyspaceIDs) + } + if err != nil { + return nil, err + } + } + return keyspaceIDs, nil +} + +// processPrimary maps the primary vindex values to the keyspace ids. +func (ic *InsertCommon) processPrimary(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex) ([]ksID, error) { + destinations, err := vindexes.Map(ctx, colVindex.Vindex, vcursor, vindexColumnsKeys) + if err != nil { + return nil, err + } + + keyspaceIDs := make([]ksID, len(destinations)) + for i, destination := range destinations { + switch d := destination.(type) { + case key.DestinationKeyspaceID: + // This is a single keyspace id, we're good. + keyspaceIDs[i] = d + case key.DestinationNone: + // No valid keyspace id, we may return an error. + if !ic.Ignore { + return nil, fmt.Errorf("could not map %v to a keyspace id", vindexColumnsKeys[i]) + } + default: + return nil, fmt.Errorf("could not map %v to a unique keyspace id: %v", vindexColumnsKeys[i], destination) + } + } + + return keyspaceIDs, nil +} + +// processOwned creates vindex entries for the values of an owned column. +func (ic *InsertCommon) processOwned(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex, ksids []ksID) error { + if !ic.Ignore { + return colVindex.Vindex.(vindexes.Lookup).Create(ctx, vcursor, vindexColumnsKeys, ksids, false /* ignoreMode */) + } + + // InsertIgnore + var createIndexes []int + var createKeys []sqltypes.Row + var createKsids []ksID + + for rowNum, rowColumnKeys := range vindexColumnsKeys { + if ksids[rowNum] == nil { + continue + } + createIndexes = append(createIndexes, rowNum) + createKeys = append(createKeys, rowColumnKeys) + createKsids = append(createKsids, ksids[rowNum]) + } + if createKeys == nil { + return nil + } + + err := colVindex.Vindex.(vindexes.Lookup).Create(ctx, vcursor, createKeys, createKsids, true) + if err != nil { + return err + } + // After creation, verify that the keys map to the keyspace ids. If not, remove + // those that don't map. + verified, err := vindexes.Verify(ctx, colVindex.Vindex, vcursor, createKeys, createKsids) + if err != nil { + return err + } + for i, v := range verified { + if !v { + ksids[createIndexes[i]] = nil + } + } + return nil +} + +// processUnowned either reverse maps or validates the values for an unowned column. +func (ic *InsertCommon) processUnowned(ctx context.Context, vcursor VCursor, vindexColumnsKeys []sqltypes.Row, colVindex *vindexes.ColumnVindex, ksids []ksID) error { + var reverseIndexes []int + var reverseKsids []ksID + + var verifyIndexes []int + var verifyKeys []sqltypes.Row + var verifyKsids []ksID + + // Check if this VIndex is reversible or not. + reversibleVindex, isReversible := colVindex.Vindex.(vindexes.Reversible) + + for rowNum, rowColumnKeys := range vindexColumnsKeys { + // If we weren't able to determine a keyspace id from the primary VIndex, skip this row + if ksids[rowNum] == nil { + continue + } + + if rowColumnKeys[0].IsNull() { + // If the value of the column is `NULL`, but this is a reversible VIndex, + // we will try to generate the value from the keyspace id generated by the primary VIndex. + if isReversible { + reverseIndexes = append(reverseIndexes, rowNum) + reverseKsids = append(reverseKsids, ksids[rowNum]) + } + + // Otherwise, don't do anything. Whether `NULL` is a valid value for this column will be + // handled by MySQL. + } else { + // If a value for this column was specified, the keyspace id values from the + // secondary VIndex need to be verified against the keyspace id from the primary VIndex + verifyIndexes = append(verifyIndexes, rowNum) + verifyKeys = append(verifyKeys, rowColumnKeys) + verifyKsids = append(verifyKsids, ksids[rowNum]) + } + } + + // Reverse map values for secondary VIndex columns from the primary VIndex's keyspace id. + if reverseKsids != nil { + reverseKeys, err := reversibleVindex.ReverseMap(vcursor, reverseKsids) + if err != nil { + return err + } + + for i, reverseKey := range reverseKeys { + // Fill the first column with the reverse-mapped value. + vindexColumnsKeys[reverseIndexes[i]][0] = reverseKey + } + } + + // Verify that the keyspace ids generated by the primary and secondary VIndexes match + if verifyIndexes != nil { + // If values were supplied, we validate against keyspace id. + verified, err := vindexes.Verify(ctx, colVindex.Vindex, vcursor, verifyKeys, verifyKsids) + if err != nil { + return err + } + + var mismatchVindexKeys []sqltypes.Row + for i, v := range verified { + rowNum := verifyIndexes[i] + if !v { + if !ic.Ignore { + mismatchVindexKeys = append(mismatchVindexKeys, vindexColumnsKeys[rowNum]) + continue + } + + // Skip the whole row if this is a `INSERT IGNORE` or `INSERT ... ON DUPLICATE KEY ...` statement + // but the keyspace ids didn't match. + ksids[verifyIndexes[i]] = nil + } + } + + if mismatchVindexKeys != nil { + return fmt.Errorf("values %v for column %v does not map to keyspace ids", mismatchVindexKeys, colVindex.Columns) + } + } + + return nil +} + +// processGenerateFromSelect generates new values using a sequence if necessary. +// If no value was generated, it returns 0. Values are generated only +// for cases where none are supplied. +func (ic *InsertCommon) processGenerateFromSelect( + ctx context.Context, + vcursor VCursor, + loggingPrimitive Primitive, + rows []sqltypes.Row, +) (insertID int64, err error) { + if ic.Generate == nil { + return 0, nil + } + var count int64 + offset := ic.Generate.Offset + genColPresent := offset < len(rows[0]) + if genColPresent { + for _, row := range rows { + if shouldGenerate(row[offset], evalengine.ParseSQLMode(vcursor.SQLMode())) { + count++ + } + } + } else { + count = int64(len(rows)) + } + + if count == 0 { + return 0, nil + } + + insertID, err = ic.execGenerate(ctx, vcursor, loggingPrimitive, count) + if err != nil { + return 0, err + } + + used := insertID + for idx, val := range rows { + if genColPresent { + if shouldGenerate(val[offset], evalengine.ParseSQLMode(vcursor.SQLMode())) { + val[offset] = sqltypes.NewInt64(used) + used++ + } + } else { + rows[idx] = append(val, sqltypes.NewInt64(used)) + used++ + } + } + + return insertID, nil +} + +// processGenerateFromValues generates new values using a sequence if necessary. +// If no value was generated, it returns 0. Values are generated only +// for cases where none are supplied. +func (ic *InsertCommon) processGenerateFromValues( + ctx context.Context, + vcursor VCursor, + loggingPrimitive Primitive, + bindVars map[string]*querypb.BindVariable, +) (insertID int64, err error) { + if ic.Generate == nil { + return 0, nil + } + + // Scan input values to compute the number of values to generate, and + // keep track of where they should be filled. + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) + resolved, err := env.Evaluate(ic.Generate.Values) + if err != nil { + return 0, err + } + count := int64(0) + values := resolved.TupleValues() + for _, val := range values { + if shouldGenerate(val, evalengine.ParseSQLMode(vcursor.SQLMode())) { + count++ + } + } + + // If generation is needed, generate the requested number of values (as one call). + if count != 0 { + insertID, err = ic.execGenerate(ctx, vcursor, loggingPrimitive, count) + if err != nil { + return 0, err + } + } + + // Fill the holes where no value was supplied. + cur := insertID + for i, v := range values { + if shouldGenerate(v, evalengine.ParseSQLMode(vcursor.SQLMode())) { + bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.Int64BindVariable(cur) + cur++ + } else { + bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.ValueBindVariable(v) + } + } + return insertID, nil +} + +func (ic *InsertCommon) execGenerate(ctx context.Context, vcursor VCursor, loggingPrimitive Primitive, count int64) (int64, error) { + // If generation is needed, generate the requested number of values (as one call). + rss, _, err := vcursor.ResolveDestinations(ctx, ic.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) + if err != nil { + return 0, err + } + if len(rss) != 1 { + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto sequence generation can happen through single shard only, it is getting routed to %d shards", len(rss)) + } + bindVars := map[string]*querypb.BindVariable{nextValBV: sqltypes.Int64BindVariable(count)} + qr, err := vcursor.ExecuteStandalone(ctx, loggingPrimitive, ic.Generate.Query, bindVars, rss[0]) + if err != nil { + return 0, err + } + // If no rows are returned, it's an internal error, and the code + // must panic, which will be caught and reported. + return qr.Rows[0][0].ToCastInt64() +} + +// shouldGenerate determines if a sequence value should be generated for a given value +func shouldGenerate(v sqltypes.Value, sqlmode evalengine.SQLMode) bool { + if v.IsNull() { + return true + } + + // Unless the NO_AUTO_VALUE_ON_ZERO sql mode is active in mysql, it also + // treats 0 as a value that should generate a new sequence. + value, err := evalengine.CoerceTo(v, sqltypes.Uint64, sqlmode) + if err != nil { + return false + } + + id, err := value.ToCastUint64() + if err != nil { + return false + } + + return id == 0 +} diff --git a/go/vt/vtgate/engine/insert_select.go b/go/vt/vtgate/engine/insert_select.go new file mode 100644 index 0000000000..d36176922d --- /dev/null +++ b/go/vt/vtgate/engine/insert_select.go @@ -0,0 +1,423 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "sync" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/key" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +var _ Primitive = (*InsertSelect)(nil) + +type ( + // InsertSelect represents the instructions to perform an insert operation with input rows from a select. + InsertSelect struct { + InsertCommon + + // Input is a select query plan to retrieve results for inserting data. + Input Primitive + + // VindexValueOffset stores the offset for each column in the ColumnVindex + // that will appear in the result set of the select query. + VindexValueOffset [][]int + } +) + +// newInsertSelect creates a new InsertSelect. +func newInsertSelect( + ignore bool, + keyspace *vindexes.Keyspace, + table *vindexes.Table, + prefix string, + suffix string, + vv [][]int, + input Primitive, +) *InsertSelect { + ins := &InsertSelect{ + InsertCommon: InsertCommon{ + Ignore: ignore, + Keyspace: keyspace, + Prefix: prefix, + Suffix: suffix, + }, + Input: input, + VindexValueOffset: vv, + } + if table != nil { + ins.TableName = table.Name.String() + for _, colVindex := range table.ColumnVindexes { + if colVindex.IsPartialVindex() { + continue + } + ins.ColVindexes = append(ins.ColVindexes, colVindex) + } + } + return ins +} + +func (ins *InsertSelect) Inputs() ([]Primitive, []map[string]any) { + return []Primitive{ins.Input}, nil +} + +// RouteType returns a description of the query routing type used by the primitive +func (ins *InsertSelect) RouteType() string { + return "InsertSelect" +} + +// TryExecute performs a non-streaming exec. +func (ins *InsertSelect) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { + ctx, cancelFunc := addQueryTimeout(ctx, vcursor, ins.QueryTimeout) + defer cancelFunc() + + if ins.Keyspace.Sharded { + return ins.execInsertSharded(ctx, vcursor, bindVars) + } + return ins.execInsertUnsharded(ctx, vcursor, bindVars) +} + +// TryStreamExecute performs a streaming exec. +func (ins *InsertSelect) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + if ins.ForceNonStreaming { + res, err := ins.TryExecute(ctx, vcursor, bindVars, wantfields) + if err != nil { + return err + } + return callback(res) + } + ctx, cancelFunc := addQueryTimeout(ctx, vcursor, ins.QueryTimeout) + defer cancelFunc() + + sharded := ins.Keyspace.Sharded + output := &sqltypes.Result{} + err := ins.execSelectStreaming(ctx, vcursor, bindVars, func(irr insertRowsResult) error { + if len(irr.rows) == 0 { + return nil + } + + var qr *sqltypes.Result + var err error + if sharded { + qr, err = ins.insertIntoShardedTable(ctx, vcursor, bindVars, irr) + } else { + qr, err = ins.insertIntoUnshardedTable(ctx, vcursor, bindVars, irr) + } + if err != nil { + return err + } + + output.RowsAffected += qr.RowsAffected + // InsertID needs to be updated to the least insertID value in sqltypes.Result + if output.InsertID == 0 || output.InsertID > qr.InsertID { + output.InsertID = qr.InsertID + } + return nil + }) + if err != nil { + return err + } + return callback(output) +} + +func (ins *InsertSelect) execInsertUnsharded(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + irr, err := ins.execSelect(ctx, vcursor, bindVars) + if err != nil { + return nil, err + } + if len(irr.rows) == 0 { + return &sqltypes.Result{}, nil + } + return ins.insertIntoUnshardedTable(ctx, vcursor, bindVars, irr) +} + +func (ins *InsertSelect) insertIntoUnshardedTable(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, irr insertRowsResult) (*sqltypes.Result, error) { + query := ins.getInsertUnshardedQuery(irr.rows, bindVars) + return ins.executeUnshardedTableQuery(ctx, vcursor, ins, bindVars, query, irr.insertID) +} + +func (ins *InsertSelect) getInsertUnshardedQuery(rows []sqltypes.Row, bindVars map[string]*querypb.BindVariable) string { + var mids sqlparser.Values + for r, inputRow := range rows { + row := sqlparser.ValTuple{} + for c, value := range inputRow { + bvName := insertVarOffset(r, c) + bindVars[bvName] = sqltypes.ValueBindVariable(value) + row = append(row, sqlparser.NewArgument(bvName)) + } + mids = append(mids, row) + } + return ins.Prefix + sqlparser.String(mids) + ins.Suffix +} + +func (ins *InsertSelect) insertIntoShardedTable( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + irr insertRowsResult, +) (*sqltypes.Result, error) { + rss, queries, err := ins.getInsertShardedQueries(ctx, vcursor, bindVars, irr.rows) + if err != nil { + return nil, err + } + + qr, err := ins.executeInsertQueries(ctx, vcursor, rss, queries, irr.insertID) + if err != nil { + return nil, err + } + qr.InsertID = uint64(irr.insertID) + return qr, nil +} + +func (ins *InsertSelect) executeInsertQueries( + ctx context.Context, + vcursor VCursor, + rss []*srvtopo.ResolvedShard, + queries []*querypb.BoundQuery, + insertID uint64, +) (*sqltypes.Result, error) { + autocommit := (len(rss) == 1 || ins.MultiShardAutocommit) && vcursor.AutocommitApproval() + err := allowOnlyPrimary(rss...) + if err != nil { + return nil, err + } + result, errs := vcursor.ExecuteMultiShard(ctx, ins, rss, queries, true /* rollbackOnError */, autocommit) + if errs != nil { + return nil, vterrors.Aggregate(errs) + } + + if insertID != 0 { + result.InsertID = insertID + } + return result, nil +} + +func (ins *InsertSelect) getInsertShardedQueries( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + rows []sqltypes.Row, +) ([]*srvtopo.ResolvedShard, []*querypb.BoundQuery, error) { + vindexRowsValues, err := ins.buildVindexRowsValues(rows) + if err != nil { + return nil, nil, err + } + + keyspaceIDs, err := ins.processVindexes(ctx, vcursor, vindexRowsValues) + if err != nil { + return nil, nil, err + } + + var indexes []*querypb.Value + var destinations []key.Destination + for i, ksid := range keyspaceIDs { + if ksid != nil { + indexes = append(indexes, &querypb.Value{ + Value: strconv.AppendInt(nil, int64(i), 10), + }) + destinations = append(destinations, key.DestinationKeyspaceID(ksid)) + } + } + if len(destinations) == 0 { + // In this case, all we have is nil KeyspaceIds, we don't do + // anything at all. + return nil, nil, nil + } + + rss, indexesPerRss, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, indexes, destinations) + if err != nil { + return nil, nil, err + } + + queries := make([]*querypb.BoundQuery, len(rss)) + for i := range rss { + bvs := sqltypes.CopyBindVariables(bindVars) // we don't want to create one huge bindvars for all values + var mids sqlparser.Values + for _, indexValue := range indexesPerRss[i] { + index, _ := strconv.Atoi(string(indexValue.Value)) + if keyspaceIDs[index] != nil { + row := sqlparser.ValTuple{} + for colOffset, value := range rows[index] { + bvName := insertVarOffset(index, colOffset) + bvs[bvName] = sqltypes.ValueBindVariable(value) + row = append(row, sqlparser.NewArgument(bvName)) + } + mids = append(mids, row) + } + } + rewritten := ins.Prefix + sqlparser.String(mids) + ins.Suffix + queries[i] = &querypb.BoundQuery{ + Sql: rewritten, + BindVariables: bvs, + } + } + + return rss, queries, nil +} + +func (ins *InsertSelect) buildVindexRowsValues(rows []sqltypes.Row) ([][]sqltypes.Row, error) { + colVindexes := ins.ColVindexes + if len(colVindexes) != len(ins.VindexValueOffset) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "vindex value offsets and vindex info do not match") + } + + // Here we go over the incoming rows and extract values for the vindexes we need to update + vindexRowsValues := make([][]sqltypes.Row, len(colVindexes)) + for _, inputRow := range rows { + for colIdx := range colVindexes { + offsets := ins.VindexValueOffset[colIdx] + row := make(sqltypes.Row, 0, len(offsets)) + for _, offset := range offsets { + if offset == -1 { // value not provided from select query + row = append(row, sqltypes.NULL) + continue + } + row = append(row, inputRow[offset]) + } + vindexRowsValues[colIdx] = append(vindexRowsValues[colIdx], row) + } + } + return vindexRowsValues, nil +} + +func (ins *InsertSelect) execInsertSharded(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + result, err := ins.execSelect(ctx, vcursor, bindVars) + if err != nil { + return nil, err + } + if len(result.rows) == 0 { + return &sqltypes.Result{}, nil + } + + return ins.insertIntoShardedTable(ctx, vcursor, bindVars, result) +} + +func (ins *InsertSelect) description() PrimitiveDescription { + other := ins.commonDesc() + other["TableName"] = ins.GetTableName() + + if len(ins.VindexValueOffset) > 0 { + valuesOffsets := map[string]string{} + for idx, ints := range ins.VindexValueOffset { + if len(ins.ColVindexes) < idx { + panic("ins.ColVindexes and ins.VindexValueOffset do not line up") + } + vindex := ins.ColVindexes[idx] + marshal, _ := json.Marshal(ints) + valuesOffsets[vindex.Name] = string(marshal) + } + other["VindexOffsetFromSelect"] = valuesOffsets + } + + return PrimitiveDescription{ + OperatorType: "Insert", + Keyspace: ins.Keyspace, + Variant: "Select", + TargetTabletType: topodatapb.TabletType_PRIMARY, + Other: other, + } +} + +func (ic *InsertCommon) commonDesc() map[string]any { + other := map[string]any{ + "MultiShardAutocommit": ic.MultiShardAutocommit, + "QueryTimeout": ic.QueryTimeout, + "InsertIgnore": ic.Ignore, + "InputAsNonStreaming": ic.ForceNonStreaming, + "NoAutoCommit": ic.PreventAutoCommit, + } + + if ic.Generate != nil { + if ic.Generate.Values == nil { + other["AutoIncrement"] = fmt.Sprintf("%s:Offset(%d)", ic.Generate.Query, ic.Generate.Offset) + } else { + other["AutoIncrement"] = fmt.Sprintf("%s:Values::%s", ic.Generate.Query, sqlparser.String(ic.Generate.Values)) + } + } + return other +} + +func insertVarOffset(rowNum, colOffset int) string { + return fmt.Sprintf("_c%d_%d", rowNum, colOffset) +} + +type insertRowsResult struct { + rows []sqltypes.Row + insertID uint64 +} + +func (ins *InsertSelect) execSelect( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, +) (insertRowsResult, error) { + res, err := vcursor.ExecutePrimitive(ctx, ins.Input, bindVars, false) + if err != nil || len(res.Rows) == 0 { + return insertRowsResult{}, err + } + + insertID, err := ins.processGenerateFromSelect(ctx, vcursor, ins, res.Rows) + if err != nil { + return insertRowsResult{}, err + } + + return insertRowsResult{ + rows: res.Rows, + insertID: uint64(insertID), + }, nil +} + +func (ins *InsertSelect) execSelectStreaming( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + callback func(irr insertRowsResult) error, +) error { + var mu sync.Mutex + return vcursor.StreamExecutePrimitiveStandalone(ctx, ins.Input, bindVars, false, func(result *sqltypes.Result) error { + if len(result.Rows) == 0 { + return nil + } + + // should process only one chunk at a time. + // as parallel chunk insert will try to use the same transaction in the vttablet + // this will cause transaction in use error out with "transaction in use" error. + mu.Lock() + defer mu.Unlock() + + insertID, err := ins.processGenerateFromSelect(ctx, vcursor, ins, result.Rows) + if err != nil { + return err + } + + return callback(insertRowsResult{ + rows: result.Rows, + insertID: uint64(insertID), + }) + }) +} diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index d08ef85627..4ee8431f08 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -33,7 +33,7 @@ import ( ) func TestInsertUnsharded(t *testing.T) { - ins := NewQueryInsert( + ins := newQueryInsert( InsertUnsharded, &vindexes.Keyspace{ Name: "ks", @@ -68,7 +68,7 @@ func TestInsertUnsharded(t *testing.T) { } func TestInsertUnshardedGenerate(t *testing.T) { - ins := NewQueryInsert( + ins := newQueryInsert( InsertUnsharded, &vindexes.Keyspace{ Name: "ks", @@ -121,7 +121,7 @@ func TestInsertUnshardedGenerate(t *testing.T) { } func TestInsertUnshardedGenerate_Zeros(t *testing.T) { - ins := NewQueryInsert( + ins := newQueryInsert( InsertUnsharded, &vindexes.Keyspace{ Name: "ks", @@ -198,7 +198,7 @@ func TestInsertShardedSimple(t *testing.T) { ks := vs.Keyspaces["sharded"] // A single row insert should be autocommitted - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -232,7 +232,7 @@ func TestInsertShardedSimple(t *testing.T) { }) // Multiple rows are not autocommitted by default - ins = NewInsert( + ins = newInsert( InsertSharded, false, ks.Keyspace, @@ -272,7 +272,7 @@ func TestInsertShardedSimple(t *testing.T) { }) // Optional flag overrides autocommit - ins = NewInsert( + ins = newInsert( InsertSharded, false, ks.Keyspace, @@ -344,7 +344,7 @@ func TestInsertShardedFail(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -394,7 +394,7 @@ func TestInsertShardedGenerate(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -513,7 +513,7 @@ func TestInsertShardedOwned(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -623,7 +623,7 @@ func TestInsertShardedOwnedWithNull(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -700,7 +700,7 @@ func TestInsertShardedGeo(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -806,7 +806,7 @@ func TestInsertShardedIgnoreOwned(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, true, ks.Keyspace, @@ -962,7 +962,7 @@ func TestInsertShardedIgnoreOwnedWithNull(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, true, ks.Keyspace, @@ -1060,7 +1060,7 @@ func TestInsertShardedUnownedVerify(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -1188,7 +1188,7 @@ func TestInsertShardedIgnoreUnownedVerify(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, true, ks.Keyspace, @@ -1294,7 +1294,7 @@ func TestInsertShardedIgnoreUnownedVerifyFail(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -1371,7 +1371,7 @@ func TestInsertShardedUnownedReverseMap(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -1485,7 +1485,7 @@ func TestInsertShardedUnownedReverseMapSuccess(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( + ins := newInsert( InsertSharded, false, ks.Keyspace, @@ -1533,21 +1533,13 @@ func TestInsertSelectSimple(t *testing.T) { ks := vs.Keyspaces["sharded"] // A single row insert should be autocommitted - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{{1}}, - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - - ins.ColVindexes = append(ins.ColVindexes, ks.Tables["t1"].ColumnVindexes...) - ins.Prefix = "prefix " - ins.Suffix = " suffix" + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + ins := newInsertSelect(false, ks.Keyspace, ks.Tables["t1"], "prefix ", " suffix", [][]int{{1}}, rb) vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -1623,23 +1615,24 @@ func TestInsertSelectOwned(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{ + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t1"], + "prefix ", + " suffix", + [][]int{ {1}, // The primary vindex has a single column as sharding key {0}}, // the onecol vindex uses the 'name' column - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - - ins.ColVindexes = append(ins.ColVindexes, ks.Tables["t1"].ColumnVindexes...) - ins.Prefix = "prefix " - ins.Suffix = " suffix" + rb, + ) vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -1723,24 +1716,22 @@ func TestInsertSelectGenerate(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := NewInsert( - InsertSelect, - false, - ks.Keyspace, - nil, - ks.Tables["t1"], - "prefix ", - nil, - " suffix") - ins.Query = "dummy_insert" - ins.VindexValueOffset = [][]int{{1}} // The primary vindex has a single column as sharding key - ins.Input = &Route{ + rb := &Route{ Query: "dummy_select", FieldQuery: "dummy_field_query", RoutingParameters: &RoutingParameters{ Opcode: Scatter, Keyspace: ks.Keyspace}} + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t1"], + "prefix ", + " suffix", + [][]int{{1}}, // The primary vindex has a single column as sharding key + rb, + ) ins.Generate = &Generate{ Keyspace: &vindexes.Keyspace{ Name: "ks2", @@ -1760,7 +1751,7 @@ func TestInsertSelectGenerate(t *testing.T) { "varchar|int64"), "a|1", "a|null", - "b|null"), + "b|0"), // This is the result for the sequence query sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -1817,20 +1808,23 @@ func TestStreamingInsertSelectGenerate(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{ - {1}}, // The primary vindex has a single column as sharding key - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - ins.ColVindexes = ks.Tables["t1"].ColumnVindexes + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t1"], + "prefix ", + " suffix", + [][]int{ + {1}}, // The primary vindex has a single column as sharding key + rb, + ) ins.Generate = &Generate{ Keyspace: &vindexes.Keyspace{ Name: "ks2", @@ -1839,8 +1833,6 @@ func TestStreamingInsertSelectGenerate(t *testing.T) { Query: "dummy_generate", Offset: 1, } - ins.Prefix = "prefix " - ins.Suffix = " suffix" vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -1913,20 +1905,21 @@ func TestInsertSelectGenerateNotProvided(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{ - {1}}, // The primary vindex has a single column as sharding key - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - - ins.ColVindexes = ks.Tables["t1"].ColumnVindexes + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t1"], + "prefix ", + " suffix", + [][]int{{1}}, // The primary vindex has a single column as sharding key, + rb, + ) ins.Generate = &Generate{ Keyspace: &vindexes.Keyspace{ Name: "ks2", @@ -1935,8 +1928,6 @@ func TestInsertSelectGenerateNotProvided(t *testing.T) { Query: "dummy_generate", Offset: 2, } - ins.Prefix = "prefix " - ins.Suffix = " suffix" vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -2001,20 +1992,21 @@ func TestStreamingInsertSelectGenerateNotProvided(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{ - {1}}, // The primary vindex has a single column as sharding key - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - - ins.ColVindexes = ks.Tables["t1"].ColumnVindexes + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t1"], + "prefix ", + " suffix", + [][]int{{1}}, // The primary vindex has a single column as sharding key, + rb, + ) ins.Generate = &Generate{ Keyspace: &vindexes.Keyspace{ Name: "ks2", @@ -2023,8 +2015,6 @@ func TestStreamingInsertSelectGenerateNotProvided(t *testing.T) { Query: "dummy_generate", Offset: 2, } - ins.Prefix = "prefix " - ins.Suffix = " suffix" vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -2099,22 +2089,21 @@ func TestInsertSelectUnowned(t *testing.T) { vs := vindexes.BuildVSchema(invschema) ks := vs.Keyspaces["sharded"] - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: ks.Keyspace, - Query: "dummy_insert", - VindexValueOffset: [][]int{ - {0}}, // the onecol vindex as unowned lookup sharding column - Input: &Route{ - Query: "dummy_select", - FieldQuery: "dummy_field_query", - RoutingParameters: &RoutingParameters{ - Opcode: Scatter, - Keyspace: ks.Keyspace}}} - - ins.ColVindexes = append(ins.ColVindexes, ks.Tables["t2"].ColumnVindexes...) - ins.Prefix = "prefix " - ins.Suffix = " suffix" + rb := &Route{ + Query: "dummy_select", + FieldQuery: "dummy_field_query", + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: ks.Keyspace}} + ins := newInsertSelect( + false, + ks.Keyspace, + ks.Tables["t2"], + "prefix ", + " suffix", + [][]int{{0}}, // // the onecol vindex as unowned lookup sharding column + rb, + ) vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} @@ -2220,16 +2209,15 @@ func TestInsertSelectShardingCases(t *testing.T) { RoutingParameters: &RoutingParameters{Opcode: Unsharded, Keyspace: uks2.Keyspace}} // sks1 and sks2 - ins := &Insert{ - Opcode: InsertSelect, - Keyspace: sks1.Keyspace, - Query: "dummy_insert", - Prefix: "prefix ", - Suffix: " suffix", - ColVindexes: sks1.Tables["s1"].ColumnVindexes, - VindexValueOffset: [][]int{{0}}, - Input: sRoute, - } + ins := newInsertSelect( + false, + sks1.Keyspace, + sks1.Tables["s1"], + "prefix ", + " suffix", + [][]int{{0}}, + sRoute, + ) vc := &loggingVCursor{ resolvedTargetTabletType: topodatapb.TabletType_PRIMARY, @@ -2298,14 +2286,15 @@ func TestInsertSelectShardingCases(t *testing.T) { `ExecuteMultiShard sks1.-20: prefix values (:_c0_0) suffix {_c0_0: type:INT64 value:"1"} true true`}) // uks1 and sks2 - ins = &Insert{ - Opcode: InsertUnsharded, - Keyspace: uks1.Keyspace, - Query: "dummy_insert", - Prefix: "prefix ", - Suffix: " suffix", - Input: sRoute, - } + ins = newInsertSelect( + false, + uks1.Keyspace, + nil, + "prefix ", + " suffix", + nil, + sRoute, + ) vc.Rewind() _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) diff --git a/go/vt/vtgate/planbuilder/insert.go b/go/vt/vtgate/planbuilder/insert.go index 39144fc858..173c021307 100644 --- a/go/vt/vtgate/planbuilder/insert.go +++ b/go/vt/vtgate/planbuilder/insert.go @@ -22,7 +22,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -91,28 +90,30 @@ func errOutIfPlanCannotBeConstructed(ctx *plancontext.PlanningContext, vTbl *vin } func insertUnshardedShortcut(stmt *sqlparser.Insert, ks *vindexes.Keyspace, tables []*vindexes.Table) logicalPlan { - eIns := &engine.Insert{} - eIns.Keyspace = ks - eIns.TableName = tables[0].Name.String() - eIns.Opcode = engine.InsertUnsharded + eIns := &engine.Insert{ + InsertCommon: engine.InsertCommon{ + Opcode: engine.InsertUnsharded, + Keyspace: ks, + TableName: tables[0].Name.String(), + }, + } eIns.Query = generateQuery(stmt) return &insert{eInsert: eIns} } type insert struct { - eInsert *engine.Insert - source logicalPlan + eInsert *engine.Insert + eInsertSelect *engine.InsertSelect + source logicalPlan } var _ logicalPlan = (*insert)(nil) func (i *insert) Primitive() engine.Primitive { - if i.source != nil { - i.eInsert.Input = i.source.Primitive() + if i.source == nil { + return i.eInsert } - return i.eInsert -} - -func (i *insert) ContainsTables() semantics.TableSet { - panic("does not expect insert to get contains tables call") + input := i.source.Primitive() + i.eInsertSelect.Input = input + return i.eInsertSelect } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index d6732ada87..5f965b55ad 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -110,20 +110,20 @@ func transformInsertionSelection(ctx *plancontext.PlanningContext, op *operators } ins := dmlOp.(*operators.Insert) - eins := &engine.Insert{ - Opcode: mapToInsertOpCode(rb.Routing.OpCode(), true), - Keyspace: rb.Routing.Keyspace(), - TableName: ins.VTable.Name.String(), - Ignore: ins.Ignore, - ForceNonStreaming: op.ForceNonStreaming, - Generate: autoIncGenerate(ins.AutoIncrement), - ColVindexes: ins.ColVindexes, - VindexValues: ins.VindexValues, + eins := &engine.InsertSelect{ + InsertCommon: engine.InsertCommon{ + Keyspace: rb.Routing.Keyspace(), + TableName: ins.VTable.Name.String(), + Ignore: ins.Ignore, + ForceNonStreaming: op.ForceNonStreaming, + Generate: autoIncGenerate(ins.AutoIncrement), + ColVindexes: ins.ColVindexes, + }, VindexValueOffset: ins.VindexValueOffset, } - lp := &insert{eInsert: eins} + lp := &insert{eInsertSelect: eins} - eins.Prefix, eins.Mid, eins.Suffix = generateInsertShardedQuery(ins.AST) + eins.Prefix, _, eins.Suffix = generateInsertShardedQuery(ins.AST) selectionPlan, err := transformToLogicalPlan(ctx, op.Select) if err != nil { @@ -548,15 +548,23 @@ func buildInsertLogicalPlan( hints *queryHints, ) (logicalPlan, error) { ins := op.(*operators.Insert) + + ic := engine.InsertCommon{ + Opcode: mapToInsertOpCode(rb.Routing.OpCode()), + Keyspace: rb.Routing.Keyspace(), + TableName: ins.VTable.Name.String(), + Ignore: ins.Ignore, + Generate: autoIncGenerate(ins.AutoIncrement), + ColVindexes: ins.ColVindexes, + } + if hints != nil { + ic.MultiShardAutocommit = hints.multiShardAutocommit + ic.QueryTimeout = hints.queryTimeout + } + eins := &engine.Insert{ - Opcode: mapToInsertOpCode(rb.Routing.OpCode(), false), - Keyspace: rb.Routing.Keyspace(), - TableName: ins.VTable.Name.String(), - Ignore: ins.Ignore, - Generate: autoIncGenerate(ins.AutoIncrement), - ColVindexes: ins.ColVindexes, - VindexValues: ins.VindexValues, - VindexValueOffset: ins.VindexValueOffset, + InsertCommon: ic, + VindexValues: ins.VindexValues, } lp := &insert{eInsert: eins} @@ -566,22 +574,14 @@ func buildInsertLogicalPlan( eins.Prefix, eins.Mid, eins.Suffix = generateInsertShardedQuery(ins.AST) } - if hints != nil { - eins.MultiShardAutocommit = hints.multiShardAutocommit - eins.QueryTimeout = hints.queryTimeout - } - eins.Query = generateQuery(stmt) return lp, nil } -func mapToInsertOpCode(code engine.Opcode, insertSelect bool) engine.InsertOpcode { +func mapToInsertOpCode(code engine.Opcode) engine.InsertOpcode { if code == engine.Unsharded { return engine.InsertUnsharded } - if insertSelect { - return engine.InsertSelect - } return engine.InsertSharded } diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 5ec8210b12..ec39db0b15 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -4084,7 +4084,7 @@ "Original": "insert into unsharded(col) select col from unsharded_tab", "Instructions": { "OperatorType": "Insert", - "Variant": "Unsharded", + "Variant": "Select", "Keyspace": { "Name": "main", "Sharded": false @@ -4119,7 +4119,7 @@ "Original": "insert into unsharded(col) select col from t1", "Instructions": { "OperatorType": "Insert", - "Variant": "Unsharded", + "Variant": "Select", "Keyspace": { "Name": "main", "Sharded": false