Skip to content

Commit

Permalink
Backport feed fix
Browse files Browse the repository at this point in the history
Handle data race condition on initial connect.
  • Loading branch information
joshuacolvin0 committed Mar 29, 2022
1 parent dabe5d5 commit d994f7d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
43 changes: 29 additions & 14 deletions packages/arb-util/broadcastclient/broadcastclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"github.com/offchainlabs/arbitrum/packages/arb-util/arblog"
"io"
"math/big"
"net"
"strings"
Expand Down Expand Up @@ -76,12 +77,12 @@ func (bc *BroadcastClient) Connect(ctx context.Context) (chan broadcaster.Broadc
}

func (bc *BroadcastClient) ConnectWithChannel(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) error {
_, err := bc.connect(ctx, messageReceiver)
earlyFrameData, _, err := bc.connect(ctx, messageReceiver)
if err != nil {
return err
}

bc.startBackgroundReader(ctx, messageReceiver)
bc.startBackgroundReader(ctx, messageReceiver, earlyFrameData)

return nil
}
Expand All @@ -104,22 +105,34 @@ func (bc *BroadcastClient) ConnectInBackground(ctx context.Context, messageRecei
})()
}

func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) (chan broadcaster.BroadcastFeedMessage, error) {
func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) (io.Reader, chan broadcaster.BroadcastFeedMessage, error) {

if len(bc.websocketUrl) == 0 {
// Nothing to do
return nil, nil
return nil, nil, nil
}

logger.Info().Str("url", bc.websocketUrl).Msg("connecting to arbitrum inbox message broadcaster")
timeoutDialer := ws.Dialer{
Timeout: 10 * time.Second,
}

conn, _, _, err := timeoutDialer.Dial(ctx, bc.websocketUrl)
conn, br, _, err := timeoutDialer.Dial(ctx, bc.websocketUrl)
if err != nil {
logger.Warn().Err(err).Msg("broadcast client unable to connect")
return nil, errors.Wrap(err, "broadcast client unable to connect")
return nil, nil, errors.Wrap(err, "broadcast client unable to connect")
}

var earlyFrameData io.Reader
if br != nil {
// Depending on how long the client takes to read the response, there may be
// data after the WebSocket upgrade response in a single read from the socket,
// ie WebSocket frames sent by the server. If this happens, Dial returns
// a non-nil bufio.Reader so that data isn't lost. But beware, this buffered
// reader is still hooked up to the socket; trying to read past what had already
// been buffered will do a blocking read on the socket, so we have to wrap it
// in a LimitedReader.
earlyFrameData = io.LimitReader(br, int64(br.Buffered()))
}

bc.connMutex.Lock()
Expand All @@ -128,10 +141,10 @@ func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan bro

logger.Info().Msg("Connected")

return messageReceiver, nil
return earlyFrameData, messageReceiver, nil
}

func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) {
func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage, earlyFrameData io.Reader) {
go func() {
for {
select {
Expand All @@ -140,7 +153,7 @@ func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageRec
default:
}

msg, op, err := wsbroadcastserver.ReadData(ctx, bc.conn, bc.idleTimeout, ws.StateClientSide)
msg, op, err := wsbroadcastserver.ReadData(ctx, bc.conn, earlyFrameData, bc.idleTimeout, ws.StateClientSide)
if err != nil {
if bc.shuttingDown {
return
Expand All @@ -151,7 +164,7 @@ func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageRec
logger.Error().Err(err).Str("feed", bc.websocketUrl).Int("opcode", int(op)).Msgf("error calling readData")
}
_ = bc.conn.Close()
bc.RetryConnect(ctx, messageReceiver)
earlyFrameData = bc.RetryConnect(ctx, messageReceiver)
continue
}

Expand Down Expand Up @@ -192,7 +205,7 @@ func (bc *BroadcastClient) GetRetryCount() int {
return bc.retryCount
}

func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) {
func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) io.Reader {
bc.retryMutex.Lock()
defer bc.retryMutex.Unlock()

Expand All @@ -202,21 +215,23 @@ func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver cha
for !bc.shuttingDown {
select {
case <-ctx.Done():
return
return nil
case <-time.After(waitDuration):
}

bc.retryCount++
_, err := bc.connect(ctx, messageReceiver)
earlyFrameData, _, err := bc.connect(ctx, messageReceiver)
if err == nil {
bc.retrying = false
return
return earlyFrameData
}

if waitDuration < maxWaitDuration {
waitDuration += 500 * time.Millisecond
}
}

return nil
}

func (bc *BroadcastClient) Close() {
Expand Down
2 changes: 1 addition & 1 deletion packages/arb-util/wsbroadcastserver/clientconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (cc *ClientConnection) readRequest(ctx context.Context, timeout time.Durati

atomic.StoreInt64(&cc.lastHeardUnix, time.Now().Unix())

return ReadData(ctx, cc.conn, timeout, ws.StateServerSide)
return ReadData(ctx, cc.conn, nil, timeout, ws.StateServerSide)
}

func (cc *ClientConnection) Write(x interface{}) error {
Expand Down
45 changes: 41 additions & 4 deletions packages/arb-util/wsbroadcastserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,61 @@ package wsbroadcastserver

import (
"context"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"errors"
"io"
"io/ioutil"
"net"
"strings"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)

type chainedReader struct {
readers []io.Reader
}

func (cr *chainedReader) Read(b []byte) (n int, err error) {
for len(cr.readers) > 0 {
n, err = cr.readers[0].Read(b)
if errors.Is(err, io.EOF) {
cr.readers = cr.readers[1:]
if n == 0 {
continue // EOF and empty, skip to next
} else {
// The Read interface specifies some data can be returned along with an EOF.
if len(cr.readers) != 1 {
// If this isn't the last reader, return the data without the EOF since this
// may not be the end of all the readers.
return n, nil
} else {
return
}
}
}
break
}
return
}

func (cr *chainedReader) add(r io.Reader) *chainedReader {
if r != nil {
cr.readers = append(cr.readers, r)
}
return cr
}

func logError(err error, msg string) {
if !strings.Contains(err.Error(), "use of closed network connection") {
logger.Error().Err(err).Msg(msg)
}
}

func ReadData(ctx context.Context, conn net.Conn, idleTimeout time.Duration, state ws.State) ([]byte, ws.OpCode, error) {
func ReadData(ctx context.Context, conn net.Conn, earlyFrameData io.Reader, idleTimeout time.Duration, state ws.State) ([]byte, ws.OpCode, error) {
controlHandler := wsutil.ControlFrameHandler(conn, state)
reader := wsutil.Reader{
Source: conn,
Source: (&chainedReader{}).add(earlyFrameData).add(conn),
State: state,
CheckUTF8: true,
SkipHeaderCheck: false,
Expand Down

0 comments on commit d994f7d

Please sign in to comment.