Skip to content

Commit

Permalink
Add foreign key support for insert on duplicate key update (#14638)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored Dec 12, 2023
1 parent 548c7d8 commit c680c16
Show file tree
Hide file tree
Showing 9 changed files with 1,161 additions and 60 deletions.
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,37 @@ func TestReplaceWithFK(t *testing.T) {
utils.AssertMatches(t, conn, `select * from u_t2`, `[[INT64(1) NULL] [INT64(2) NULL]]`)
}

// TestInsertWithFKOnDup tests that insertion with on duplicate key update works as expected.
func TestInsertWithFKOnDup(t *testing.T) {
mcmp, closer := start(t)
defer closer()

utils.Exec(t, mcmp.VtConn, "use `uks`")

// insert some data.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 1), (200, 2), (300, 3), (400, 4)`)
mcmp.Exec(`insert into u_t2(id, col2) values (1000, 1), (2000, 2), (3000, 3), (4000, 4)`)

// updating child to an existing value in parent.
mcmp.Exec(`insert into u_t2(id, col2) values (4000, 50) on duplicate key update col2 = 1`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)

// updating parent, value not referred in child.
mcmp.Exec(`insert into u_t1(id, col1) values (400, 50) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(1)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) INT64(1)] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) INT64(1)]]`)

// updating parent, child updated to null.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 75) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(75)] [INT64(200) INT64(2)] [INT64(300) INT64(3)] [INT64(400) INT64(50)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) INT64(3)] [INT64(4000) NULL]]`)

// inserting multiple rows in parent, some child rows updated to null.
mcmp.Exec(`insert into u_t1(id, col1) values (100, 42),(600, 2),(300, 24),(200, 2) on duplicate key update col1 = values(col1)`)
mcmp.AssertMatches(`select * from u_t1 order by id`, `[[INT64(100) INT64(42)] [INT64(200) INT64(2)] [INT64(300) INT64(24)] [INT64(400) INT64(50)] [INT64(600) INT64(2)]]`)
mcmp.AssertMatches(`select * from u_t2 order by id`, `[[INT64(1000) NULL] [INT64(2000) INT64(2)] [INT64(3000) NULL] [INT64(4000) NULL]]`)
}

// TestDDLFk tests that table is created with fk constraint when foreign_key_checks is off.
func TestDDLFk(t *testing.T) {
mcmp, closer := start(t)
Expand Down
35 changes: 35 additions & 0 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

138 changes: 138 additions & 0 deletions go/vt/vtgate/engine/upsert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
Copyright 2023 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"context"
"fmt"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

var _ Primitive = (*Upsert)(nil)

// Upsert Primitive will execute the insert primitive first and
// if there is `Duplicate Key` error, it executes the update primitive.
type Upsert struct {
Upserts []upsert

txNeeded
}

type upsert struct {
Insert Primitive
Update Primitive
}

// AddUpsert appends to the Upsert Primitive.
func (u *Upsert) AddUpsert(ins, upd Primitive) {
u.Upserts = append(u.Upserts, upsert{
Insert: ins,
Update: upd,
})
}

// RouteType implements Primitive interface type.
func (u *Upsert) RouteType() string {
return "UPSERT"
}

// GetKeyspaceName implements Primitive interface type.
func (u *Upsert) GetKeyspaceName() string {
if len(u.Upserts) > 0 {
return u.Upserts[0].Insert.GetKeyspaceName()
}
return ""
}

// GetTableName implements Primitive interface type.
func (u *Upsert) GetTableName() string {
if len(u.Upserts) > 0 {
return u.Upserts[0].Insert.GetTableName()
}
return ""
}

// GetFields implements Primitive interface type.
func (u *Upsert) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.VT13001("unexpected to receive GetFields call for insert on duplicate key update query")
}

// TryExecute implements Primitive interface type.
func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
result := &sqltypes.Result{}
for _, up := range u.Upserts {
qr, err := execOne(ctx, vcursor, bindVars, wantfields, up)
if err != nil {
return nil, err
}
result.RowsAffected += qr.RowsAffected
}
return result, nil
}

func execOne(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, up upsert) (*sqltypes.Result, error) {
insQr, err := vcursor.ExecutePrimitive(ctx, up.Insert, bindVars, wantfields)
if err == nil {
return insQr, nil
}
if vterrors.Code(err) != vtrpcpb.Code_ALREADY_EXISTS {
return nil, err
}
updQr, err := vcursor.ExecutePrimitive(ctx, up.Update, bindVars, wantfields)
if err != nil {
return nil, err
}
// To match mysql, need to report +1 on rows affected if there is any change.
if updQr.RowsAffected > 0 {
updQr.RowsAffected += 1
}
return updQr, nil
}

// TryStreamExecute implements Primitive interface type.
func (u *Upsert) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
qr, err := u.TryExecute(ctx, vcursor, bindVars, wantfields)
if err != nil {
return err
}
return callback(qr)
}

// Inputs implements Primitive interface type.
func (u *Upsert) Inputs() ([]Primitive, []map[string]any) {
var inputs []Primitive
var inputsMap []map[string]any
for i, up := range u.Upserts {
inputs = append(inputs, up.Insert, up.Update)
inputsMap = append(inputsMap,
map[string]any{inputName: fmt.Sprintf("Insert-%d", i+1)},
map[string]any{inputName: fmt.Sprintf("Update-%d", i+1)})
}
return inputs, inputsMap
}

func (u *Upsert) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "Upsert",
TargetTabletType: topodatapb.TabletType_PRIMARY,
}
}
27 changes: 27 additions & 0 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
return transformFkVerify(ctx, op)
case *operators.InsertSelection:
return transformInsertionSelection(ctx, op)
case *operators.Upsert:
return transformUpsert(ctx, op)
case *operators.HashJoin:
return transformHashJoin(ctx, op)
case *operators.Sequential:
Expand All @@ -75,6 +77,31 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op operators.Opera
return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToLogicalPlan)", op))
}

func transformUpsert(ctx *plancontext.PlanningContext, op *operators.Upsert) (logicalPlan, error) {
u := &upsert{}
for _, source := range op.Sources {
iLp, uLp, err := transformOneUpsert(ctx, source)
if err != nil {
return nil, err
}
u.insert = append(u.insert, iLp)
u.update = append(u.update, uLp)
}
return u, nil
}

func transformOneUpsert(ctx *plancontext.PlanningContext, source operators.UpsertSource) (iLp, uLp logicalPlan, err error) {
iLp, err = transformToLogicalPlan(ctx, source.Insert)
if err != nil {
return
}
if ins, ok := iLp.(*insert); ok {
ins.eInsert.PreventAutoCommit = true
}
uLp, err = transformToLogicalPlan(ctx, source.Update)
return
}

func transformSequential(ctx *plancontext.PlanningContext, op *operators.Sequential) (logicalPlan, error) {
var lps []logicalPlan
for _, source := range op.Sources {
Expand Down
Loading

0 comments on commit c680c16

Please sign in to comment.