Skip to content

Commit

Permalink
GODRIVER-2577 Retry heartbeat on timeout to prevent pool cleanup in F…
Browse files Browse the repository at this point in the history
…AAS pause. (#1133)
  • Loading branch information
qingyang-hu authored and benjirewis committed Dec 6, 2022
1 parent eda6752 commit a1412b7
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 17 deletions.
70 changes: 53 additions & 17 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ func (s *Server) update() {
}
}

timeoutCnt := 0
for {
// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
// can safely read from it again because Disconnect closes the channel.
Expand All @@ -545,18 +546,42 @@ func (s *Server) update() {
continue
}

// Must hold the processErrorLock while updating the server description and clearing the
// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
// pool.ready() calls from concurrent server description updates.
s.processErrorLock.Lock()
s.updateDescription(desc)
if err := desc.LastError; err != nil {
// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
// IDs.
s.pool.clear(err, nil)
if isShortcut := func() bool {
// Must hold the processErrorLock while updating the server description and clearing the
// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
// pool.ready() calls from concurrent server description updates.
s.processErrorLock.Lock()
defer s.processErrorLock.Unlock()

s.updateDescription(desc)
// Retry after the first timeout before clearing the pool in case of a FAAS pause as
// described in GODRIVER-2577.
if err := unwrapConnectionError(desc.LastError); err != nil && timeoutCnt < 1 {
if err == context.Canceled || err == context.DeadlineExceeded {
timeoutCnt++
// We want to immediately retry on timeout error. Continue to next loop.
return true
}
if err, ok := err.(net.Error); ok && err.Timeout() {
timeoutCnt++
// We want to immediately retry on timeout error. Continue to next loop.
return true
}
}
if err := desc.LastError; err != nil {
// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
// IDs.
s.pool.clear(err, nil)
}
// We're either not handling a timeout error, or we just handled the 2nd consecutive
// timeout error. In either case, reset the timeout count to 0 and return false to
// continue the normal check process.
timeoutCnt = 0
return false
}(); isShortcut {
continue
}
s.processErrorLock.Unlock()

// If the server supports streaming or we're already streaming, we want to move to streaming the next response
// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
Expand Down Expand Up @@ -707,19 +732,31 @@ func (s *Server) check() (description.Server, error) {
var err error
var durationNanos int64

// Create a new connection if this is the first check, the connection was closed after an error during the previous
// check, or the previous check was cancelled.
start := time.Now()
if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
// Create a new connection if this is the first check, the connection was closed after an error during the previous
// check, or the previous check was cancelled.
isNilConn := s.conn == nil
if !isNilConn {
s.publishServerHeartbeatStartedEvent(s.conn.ID(), false)
}
// Create a new connection and add it's handshake RTT as a sample.
err = s.setupHeartbeatConnection()
durationNanos = time.Since(start).Nanoseconds()
if err == nil {
// Use the description from the connection handshake as the value for this check.
s.rttMonitor.addSample(s.conn.helloRTT)
descPtr = &s.conn.desc
if !isNilConn {
s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, s.conn.desc, false)
}
} else {
err = unwrapConnectionError(err)
if !isNilConn {
s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, false)
}
}
}

if descPtr == nil && err == nil {
} else {
// An existing connection is being used. Use the server description properties to execute the right heartbeat.

// Wrap conn in a type that implements driver.StreamerConnection.
Expand All @@ -729,7 +766,6 @@ func (s *Server) check() (description.Server, error) {
streamable := previousDescription.TopologyVersion != nil

s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
start := time.Now()
switch {
case s.conn.getCurrentlyStreaming():
// The connection is already in a streaming state, so we stream the next response.
Expand Down
142 changes: 142 additions & 0 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ package topology

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"net"
"os"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -49,6 +53,144 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n
return cnc, nil
}

type errorQueue struct {
errors []error
mutex sync.Mutex
}

func (eq *errorQueue) head() error {
eq.mutex.Lock()
defer eq.mutex.Unlock()
if len(eq.errors) > 0 {
return eq.errors[0]
}
return nil
}

func (eq *errorQueue) dequeue() bool {
eq.mutex.Lock()
defer eq.mutex.Unlock()
if len(eq.errors) > 0 {
eq.errors = eq.errors[1:]
return true
}
return false
}

type timeoutConn struct {
net.Conn
errors *errorQueue
}

func (c *timeoutConn) Read(b []byte) (int, error) {
n, err := 0, c.errors.head()
if err == nil {
n, err = c.Conn.Read(b)
}
return n, err
}

func (c *timeoutConn) Write(b []byte) (int, error) {
n, err := 0, c.errors.head()
if err == nil {
n, err = c.Conn.Write(b)
}
return n, err
}

type timeoutDialer struct {
Dialer
errors *errorQueue
}

func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
c, e := d.Dialer.DialContext(ctx, network, address)

if caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE"); len(caFile) > 0 {
pem, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}

ca := x509.NewCertPool()
if !ca.AppendCertsFromPEM(pem) {
return nil, errors.New("unable to load CA file")
}

config := &tls.Config{
InsecureSkipVerify: true,
RootCAs: ca,
}
c = tls.Client(c, config)
}
return &timeoutConn{c, d.errors}, e
}

// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
func TestServerHeartbeatTimeout(t *testing.T) {
networkTimeoutError := &net.DNSError{
IsTimeout: true,
}

testCases := []struct {
desc string
ioErrors []error
expectPoolCleared bool
}{
{
desc: "one single timeout should not clear the pool",
ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
expectPoolCleared: false,
},
{
desc: "continuous timeouts should clear the pool",
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
expectPoolCleared: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

var wg sync.WaitGroup
wg.Add(1)

errors := &errorQueue{errors: tc.ioErrors}
tpm := monitor.NewTestPoolMonitor()
server := NewServer(
address.Address("localhost:27017"),
primitive.NewObjectID(),
WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor {
return tpm.PoolMonitor
}),
WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
return append(opts,
WithDialer(func(d Dialer) Dialer {
var dialer net.Dialer
return &timeoutDialer{&dialer, errors}
}))
}),
WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor {
return &event.ServerMonitor{
ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) {
if !errors.dequeue() {
wg.Done()
}
},
}
}),
WithHeartbeatInterval(func(time.Duration) time.Duration {
return 200 * time.Millisecond
}),
)
require.NoError(t, server.Connect(nil))
wg.Wait()
assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared())
})
}
}

// TestServerConnectionTimeout tests how different timeout errors are handled during connection
// creation and server handshake.
func TestServerConnectionTimeout(t *testing.T) {
Expand Down

0 comments on commit a1412b7

Please sign in to comment.