Skip to content

Commit

Permalink
addressed review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed May 22, 2024
1 parent be4fd87 commit 7d9d50f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 15 deletions.
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa
return res, nil
}

// executeLiteralUpdate executes the primitive that can be executed with a single bind variable from the input result.
// The column updated have same value for all rows in the input result.
func executeLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int) (*sqltypes.Result, error) {
var bv *querypb.BindVariable
if len(outputCols) == 1 {
Expand Down Expand Up @@ -122,6 +124,8 @@ func getBVMulti(rows []sqltypes.Row, offsets []int) *querypb.BindVariable {
return bv
}

// executeNonLiteralUpdate executes the primitive that needs to be executed per row from the input result.
// The column updated might have different value for each row in the input result.
func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int, vars map[string]int) (qr *sqltypes.Result, err error) {
var res *sqltypes.Result
for _, row := range inputRes.Rows {
Expand Down Expand Up @@ -175,7 +179,7 @@ func (dml *DMLWithInput) description() PrimitiveDescription {
if len(vars) == 0 {
continue
}
bvList = append(bvList, fmt.Sprintf("%d:%v", idx, vars))
bvList = append(bvList, fmt.Sprintf("%d:[%s]", idx, orderedStringIntMap(vars)))
}
if len(bvList) > 0 {
other["BindVars"] = bvList
Expand Down
9 changes: 9 additions & 0 deletions go/vt/vtgate/engine/plan_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"

"vitess.io/vitess/go/tools/graphviz"
"vitess.io/vitess/go/vt/key"
Expand Down Expand Up @@ -266,3 +267,11 @@ func (m orderedMap) MarshalJSON() ([]byte, error) {
buf.WriteString("}")
return buf.Bytes(), nil
}

func (m orderedMap) String() string {
var output []string
for _, val := range m {
output = append(output, fmt.Sprintf("%s:%v", val.key, val.val))
}
return strings.Join(output, " ")
}
21 changes: 10 additions & 11 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,31 +219,30 @@ func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparse
// E.g. UPDATE t1 join t2 on t1.col = t2.col SET t1.col = t2.col + 1 where t2.col = 10;
// SET t1.col = t2.col + 1 -> SET t1.col = :t2_col + 1 (t2_col is the bindvar column which will be provided from the input)
ueMap := make(map[semantics.TableSet]updList)
var dependentCols updList
for _, ue := range upd.Exprs {
target := ctx.SemTable.DirectDeps(ue.Name)
exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr)
jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target))
updCol := updColumn{ue.Name, jc}
ueMap[target] = append(ueMap[target], updCol)
dependentCols = append(dependentCols, updCol)
ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc})
}

// Check if any of the dependent columns are updated in the same query.
// This can result in a mismatch of rows on how MySQL interprets it and how Vitess would have updated those rows.
// It is safe to fail for those cases.
errIfDependentColumnUpdated(ctx, upd, dependentCols)
errIfDependentColumnUpdated(ctx, upd, ueMap)

return ueMap
}

func errIfDependentColumnUpdated(ctx *plancontext.PlanningContext, upd *sqlparser.Update, dependentCols updList) {
func errIfDependentColumnUpdated(ctx *plancontext.PlanningContext, upd *sqlparser.Update, ueMap map[semantics.TableSet]updList) {
for _, ue := range upd.Exprs {
for _, dc := range dependentCols {
for _, bvExpr := range dc.jc.LHSExprs {
if ctx.SemTable.EqualsExprWithDeps(ue.Name, bvExpr.Expr) {
panic(vterrors.VT12001(
fmt.Sprintf("'%s' column referenced in update expression '%s' is itself updated", sqlparser.String(ue.Name), sqlparser.String(dc.jc.Original))))
for _, list := range ueMap {
for _, dc := range list {
for _, bvExpr := range dc.jc.LHSExprs {
if ctx.SemTable.EqualsExprWithDeps(ue.Name, bvExpr.Expr) {
panic(vterrors.VT12001(
fmt.Sprintf("'%s' column referenced in update expression '%s' is itself updated", sqlparser.String(ue.Name), sqlparser.String(dc.jc.Original))))
}
}
}
}
Expand Down
82 changes: 79 additions & 3 deletions go/vt/vtgate/planbuilder/testdata/dml_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -5829,7 +5829,7 @@
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"BindVars": [
"0:map[ue_col:1]"
"0:[ue_col:1]"
],
"Offset": [
"0:[0]"
Expand Down Expand Up @@ -5895,6 +5895,82 @@
]
}
},
{
"comment": "update with multi table join with single target having multiple dependent column update",
"query": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id",
"plan": {
"QueryType": "UPDATE",
"Original": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id",
"Instructions": {
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"BindVars": [
"0:[ue_bar:2 ue_foo:1]"
],
"Offset": [
"0:[0]"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:0,L:0,L:1",
"JoinVars": {
"ue_id": 2
},
"TableName": "user_extra_`user`",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select ue.foo, ue.bar, ue.id from user_extra as ue where 1 != 1",
"Query": "select ue.foo, ue.bar, ue.id from user_extra as ue for update",
"Table": "user_extra"
},
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select u.id from `user` as u where 1 != 1",
"Query": "select u.id from `user` as u where u.id = :ue_id for update",
"Table": "`user`",
"Values": [
":ue_id"
],
"Vindex": "user_index"
}
]
},
{
"OperatorType": "Update",
"Variant": "IN",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "PRIMARY",
"Query": "update `user` as u set u.col = :ue_foo + :ue_bar + u.baz where u.id in ::dml_vals",
"Table": "user",
"Values": [
"::dml_vals"
],
"Vindex": "user_index"
}
]
},
"TablesUsed": [
"user.user",
"user.user_extra"
]
}
},
{
"comment": "update with multi table join with multi target having dependent column update",
"query": "update user, user_extra ue set user.name = ue.id + 'foo', ue.bar = user.baz where user.id = ue.id and user.id = 1",
Expand All @@ -5905,8 +5981,8 @@
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"BindVars": [
"0:map[ue_id:1]",
"1:map[user_baz:3]"
"0:[ue_id:1]",
"1:[user_baz:3]"
],
"Offset": [
"0:[0]",
Expand Down

0 comments on commit 7d9d50f

Please sign in to comment.