diff --git a/go/mysql/endtoend/client_test.go b/go/mysql/endtoend/client_test.go index ce01c57369d..964a8702471 100644 --- a/go/mysql/endtoend/client_test.go +++ b/go/mysql/endtoend/client_test.go @@ -298,7 +298,28 @@ func TestReplicationStatus(t *testing.T) { status, err := conn.ShowReplicationStatus() assert.Equal(t, mysql.ErrNotReplica, err, "Got unexpected result for ShowReplicationStatus: %v %v", status, err) +} + +func TestReplicationStatusWithMysqlHang(t *testing.T) { + params := connParams + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + conn, err := mysql.Connect(ctx, ¶ms) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + err = cluster.SimulateMySQLHang() + require.NoError(t, err) + + defer cluster.StopSimulateMySQLHang() + status, err := conn.ShowReplicationStatusWithContext(ctx) + assert.Equal(t, ctx.Err().Error(), "context deadline exceeded") + assert.Equal(t, ctx.Err(), err, "Got unexpected result for ShowReplicationStatus: %v %v", status, err) + assert.True(t, conn.IsClosed()) } func TestSessionTrackGTIDs(t *testing.T) { diff --git a/go/mysql/endtoend/main_test.go b/go/mysql/endtoend/main_test.go index 466735c02e4..c641332b565 100644 --- a/go/mysql/endtoend/main_test.go +++ b/go/mysql/endtoend/main_test.go @@ -40,6 +40,7 @@ import ( var ( connParams mysql.ConnParams + cluster vttest.LocalCluster ) // assertSQLError makes sure we get the right error. @@ -200,8 +201,17 @@ ssl-key=%v/server-key.pem OnlyMySQL: true, ExtraMyCnf: []string{extraMyCnf, maxPacketMyCnf}, } - cluster := vttest.LocalCluster{ + + env, err := vttest.NewLocalTestEnv(0) + if err != nil { + fmt.Fprintf(os.Stderr, "%v", err) + return 1 + } + env.EnableToxiproxy = true + + cluster = vttest.LocalCluster{ Config: cfg, + Env: env, } if err := cluster.Setup(); err != nil { fmt.Fprintf(os.Stderr, "could not launch mysql: %v\n", err) diff --git a/go/mysql/flavor.go b/go/mysql/flavor.go index f732b1ccb88..24de3d3c9a5 100644 --- a/go/mysql/flavor.go +++ b/go/mysql/flavor.go @@ -401,7 +401,35 @@ func resultToMap(qr *sqltypes.Result) (map[string]string, error) { // ShowReplicationStatus executes the right command to fetch replication status, // and returns a parsed Position with other fields. func (c *Conn) ShowReplicationStatus() (replication.ReplicationStatus, error) { - return c.flavor.status(c) + return c.ShowReplicationStatusWithContext(context.TODO()) +} + +func (c *Conn) ShowReplicationStatusWithContext(ctx context.Context) (replication.ReplicationStatus, error) { + result := make(chan replication.ReplicationStatus, 1) + errors := make(chan error, 1) + + go func() { + res, err := c.flavor.status(c) + if err != nil { + errors <- err + } else { + result <- res + } + }() + + for { + select { + case <-ctx.Done(): + c.Close() + return replication.ReplicationStatus{}, ctx.Err() + + case err := <-errors: + return replication.ReplicationStatus{}, err + + case res := <-result: + return res, nil + } + } } // ShowPrimaryStatus executes the right SHOW BINARY LOG STATUS command, diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index d49032f34a1..8640910fa17 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -376,7 +376,10 @@ func (pool *ConnPool[C]) put(conn *Pooled[C]) { if conn == nil { var err error - conn, err = pool.connNew(context.Background()) + // TODO: Do we really want to wait for up to a second here? + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + conn, err = pool.connNew(ctx) if err != nil { pool.closedConn() return diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index f6447eda549..35fb7359fcf 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -320,6 +320,10 @@ func (fmd *FakeMysqlDaemon) ReplicationStatus(ctx context.Context) (replication. }, nil } +func (fmd *FakeMysqlDaemon) ReplicationStatusWithContext(ctx context.Context) (replication.ReplicationStatus, error) { + return fmd.ReplicationStatus(ctx) +} + // PrimaryStatus is part of the MysqlDaemon interface. func (fmd *FakeMysqlDaemon) PrimaryStatus(ctx context.Context) (replication.PrimaryStatus, error) { if fmd.PrimaryStatusError != nil { diff --git a/go/vt/mysqlctl/query.go b/go/vt/mysqlctl/query.go index 7a1816e3cfb..154ed132062 100644 --- a/go/vt/mysqlctl/query.go +++ b/go/vt/mysqlctl/query.go @@ -35,20 +35,42 @@ func getPoolReconnect(ctx context.Context, pool *dbconnpool.ConnectionPool) (*db if err != nil { return conn, err } - // Run a test query to see if this connection is still good. - if _, err := conn.Conn.ExecuteFetch("SELECT 1", 1, false); err != nil { + + errChan := make(chan error, 1) + resultChan := make(chan *sqltypes.Result, 1) + + go func() { + result, err := conn.Conn.ExecuteFetch("SELECT 1", 1, false) + if err != nil { + errChan <- err + } else { + resultChan <- result + } + }() + + select { + case <-ctx.Done(): + conn.Close() + conn.Recycle() + return nil, ctx.Err() + + case err := <-errChan: // If we get a connection error, try to reconnect. if sqlErr, ok := err.(*sqlerror.SQLError); ok && (sqlErr.Number() == sqlerror.CRServerGone || sqlErr.Number() == sqlerror.CRServerLost) { if err := conn.Conn.Reconnect(ctx); err != nil { conn.Recycle() return nil, err } + return conn, nil } + conn.Recycle() return nil, err + + case <-resultChan: + return conn, nil } - return conn, nil } // ExecuteSuperQuery allows the user to execute a query as a super user. diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index c94df03e7cf..07e0e3b5fad 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -381,9 +381,9 @@ func (mysqld *Mysqld) ReplicationStatus(ctx context.Context) (replication.Replic if err != nil { return replication.ReplicationStatus{}, err } - defer conn.Recycle() - return conn.Conn.ShowReplicationStatus() + defer conn.Recycle() + return conn.Conn.ShowReplicationStatusWithContext(ctx) } // PrimaryStatus returns the primary replication statuses diff --git a/go/vt/vttablet/tabletserver/repltracker/reader.go b/go/vt/vttablet/tabletserver/repltracker/reader.go index b50e5e4b2c7..6b798d99ca1 100644 --- a/go/vt/vttablet/tabletserver/repltracker/reader.go +++ b/go/vt/vttablet/tabletserver/repltracker/reader.go @@ -141,7 +141,7 @@ func (r *heartbeatReader) Status() (time.Duration, error) { func (r *heartbeatReader) readHeartbeat() { defer r.env.LogError() - ctx, cancel := context.WithDeadline(context.Background(), r.now().Add(r.interval)) + ctx, cancel := context.WithTimeout(context.Background(), r.interval) defer cancel() res, err := r.fetchMostRecentHeartbeat(ctx) @@ -149,6 +149,7 @@ func (r *heartbeatReader) readHeartbeat() { r.recordError(vterrors.Wrap(err, "failed to read most recent heartbeat")) return } + ts, err := parseHeartbeatResult(res) if err != nil { r.recordError(vterrors.Wrap(err, "failed to parse heartbeat result"))