Skip to content

Commit

Permalink
Ensure hexval and int don't share BindVar after Normalization (#14451)
Browse files Browse the repository at this point in the history
Signed-off-by: William Martin <[email protected]>
Signed-off-by: Arthur Schreiber <[email protected]>
  • Loading branch information
williammartin authored and arthurschreiber committed Nov 11, 2023
1 parent 12d93c7 commit 0b47a94
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
30 changes: 8 additions & 22 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*quer
type normalizer struct {
bindVars map[string]*querypb.BindVariable
reserved *ReservedVars
vals map[string]string
vals map[Literal]string
err error
inDerived bool
}
Expand All @@ -53,7 +53,7 @@ func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVari
return &normalizer{
bindVars: bindVars,
reserved: reserved,
vals: make(map[string]string),
vals: make(map[Literal]string),
}
}

Expand Down Expand Up @@ -190,31 +190,18 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
}

// Check if there's a bindvar for that value already.
key := keyFor(bval, node)
bvname, ok := nz.vals[key]
bvname, ok := nz.vals[*node]
if !ok {
// If there's no such bindvar, make a new one.
bvname = nz.reserved.nextUnusedVar()
nz.vals[key] = bvname
nz.vals[*node] = bvname
nz.bindVars[bvname] = bval
}

// Modify the AST node to a bindvar.
cursor.Replace(NewArgument(bvname))
}

func keyFor(bval *querypb.BindVariable, lit *Literal) string {
if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar {
return lit.Val
}

// Prefixing strings with "'" ensures that a string
// and number that have the same representation don't
// collide.
return "'" + lit.Val

}

// convertLiteral converts an Literal without the dedup.
func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {
err := validateLiteral(node)
Expand Down Expand Up @@ -273,15 +260,14 @@ func (nz *normalizer) parameterize(left, right Expr) Expr {
if bval == nil {
return nil
}
key := keyFor(bval, lit)
bvname := nz.decideBindVarName(key, lit, col, bval)
bvname := nz.decideBindVarName(lit, col, bval)
return Argument(bvname)
}

func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string {
func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string {
if len(lit.Val) <= 256 {
// first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name
bvname, ok := nz.vals[key]
bvname, ok := nz.vals[*lit]
if ok {
return bvname
}
Expand All @@ -291,7 +277,7 @@ func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName,
// Big values are most likely not for vindexes.
// We save a lot of CPU because we avoid building
bvname := nz.reserved.ReserveColName(col)
nz.vals[key] = bvname
nz.vals[*lit] = bvname
nz.bindVars[bvname] = bval

return bvname
Expand Down
8 changes: 8 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ func TestNormalize(t *testing.T) {
in: `select * from (select 12) as t`,
outstmt: `select * from (select 12 from dual) as t`,
outbv: map[string]*querypb.BindVariable{},
}, {
// HexVal and Int should not share a bindvar just because they have the same value
in: `select * from t where v1 = x'31' and v2 = 31`,
outstmt: `select * from t where v1 = :v1 and v2 = :v2`,
outbv: map[string]*querypb.BindVariable{
"v1": sqltypes.HexValBindVariable([]byte("x'31'")),
"v2": sqltypes.Int64BindVariable(31),
},
}}
for _, tc := range testcases {
t.Run(tc.in, func(t *testing.T) {
Expand Down

0 comments on commit 0b47a94

Please sign in to comment.