diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 09f0c42942e..d4ad5ad9dac 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -376,11 +376,11 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } key := strings.ToLower(query) db.mu.Lock() - defer db.mu.Unlock() db.queryCalled[key]++ db.querylog = append(db.querylog, key) // Check if we should close the connection and provoke errno 2013. if db.shouldClose.Load() { + defer db.mu.Unlock() c.Close() // log error @@ -394,6 +394,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // The driver may send this at connection time, and we don't want it to // interfere. if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") { + defer db.mu.Unlock() + // log error if err := callback(&sqltypes.Result{}); err != nil { log.Errorf("callback failed : %v", err) @@ -403,12 +405,14 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // check if we should reject it. if err, ok := db.rejectedData[key]; ok { + db.mu.Unlock() return err } // Check explicit queries from AddQuery(). result, ok := db.data[key] if ok { + db.mu.Unlock() if f := result.BeforeFunc; f != nil { f() } @@ -419,12 +423,9 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R for _, pat := range db.patternData { if pat.expr.MatchString(query) { userCallback, ok := db.queryPatternUserCallback[pat.expr] + db.mu.Unlock() if ok { - // Since the user call back can be indefinitely stuck, we shouldn't hold the lock indefinitely. - // This is only test code, so no actual cause for concern. - db.mu.Unlock() userCallback(query) - db.mu.Lock() } if pat.err != "" { return fmt.Errorf(pat.err) @@ -433,6 +434,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } } + defer db.mu.Unlock() + if db.neverFail.Load() { return callback(&sqltypes.Result{}) } diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index bedbdc66c0a..a41fcc7b7ec 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -40,6 +40,8 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) +const defaultKillTimeout = 5 * time.Second + // Conn is a db connection for tabletserver. // It performs automatic reconnects as needed. // Its Execute function has a timeout that can kill @@ -52,11 +54,13 @@ type Conn struct { env tabletenv.Env dbaPool *dbconnpool.ConnectionPool stats *tabletenv.Stats - current atomic.Value + current atomic.Pointer[string] // err will be set if a query is killed through a Kill. errmu sync.Mutex err error + + killTimeout time.Duration } // NewConnection creates a new DBConn. It triggers a CheckMySQL if creation fails. @@ -71,12 +75,12 @@ func newPooledConn(ctx context.Context, pool *Pool, appParams dbconfigs.Connecto return nil, err } db := &Conn{ - conn: c, - env: pool.env, - stats: pool.env.Stats(), - dbaPool: pool.dbaPool, + conn: c, + env: pool.env, + stats: pool.env.Stats(), + dbaPool: pool.dbaPool, + killTimeout: defaultKillTimeout, } - db.current.Store("") return db, nil } @@ -87,12 +91,12 @@ func NewConn(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpoo return nil, err } dbconn := &Conn{ - conn: c, - dbaPool: dbaPool, - stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), - env: env, + conn: c, + dbaPool: dbaPool, + stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), + env: env, + killTimeout: defaultKillTimeout, } - dbconn.current.Store("") if setting == nil { return dbconn, nil } @@ -153,8 +157,8 @@ func (dbc *Conn) Exec(ctx context.Context, query string, maxrows int, wantfields } func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { - dbc.current.Store(query) - defer dbc.current.Store("") + dbc.current.Store(&query) + defer dbc.current.Store(nil) // Check if the context is already past its deadline before // trying to execute the query. @@ -162,19 +166,33 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi return nil, fmt.Errorf("%v before execution started", err) } - defer dbc.stats.MySQLTimings.Record("Exec", time.Now()) - - done, wg := dbc.setDeadline(ctx) - qr, err := dbc.conn.ExecuteFetch(query, maxrows, wantfields) + now := time.Now() + defer dbc.stats.MySQLTimings.Record("Exec", now) - if done != nil { - close(done) - wg.Wait() + type execResult struct { + result *sqltypes.Result + err error } - if dbcerr := dbc.Err(); dbcerr != nil { - return nil, dbcerr + + ch := make(chan execResult) + go func() { + result, err := dbc.conn.ExecuteFetch(query, maxrows, wantfields) + ch <- execResult{result, err} + }() + + select { + case <-ctx.Done(): + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) + defer cancel() + + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) + return nil, dbc.Err() + case r := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return nil, dbcErr + } + return r.result, r.err } - return qr, err } // ExecOnce executes the specified query, but does not retry on connection errors. @@ -250,22 +268,30 @@ func (dbc *Conn) Stream(ctx context.Context, query string, callback func(*sqltyp } func (dbc *Conn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int) error { - defer dbc.stats.MySQLTimings.Record("ExecStream", time.Now()) + dbc.current.Store(&query) + defer dbc.current.Store(nil) - dbc.current.Store(query) - defer dbc.current.Store("") + now := time.Now() + defer dbc.stats.MySQLTimings.Record("ExecStream", now) - done, wg := dbc.setDeadline(ctx) - err := dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) + ch := make(chan error) + go func() { + ch <- dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) + }() - if done != nil { - close(done) - wg.Wait() - } - if dbcerr := dbc.Err(); dbcerr != nil { - return dbcerr + select { + case <-ctx.Done(): + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) + defer cancel() + + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) + return dbc.Err() + case err := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return dbcErr + } + return err } - return err } // StreamOnce executes the query and streams the results. But, does not retry on connection errors. @@ -363,10 +389,19 @@ func (dbc *Conn) IsClosed() bool { return dbc.conn.IsClosed() } -// Kill kills the currently executing query both on MySQL side +// Kill wraps KillWithContext using context.Background. +func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { + return dbc.KillWithContext(context.Background(), reason, elapsed) +} + +// KillWithContext kills the currently executing query both on MySQL side // and on the connection side. If no query is executing, it's a no-op. // Kill will also not kill a query more than once. -func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { +func (dbc *Conn) KillWithContext(ctx context.Context, reason string, elapsed time.Duration) error { + if cause := context.Cause(ctx); cause != nil { + return cause + } + dbc.stats.KillCounters.Add("Queries", 1) log.Infof("Due to %s, elapsed time: %v, killing query ID %v %s", reason, elapsed, dbc.conn.ID(), dbc.CurrentForLogging()) @@ -377,25 +412,43 @@ func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { dbc.conn.Close() // Server side action. Kill the session. - killConn, err := dbc.dbaPool.Get(context.TODO()) + killConn, err := dbc.dbaPool.Get(ctx) if err != nil { log.Warningf("Failed to get conn from dba pool: %v", err) return err } defer killConn.Recycle() + + ch := make(chan error) sql := fmt.Sprintf("kill %d", dbc.conn.ID()) - _, err = killConn.Conn.ExecuteFetch(sql, 10000, false) - if err != nil { - log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), - dbc.CurrentForLogging(), err) - return err + go func() { + _, err := killConn.Conn.ExecuteFetch(sql, -1, false) + ch <- err + }() + + select { + case <-ctx.Done(): + killConn.Close() + + dbc.stats.InternalErrors.Add("HungQuery", 1) + log.Warningf("Query may be hung: %s", dbc.CurrentForLogging()) + + return context.Cause(ctx) + case err := <-ch: + if err != nil { + log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) + return err + } + return nil } - return nil } // Current returns the currently executing query. func (dbc *Conn) Current() string { - return dbc.current.Load().(string) + if q := dbc.current.Load(); q != nil { + return *q + } + return "" } // ID returns the connection id. @@ -437,45 +490,6 @@ func (dbc *Conn) Reconnect(ctx context.Context) error { return nil } -// setDeadline starts a goroutine that will kill the currently executing query -// if the deadline is exceeded. It returns a channel and a waitgroup. After the -// query is done executing, the caller is required to close the done channel -// and wait for the waitgroup to make sure that the necessary cleanup is done. -func (dbc *Conn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup) { - if ctx.Done() == nil { - return nil, nil - } - done := make(chan bool) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - startTime := time.Now() - select { - case <-ctx.Done(): - dbc.Kill(ctx.Err().Error(), time.Since(startTime)) - case <-done: - return - } - elapsed := time.Since(startTime) - - // Give 2x the elapsed time and some buffer as grace period - // for the query to get killed. - tmr2 := time.NewTimer(2*elapsed + 5*time.Second) - defer tmr2.Stop() - select { - case <-tmr2.C: - dbc.stats.InternalErrors.Add("HungQuery", 1) - log.Warningf("Query may be hung: %s", dbc.CurrentForLogging()) - case <-done: - return - } - <-done - log.Warningf("Hung query returned") - }() - return done, &wg -} - // CurrentForLogging applies transformations to the query making it suitable to log. // It applies sanitization rules based on tablet settings and limits the max length of // queries. diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go index d22b8f1c311..09a85a3e11a 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "testing" "time" @@ -33,7 +34,9 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vtenv" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" ) @@ -297,6 +300,59 @@ func TestDBConnKill(t *testing.T) { } } +func TestDBKillWithContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := dbconfigs.New(db.ConnParams()) + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddQuery(query, &sqltypes.Result{}) + db.SetBeforeFunc(query, func() { + // should take longer than our context deadline below. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // KillWithContext should return context.DeadlineExceeded + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestDBKillWithContextDoneContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := dbconfigs.New(db.ConnParams()) + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddRejectedQuery(query, errors.New("rejected")) + + contextErr := errors.New("context error") + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(contextErr) // cancel the context immediately + + // KillWithContext should return the cancellation cause + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, contextErr) +} + // TestDBConnClose tests that an Exec returns immediately if a connection // is asynchronously killed (and closed) in the middle of an execution. func TestDBConnClose(t *testing.T) { @@ -531,3 +587,51 @@ func TestDBConnReApplySetting(t *testing.T) { db.VerifyAllExecutedOrFail() } + +func TestDBExecOnceKillTimeout(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := dbconfigs.New(db.ConnParams()) + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := newPooledConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + // A very long running query that will be killed. + expectedQuery := "select 1" + var timestampQuery atomic.Int64 + db.AddQuery(expectedQuery, &sqltypes.Result{}) + db.SetBeforeFunc(expectedQuery, func() { + timestampQuery.Store(time.Now().UnixMicro()) + // should take longer than our context deadline below. + time.Sleep(1000 * time.Millisecond) + }) + + // We expect a kill-query to be fired, too. + // It should also run into a timeout. + var timestampKill atomic.Int64 + dbConn.killTimeout = 100 * time.Millisecond + db.AddQueryPatternWithCallback(`kill \d+`, &sqltypes.Result{}, func(string) { + timestampKill.Store(time.Now().UnixMicro()) + // should take longer than the configured kill timeout above. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + result, err := dbConn.ExecOnce(ctx, "select 1", 1, false) + timeDone := time.Now() + + require.Error(t, err) + require.Equal(t, vtrpcpb.Code_CANCELED, vterrors.Code(err)) + require.Nil(t, result) + timeQuery := time.UnixMicro(timestampQuery.Load()) + timeKill := time.UnixMicro(timestampKill.Load()) + require.WithinDuration(t, timeQuery, timeKill, 150*time.Millisecond) + require.WithinDuration(t, timeKill, timeDone, 150*time.Millisecond) +} diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index df68c8b0a83..0374fb416a6 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -1122,6 +1122,20 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { db.SetBeforeFunc("update test_table set name_string = 'tx1' where pk = 1 and `name` = 1 limit 10001", func() { close(tx1Started) + + // Wait for other queries to be pending. + <-allQueriesPending + }) + + db.SetBeforeFunc("update test_table set name_string = 'tx2' where pk = 1 and `name` = 1 limit 10001", + func() { + // Wait for other queries to be pending. + <-allQueriesPending + }) + + db.SetBeforeFunc("update test_table set name_string = 'tx3' where pk = 1 and `name` = 1 limit 10001", + func() { + // Wait for other queries to be pending. <-allQueriesPending }) @@ -1190,6 +1204,8 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // to allow more than connection attempt at a time. err := waitForTxSerializationPendingQueries(tsv, "test_table where pk = 1 and `name` = 1", 3) require.NoError(t, err) + + // Signal that all queries are pending now. close(allQueriesPending) wg.Wait()