Skip to content

Commit

Permalink
evalengine: Handle zero dates correctly (#14610)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
Signed-off-by: Vicent Marti <[email protected]>
Co-authored-by: Vicent Marti <[email protected]>
  • Loading branch information
dbussink and vmg authored Nov 27, 2023
1 parent f8348c7 commit 5e23ddc
Show file tree
Hide file tree
Showing 30 changed files with 244 additions and 193 deletions.
9 changes: 1 addition & 8 deletions config/mycnf/test-suite.cnf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This sets some unsafe settings specifically for
# the test-suite which is currently MySQL 5.7 based
# the test-suite which is currently MySQL 8.0 based
# In future it should be renamed testsuite.cnf

innodb_buffer_pool_size = 32M
Expand All @@ -14,13 +14,6 @@ key_buffer_size = 2M
sync_binlog=0
innodb_doublewrite=0

# These two settings are required for the testsuite to pass,
# but enabling them does not spark joy. They should be removed
# in the future. See:
# https://github.com/vitessio/vitess/issues/5396

sql_mode = STRICT_TRANS_TABLES

# set a short heartbeat interval in order to detect failures quickly
slave_net_timeout = 4
# Disabling `super-read-only`. `test-suite` is mainly used for `vttestserver`. Since `vttestserver` uses a single MySQL for primary and replicas,
Expand Down
3 changes: 3 additions & 0 deletions go/mysql/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package config

const DefaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
4 changes: 4 additions & 0 deletions go/mysql/endtoend/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,10 @@ func TestRowReplicationTypes(t *testing.T) {
t.Fatal(err)
}
defer dConn.Close()
// We have tests for zero dates, so we need to allow that for this session.
if _, err := dConn.ExecuteFetch("SET @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 0, false); err != nil {
t.Fatal(err)
}

// Set the connection time zone for execution of the
// statements to PST. That way we're sure to test the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/config"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/schema"
"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -58,8 +59,7 @@ var (
)

const (
testDataPath = "testdata"
defaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
testDataPath = "testdata"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -178,7 +178,7 @@ func testSingle(t *testing.T, testName string) {
}
}

sqlMode := defaultSQLMode
sqlMode := config.DefaultSQLMode
if overrideSQLMode, exists := readTestFile(t, testName, "sql_mode"); exists {
sqlMode = overrideSQLMode
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ var (
)

const (
testDataPath = "../../onlineddl/vrepl_suite/testdata"
defaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
testDataPath = "../../onlineddl/vrepl_suite/testdata"
sqlModeAllowsZeroDate = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
)

type testTableSchema struct {
Expand Down Expand Up @@ -202,7 +202,7 @@ func testSingle(t *testing.T, testName string) {
return
}

sqlModeQuery := fmt.Sprintf("set @@global.sql_mode='%s'", defaultSQLMode)
sqlModeQuery := fmt.Sprintf("set @@global.sql_mode='%s'", sqlModeAllowsZeroDate)
_ = mysqlExec(t, sqlModeQuery, "")
_ = mysqlExec(t, "set @@global.event_scheduler=0", "")

Expand Down
3 changes: 2 additions & 1 deletion go/vt/sidecardb/sidecardb.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/history"
"vitess.io/vitess/go/mysql/config"
"vitess.io/vitess/go/mysql/sqlerror"

"vitess.io/vitess/go/mysql/fakesqldb"
Expand Down Expand Up @@ -485,7 +486,7 @@ func AddSchemaInitQueries(db *fakesqldb.DB, populateTables bool) {
sqlModeResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"sql_mode",
"varchar"),
"ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION",
config.DefaultSQLMode,
)
db.AddQuery("select @@session.sql_mode as sql_mode", sqlModeResult)

Expand Down
22 changes: 12 additions & 10 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
})
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
if err != nil {
return nil, err
}
Expand All @@ -116,7 +116,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}, nil
}

func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error {
func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, sqlmode evalengine.SQLMode) error {
if len(row) != len(fields) {
return errWrongNumberOfColumnsInSelect
}
Expand All @@ -126,7 +126,7 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field)
continue
}
if fields[i].Type != value.Type() {
newValue, err := evalengine.CoerceTo(value, fields[i].Type)
newValue, err := evalengine.CoerceTo(value, fields[i].Type, sqlmode)
if err != nil {
return err
}
Expand Down Expand Up @@ -228,16 +228,17 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV

// TryStreamExecute performs a streaming exec.
func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error {
sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode())
if vcursor.Session().InTransaction() {
// as we are in a transaction, we need to execute all queries inside a single connection,
// which holds the single transaction we have
return c.sequentialStreamExec(ctx, vcursor, bindVars, callback)
return c.sequentialStreamExec(ctx, vcursor, bindVars, callback, sqlmode)
}
// not in transaction, so execute in parallel.
return c.parallelStreamExec(ctx, vcursor, bindVars, callback)
return c.parallelStreamExec(ctx, vcursor, bindVars, callback, sqlmode)
}

func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error {
// Scoped context; any early exit triggers cancel() to clean up ongoing work.
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
Expand Down Expand Up @@ -271,7 +272,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fields); err != nil {
if err := c.coerceValuesTo(row, fields, sqlmode); err != nil {
return err
}
}
Expand Down Expand Up @@ -340,7 +341,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
return wg.Wait()
}

func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {
func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error {
// all the below fields ensure that the fields are sent only once.
results := make([][]*sqltypes.Result, len(c.Sources))

Expand Down Expand Up @@ -374,7 +375,7 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fields, callback); err != nil {
if err = c.coerceAndVisitResults(res, fields, callback, sqlmode); err != nil {
return err
}
}
Expand All @@ -386,6 +387,7 @@ func (c *Concatenate) coerceAndVisitResults(
res []*sqltypes.Result,
fields []*querypb.Field,
callback func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
for _, r := range res {
if len(r.Rows) > 0 &&
Expand All @@ -402,7 +404,7 @@ func (c *Concatenate) coerceAndVisitResults(
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fields)
err := c.coerceValuesTo(row, fields, sqlmode)
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type (
probeTable struct {
seenRows map[evalengine.HashCode][]sqltypes.Row
checkCols []CheckCol
sqlmode evalengine.SQLMode
}
)

Expand Down Expand Up @@ -119,14 +120,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode
return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code")
}
col := inputRow[checkCol.Col]
hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type())
hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type(), pt.sqlmode)
if err != nil {
if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil {
return 0, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type())
hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type(), pt.sqlmode)
if err != nil {
return 0, err
}
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/google/go-cmp/cmp"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/config"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
"vitess.io/vitess/go/vt/key"
Expand Down Expand Up @@ -135,6 +136,10 @@ func (t *noopVCursor) TimeZone() *time.Location {
return nil
}

func (t *noopVCursor) SQLMode() string {
return config.DefaultSQLMode
}

func (t *noopVCursor) ExecutePrimitive(ctx context.Context, primitive Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
return primitive.TryExecute(ctx, t, bindVars, wantfields)
}
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/engine/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type (
lhsKey, rhsKey int
cols []int
hasher vthash.Hasher
sqlmode evalengine.SQLMode
}

probeTableEntry struct {
Expand Down Expand Up @@ -283,7 +284,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error {
}

func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) {
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ)
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode)
if err != nil {
return vthash.Hash{}, err
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type (

ConnCollation() collations.ID
TimeZone() *time.Location
SQLMode() string

ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error)

Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/api_coerce.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) {
cast, err := valueToEvalCast(value, value.Type(), collations.Unknown)
func CoerceTo(value sqltypes.Value, typ sqltypes.Type, sqlmode SQLMode) (sqltypes.Value, error) {
cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode)
if err != nil {
return sqltypes.Value{}, err
}
Expand Down
20 changes: 10 additions & 10 deletions go/vt/vtgate/evalengine/api_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type HashCode = uint64

// NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same
// for two values that are considered equal by `NullsafeCompare`.
func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type) (HashCode, error) {
e, err := valueToEvalCast(v, coerceType, collation)
func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode) (HashCode, error) {
e, err := valueToEvalCast(v, coerceType, collation, sqlmode)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "
// for two values that are considered equal by `NullsafeCompare`.
// This can be used to avoid having to do comparison checks after a hash,
// since we consider the 128 bits of entropy enough to guarantee uniqueness.
func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error {
func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error {
switch {
case v.IsNull(), sqltypes.IsNull(coerceTo):
hash.Write16(hashPrefixNil)
Expand All @@ -97,7 +97,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat
case v.IsText(), v.IsBinary():
f, _ = fastparse.ParseFloat64(v.RawStr())
default:
return nullsafeHashcode128Default(hash, v, collation, coerceTo)
return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode)
}
if err != nil {
return err
Expand Down Expand Up @@ -137,7 +137,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat
}
neg = i < 0
default:
return nullsafeHashcode128Default(hash, v, collation, coerceTo)
return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode)
}
if err != nil {
return err
Expand Down Expand Up @@ -180,7 +180,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat
u, err = uint64(fval), nil
}
default:
return nullsafeHashcode128Default(hash, v, collation, coerceTo)
return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode)
}
if err != nil {
return err
Expand Down Expand Up @@ -223,20 +223,20 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat
fval, _ := fastparse.ParseFloat64(v.RawStr())
dec = decimal.NewFromFloat(fval)
default:
return nullsafeHashcode128Default(hash, v, collation, coerceTo)
return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode)
}
hash.Write16(hashPrefixDecimal)
dec.Hash(hash)
default:
return nullsafeHashcode128Default(hash, v, collation, coerceTo)
return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode)
}
return nil
}

func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error {
func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error {
// Slow path to handle all other types. This uses the generic
// logic for value casting to ensure we match MySQL here.
e, err := valueToEvalCast(v, coerceTo, collation)
e, err := valueToEvalCast(v, coerceTo, collation, sqlmode)
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions go/vt/vtgate/evalengine/api_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func TestHashCodes(t *testing.T) {
require.NoError(t, err)
require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal))

h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type())
h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0)
require.NoError(t, err)

h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type())
h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0)
require.ErrorIs(t, err, tc.err)

assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal))
Expand All @@ -82,9 +82,9 @@ func TestHashCodesRandom(t *testing.T) {
typ, err := coerceTo(v1.Type(), v2.Type())
require.NoError(t, err)

hash1, err := NullsafeHashcode(v1, collation, typ)
hash1, err := NullsafeHashcode(v1, collation, typ, 0)
require.NoError(t, err)
hash2, err := NullsafeHashcode(v2, collation, typ)
hash2, err := NullsafeHashcode(v2, collation, typ, 0)
require.NoError(t, err)
if cmp == 0 {
equal++
Expand Down Expand Up @@ -142,11 +142,11 @@ func TestHashCodes128(t *testing.T) {
require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal))

hasher1 := vthash.New()
err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type())
err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0)
require.NoError(t, err)

hasher2 := vthash.New()
err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type())
err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0)
require.ErrorIs(t, err, tc.err)

h1 := hasher1.Sum128()
Expand All @@ -172,10 +172,10 @@ func TestHashCodesRandom128(t *testing.T) {
require.NoError(t, err)

hasher1 := vthash.New()
err = NullsafeHashcode128(&hasher1, v1, collation, typ)
err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0)
require.NoError(t, err)
hasher2 := vthash.New()
err = NullsafeHashcode128(&hasher2, v2, collation, typ)
err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0)
require.NoError(t, err)
if cmp == 0 {
equal++
Expand Down
Loading

0 comments on commit 5e23ddc

Please sign in to comment.