diff --git a/config/mycnf/test-suite.cnf b/config/mycnf/test-suite.cnf index e6d0992f6e6..28f4ac16e0d 100644 --- a/config/mycnf/test-suite.cnf +++ b/config/mycnf/test-suite.cnf @@ -1,5 +1,5 @@ # This sets some unsafe settings specifically for -# the test-suite which is currently MySQL 5.7 based +# the test-suite which is currently MySQL 8.0 based # In future it should be renamed testsuite.cnf innodb_buffer_pool_size = 32M @@ -14,13 +14,6 @@ key_buffer_size = 2M sync_binlog=0 innodb_doublewrite=0 -# These two settings are required for the testsuite to pass, -# but enabling them does not spark joy. They should be removed -# in the future. See: -# https://github.com/vitessio/vitess/issues/5396 - -sql_mode = STRICT_TRANS_TABLES - # set a short heartbeat interval in order to detect failures quickly slave_net_timeout = 4 # Disabling `super-read-only`. `test-suite` is mainly used for `vttestserver`. Since `vttestserver` uses a single MySQL for primary and replicas, diff --git a/go/mysql/config/config.go b/go/mysql/config/config.go new file mode 100644 index 00000000000..8abf9d7dc71 --- /dev/null +++ b/go/mysql/config/config.go @@ -0,0 +1,3 @@ +package config + +const DefaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION" diff --git a/go/mysql/endtoend/replication_test.go b/go/mysql/endtoend/replication_test.go index 0c1fa006347..441209b35f7 100644 --- a/go/mysql/endtoend/replication_test.go +++ b/go/mysql/endtoend/replication_test.go @@ -908,6 +908,10 @@ func TestRowReplicationTypes(t *testing.T) { t.Fatal(err) } defer dConn.Close() + // We have tests for zero dates, so we need to allow that for this session. + if _, err := dConn.ExecuteFetch("SET @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 0, false); err != nil { + t.Fatal(err) + } // Set the connection time zone for execution of the // statements to PST. That way we're sure to test the diff --git a/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go b/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go index c8b87215036..56818069e05 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go +++ b/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go @@ -28,6 +28,7 @@ import ( "time" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/mysql/config" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/sqlparser" @@ -58,8 +59,7 @@ var ( ) const ( - testDataPath = "testdata" - defaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION" + testDataPath = "testdata" ) func TestMain(m *testing.M) { @@ -178,7 +178,7 @@ func testSingle(t *testing.T, testName string) { } } - sqlMode := defaultSQLMode + sqlMode := config.DefaultSQLMode if overrideSQLMode, exists := readTestFile(t, testName, "sql_mode"); exists { sqlMode = overrideSQLMode } diff --git a/go/test/endtoend/schemadiff/vrepl/schemadiff_vrepl_suite_test.go b/go/test/endtoend/schemadiff/vrepl/schemadiff_vrepl_suite_test.go index 2dc79840018..055dc7a1df5 100644 --- a/go/test/endtoend/schemadiff/vrepl/schemadiff_vrepl_suite_test.go +++ b/go/test/endtoend/schemadiff/vrepl/schemadiff_vrepl_suite_test.go @@ -53,8 +53,8 @@ var ( ) const ( - testDataPath = "../../onlineddl/vrepl_suite/testdata" - defaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION" + testDataPath = "../../onlineddl/vrepl_suite/testdata" + sqlModeAllowsZeroDate = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION" ) type testTableSchema struct { @@ -202,7 +202,7 @@ func testSingle(t *testing.T, testName string) { return } - sqlModeQuery := fmt.Sprintf("set @@global.sql_mode='%s'", defaultSQLMode) + sqlModeQuery := fmt.Sprintf("set @@global.sql_mode='%s'", sqlModeAllowsZeroDate) _ = mysqlExec(t, sqlModeQuery, "") _ = mysqlExec(t, "set @@global.event_scheduler=0", "") diff --git a/go/vt/sidecardb/sidecardb.go b/go/vt/sidecardb/sidecardb.go index 92d416f9d37..4b8c37039d7 100644 --- a/go/vt/sidecardb/sidecardb.go +++ b/go/vt/sidecardb/sidecardb.go @@ -29,6 +29,7 @@ import ( "vitess.io/vitess/go/constants/sidecar" "vitess.io/vitess/go/history" + "vitess.io/vitess/go/mysql/config" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/mysql/fakesqldb" @@ -485,7 +486,7 @@ func AddSchemaInitQueries(db *fakesqldb.DB, populateTables bool) { sqlModeResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields( "sql_mode", "varchar"), - "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION", + config.DefaultSQLMode, ) db.AddQuery("select @@session.sql_mode as sql_mode", sqlModeResult) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 1e8cb655547..27b35c32aa8 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -105,7 +105,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil - }) + }, evalengine.ParseSQLMode(vcursor.SQLMode())) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars }, nil } -func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) error { +func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, sqlmode evalengine.SQLMode) error { if len(row) != len(fields) { return errWrongNumberOfColumnsInSelect } @@ -126,7 +126,7 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field) continue } if fields[i].Type != value.Type() { - newValue, err := evalengine.CoerceTo(value, fields[i].Type) + newValue, err := evalengine.CoerceTo(value, fields[i].Type, sqlmode) if err != nil { return err } @@ -228,16 +228,17 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV // TryStreamExecute performs a streaming exec. func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { + sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode()) if vcursor.Session().InTransaction() { // as we are in a transaction, we need to execute all queries inside a single connection, // which holds the single transaction we have - return c.sequentialStreamExec(ctx, vcursor, bindVars, callback) + return c.sequentialStreamExec(ctx, vcursor, bindVars, callback, sqlmode) } // not in transaction, so execute in parallel. - return c.parallelStreamExec(ctx, vcursor, bindVars, callback) + return c.parallelStreamExec(ctx, vcursor, bindVars, callback, sqlmode) } -func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error { +func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error { // Scoped context; any early exit triggers cancel() to clean up ongoing work. ctx, cancel := context.WithCancel(inCtx) defer cancel() @@ -271,7 +272,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Apply type coercion if needed. if needsCoercion { for _, row := range res.Rows { - if err := c.coerceValuesTo(row, fields); err != nil { + if err := c.coerceValuesTo(row, fields, sqlmode); err != nil { return err } } @@ -340,7 +341,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, return wg.Wait() } -func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { +func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error { // all the below fields ensure that the fields are sent only once. results := make([][]*sqltypes.Result, len(c.Sources)) @@ -374,7 +375,7 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fields, callback); err != nil { + if err = c.coerceAndVisitResults(res, fields, callback, sqlmode); err != nil { return err } } @@ -386,6 +387,7 @@ func (c *Concatenate) coerceAndVisitResults( res []*sqltypes.Result, fields []*querypb.Field, callback func(*sqltypes.Result) error, + sqlmode evalengine.SQLMode, ) error { for _, r := range res { if len(r.Rows) > 0 && @@ -402,7 +404,7 @@ func (c *Concatenate) coerceAndVisitResults( } if needsCoercion { for _, row := range r.Rows { - err := c.coerceValuesTo(row, fields) + err := c.coerceValuesTo(row, fields, sqlmode) if err != nil { return err } diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 2d263464a2e..c7d6742c136 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -46,6 +46,7 @@ type ( probeTable struct { seenRows map[evalengine.HashCode][]sqltypes.Row checkCols []CheckCol + sqlmode evalengine.SQLMode } ) @@ -119,14 +120,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type()) + hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type(), pt.sqlmode) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { return 0, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type()) + hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type(), pt.sqlmode) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 0ee95a72e60..f9cedd74bfc 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -31,6 +31,7 @@ import ( "github.com/google/go-cmp/cmp" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/config" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/key" @@ -135,6 +136,10 @@ func (t *noopVCursor) TimeZone() *time.Location { return nil } +func (t *noopVCursor) SQLMode() string { + return config.DefaultSQLMode +} + func (t *noopVCursor) ExecutePrimitive(ctx context.Context, primitive Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return primitive.TryExecute(ctx, t, bindVars, wantfields) } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index 4f205f1bcdc..4e305e8b59d 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -75,6 +75,7 @@ type ( lhsKey, rhsKey int cols []int hasher vthash.Hasher + sqlmode evalengine.SQLMode } probeTableEntry struct { @@ -283,7 +284,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error { } func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) { - err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ) + err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode) if err != nil { return vthash.Hash{}, err } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index dc1259e2267..833f4cc3b45 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -88,6 +88,7 @@ type ( ConnCollation() collations.ID TimeZone() *time.Location + SQLMode() string ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 143d22ab78c..89b36458198 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -23,8 +23,8 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { - cast, err := valueToEvalCast(value, value.Type(), collations.Unknown) +func CoerceTo(value sqltypes.Value, typ sqltypes.Type, sqlmode SQLMode) (sqltypes.Value, error) { + cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode) if err != nil { return sqltypes.Value{}, err } diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 209f766840d..3bce100839c 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,8 +34,8 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type) (HashCode, error) { - e, err := valueToEvalCast(v, coerceType, collation) +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type, sqlmode SQLMode) (HashCode, error) { + e, err := valueToEvalCast(v, coerceType, collation, sqlmode) if err != nil { return 0, err } @@ -75,7 +75,7 @@ var ErrHashCoercionIsNotExact = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, " // for two values that are considered equal by `NullsafeCompare`. // This can be used to avoid having to do comparison checks after a hash, // since we consider the 128 bits of entropy enough to guarantee uniqueness. -func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error { +func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { switch { case v.IsNull(), sqltypes.IsNull(coerceTo): hash.Write16(hashPrefixNil) @@ -97,7 +97,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case v.IsText(), v.IsBinary(): f, _ = fastparse.ParseFloat64(v.RawStr()) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) } if err != nil { return err @@ -137,7 +137,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat } neg = i < 0 default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) } if err != nil { return err @@ -180,7 +180,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat u, err = uint64(fval), nil } default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) } if err != nil { return err @@ -223,20 +223,20 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) } hash.Write16(hashPrefixDecimal) dec.Hash(hash) default: - return nullsafeHashcode128Default(hash, v, collation, coerceTo) + return nullsafeHashcode128Default(hash, v, collation, coerceTo, sqlmode) } return nil } -func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error { +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type, sqlmode SQLMode) error { // Slow path to handle all other types. This uses the generic // logic for value casting to ensure we match MySQL here. - e, err := valueToEvalCast(v, coerceTo, collation) + e, err := valueToEvalCast(v, coerceTo, collation, sqlmode) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 0add16de89d..c1e5d880bdd 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -56,10 +56,10 @@ func TestHashCodes(t *testing.T) { require.NoError(t, err) require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) require.NoError(t, err) - h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) require.ErrorIs(t, err, tc.err) assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) @@ -82,9 +82,9 @@ func TestHashCodesRandom(t *testing.T) { typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) - hash1, err := NullsafeHashcode(v1, collation, typ) + hash1, err := NullsafeHashcode(v1, collation, typ, 0) require.NoError(t, err) - hash2, err := NullsafeHashcode(v2, collation, typ) + hash2, err := NullsafeHashcode(v2, collation, typ, 0) require.NoError(t, err) if cmp == 0 { equal++ @@ -142,11 +142,11 @@ func TestHashCodes128(t *testing.T) { require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type(), 0) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type(), 0) require.ErrorIs(t, err, tc.err) h1 := hasher1.Sum128() @@ -172,10 +172,10 @@ func TestHashCodesRandom128(t *testing.T) { require.NoError(t, err) hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, v1, collation, typ) + err = NullsafeHashcode128(&hasher1, v1, collation, typ, 0) require.NoError(t, err) hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, v2, collation, typ) + err = NullsafeHashcode128(&hasher2, v2, collation, typ, 0) require.NoError(t, err) if cmp == 0 { equal++ diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index e7d0e962fab..d757b3c3192 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -33,6 +33,7 @@ type compiler struct { collation collations.ID dynamicTypes []ctype asm assembler + sqlmode SQLMode } type CompilerLog interface { @@ -257,7 +258,7 @@ func (c *compiler) compileToDate(doct ctype, offset int) ctype { case sqltypes.Date: return doct default: - c.asm.Convert_xD(offset) + c.asm.Convert_xD(offset, c.sqlmode.AllowZeroDate()) } return ctype{Type: sqltypes.Date, Col: collationBinary, Flag: flagNullable} } @@ -268,7 +269,7 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype { c.asm.Convert_tp(offset, prec) return doct default: - c.asm.Convert_xDT(offset, prec) + c.asm.Convert_xDT(offset, prec, c.sqlmode.AllowZeroDate()) } return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable} } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 19d4c01a399..cbf9df9c57e 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() { }, "CMPFLAG NE [NULL]") } -func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation) { +func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) { elseOffset := 0 if hasElse { elseOffset = 1 @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate) goto done } } if elseOffset != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate) } else { env.vm.stack[env.vm.sp-stackDepth] = nil } @@ -1126,12 +1126,12 @@ func (asm *assembler) Convert_xu(offset int) { }, "CONV (SP-%d), UINT64", offset) } -func (asm *assembler) Convert_xD(offset int) { +func (asm *assembler) Convert_xD(offset int, allowZero bool) { asm.emit(func(env *ExpressionEnv) int { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) + d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now, allowZero) if d == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1141,27 +1141,12 @@ func (asm *assembler) Convert_xD(offset int) { }, "CONV (SP-%d), DATE", offset) } -func (asm *assembler) Convert_xD_nz(offset int) { +func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) { asm.emit(func(env *ExpressionEnv) int { // Need to explicitly check here or we otherwise // store a nil wrapper in an interface vs. a direct // nil. - d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now) - if d == nil || d.isZero() { - env.vm.stack[env.vm.sp-offset] = nil - } else { - env.vm.stack[env.vm.sp-offset] = d - } - return 1 - }, "CONV (SP-%d), DATE(NOZERO)", offset) -} - -func (asm *assembler) Convert_xDT(offset, prec int) { - asm.emit(func(env *ExpressionEnv) int { - // Need to explicitly check here or we otherwise - // store a nil wrapper in an interface vs. a direct - // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) + dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero) if dt == nil { env.vm.stack[env.vm.sp-offset] = nil } else { @@ -1171,21 +1156,6 @@ func (asm *assembler) Convert_xDT(offset, prec int) { }, "CONV (SP-%d), DATETIME", offset) } -func (asm *assembler) Convert_xDT_nz(offset, prec int) { - asm.emit(func(env *ExpressionEnv) int { - // Need to explicitly check here or we otherwise - // store a nil wrapper in an interface vs. a direct - // nil. - dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now) - if dt == nil || dt.isZero() { - env.vm.stack[env.vm.sp-offset] = nil - } else { - env.vm.stack[env.vm.sp-offset] = dt - } - return 1 - }, "CONV (SP-%d), DATETIME(NOZERO)", offset) -} - func (asm *assembler) Convert_xT(offset, prec int) { asm.emit(func(env *ExpressionEnv) int { t := evalToTime(env.vm.stack[env.vm.sp-offset], prec) @@ -4189,7 +4159,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col goto baddate } - tmp = evalToTemporal(env.vm.stack[env.vm.sp-2]) + tmp = evalToTemporal(env.vm.stack[env.vm.sp-2], true) if tmp == nil { goto baddate } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 77174a01d71..e7b51b41748 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -564,6 +564,14 @@ func TestCompilerSingle(t *testing.T) { expression: `case when null is null then 23 else null end`, result: `INT64(23)`, }, + { + expression: `CAST(0 AS DATE)`, + result: `NULL`, + }, + { + expression: `DAYOFMONTH(0)`, + result: `INT64(0)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 82f1ec688c7..33312cddc5f 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -176,7 +176,7 @@ func evalIsTruthy(e eval) boolean { } } -func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) { +func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) { if e == nil { return nil, nil } @@ -208,9 +208,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (ev case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: return evalToInt64(e).toUint64(), nil case sqltypes.Date: - return evalToDate(e, now), nil + return evalToDate(e, now, allowZero), nil case sqltypes.Datetime, sqltypes.Timestamp: - return evalToDateTime(e, -1, now), nil + return evalToDateTime(e, -1, now, allowZero), nil case sqltypes.Time: return evalToTime(e, -1), nil default: @@ -218,7 +218,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (ev } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID, sqlmode SQLMode) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -338,7 +338,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - d := evalToDate(e, time.Now()) + d := evalToDate(e, time.Now(), sqlmode.AllowZeroDate()) if d == nil { return nil, nil } @@ -349,7 +349,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, err } // Separate return here to avoid nil wrapped in interface type - dt := evalToDateTime(e, -1, time.Now()) + dt := evalToDateTime(e, -1, time.Now(), sqlmode.AllowZeroDate()) if dt == nil { return nil, nil } diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d44839a6853..7952fc5aa5d 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -164,11 +164,17 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations. return tmp } -func newEvalDateTime(dt datetime.DateTime, l int) *evalTemporal { +func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } -func newEvalDate(d datetime.Date) *evalTemporal { +func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { + if !allowZero && d.IsZero() { + return nil + } return &evalTemporal{t: sqltypes.Date, dt: datetime.DateTime{Date: d}} } @@ -185,7 +191,7 @@ func parseDate(s []byte) (*evalTemporal, error) { if !ok { return nil, errIncorrectTemporal("DATE", s) } - return newEvalDate(t), nil + return newEvalDate(t, true), nil } func parseDateTime(s []byte) (*evalTemporal, error) { @@ -193,7 +199,7 @@ func parseDateTime(s []byte) (*evalTemporal, error) { if !ok { return nil, errIncorrectTemporal("DATETIME", s) } - return newEvalDateTime(t, l), nil + return newEvalDateTime(t, l, true), nil } func parseTime(s []byte) (*evalTemporal, error) { @@ -211,56 +217,56 @@ func precision(req, got int) int { return req } -func evalToTemporal(e eval) *evalTemporal { +func evalToTemporal(e eval, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e case *evalBytes: if t, l, ok := datetime.ParseDateTime(e.string(), -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDate(e.string()); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, l, ok := datetime.ParseTime(e.string(), -1); ok { return newEvalTime(t, l) } case *evalInt64: if t, ok := datetime.ParseDateTimeInt64(e.i); ok { - return newEvalDateTime(t, 0) + return newEvalDateTime(t, 0, allowZero) } if d, ok := datetime.ParseDateInt64(e.i); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, ok := datetime.ParseTimeInt64(e.i); ok { return newEvalTime(t, 0) } case *evalUint64: if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { - return newEvalDateTime(t, 0) + return newEvalDateTime(t, 0, allowZero) } if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, ok := datetime.ParseTimeInt64(int64(e.u)); ok { return newEvalTime(t, 0) } case *evalFloat: if t, l, ok := datetime.ParseDateTimeFloat(e.f, -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateFloat(e.f); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if t, l, ok := datetime.ParseTimeFloat(e.f, -1); ok { return newEvalTime(t, l) } case *evalDecimal: if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, -1); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateDecimal(e.dec); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } if d, l, ok := datetime.ParseTimeDecimal(e.dec, e.length, -1); ok { return newEvalTime(d, l) @@ -271,9 +277,9 @@ func evalToTemporal(e eval) *evalTemporal { return newEvalTime(dt.Time, datetime.DefaultPrecision) } if dt.Time.IsZero() { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } - return newEvalDateTime(dt, datetime.DefaultPrecision) + return newEvalDateTime(dt, datetime.DefaultPrecision, allowZero) } } return nil @@ -326,74 +332,74 @@ func evalToTime(e eval, l int) *evalTemporal { return nil } -func evalToDateTime(e eval, l int, now time.Time) *evalTemporal { +func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e.toDateTime(precision(l, int(e.prec)), now) case *evalBytes: if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalInt64: if t, ok := datetime.ParseDateTimeInt64(e.i); ok { - return newEvalDateTime(t, precision(l, 0)) + return newEvalDateTime(t, precision(l, 0), allowZero) } if d, ok := datetime.ParseDateInt64(e.i); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalUint64: if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { - return newEvalDateTime(t, precision(l, 0)) + return newEvalDateTime(t, precision(l, 0), allowZero) } if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalFloat: if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateFloat(e.f); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalDecimal: if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { - return newEvalDateTime(t, l) + return newEvalDateTime(t, l, allowZero) } if d, ok := datetime.ParseDateDecimal(e.dec); ok { - return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0)) + return newEvalDateTime(datetime.DateTime{Date: d}, precision(l, 0), allowZero) } case *evalJSON: if dt, ok := e.DateTime(); ok { - return newEvalDateTime(dt, precision(l, datetime.DefaultPrecision)) + return newEvalDateTime(dt, precision(l, datetime.DefaultPrecision), allowZero) } } return nil } -func evalToDate(e eval, now time.Time) *evalTemporal { +func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: return e.toDate(now) case *evalBytes: if t, _ := datetime.ParseDate(e.string()); !t.IsZero() { - return newEvalDate(t) + return newEvalDate(t, allowZero) } if dt, _, _ := datetime.ParseDateTime(e.string(), -1); !dt.IsZero() { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } case evalNumeric: if t, ok := datetime.ParseDateInt64(e.toInt64().i); ok { - return newEvalDate(t) + return newEvalDate(t, allowZero) } if dt, ok := datetime.ParseDateTimeInt64(e.toInt64().i); ok { - return newEvalDate(dt.Date) + return newEvalDate(dt.Date, allowZero) } case *evalJSON: if d, ok := e.Date(); ok { - return newEvalDate(d) + return newEvalDate(d, allowZero) } } return nil diff --git a/go/vt/vtgate/evalengine/expr_convert.go b/go/vt/vtgate/evalengine/expr_convert.go index 900d4e37f8f..5b2d82b707f 100644 --- a/go/vt/vtgate/evalengine/expr_convert.go +++ b/go/vt/vtgate/evalengine/expr_convert.go @@ -124,12 +124,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) { case p > 6: return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p) } - if dt := evalToDateTime(e, c.Length, env.now); dt != nil { + if dt := evalToDateTime(e, c.Length, env.now, env.sqlmode.AllowZeroDate()); dt != nil { return dt, nil } return nil, nil case "DATE": - if d := evalToDate(e, env.now); d != nil { + if d := evalToDate(e, env.now, env.sqlmode.AllowZeroDate()); d != nil { return d, nil } return nil, nil diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index a6ca1411e74..1c92b0a45ee 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -30,6 +30,7 @@ import ( type VCursor interface { TimeZone() *time.Location GetKeyspace() string + SQLMode() string } type ( @@ -43,9 +44,10 @@ type ( Fields []*querypb.Field // internal state - now time.Time - vc VCursor - user *querypb.VTGateCallerID + now time.Time + vc VCursor + user *querypb.VTGateCallerID + sqlmode SQLMode } ) @@ -121,5 +123,32 @@ func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVari env := &ExpressionEnv{BindVars: bindVars, vc: vc} env.user = callerid.ImmediateCallerIDFromContext(ctx) env.SetTime(time.Now()) + if vc != nil { + env.sqlmode = ParseSQLMode(vc.SQLMode()) + } return env } + +const ( + sqlModeParsed = 1 << iota + sqlModeNoZeroDate +) + +type SQLMode uint32 + +func (mode SQLMode) AllowZeroDate() bool { + if mode == 0 { + // default: do not allow zero-date if the sqlmode is not set + return false + } + return (mode & sqlModeNoZeroDate) == 0 +} + +func ParseSQLMode(sqlmode string) SQLMode { + var mode SQLMode + if strings.Contains(sqlmode, "NO_ZERO_DATE") { + mode |= sqlModeNoZeroDate + } + mode |= sqlModeParsed + return mode +} diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index f22ca091acb..7fe836d7164 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -633,7 +633,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) { if !matched { return nil, nil } - return evalCoerce(result, ta.result(), ca.result().Collation, env.now) + return evalCoerce(result, ta.result(), ca.result().Collation, env.now, env.sqlmode.AllowZeroDate()) } func (c *CaseExpr) constant() bool { @@ -716,7 +716,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { f |= flagNullable } ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} - c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) + c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index b0e7366ce8f..49a328a852f 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -266,7 +266,7 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { case *evalTemporal: t = e.toDateTime(datetime.DefaultPrecision, env.now) default: - t = evalToDateTime(date, datetime.DefaultPrecision, env.now) + t = evalToDateTime(date, datetime.DefaultPrecision, env.now, env.sqlmode.AllowZeroDate()) if t == nil || t.isZero() { return nil, nil } @@ -291,7 +291,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Datetime, sqltypes.Date: default: - c.asm.Convert_xDT_nz(1, datetime.DefaultPrecision) + c.asm.Convert_xDT(1, datetime.DefaultPrecision, false) } format, err := call.Arguments[1].compile(c) @@ -359,7 +359,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - dt := evalToDateTime(n, -1, env.now) + dt := evalToDateTime(n, -1, env.now, env.sqlmode.AllowZeroDate()) if dt == nil || dt.isZero() { return nil, nil } @@ -368,7 +368,7 @@ func (call *builtinConvertTz) eval(env *ExpressionEnv) (eval, error) { if !ok { return nil, nil } - return newEvalDateTime(out, int(dt.prec)), nil + return newEvalDateTime(out, int(dt.prec), env.sqlmode.AllowZeroDate()), nil } func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { @@ -402,7 +402,7 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { switch n.Type { case sqltypes.Datetime, sqltypes.Date: default: - c.asm.Convert_xDT_nz(3, -1) + c.asm.Convert_xDT(3, -1, false) } c.asm.Fn_CONVERT_TZ() @@ -418,7 +418,7 @@ func (b *builtinDate) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil { return nil, nil } @@ -436,7 +436,7 @@ func (call *builtinDate) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) } c.asm.jumpDestination(skip) @@ -451,7 +451,7 @@ func (b *builtinDayOfMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -469,7 +469,7 @@ func (call *builtinDayOfMonth) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_DAYOFMONTH() c.asm.jumpDestination(skip) @@ -484,7 +484,7 @@ func (b *builtinDayOfWeek) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -502,7 +502,7 @@ func (call *builtinDayOfWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_DAYOFWEEK() c.asm.jumpDestination(skip) @@ -517,7 +517,7 @@ func (b *builtinDayOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -535,7 +535,7 @@ func (call *builtinDayOfYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_DAYOFYEAR() c.asm.jumpDestination(skip) @@ -610,7 +610,7 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) { t = t.In(tz) } - dt := newEvalDateTime(datetime.NewDateTimeFromStd(t), prec) + dt := newEvalDateTime(datetime.NewDateTimeFromStd(t), prec, env.sqlmode.AllowZeroDate()) if len(b.Arguments) == 1 { return dt, nil @@ -761,7 +761,7 @@ func (b *builtinMakedate) eval(env *ExpressionEnv) (eval, error) { if t.IsZero() { return nil, nil } - return newEvalDate(datetime.NewDateTimeFromStd(t).Date), nil + return newEvalDate(datetime.NewDateTimeFromStd(t).Date, env.sqlmode.AllowZeroDate()), nil } func (call *builtinMakedate) compile(c *compiler) (ctype, error) { @@ -1102,7 +1102,7 @@ func (b *builtinMonth) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1120,7 +1120,7 @@ func (call *builtinMonth) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_MONTH() c.asm.jumpDestination(skip) @@ -1135,7 +1135,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil { return nil, nil } @@ -1158,7 +1158,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) } col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_MONTHNAME(col) @@ -1174,7 +1174,7 @@ func (b *builtinQuarter) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1192,7 +1192,7 @@ func (call *builtinQuarter) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_QUARTER() c.asm.jumpDestination(skip) @@ -1270,7 +1270,7 @@ func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { case *evalTemporal: dt = e.toDateTime(int(e.prec), env.now) default: - dt = evalToDateTime(date, -1, env.now) + dt = evalToDateTime(date, -1, env.now, env.sqlmode.AllowZeroDate()) if dt == nil || dt.isZero() { var prec int32 switch d := date.(type) { @@ -1351,7 +1351,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } if lit, ok := call.Arguments[0].(*Literal); ok { - if dt := evalToDateTime(lit.inner, -1, time.Now()); dt != nil { + if dt := evalToDateTime(lit.inner, -1, time.Now(), c.sqlmode.AllowZeroDate()); dt != nil { if dt.prec == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } @@ -1373,7 +1373,7 @@ func (b *builtinWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -1406,7 +1406,7 @@ func (call *builtinWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } if len(call.Arguments) == 1 { @@ -1433,7 +1433,7 @@ func (b *builtinWeekDay) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -1451,7 +1451,7 @@ func (call *builtinWeekDay) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_WEEKDAY() @@ -1467,7 +1467,7 @@ func (b *builtinWeekOfYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -1487,7 +1487,7 @@ func (call *builtinWeekOfYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } c.asm.Fn_WEEKOFYEAR() @@ -1503,7 +1503,7 @@ func (b *builtinYear) eval(env *ExpressionEnv) (eval, error) { if date == nil { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, true) if d == nil { return nil, nil } @@ -1522,7 +1522,7 @@ func (call *builtinYear) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD(1) + c.asm.Convert_xD(1, true) } c.asm.Fn_YEAR() @@ -1539,7 +1539,7 @@ func (b *builtinYearWeek) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - d := evalToDate(date, env.now) + d := evalToDate(date, env.now, env.sqlmode.AllowZeroDate()) if d == nil || d.isZero() { return nil, nil } @@ -1572,7 +1572,7 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { switch arg.Type { case sqltypes.Date, sqltypes.Datetime: default: - c.asm.Convert_xD_nz(1) + c.asm.Convert_xD(1, false) } if len(call.Arguments) == 1 { @@ -1624,7 +1624,7 @@ func (call *builtinDateMath) eval(env *ExpressionEnv) (eval, error) { return tmp.addInterval(interval, collations.Unknown, env.now), nil } - if tmp := evalToTemporal(date); tmp != nil { + if tmp := evalToTemporal(date, env.sqlmode.AllowZeroDate()); tmp != nil { return tmp.addInterval(interval, call.collate, env.now), nil } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index 649dc7b5583..44c02b2a5a5 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/config" "vitess.io/vitess/go/mysql/format" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" @@ -215,6 +216,10 @@ func (vc *vcursor) TimeZone() *time.Location { return time.Local } +func (vc *vcursor) SQLMode() string { + return config.DefaultSQLMode +} + func initTimezoneData(t *testing.T, conn *mysql.Conn) { // We load the timezone information into MySQL. The evalengine assumes // our backend MySQL is configured with the timezone information as well diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 7993854db36..ea38f116de2 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -569,6 +569,7 @@ type Config struct { Collation collations.ID NoConstantFolding bool NoCompilation bool + SQLMode SQLMode } func Translate(e sqlparser.Expr, cfg *Config) (Expr, error) { @@ -603,7 +604,7 @@ func Translate(e sqlparser.Expr, cfg *Config) (Expr, error) { } if len(ast.untyped) == 0 && !cfg.NoCompilation { - comp := compiler{collation: cfg.Collation} + comp := compiler{collation: cfg.Collation, sqlmode: cfg.SQLMode} return comp.compile(expr) } @@ -626,9 +627,9 @@ type typedExpr struct { err error } -func (typed *typedExpr) compile(expr IR, collation collations.ID) (*CompiledExpr, error) { +func (typed *typedExpr) compile(expr IR, collation collations.ID, sqlmode SQLMode) (*CompiledExpr, error) { typed.once.Do(func() { - comp := compiler{collation: collation, dynamicTypes: typed.types} + comp := compiler{collation: collation, dynamicTypes: typed.types, sqlmode: sqlmode} typed.compiled, typed.err = comp.compile(expr) }) return typed.compiled, typed.err @@ -695,7 +696,7 @@ func (u *UntypedExpr) Compile(env *ExpressionEnv) (*CompiledExpr, error) { if err != nil { return nil, err } - return typed.compile(u.ir, u.collation) + return typed.compile(u.ir, u.collation, env.sqlmode) } func (u *UntypedExpr) typeof(env *ExpressionEnv) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index fa7fa7e11a6..2a9d6c9f93e 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -41,11 +41,11 @@ import ( // externally communicates with the `WEIGHT_STRING` function, so that we // can also use this to order / sort other types like Float and Decimal // as well. -func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { // We optimize here for the case where we already have the desired type. // Otherwise, we fall back to the general evalengine conversion logic. if v.Type() != coerceTo { - return fallbackWeightString(dst, v, coerceTo, col, length, precision) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) } switch { @@ -117,12 +117,12 @@ func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col coll } return j.WeightString(dst), false, nil default: - return fallbackWeightString(dst, v, coerceTo, col, length, precision) + return fallbackWeightString(dst, v, coerceTo, col, length, precision, sqlmode) } } -func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { - e, err := valueToEvalCast(v, coerceTo, col) +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int, sqlmode SQLMode) ([]byte, bool, error) { + e, err := valueToEvalCast(v, coerceTo, col, sqlmode) if err != nil { return dst, false, err } diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 0dee4c72d03..7e43315f7df 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -136,7 +136,7 @@ func TestWeightStrings(t *testing.T) { items := make([]item, 0, Length) for i := 0; i < Length; i++ { v := tc.gen() - w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec) + w, _, err := WeightString(nil, v, typ, tc.col, tc.len, tc.prec, 0) require.NoError(t, err) items = append(items, item{value: v, weight: string(w)}) @@ -156,9 +156,9 @@ func TestWeightStrings(t *testing.T) { a := items[i] b := items[i+1] - v1, err := valueToEvalCast(a.value, typ, tc.col) + v1, err := valueToEvalCast(a.value, typ, tc.col, 0) require.NoError(t, err) - v2, err := valueToEvalCast(b.value, typ, tc.col) + v2, err := valueToEvalCast(b.value, typ, tc.col, 0) require.NoError(t, err) cmp, err := evalCompareNullSafe(v1, v2) diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 9ec4bb0dc03..db678d56354 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -27,10 +27,9 @@ import ( "github.com/google/uuid" - "vitess.io/vitess/go/mysql/sqlerror" - "vitess.io/vitess/go/vt/sysvars" - "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/config" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/discovery" @@ -44,6 +43,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/srvtopo" + "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/topo" topoprotopb "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/topotools" @@ -211,6 +211,12 @@ func (vc *vcursorImpl) TimeZone() *time.Location { return vc.safeSession.TimeZone() } +func (vc *vcursorImpl) SQLMode() string { + // TODO: Implement return the current sql_mode. + // This is currently hardcoded to the default in MySQL 8.0. + return config.DefaultSQLMode +} + // MaxMemoryRows returns the maxMemoryRows flag value. func (vc *vcursorImpl) MaxMemoryRows() int { return maxMemoryRows diff --git a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go index 576ce4c22a8..ee1a1dbc06c 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go @@ -29,6 +29,7 @@ import ( "time" "vitess.io/vitess/go/mysql/replication" + "vitess.io/vitess/go/vt/dbconnpool" "vitess.io/vitess/go/vt/vttablet" "vitess.io/vitess/go/test/utils" @@ -225,6 +226,15 @@ func execStatements(t *testing.T, queries []string) { } } +func execConnStatements(t *testing.T, conn *dbconnpool.DBConnection, queries []string) { + t.Helper() + for _, query := range queries { + if _, err := conn.ExecuteFetch(query, 10000, false); err != nil { + t.Fatalf("ExecuteFetch(%v) failed: %v", query, err) + } + } +} + //-------------------------------------- // Topos and tablets diff --git a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go index b960635ff11..0e35036321f 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go @@ -1676,22 +1676,26 @@ func TestCopyTablesWithInvalidDates(t *testing.T) { func testCopyTablesWithInvalidDates(t *testing.T) { defer deleteTablet(addTablet(100)) - execStatements(t, []string{ - "create table src1(id int, dt date, primary key(id))", - fmt.Sprintf("create table %s.dst1(id int, dt date, primary key(id))", vrepldb), - "insert into src1 values(1, '2020-01-12'), (2, '0000-00-00');", - }) + conn, err := env.Mysqld.GetDbaConnection(context.Background()) + require.NoError(t, err) // default mysql flavor allows invalid dates: so disallow explicitly for this test - if err := env.Mysqld.ExecuteSuperQuery(context.Background(), "SET @@global.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')"); err != nil { + if _, err := conn.ExecuteFetch("SET @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 0, false); err != nil { fmt.Fprintf(os.Stderr, "%v", err) } defer func() { - if err := env.Mysqld.ExecuteSuperQuery(context.Background(), "SET @@global.sql_mode=REPLACE(@@global.sql_mode, ',NO_ZERO_DATE,NO_ZERO_IN_DATE','')"); err != nil { + if _, err := conn.ExecuteFetch("SET @@session.sql_mode=REPLACE(@@session.sql_mode, ',NO_ZERO_DATE,NO_ZERO_IN_DATE','')", 0, false); err != nil { fmt.Fprintf(os.Stderr, "%v", err) } }() - defer execStatements(t, []string{ + + execConnStatements(t, conn, []string{ + "create table src1(id int, dt date, primary key(id))", + fmt.Sprintf("create table %s.dst1(id int, dt date, primary key(id))", vrepldb), + "insert into src1 values(1, '2020-01-12'), (2, '0000-00-00');", + }) + + defer execConnStatements(t, conn, []string{ "drop table src1", fmt.Sprintf("drop table %s.dst1", vrepldb), })